q-learning-terrain-navigator/q-learning-terrain-navigator.ipynb

578 lines
110 KiB
Plaintext
Raw Permalink Normal View History

2024-10-22 02:49:35 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Packages"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# import necessary libraries\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import colors\n",
"import matplotlib.animation as animation\n",
"import json\n",
"import time\n",
"import threading\n",
"import tqdm\n",
"from tqdm import tqdm\n",
"from tqdm import trange\n",
"import datetime"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create Map\n",
"Create a map for the Q-learning algorithm to try. You can choose any grid size, but the larger the grid, the more compute it will take. I would suggest around an 8x8 to 12x12 grid."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pygame 2.1.0 (SDL 2.0.16, Python 3.10.14)\n",
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
]
}
],
"source": [
"!./map_generator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Importing Map Array and Displaying Map\n",
"<br>"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Load the saved map\n",
"with open(\"map_data.json\", \"r\") as f:\n",
" rewards = np.array(json.load(f))\n",
"\n",
"#rewards[rewards == 1000] = 500\n",
"\n",
"environment_rows = rewards.shape[0]\n",
"environment_columns = rewards.shape[1]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGxCAYAAADLfglZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmWUlEQVR4nO3de3RU5b3/8c8QyCRgLiU2CVEIoUWuckuAwz0KRCBgsadViVQQ6kFBJKIUUpAMWBKFysLKbYGCpQjSdRBElyjhIkiBQ0CIHLxQLZAU5QQBEwiYQPL8/nBlfg5JIOhMnlzer7X2H/PsZ+b57pnM/uTZs2ePwxhjBACABfVsFwAAqLsIIQCANYQQAMAaQggAYA0hBACwhhACAFhDCAEArCGEAADWEEIAAGsIoRrmtddek8PhkMPh0AcffFBmvTFGv/zlL+VwOBQfH1+ltWVlZcnhcGjatGkV9vnnP/8ph8OhJ598UpLUvHlzjR49uooqrNjo0aPVvHlzjzZf1nbp0iW5XK5yX8PS1/jEiRM+GdvbTpw4IYfDoddee83dtmfPHrlcLn377bdl+jdv3lxDhw79SWPm5+fr+eefV/fu3RUaGqoGDRooIiJCgwYN0po1a1RYWPiTHv9G4uPjq/z9VVvVt10AfpygoCC9+uqrZd4IO3fu1JdffqmgoKAqr6ljx46KjY3VqlWrNGfOHPn5+ZXps3LlSknS2LFjJUkbNmxQcHBwldZZWb6s7dKlS5o1a5YklXkNExMTtXfvXjVp0sQnY3tbkyZNtHfvXv3iF79wt+3Zs0ezZs3S6NGjFRoa6tXx/vnPf2rQoEHKzc3Vf/3Xf2n69On62c9+pq+//lrvv/++xowZo08//VTPPfecV8eFbxBCNdQDDzyg119/XYsWLfLYUb766qvq0aOH8vPzrdQ1duxYjR8/Xps3by7z325xcbFWrVql2NhYdezYUZLUuXNnG2VWiq3afv7zn+vnP/+5lbF/DKfTqf/4j/+okrGuXr2q4cOH69y5c9q/f7/atGnjsf7+++/XzJkzdejQoSqpBz8dh+NqqBEjRkiS1q5d627Ly8vT+vXrNWbMmHLvM2vWLHXv3l2NGzdWcHCwunTpoldffVXXXsO29HDJhg0b1KFDBwUEBKhFixb6y1/+csO6kpKSFBgY6J7x/NCWLVt06tQpj/quPeRVUlKiP/3pT2rVqpUCAwMVGhqqDh066KWXXnL3Ke/QmSS5XC45HA6PtkWLFqlv374KDw9Xo0aNdOedd2ru3Lm6cuXKDbfl2tri4+Pdh0KvXUoPRZ05c0bjx49X27Ztdcsttyg8PFx33323PvzwQ/fjnDhxwh0ys2bNcj9G6VgVHY5bsWKFOnbsqICAADVu3Fj33XefPv30U48+o0eP1i233KIvvvhCQ4YM0S233KKmTZvq6aefvuEhqilTpigkJETFxcXutokTJ8rhcGjevHnutrNnz6pevXp6+eWX3dvzw+fA5XJpypQpkqSYmJgKDx+/99576tKliwIDA9W6dWutWLHiuvVJ389OP/nkE02fPr1MAJWKjo7W8OHDPdqys7M1cuRIhYeHy+l0qk2bNnrxxRdVUlLi0a+y7xF4DzOhGio4OFi/+c1vtGLFCo0bN07S94FUr149PfDAA1qwYEGZ+5w4cULjxo1Ts2bNJEn79u3TxIkTderUKc2cOdOj7+HDh5WcnCyXy6XIyEi9/vrrmjRpkoqKivTMM89UWFdISIj+8z//U+vWrdOZM2c8/qNfuXKlAgIClJSUVOH9586dK5fLpRkzZqhv3766cuWKPvvss3I/W6iML7/8UklJSYqJiZG/v7+ysrI0Z84cffbZZ5Xa6f3Q4sWLy8wwn332We3YsUOtWrWSJJ07d06SlJqaqsjISF28eFEbNmxQfHy8tm3bpvj4eDVp0kTvvfeeBg0apLFjx+r3v/+9JF139pOenq4//vGPGjFihNLT03X27Fm5XC716NFDmZmZatmypbvvlStXdO+992rs2LF6+umntWvXLj333HMKCQkp8zr/0IABA/TnP/9Z+/fvV48ePSRJW7duVWBgoDIyMtzBsm3bNhljNGDAgHIf5/e//73OnTunl19+WW+++ab7sGLbtm3dfbKysvT0009r2rRpioiI0CuvvKKxY8fql7/8pfr27VthjRkZGZKke++9t8I+1zpz5ox69uypoqIiPffcc2revLneeecdPfPMM/ryyy+1ePFid9+beY/ASwxqlJUrVxpJJjMz0+zYscNIMv/7v/9rjDGma9euZvTo0cYYY9q1a2f69etX4eMUFxebK1eumNmzZ5uwsDBTUlLiXhcdHW0cDoc5fPiwx30GDhxogoODTUFBwXVrLK1r/vz57razZ88ap9NpHnroIY++0dHRZtSoUe7bQ4cONZ06dbru448aNcpER0eXaU9NTTXX+5Mu3eZVq1YZPz8/c+7cues+5rW1XWvevHlGklm2bFmFfa5evWquXLli+vfvb+677z53+5kzZ4wkk5qaWuY+pa/x8ePHjTHGnD9/3gQGBpohQ4Z49MvOzjZOp9MkJSV5bIck8/e//92j75AhQ0yrVq0qrNMYYwoKCoy/v7+ZPXu2McaYf//730aSmTp1qgkMDDTfffedMcaYRx991ERFRbnvd/z4cSPJrFy5ssxzU7oNPxQdHW0CAgLMyZMn3W2XL182jRs3NuPGjbtujYMGDTKS3LWUKikpMVeuXHEvV69eda+bNm2akWT+53/+x+M+jz/+uHE4HObzzz8vd6zrvUf69et33fcXKo/DcTVYv3799Itf/EIrVqzQkSNHlJmZWeGhOEnavn27BgwYoJCQEPn5+alBgwaaOXOmzp49q9zcXI++7dq1c39uUyopKUn5+fn66KOPKlXXDw/Jvf766yosLLxufZLUrVs3ZWVlafz48Xr//fd/8mdbhw4d0r333quwsDD3Nj/88MMqLi7WsWPHfvTjrl27Vn/4wx80Y8YMPfroox7rli5dqi5duiggIED169dXgwYNtG3btjKHzipr7969unz5cpkz9Zo2baq7775b27Zt82h3OBwaNmyYR1uHDh108uTJ647TsGFD9ejRQ1u3bpX0/awjNDRUU6ZMUVFRkXbv3i3p+9lRRbOgyurUqZN7tiFJAQEBuuOOO25YY0VeeuklNWjQwL388G93+/btatu2rbp16+Zxn9GjR8sYo+3bt3v0rex7BN5BCNVgDodDjzzyiFavXq2lS5fqjjvuUJ8+fcrtu3//fiUkJEiSli9frn/84x/KzMzU9OnTJUmXL1/26B8ZGVnmMUrbzp49e8O6xowZoyNHjujAgQOSvj8UFxMTo7vuuuu6901JSdGf//xn7du3T4MHD1ZYWJj69+/vfpybkZ2drT59+ujUqVN66aWX9OGHHyozM1OLFi2SVHabK2vHjh0aPXq0Hn744TJnYM2fP1+PP/64unfvrvXr12vfvn3KzMzUoEGDfvR4pc93eWfLRUVFlXk9GjZsqICAAI82p9Op77777oZjDRgwQPv27VNBQYG2bt2qu+++W2FhYYqNjdXWrVt1/PhxHT9+/CeHUFhYWJk2p9N5w+eoNLiuDaukpCRlZmYqMzNTXbp08Vh39uzZCp+70vXSzb9H4B2EUA03evRoffPNN1q6dKkeeeSRCvu98cYbatCggd555x3df//96tmzp+Li4irsf/r06QrbytuBlFeXn5+fVqxYoaysLB06dEhjxowpc+LAterXr6/Jkyfro48+0rlz57R27Vrl5OTonnvu0aVLlyR9/19zeR+yf/PNNx63N27cqIKCAr355psaOXKkevfurbi4OPn7+9+w/op8/PHHGj58uPr166fly5eXWb969WrFx8dryZIlSkxMVPfu3RUXF6cLFy786DFLn++vv/66zLqvvvpKt956649+7Gv1799fRUVF2rVrl7Zt26aBAwe62zMyMtyfyfTv399rY96M0no2bdrk0R4eHq64uDjFxcWV+XpCWFhYhc+dJPfzd7PvEXgHIVTD3XbbbZoyZYqGDRumUaNGVdjP4XCofv3
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Define the colormap for the grid values\n",
"cmap = colors.ListedColormap(['black', 'red', (0.5451, 0.2706, 0.0745), 'blue', 'gray', (0,1,0)])\n",
"# Bounds now account for the actual range of values, with small gaps between to handle exact matching\n",
"bounds = [-1000.5, -100.5, -99.5, -49.5, -9, -0.5, 2000.5]\n",
"norm = colors.BoundaryNorm(bounds, cmap.N)\n",
"\n",
"# Create the plot\n",
"plt.imshow(rewards, cmap=cmap, norm=norm)\n",
"\n",
"\n",
"# Display the plot\n",
"plt.title(\"Map Visualization with Goal\")\n",
"plt.show()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualization Functions\n",
"<br>"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def graph(q_table, save=False, title=\"\"):\n",
" # Define the colormap for the grid values\n",
" #fig, ax = plt.subplots(figsize=(8, 8), dpi=200) # Increased figure size and DPI\n",
"\n",
" cmap = colors.ListedColormap(['black', 'red', (0.5451, 0.2706, 0.0745), 'blue', 'gray', (0,1,0)])\n",
" # Bounds now account for the actual range of values, with small gaps between to handle exact matching\n",
" bounds = [-1000.5, -100.5, -99.5, -49.5, -9, -0.5, 1000.5]\n",
" norm = colors.BoundaryNorm(bounds, cmap.N)\n",
"\n",
" \n",
" # Create the plot for rewards\n",
" plt.imshow(rewards, cmap=cmap, norm=norm)\n",
" \n",
" # Calculate the optimal direction from Q-table\n",
" # Directions: up (0), right (1), down (2), left (3)\n",
" optimal_directions = np.argmax(q_table, axis=2)\n",
" \n",
" # Initialize arrays for arrow direction (dx, dy) at each grid point\n",
" dx = np.zeros_like(optimal_directions, dtype=float)\n",
" dy = np.zeros_like(optimal_directions, dtype=float)\n",
" \n",
" # Define movement deltas for [up, right, down, left]\n",
" move_map = {\n",
" 0: (0, -1), # up\n",
" 1: (1, 0), # right\n",
" 2: (0, 1), # down\n",
" 3: (-1, 0), # left\n",
" }\n",
"\n",
" # Fill in dx, dy based on optimal directions, but only if the sum of Q-values is not zero\n",
" for i in range(optimal_directions.shape[0]):\n",
" for j in range(optimal_directions.shape[1]):\n",
" if np.sum(q_table[i, j]) != 0: # Check if the Q-values are non-zero\n",
" direction = optimal_directions[i, j]\n",
" dx[i, j], dy[i, j] = move_map[direction]\n",
" \n",
" # Create a meshgrid for plotting arrows\n",
" x, y = np.meshgrid(np.arange(optimal_directions.shape[1]), np.arange(optimal_directions.shape[0]))\n",
" \n",
" # Plot arrows using quiver, only for non-zero vectors\n",
" plt.quiver(x, y, dx, dy, angles='xy', scale_units='xy', scale=1, color='black')\n",
" plt.title(title)\n",
"\n",
" if save:\n",
" timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
" filename = f\"images/plot_{timestamp}.png\"\n",
" plt.savefig(filename, format='png')\n",
" \n",
" # Display the plot with arrows\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def graph_path(path):\n",
" # Define the colormap for the grid values\n",
" cmap = colors.ListedColormap(['black', 'red', (0.5451, 0.2706, 0.0745), 'blue', 'gray', (0,1,0)])\n",
" bounds = [-1000.5, -100.5, -99.5, -49.5, -9, -0.5, 1000.5]\n",
" norm = colors.BoundaryNorm(bounds, cmap.N)\n",
"\n",
" # Create the plot for rewards\n",
" plt.imshow(rewards, cmap=cmap, norm=norm)\n",
"\n",
" move_map = {\n",
" 0: (0, -1), # up\n",
" 1: (1, 0), # right\n",
" 2: (0, 1), # down\n",
" 3: (-1, 0), # left\n",
" }\n",
"\n",
" # Now plot the path taken by the robot\n",
" path_x = [pos[1] for pos in path]\n",
" path_y = [pos[0] for pos in path]\n",
" \n",
" # Create arrows for the robot's path\n",
" for i in range(len(path) - 1):\n",
" start_x, start_y = path_x[i], path_y[i]\n",
" end_x, end_y = path_x[i + 1], path_y[i + 1]\n",
" plt.arrow(start_x, start_y, end_x - start_x, end_y - start_y, color='yellow', head_width=0.2)\n",
"\n",
" # Display the plot with arrows\n",
" plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Learning helper functions\n",
"<br>"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# define actions\n",
"# we will use numeric (index) to represent the actions\n",
"# 0 = up, 1 = right, 2 = down, 3 = left\n",
"actions = ['up', 'right', 'down', 'left']"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# because we will end the episode if we reach Goal\n",
"def is_terminal_state(current_row_index, current_column_index):\n",
" if rewards[current_row_index, current_column_index] != np.max(rewards): # it is not terminal if the rewards is -1\n",
" return False\n",
" else:\n",
" return True\n",
"\n",
"# this starting location must not be on the road\n",
"def get_starting_location():\n",
" current_row_index = np.random.randint(environment_rows) # get a random row index\n",
" current_column_index = np.random.randint(environment_columns) # get a random column index\n",
" \n",
" while rewards[current_row_index, current_column_index] != -1: # True if it is terminal\n",
" current_row_index = np.random.randint(environment_rows) # repeat to get another random row index\n",
" current_column_index = np.random.randint(environment_columns) # repeat to get another random row index\n",
" return current_row_index, current_column_index # returns a random starting location that is not terminal\n",
"\n",
"\n",
"# define an epsilon greedy algorithm for deciding the next action\n",
"def get_next_action(current_row_index, current_column_index, epsilon):\n",
" if np.random.random() < epsilon: # choose the action with the highest q_values\n",
" return np.random.randint(4)\n",
" else: # choose a random action\n",
" return np.argmax(q_values[current_row_index, current_column_index])\n",
"\n",
"\n",
"# define a function that will get the next location based on the chosen action\n",
"# refer to how the board is drawn physically, with the rows and columns\n",
"def get_next_location(current_row_index, current_column_index, action_index):\n",
" new_row_index = current_row_index\n",
" new_column_index = current_column_index\n",
" if actions[action_index] == 'up' and current_row_index > 0:\n",
" new_row_index -= 1\n",
" elif actions[action_index] == 'right' and current_column_index < environment_columns - 1:\n",
" new_column_index += 1\n",
" elif actions[action_index] == 'down' and current_row_index < environment_rows - 1:\n",
" new_row_index += 1\n",
" elif actions[action_index] == 'left' and current_column_index > 0:\n",
" new_column_index -= 1\n",
" return new_row_index, new_column_index\n",
"\n",
"\n",
"# Define a function that will get the shortest path that is on the white tiles \n",
"def get_shortest_path(start_row_index, start_column_index):\n",
" i = 0\n",
" if is_terminal_state(start_row_index, start_column_index): # check if it is on Goal or Cliff\n",
" return [] # if yes, there are no available steps\n",
" \n",
" else: #if this is a 'legal' starting location\n",
" current_row_index, current_column_index = start_row_index, start_column_index\n",
" shortest_path = []\n",
" shortest_path.append([current_row_index, current_column_index]) # add the current coordinate to the list\n",
"\n",
" while not is_terminal_state(current_row_index, current_column_index): # repeat until we reach Goal or Cliff\n",
" action_index = get_next_action(current_row_index, current_column_index, 1.) \n",
" # get next coordinate \n",
" \n",
" current_row_index, current_column_index = get_next_location(current_row_index, current_column_index, action_index)\n",
" # update that next coordinate as current coordinate\n",
" \n",
" shortest_path.append([current_row_index, current_column_index]) \n",
" # add the current coordinate to the list\n",
"\n",
" i += 1\n",
" if i > 100:\n",
" return 0;\n",
" return shortest_path"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def q_learn_single(epsilon = 0.9, discount_factor = 0.9, learning_rate = 0.9, epochs = 1000,):\n",
" q_values = np.zeros((environment_rows, environment_columns, 4))\n",
" \n",
" for episode in tqdm(range(epochs), desc=\"Training Progress\", unit=\"epochs\", ncols=100): # Adjust `ncols` to shorten the bar\n",
" row_index, column_index = get_starting_location()\n",
"\n",
" while not is_terminal_state(row_index, column_index):\n",
" # choose which action to take (i.e., where to move next)\n",
" action_index = get_next_action(row_index, column_index, epsilon)\n",
"\n",
" # perform the chosen action, and transition to the next state / next location\n",
" old_row_index, old_column_index = row_index, column_index # store the old row and column indexes\n",
" row_index, column_index = get_next_location(row_index, column_index, action_index)\n",
"\n",
" # receive the reward for moving to the new state, and calculate the temporal difference\n",
" reward = rewards[row_index, column_index]\n",
" old_q_value = q_values[old_row_index, old_column_index, action_index]\n",
" temporal_difference = reward + (discount_factor * np.max(q_values[row_index, column_index])) - old_q_value\n",
"\n",
" # update the Q-value for the previous state and action pair\n",
" new_q_value = old_q_value + (learning_rate * temporal_difference)\n",
" q_values[old_row_index, old_column_index, action_index] = new_q_value\n",
"\n",
" print('Training complete!')\n",
"\n",
" return q_values\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def q_learn_single(epsilon = 0.9, discount_factor = 0.9, learning_rate = 0.9, epochs = 1000):\n",
" # Initialize the Q-table with zeros for each state-action pair\n",
" # The shape is (environment_rows, environment_columns, 4) \n",
" # where 4 represents 4 possible actions (e.g., up, down, left, right)\n",
" q_values = np.zeros((environment_rows, environment_columns, 4))\n",
" \n",
" # Iterate through a number of episodes (i.e., learning cycles)\n",
" for episode in tqdm(range(epochs), desc=\"Training Progress\", unit=\"epochs\", ncols=100):\n",
" # Start each episode by selecting a random starting location in the environment\n",
" row_index, column_index = get_starting_location()\n",
"\n",
" # Continue taking actions until the agent reaches a terminal state\n",
" while not is_terminal_state(row_index, column_index):\n",
" # Choose the next action based on an epsilon-greedy policy\n",
" # This function should balance exploration (random) vs exploitation (best known action)\n",
" action_index = get_next_action(row_index, column_index, epsilon)\n",
"\n",
" # Save the old position before taking the action\n",
" old_row_index, old_column_index = row_index, column_index\n",
" \n",
" # Move to the new state based on the chosen action\n",
" row_index, column_index = get_next_location(row_index, column_index, action_index)\n",
"\n",
" # Get the reward for the new state the agent has moved to\n",
" reward = rewards[row_index, column_index]\n",
" \n",
" # Retrieve the Q-value of the old state-action pair\n",
" old_q_value = q_values[old_row_index, old_column_index, action_index]\n",
"\n",
" # Calculate the temporal difference: \n",
" # TD = Reward + Discount * (Max Q-value for the next state) - Old Q-value\n",
" temporal_difference = reward + (discount_factor * np.max(q_values[row_index, column_index])) - old_q_value\n",
"\n",
" # Update the Q-value for the previous state-action pair using the learning rate\n",
" new_q_value = old_q_value + (learning_rate * temporal_difference)\n",
" q_values[old_row_index, old_column_index, action_index] = new_q_value # Assign updated value\n",
"\n",
" # After all episodes, print a message indicating the training is complete\n",
" print('Training complete!')\n",
"\n",
" # Return the Q-values for further use (e.g., evaluation or exploitation phase)\n",
" return q_values\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# single episode\n",
"\n",
"def eposode(epsilon, discount_factor, learning_rate, epochs):\n",
" for episode in range(epochs):\n",
" row_index, column_index = get_starting_location()\n",
" \n",
" while not is_terminal_state(row_index, column_index):\n",
" # choose which action to take (i.e., where to move next)\n",
" action_index = get_next_action(row_index, column_index, epsilon)\n",
" \n",
" # perform the chosen action, and transition to the next state / next location\n",
" old_row_index, old_column_index = row_index, column_index # store the old row and column indexes\n",
" row_index, column_index = get_next_location(row_index, column_index, action_index)\n",
" \n",
" # receive the reward for moving to the new state, and calculate the temporal difference\n",
" reward = rewards[row_index, column_index]\n",
" old_q_value = q_values[old_row_index, old_column_index, action_index]\n",
" temporal_difference = reward + (discount_factor * np.max(q_values[row_index, column_index])) - old_q_value\n",
" \n",
" # update the Q-value for the previous state and action pair\n",
" new_q_value = old_q_value + (learning_rate * temporal_difference)\n",
" q_values[old_row_index, old_column_index, action_index] = new_q_value\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def q_learn_multi(epsilon=0.9, discount_factor=0.9, learning_rate=0.9, epochs=250, threads = 4):\n",
" \n",
" thread_array = []\n",
"\n",
" \n",
" for num in range(threads):\n",
" thread = threading.Thread(target=eposode, args=(epsilon, discount_factor, learning_rate, epochs))\n",
" thread_array.append(thread)\n",
" thread.start()\n",
"\n",
" for thread in thread_array:\n",
" thread.join()\n",
" print('Training complete!')\n",
"\n",
" return q_values\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Learning Multi-threaded\n",
"<br>"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training complete!\n"
]
}
],
"source": [
"q_values = np.zeros((environment_rows, environment_columns, 4))\n",
"\n",
"q_values = q_learn_multi(0.7, 0.6, 0.1, 500, 12)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAa0AAAGxCAYAAADRQunXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABxN0lEQVR4nO3deXhU5dn48e9kkkwWkgBZSEKWyTbDLrK8FBABURQQtVrrvtRKacWF8qu11iroW6Xq29YFl2JbxVpcqqKCK5VFNgWRRQFnErKQEAgJ2fdM5vz+GGbMZIEwmTknk9yf65oLzpk5cz9n8pxzP89zNp2iKApCCCGEHwjQugBCCCFEd0nSEkII4TckaQkhhPAbkrSEEEL4DUlaQggh/IYkLSGEEH5DkpYQQgi/IUlLCCGE35CkJYQQwm/0yaR16623YjQa3eY99thjvPfeex0+u2nTJnQ6HZs2bTrj927fvp1ly5ZRWVnZ4T2j0cill17qWYFV9sorr6DT6cjPz9e6KN2Wn5+PTqfjlVdecc3zh/V49tlnGTZsGAaDgbS0NB5++GFaWlrOuNyyZcvQ6XRdvt544w2vlbGz39YZv695/vnn3dbzbDQ3N/PLX/6ShIQE9Ho9Y8eO9WrZAD766COWLVvm9e/1hjfeeIOxY8cSEhJCYmIiixcvpra2ttvLFxQUcNttt5GYmIjBYGDo0KH8+Mc/PvuCKH3QLbfcoqSmprrNCw8PV2655ZYOn62qqlJ27NihVFVVnfF7n3zySQVQ8vLyOryXmpqqzJs3z8MSq+vll1/ucj16q8bGRmXHjh3KiRMnXPN6+3r88Y9/VHQ6nXL//fcrGzduVJ544gklODhYWbBgwRmXLSwsVHbs2NHhNWrUKCU0NFSpqKjwWjnz8vIUQHn55Zc7xO9rRo4cqUyfPt2jZZ966ikFUJ599lll+/btyv79+71bOEVRFi1apPTG3fJrr72mAMrtt9+ubNiwQXnxxReVqKgo5aKLLurW8t9++60SHR2tTJw4Ufn3v/+tbN68WXnjjTeUn/3sZ2ddlsCzT3N9S2RkJD/60Y80LUNDQwMhISF9smXrLQaDQfO/09k4efIkf/zjH1mwYAGPPfYYADNmzKClpYU//OEPLF68mBEjRnS5fFJSEklJSW7z8vPzOXDgADfccAMDBw70ZfE7jd/ffffdd4SGhnLnnXdqXZSz1tDQQGhoqEfLtra2cu+99zJ79mxeeuklAGbOnElERAQ33HADH3/8MXPmzOlyeUVRuOmmm0hOTmbLli0YDAbXe9dcc81Zl0e14UHncMP+/fu5+uqriYqKYvDgwSxZsgSbzYbFYuGSSy4hIiICo9HIE0884bZ8V0NB3Rne0+l01NXVsWrVKtfwyowZM7q9vLP89957LwBpaWmu72m/3CeffMK4ceMIDQ1l2LBh/POf/+x0PT777DNuu+02YmNjCQsLo6mpCYA333yTyZMnEx4ezoABA7j44ovZs2eP23d8/fXXXHvttRiNRkJDQzEajVx33XUUFBR0KPeXX37J1KlTXV36+++/v1vDU2fSnXLeeuutDBgwgAMHDjBr1izCw8OJjY3lzjvvpL6+3u2z//nPf5g0aRJRUVGEhYWRnp7Obbfd5nq/syGsrvzzn//knHPOISQkhMGDB/PjH/+YQ4cOdVq2nJwc5s6dy4ABA0hOTub//b//5/pb9MQnn3xCY2MjP/vZz9zm/+xnP0NRlE6Hqs/kn//8J4qicPvtt3tcruLiYn76058SERFBVFQU11xzDcePH+/wuc6GBzds2MCMGTOIjo4mNDSUlJQUrrrqKre/ZVNTE4888gjDhw8nJCSE6OhoZs6cyfbt212faWxs5P777yctLY3g4GCGDh3KokWLOgy763S6TofKjEYjt956q2vauU1t3LiRX/3qV8TExBAdHc2VV15JcXGx23IHDhxg8+bNru23/WGEruh0Ov7+97/T0NDgWtZZF5977jnOP/984uLiCA8PZ/To0TzxxBOdbmeffPIJs2bNctXz4cOHs3z5csBRJ5977jlXPOfLuc/r7u/mPFTx7rvvcu655xISEsLDDz/crfXszJdffsmxY8c61OWrr76aAQMGsGbNmtMu/8UXX7B3714WL17slrA8pfoxrZ/+9Kecc845vPPOOyxYsIC//vWv/PrXv+aKK65g3rx5rFmzhgsuuID77ruPd9991ysxd+zYQWhoKHPnzmXHjh3s2LGD559//qy+4/bbb+euu+4C4N1333V9z7hx41yf2bdvH//v//0/fv3rX/P+++8zZswYfv7zn/PFF190+L7bbruNoKAg/vWvf/H2228TFBTEY489xnXXXceIESN46623+Ne//kVNTQ3Tpk3j4MGDrmXz8/Mxm8089dRTfPrppzz++OMcO3aMiRMnUlZW5vrcwYMHmTVrFpWVlbzyyiu8+OKL7Nmzhz/+8Y8dyuPc8LuTFLpbToCWlhbmzp3LrFmzeO+997jzzjv529/+5tbC2rFjB9dccw3p6em88cYbfPjhhzz00EPYbLYzlqW95cuX8/Of/5yRI0fy7rvv8vTTT7N//34mT55MdnZ2h7JddtllzJo1i/fff5/bbruNv/71rzz++ONun2ttbcVms53xZbfbXct89913AIwePdrtuxISEoiJiXG93112u51XXnmFzMxMpk+fflbLOjU0NHDhhRfy2WefsXz5cv7zn/8QHx/frdZufn4+8+bNIzg4mH/+85988skn/OlPfyI8PJzm5mYAbDYbc+bM4X//93+59NJLWbNmDa+88gpTpkzhyJEjgKPVfcUVV/B///d/3HTTTXz44YcsWbKEVatWccEFF/SowXD77bcTFBTE6tWreeKJJ9i0aRM33nij6/01a9aQnp7Oueee69p+z7TDddqxYwdz584lNDTUtey8efMAOHz4MNdffz3/+te/WLduHT//+c958sknWbhwodt3/OMf/2Du3LnY7XZefPFF1q5dy913301RUREADz74ID/5yU9c8ZyvhISEs/7dvvnmG+69917uvvtuPvnkE6666iqgZ3V5zJgxbjGCgoIYNmzYGeuyc/8XERHB3LlzCQkJYcCAAVx66aV8//333fr93Zz1gKKHli5dqgDKn//8Z7f5Y8eOVQDl3Xffdc1raWlRYmNjlSuvvNI1r6vjFxs3blQAZePGja55Z3NMq7Plu3KmY1ohISFKQUGBa15DQ4MyePBgZeHChR3W4+abb3Zb/siRI0pgYKBy1113uc2vqalR4uPjlZ/+9Kddlstmsym1tbVKeHi48vTTT7vmX3PNNUpoaKhy/Phxt88OGzasw3qsWrVK0ev1yqpVq077G5xNOW+55RYFcCuToijKo48+qgDK1q1bFUVRlP/7v/9TAKWysrLLuJ0dd2lfJyoqKpTQ0FBl7ty5HcpsMBiU66+/vkPZ3nrrLbfPzp07VzGbzW7zUlNTFeCMr6VLl7qWWbBggWIwGDpdF5PJpMyePbvLde3Mxx9/rADK8uXLz2q5tl544QUFUN5//323+QsWLOjw2zq3V6e3335bAZS9e/d2+f2vvvqqAigvvfRSl5/55JNPFEB54okn3Oa/+eabCqCsXLnSNa/9b+qUmprqti0768Edd9zh9rknnnhCAZRjx4655vXkmNYtt9yihIeHn/Yzra2tSktLi/Lqq68qer1eKS8vVxTFsX1ERkYq5513nmK327tcvqtjWmfzu6Wmpip6vV6xWCwdvmf69Ondqsttf1/n9tr2d3SaPXu2YjKZTvubLFy4UAGUyMhI5ec//7ny3//+V/nXv/6lpKamKjExMUpxcfFpl29P9WNa7c+wGz58OPv27XMbEw0MDCQzM7PT4S5fUxSF1tZWt3mBgd37mcaOHUtKSoprOiQkBJPJ1Ol6OFs+Tp9++ik2m42bb77ZrYcREhLC9OnT2bhxo2tebW0t//u//8s777xDfn6+W3nbDoNt3LiRWbNmMWTIENc8vV7PNddc02G44Oabb+bmm28+4zqeTTmdbrjhBrfp66+/ngceeICNGzcydepUJk6cCDh64T/
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"graph(q_values, save=True, title=\"multi-thread: epsilon=0.7, discount_factor=0.6\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Learning Single Threaded\n",
"<br>"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training Progress: 100%|████████████████████████████████████| 1000/1000 [00:37<00:00, 26.43epochs/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training complete!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"q_values = np.zeros((environment_rows, environment_columns, 4))\n",
"\n",
"q_values = q_learn_single(0.9, 0.7, 0.1, 1000)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGxCAYAAAADEuOPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB0aUlEQVR4nO3dd3xUZdr4/8/MJJkUkhDSG+kz1IAUWUAELLCsYt191hXb6mN5xMKyxbar6K5i2fWHimXxeb6urgvqqqigqFkQRERBFAXEmYQUEgIppPdM5vz+mOSYSYEwmTmTIdf79ZoXnDbXfSb3Ode579N0iqIoCCGEED5G7+0CCCGEEK6QBCaEEMInSQITQgjhkySBCSGE8EmSwIQQQvgkSWBCCCF8kiQwIYQQPkkSmBBCCJ8kCUwIIYRP8noCu+6660hNTfV4nNTUVK677jq3fudzzz3HP/7xj17jt27dik6n480333RrPE+ZN28e8+bN83YxTsmKFSvQ6XRO44b6epSXl3PdddcRFRVFcHAwM2fOZPPmzQNe/l//+hdnnHEGgYGBREVFceWVV1JcXOz2cvb123pi+/G20tJSVqxYwd69e11a/ptvvmHu3LmEh4ej0+lYtWqVW8sH8Mgjj/DOO++4/XsHq6GhgWXLlpGQkEBgYCCTJ0/mtddeO6XvePfdd5k7dy5hYWGEhIQwfvx41qxZc0rf4fUE9qc//Yn169d7uxgu6S+BCc/77//+b3bu3OntYgxYa2sr5557Lps3b+app57i3XffJTY2lp/+9Kds27btpMs/88wzXHXVVUybNo13332Xxx57jK1btzJnzhyqq6s9Xv7169fzpz/9yeNxtFRaWsqDDz7ocgK7/vrrOXr0KK+99ho7d+7kiiuucG8BGboJ7LLLLuPll1/mgQceYNOmTUyfPp1f/epXrF27dkDLP/roo1x22WVMmDCBN954g/fee49bb72Vtra2UyuIMkykpKQo1157rVu/c/z48crcuXN7jf/kk08UQPn3v//t0vc2NjYOsmSnZu7cuX2uh68Zyuvx7LPPKoDy+eefq+Pa29uVcePGKWeeeeYJl21paVHCw8OVxYsXO43//PPPFUC599573VrWBx54QBkOu4bdu3crgPLSSy+5tLyfn5/yP//zP+4tVA8hISFu32+1tbUp7e3tLi///vvvK4Cydu1ap/Hnn3++kpCQoNhsthMu/9VXXyl6vV557LHHXC5DF4+2wCoqKrjppptITk7GaDQSHR3N7Nmz+c9//qPO01cXok6n47bbbuOf//wnY8eOJTg4mEmTJrFx48ZeMd59912ys7MxGo2kp6fz1FNP9dkF0pe6ujp+97vfkZaWRkBAAImJiSxbtozGxsaTLpuamsqBAwfYtm0bOp0OnU7Xaz3a29u57777SEhIICwsjPPOOw+LxeI0z7x585gwYQKffvops2bNIjg4mOuvv/6Uyvfss89y9tlnExMTQ0hICBMnTuTxxx+nvb3daT5FUXj88cdJSUkhMDCQKVOmsGnTppOu68kMtJxdf9e///3vmEwmjEYj48aN69X10NTUpH5fYGAgo0aNYtq0aaxbt06dZ6B/46qqKm699VYSExMJCAggPT2d++67j9bW1j7LNpA654r169djNpuZOXOmOs7Pz4+rrrqKXbt2ceTIkX6X3b9/P7W1tfzsZz9zGj9z5kxGjRrFW2+95XK53n//fSZPnozRaCQtLY2//vWvfc7XswvRbrfzl7/8BbPZTFBQECNHjiQ7O5unnnrKabkffviBX/3qV8TGxmI0Ghk9ejTXXHON0++/f/9+Lr74YiIiItTuqJdfftnpe/7xj3+g0+koLCx0Gt/VXb9161Z1XNc2tXv3bubMmUNwcDDp6ek8+uij2O12dbnp06cD8Otf/1rdhlesWHHS36yrLDabjeeff15dFhz7vFtvvZVx48YxYsQIYmJiOOecc9i+fXuv72ltbeWhhx5i7NixBAYGEhkZyfz58/n8888BR51sbGzk5ZdfVmN07yIfyO/W9fv885//5Le//S2JiYkYjUby8vJOup79Wb9+PSNGjOAXv/iF0/hf//rXlJaW8uWXX55w+dWrV2M0Grn99ttdLkMXv0F/wwlcffXVfP311zz88MOYTCZqamr4+uuvOX78+EmXff/999m9ezcPPfQQI0aM4PHHH+fSSy/FYrGQnp4OwIcffshll13G2Wefzeuvv47NZuOvf/0rZWVlJ/3+pqYm5s6dS0lJCffeey/Z2dkcOHCA+++/n3379vGf//znhDvI9evX8/Of/5zw8HCee+45AIxGo9M89957L7Nnz+Z///d/qaur46677mLx4sUcPHgQg8Ggznf06FGuuuoq/vCHP/DII4+g1+tPqXyHDh3iyiuvVBPIt99+y8MPP8wPP/zA//t//0+N8+CDD/Lggw9yww038POf/5zi4mJuvPFGOjo6MJvNTmWfN28e27ZtQznJywpO9Xd87733+OSTT3jooYcICQnhueee41e/+hV+fn78/Oc/B2D58uX885//5C9/+QtnnHEGjY2N7N+/f0D1pruWlhbmz5/PoUOHePDBB8nOzmb79u2sXLmSvXv38v777zvNP5A6pygKHR0dA4rv5/fj5rV//37mzJnTa57s7GwADhw4QGJiYp/f09Wt0rN+dY3Lzc2lpaWFwMDAAZWry+bNm7n44ouZOXMmr732Gh0dHTz++OMD2n4ef/xxVqxYwR//+EfOPvts2tvb+eGHH6ipqVHn+fbbbznrrLOIiorioYceIisri6NHj/Lee+/R1taG0WjEYrEwa9YsYmJiePrpp4mMjOTVV1/luuuuo6ysjD/84Q+ntE5djh07xpIlS/jtb3/LAw88wPr167nnnntISEjgmmuuYcqUKbz00kv8+te/5o9//CMXXHABAElJSSf97gsuuICdO3cyc+ZMfv7zn/Pb3/5WnVZVVQXAAw88QFxcHA0NDaxfv5558+axefNmNQHZbDYWLVrE9u3bWbZsGeeccw42m40vvviCw4cPM2vWLHbu3Mk555zD/Pnz1e7bsLAwgFP+3e655x5mzpzJCy+8gF6vJyYmRi3HQBgMBnU73r9/P2PHjnWq3/BjXd6/fz+zZs3q97s+/fRTxo4dy1tvvcWf//xn8vLyiI+P56qrruKhhx4iICBgQGUCPNtPMGLECGXZsmUnnOfaa69VUlJSnMYBSmxsrFJXV6eOO3bsmKLX65WVK1eq46ZPn64kJycrra2t6rj6+nolMjKyVxdIzy7ElStXKnq9Xtm9e7fTfG+++aYCKB988MFJ1+9kXYg/+9nPnMa/8cYbCqDs3LlTHTd37lwFUDZv3uw0r6vl6+joUNrb25VXXnlFMRgMSlVVlaIoilJdXa0EBgYql156qdP8O3bsUIBe63HOOecoBoPhhOt/quUElKCgIOXYsWPqOJvNpowZM0bJzMxUx02YMEG55JJLThi3r26unl2IL7zwggIob7zxhtN8jz32mAIoH3/8sVPZBlLnuv62A/kUFBSoy/n7+ys333xzr/Xo6gbs2R3T3fHjxxW9Xq/ccMMNTuPz8vLUWKWlpf0u358ZM2YoCQkJSnNzszqurq5OGTVq1Em3nwsvvFCZPHnyCb//nHPOUUaOHKmUl5f3O88VV1yhGI1G5fDhw07jFy1apAQHBys1NTWKoijKSy+91Os3VZQf/x6ffPKJOq5rm/ryyy+d5h03bpyycOFCdXiwXYiAsnTp0hPOY7PZlPb2duXcc8912vZeeeUVBVBefPHFEy7fXxfiQH+3rt/n7LPP7vUdBQUFA67L3X/frKwsp9+xS2lpqQIojzzyyAnXyWg0KqGhoUpERISyevVqZcuWLcp9992nGAwG5corrzzhsj15tAvxzDPP5B//+Ad/+ctf+OKLL3p1aZ3I/PnzCQ0NVYdjY2OJiYmhqKgIgMbGRr766isuueQSp4w9YsQIFi9efNLv37hxIxMmTGDy5MnYbDb1s3DhQqcuCbvd7jR9oEffABdddJHTcNcRStc6dImIiOCcc85xqXzguBrqoos
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"graph(q_values, save=True, title=\"single-thread: epsilon=0.9, discount_factor=0.6\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}