578 lines
110 KiB
Plaintext
578 lines
110 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Import Packages"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# import necessary libraries\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from matplotlib import colors\n",
|
||
|
"import matplotlib.animation as animation\n",
|
||
|
"import json\n",
|
||
|
"import time\n",
|
||
|
"import threading\n",
|
||
|
"import tqdm\n",
|
||
|
"from tqdm import tqdm\n",
|
||
|
"from tqdm import trange\n",
|
||
|
"import datetime"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Create Map\n",
|
||
|
"Create a map for the Q-learning algorithm to try. You can choose any grid size, but the larger the grid, the more compute it will take. I would suggest around an 8x8 to 12x12 grid."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"pygame 2.1.0 (SDL 2.0.16, Python 3.10.14)\n",
|
||
|
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"!./map_generator"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Importing Map Array and Displaying Map\n",
|
||
|
"<br>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Load the saved map\n",
|
||
|
"with open(\"map_data.json\", \"r\") as f:\n",
|
||
|
" rewards = np.array(json.load(f))\n",
|
||
|
"\n",
|
||
|
"#rewards[rewards == 1000] = 500\n",
|
||
|
"\n",
|
||
|
"environment_rows = rewards.shape[0]\n",
|
||
|
"environment_columns = rewards.shape[1]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGxCAYAAADLfglZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmWUlEQVR4nO3de3RU5b3/8c8QyCRgLiU2CVEIoUWuckuAwz0KRCBgsadViVQQ6kFBJKIUUpAMWBKFysLKbYGCpQjSdRBElyjhIkiBQ0CIHLxQLZAU5QQBEwiYQPL8/nBlfg5JIOhMnlzer7X2H/PsZ+b57pnM/uTZs2ePwxhjBACABfVsFwAAqLsIIQCANYQQAMAaQggAYA0hBACwhhACAFhDCAEArCGEAADWEEIAAGsIoRrmtddek8PhkMPh0AcffFBmvTFGv/zlL+VwOBQfH1+ltWVlZcnhcGjatGkV9vnnP/8ph8OhJ598UpLUvHlzjR49uooqrNjo0aPVvHlzjzZf1nbp0iW5XK5yX8PS1/jEiRM+GdvbTpw4IYfDoddee83dtmfPHrlcLn377bdl+jdv3lxDhw79SWPm5+fr+eefV/fu3RUaGqoGDRooIiJCgwYN0po1a1RYWPiTHv9G4uPjq/z9VVvVt10AfpygoCC9+uqrZd4IO3fu1JdffqmgoKAqr6ljx46KjY3VqlWrNGfOHPn5+ZXps3LlSknS2LFjJUkbNmxQcHBwldZZWb6s7dKlS5o1a5YklXkNExMTtXfvXjVp0sQnY3tbkyZNtHfvXv3iF79wt+3Zs0ezZs3S6NGjFRoa6tXx/vnPf2rQoEHKzc3Vf/3Xf2n69On62c9+pq+//lrvv/++xowZo08//VTPPfecV8eFbxBCNdQDDzyg119/XYsWLfLYUb766qvq0aOH8vPzrdQ1duxYjR8/Xps3by7z325xcbFWrVql2NhYdezYUZLUuXNnG2VWiq3afv7zn+vnP/+5lbF/DKfTqf/4j/+okrGuXr2q4cOH69y5c9q/f7/atGnjsf7+++/XzJkzdejQoSqpBz8dh+NqqBEjRkiS1q5d627Ly8vT+vXrNWbMmHLvM2vWLHXv3l2NGzdWcHCwunTpoldffVXXXsO29HDJhg0b1KFDBwUEBKhFixb6y1/+csO6kpKSFBgY6J7x/NCWLVt06tQpj/quPeRVUlKiP/3pT2rVqpUCAwMVGhqqDh066KWXXnL3Ke/QmSS5XC45HA6PtkWLFqlv374KDw9Xo0aNdOedd2ru3Lm6cuXKDbfl2tri4+Pdh0KvXUoPRZ05c0bjx49X27Ztdcsttyg8PFx33323PvzwQ/fjnDhxwh0ys2bNcj9G6VgVHY5bsWKFOnbsqICAADVu3Fj33XefPv30U48+o0eP1i233KIvvvhCQ4YM0S233KKmTZvq6aefvuEhqilTpigkJETFxcXutokTJ8rhcGjevHnutrNnz6pevXp6+eWX3dvzw+fA5XJpypQpkqSYmJgKDx+/99576tKliwIDA9W6dWutWLHiuvVJ389OP/nkE02fPr1MAJWKjo7W8OHDPdqys7M1cuRIhYeHy+l0qk2bNnrxxRdVUlLi0a+y7xF4DzOhGio4OFi/+c1vtGLFCo0bN07S94FUr149PfDAA1qwYEGZ+5w4cULjxo1Ts2bNJEn79u3TxIkTderUKc2cOdOj7+HDh5WcnCyXy6XIyEi9/vrrmjRpkoqKivTMM89UWFdISIj+8z//U+vWrdOZM2c8/qNfuXKlAgIClJSUVOH9586dK5fLpRkzZqhv3766cuWKPvvss3I/W6iML7/8UklJSYqJiZG/v7+ysrI0Z84cffbZZ5Xa6f3Q4sWLy8wwn332We3YsUOtWrWSJJ07d06SlJqaqsjISF28eFEbNmxQfHy8tm3bpvj4eDVp0kTvvfeeBg0apLFjx+r3v/+9JF139pOenq4//vGPGjFihNLT03X27Fm5XC716NFDmZmZatmypbvvlStXdO+992rs2LF6+umntWvXLj333HMKCQkp8zr/0IABA/TnP/9Z+/fvV48ePSRJW7duVWBgoDIyMtzBsm3bNhljNGDAgHIf5/e//73OnTunl19+WW+++ab7sGLbtm3dfbKysvT0009r2rRpioiI0CuvvKKxY8fql7/8pfr27VthjRkZGZKke++9t8I+1zpz5ox69uypoqIiPffcc2revLneeecdPfPMM/ryyy+1ePFid9+beY/ASwxqlJUrVxpJJjMz0+zYscNIMv/7v/9rjDGma9euZvTo0cYYY9q1a2f69etX4eMUFxebK1eumNmzZ5uwsDBTUlLiXhcdHW0cDoc5fPiwx30GDhxogoODTUFBwXVrLK1r/vz57razZ88ap9NpHnroIY++0dHRZtSoUe7bQ4cONZ06dbru448aNcpER0eXaU9NTTXX+5Mu3eZVq1YZPz8/c+7cues+5rW1XWvevHlGklm2bFmFfa5evWquXLli+vfvb+677z53+5kzZ4wkk5qaWuY+pa/x8ePHjTHGnD9/3gQGBpohQ4Z49MvOzjZOp9MkJSV5bIck8/e//92j75AhQ0yrVq0qrNMYYwoKCoy/v7+ZPXu2McaYf//730aSmTp1qgkMDDTfffedMcaYRx991ERFRbnvd/z4cSPJrFy5ssxzU7oNPxQdHW0CAgLMyZMn3W2XL182jRs3NuPGjbtujYMGDTKS3LWUKikpMVeuXHEvV69eda+bNm2akWT+53/+x+M+jz/+uHE4HObzzz8vd6zrvUf69et33fcXKo/DcTVYv3799Itf/EIrVqzQkSNHlJmZWeGhOEnavn27BgwYoJCQEPn5+alBgwaaOXOmzp49q9zcXI++7dq1c39uUyopKUn5+fn66KOPKlXXDw/Jvf766yosLLxufZLUrVs3ZWVlafz48Xr//fd/8mdbhw4d0r333quwsDD3Nj/88MMqLi7WsWPHfvTjrl27Vn/4wx80Y8YMPfroox7rli5dqi5duiggIED169dXgwYNtG3btjKHzipr7969unz5cpkz9Zo2baq7775b27Zt82h3OBwaNmyYR1uHDh108uTJ647TsGFD9ejRQ1u3bpX0/awjNDRUU6ZMUVFRkXbv3i3p+9lRRbOgyurUqZN7tiFJAQEBuuOOO25YY0VeeuklNWjQwL388G93+/btatu2rbp16+Zxn9GjR8sYo+3bt3v0rex7BN5BCNVgDodDjzzyiFavXq2lS5fqjjvuUJ8+fcrtu3//fiUkJEiSli9frn/84x/KzMzU9OnTJUmXL1/26B8ZGVnmMUrbzp49e8O6xowZoyNHjujAgQOSvj8UFxMTo7vuuuu6901JSdGf//xn7du3T4MHD1ZYWJj69+/vfpybkZ2drT59+ujUqVN66aWX9OGHHyozM1OLFi2SVHabK2vHjh0aPXq0Hn744TJnYM2fP1+PP/64unfvrvXr12vfvn3KzMzUoEGDfvR4pc93eWfLRUVFlXk9GjZsqICAAI82p9Op77777oZjDRgwQPv27VNBQYG2bt2qu+++W2FhYYqNjdXWrVt1/PhxHT9+/CeHUFhYWJk2p9N5w+eoNLiuDaukpCRlZmYqMzNTXbp08Vh39uzZCp+70vXSzb9H4B2EUA03evRoffPNN1q6dKkeeeSRCvu98cYbatCggd555x3df//96tmzp+Li4irsf/r06QrbytuBlFeXn5+fVqxYoaysLB06dEhjxowpc+LAterXr6/Jkyfro48+0rlz57R27Vrl5OTonnvu0aVLlyR9/19zeR+yf/PNNx63N27cqIKCAr355psaOXKkevfurbi4OPn7+9+w/op8/PHHGj58uPr166fly5eXWb969WrFx8dryZIlSkxMVPfu3RUXF6cLFy786DFLn++vv/66zLqvvvpKt956649+7Gv1799fRUVF2rVrl7Zt26aBAwe62zMyMtyfyfTv399rY96M0no2bdrk0R4eHq64uDjFxcWV+XpCWFhYhc+dJPfzd7PvEXgHIVTD3XbbbZoyZYqGDRumUaNGVdjP4XCofv3
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Define the colormap for the grid values\n",
|
||
|
"cmap = colors.ListedColormap(['black', 'red', (0.5451, 0.2706, 0.0745), 'blue', 'gray', (0,1,0)])\n",
|
||
|
"# Bounds now account for the actual range of values, with small gaps between to handle exact matching\n",
|
||
|
"bounds = [-1000.5, -100.5, -99.5, -49.5, -9, -0.5, 2000.5]\n",
|
||
|
"norm = colors.BoundaryNorm(bounds, cmap.N)\n",
|
||
|
"\n",
|
||
|
"# Create the plot\n",
|
||
|
"plt.imshow(rewards, cmap=cmap, norm=norm)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Display the plot\n",
|
||
|
"plt.title(\"Map Visualization with Goal\")\n",
|
||
|
"plt.show()\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Visualization Functions\n",
|
||
|
"<br>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def graph(q_table, save=False, title=\"\"):\n",
|
||
|
" # Define the colormap for the grid values\n",
|
||
|
" #fig, ax = plt.subplots(figsize=(8, 8), dpi=200) # Increased figure size and DPI\n",
|
||
|
"\n",
|
||
|
" cmap = colors.ListedColormap(['black', 'red', (0.5451, 0.2706, 0.0745), 'blue', 'gray', (0,1,0)])\n",
|
||
|
" # Bounds now account for the actual range of values, with small gaps between to handle exact matching\n",
|
||
|
" bounds = [-1000.5, -100.5, -99.5, -49.5, -9, -0.5, 1000.5]\n",
|
||
|
" norm = colors.BoundaryNorm(bounds, cmap.N)\n",
|
||
|
"\n",
|
||
|
" \n",
|
||
|
" # Create the plot for rewards\n",
|
||
|
" plt.imshow(rewards, cmap=cmap, norm=norm)\n",
|
||
|
" \n",
|
||
|
" # Calculate the optimal direction from Q-table\n",
|
||
|
" # Directions: up (0), right (1), down (2), left (3)\n",
|
||
|
" optimal_directions = np.argmax(q_table, axis=2)\n",
|
||
|
" \n",
|
||
|
" # Initialize arrays for arrow direction (dx, dy) at each grid point\n",
|
||
|
" dx = np.zeros_like(optimal_directions, dtype=float)\n",
|
||
|
" dy = np.zeros_like(optimal_directions, dtype=float)\n",
|
||
|
" \n",
|
||
|
" # Define movement deltas for [up, right, down, left]\n",
|
||
|
" move_map = {\n",
|
||
|
" 0: (0, -1), # up\n",
|
||
|
" 1: (1, 0), # right\n",
|
||
|
" 2: (0, 1), # down\n",
|
||
|
" 3: (-1, 0), # left\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" # Fill in dx, dy based on optimal directions, but only if the sum of Q-values is not zero\n",
|
||
|
" for i in range(optimal_directions.shape[0]):\n",
|
||
|
" for j in range(optimal_directions.shape[1]):\n",
|
||
|
" if np.sum(q_table[i, j]) != 0: # Check if the Q-values are non-zero\n",
|
||
|
" direction = optimal_directions[i, j]\n",
|
||
|
" dx[i, j], dy[i, j] = move_map[direction]\n",
|
||
|
" \n",
|
||
|
" # Create a meshgrid for plotting arrows\n",
|
||
|
" x, y = np.meshgrid(np.arange(optimal_directions.shape[1]), np.arange(optimal_directions.shape[0]))\n",
|
||
|
" \n",
|
||
|
" # Plot arrows using quiver, only for non-zero vectors\n",
|
||
|
" plt.quiver(x, y, dx, dy, angles='xy', scale_units='xy', scale=1, color='black')\n",
|
||
|
" plt.title(title)\n",
|
||
|
"\n",
|
||
|
" if save:\n",
|
||
|
" timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
|
||
|
" filename = f\"images/plot_{timestamp}.png\"\n",
|
||
|
" plt.savefig(filename, format='png')\n",
|
||
|
" \n",
|
||
|
" # Display the plot with arrows\n",
|
||
|
" plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def graph_path(path):\n",
|
||
|
" # Define the colormap for the grid values\n",
|
||
|
" cmap = colors.ListedColormap(['black', 'red', (0.5451, 0.2706, 0.0745), 'blue', 'gray', (0,1,0)])\n",
|
||
|
" bounds = [-1000.5, -100.5, -99.5, -49.5, -9, -0.5, 1000.5]\n",
|
||
|
" norm = colors.BoundaryNorm(bounds, cmap.N)\n",
|
||
|
"\n",
|
||
|
" # Create the plot for rewards\n",
|
||
|
" plt.imshow(rewards, cmap=cmap, norm=norm)\n",
|
||
|
"\n",
|
||
|
" move_map = {\n",
|
||
|
" 0: (0, -1), # up\n",
|
||
|
" 1: (1, 0), # right\n",
|
||
|
" 2: (0, 1), # down\n",
|
||
|
" 3: (-1, 0), # left\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" # Now plot the path taken by the robot\n",
|
||
|
" path_x = [pos[1] for pos in path]\n",
|
||
|
" path_y = [pos[0] for pos in path]\n",
|
||
|
" \n",
|
||
|
" # Create arrows for the robot's path\n",
|
||
|
" for i in range(len(path) - 1):\n",
|
||
|
" start_x, start_y = path_x[i], path_y[i]\n",
|
||
|
" end_x, end_y = path_x[i + 1], path_y[i + 1]\n",
|
||
|
" plt.arrow(start_x, start_y, end_x - start_x, end_y - start_y, color='yellow', head_width=0.2)\n",
|
||
|
"\n",
|
||
|
" # Display the plot with arrows\n",
|
||
|
" plt.show()\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Q-Learning helper functions\n",
|
||
|
"<br>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# define actions\n",
|
||
|
"# we will use numeric (index) to represent the actions\n",
|
||
|
"# 0 = up, 1 = right, 2 = down, 3 = left\n",
|
||
|
"actions = ['up', 'right', 'down', 'left']"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# because we will end the episode if we reach Goal\n",
|
||
|
"def is_terminal_state(current_row_index, current_column_index):\n",
|
||
|
" if rewards[current_row_index, current_column_index] != np.max(rewards): # it is not terminal if the rewards is -1\n",
|
||
|
" return False\n",
|
||
|
" else:\n",
|
||
|
" return True\n",
|
||
|
"\n",
|
||
|
"# this starting location must not be on the road\n",
|
||
|
"def get_starting_location():\n",
|
||
|
" current_row_index = np.random.randint(environment_rows) # get a random row index\n",
|
||
|
" current_column_index = np.random.randint(environment_columns) # get a random column index\n",
|
||
|
" \n",
|
||
|
" while rewards[current_row_index, current_column_index] != -1: # True if it is terminal\n",
|
||
|
" current_row_index = np.random.randint(environment_rows) # repeat to get another random row index\n",
|
||
|
" current_column_index = np.random.randint(environment_columns) # repeat to get another random row index\n",
|
||
|
" return current_row_index, current_column_index # returns a random starting location that is not terminal\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# define an epsilon greedy algorithm for deciding the next action\n",
|
||
|
"def get_next_action(current_row_index, current_column_index, epsilon):\n",
|
||
|
" if np.random.random() < epsilon: # choose the action with the highest q_values\n",
|
||
|
" return np.random.randint(4)\n",
|
||
|
" else: # choose a random action\n",
|
||
|
" return np.argmax(q_values[current_row_index, current_column_index])\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# define a function that will get the next location based on the chosen action\n",
|
||
|
"# refer to how the board is drawn physically, with the rows and columns\n",
|
||
|
"def get_next_location(current_row_index, current_column_index, action_index):\n",
|
||
|
" new_row_index = current_row_index\n",
|
||
|
" new_column_index = current_column_index\n",
|
||
|
" if actions[action_index] == 'up' and current_row_index > 0:\n",
|
||
|
" new_row_index -= 1\n",
|
||
|
" elif actions[action_index] == 'right' and current_column_index < environment_columns - 1:\n",
|
||
|
" new_column_index += 1\n",
|
||
|
" elif actions[action_index] == 'down' and current_row_index < environment_rows - 1:\n",
|
||
|
" new_row_index += 1\n",
|
||
|
" elif actions[action_index] == 'left' and current_column_index > 0:\n",
|
||
|
" new_column_index -= 1\n",
|
||
|
" return new_row_index, new_column_index\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Define a function that will get the shortest path that is on the white tiles \n",
|
||
|
"def get_shortest_path(start_row_index, start_column_index):\n",
|
||
|
" i = 0\n",
|
||
|
" if is_terminal_state(start_row_index, start_column_index): # check if it is on Goal or Cliff\n",
|
||
|
" return [] # if yes, there are no available steps\n",
|
||
|
" \n",
|
||
|
" else: #if this is a 'legal' starting location\n",
|
||
|
" current_row_index, current_column_index = start_row_index, start_column_index\n",
|
||
|
" shortest_path = []\n",
|
||
|
" shortest_path.append([current_row_index, current_column_index]) # add the current coordinate to the list\n",
|
||
|
"\n",
|
||
|
" while not is_terminal_state(current_row_index, current_column_index): # repeat until we reach Goal or Cliff\n",
|
||
|
" action_index = get_next_action(current_row_index, current_column_index, 1.) \n",
|
||
|
" # get next coordinate \n",
|
||
|
" \n",
|
||
|
" current_row_index, current_column_index = get_next_location(current_row_index, current_column_index, action_index)\n",
|
||
|
" # update that next coordinate as current coordinate\n",
|
||
|
" \n",
|
||
|
" shortest_path.append([current_row_index, current_column_index]) \n",
|
||
|
" # add the current coordinate to the list\n",
|
||
|
"\n",
|
||
|
" i += 1\n",
|
||
|
" if i > 100:\n",
|
||
|
" return 0;\n",
|
||
|
" return shortest_path"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def q_learn_single(epsilon = 0.9, discount_factor = 0.9, learning_rate = 0.9, epochs = 1000,):\n",
|
||
|
" q_values = np.zeros((environment_rows, environment_columns, 4))\n",
|
||
|
" \n",
|
||
|
" for episode in tqdm(range(epochs), desc=\"Training Progress\", unit=\"epochs\", ncols=100): # Adjust `ncols` to shorten the bar\n",
|
||
|
" row_index, column_index = get_starting_location()\n",
|
||
|
"\n",
|
||
|
" while not is_terminal_state(row_index, column_index):\n",
|
||
|
" # choose which action to take (i.e., where to move next)\n",
|
||
|
" action_index = get_next_action(row_index, column_index, epsilon)\n",
|
||
|
"\n",
|
||
|
" # perform the chosen action, and transition to the next state / next location\n",
|
||
|
" old_row_index, old_column_index = row_index, column_index # store the old row and column indexes\n",
|
||
|
" row_index, column_index = get_next_location(row_index, column_index, action_index)\n",
|
||
|
"\n",
|
||
|
" # receive the reward for moving to the new state, and calculate the temporal difference\n",
|
||
|
" reward = rewards[row_index, column_index]\n",
|
||
|
" old_q_value = q_values[old_row_index, old_column_index, action_index]\n",
|
||
|
" temporal_difference = reward + (discount_factor * np.max(q_values[row_index, column_index])) - old_q_value\n",
|
||
|
"\n",
|
||
|
" # update the Q-value for the previous state and action pair\n",
|
||
|
" new_q_value = old_q_value + (learning_rate * temporal_difference)\n",
|
||
|
" q_values[old_row_index, old_column_index, action_index] = new_q_value\n",
|
||
|
"\n",
|
||
|
" print('Training complete!')\n",
|
||
|
"\n",
|
||
|
" return q_values\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def q_learn_single(epsilon = 0.9, discount_factor = 0.9, learning_rate = 0.9, epochs = 1000):\n",
|
||
|
" # Initialize the Q-table with zeros for each state-action pair\n",
|
||
|
" # The shape is (environment_rows, environment_columns, 4) \n",
|
||
|
" # where 4 represents 4 possible actions (e.g., up, down, left, right)\n",
|
||
|
" q_values = np.zeros((environment_rows, environment_columns, 4))\n",
|
||
|
" \n",
|
||
|
" # Iterate through a number of episodes (i.e., learning cycles)\n",
|
||
|
" for episode in tqdm(range(epochs), desc=\"Training Progress\", unit=\"epochs\", ncols=100):\n",
|
||
|
" # Start each episode by selecting a random starting location in the environment\n",
|
||
|
" row_index, column_index = get_starting_location()\n",
|
||
|
"\n",
|
||
|
" # Continue taking actions until the agent reaches a terminal state\n",
|
||
|
" while not is_terminal_state(row_index, column_index):\n",
|
||
|
" # Choose the next action based on an epsilon-greedy policy\n",
|
||
|
" # This function should balance exploration (random) vs exploitation (best known action)\n",
|
||
|
" action_index = get_next_action(row_index, column_index, epsilon)\n",
|
||
|
"\n",
|
||
|
" # Save the old position before taking the action\n",
|
||
|
" old_row_index, old_column_index = row_index, column_index\n",
|
||
|
" \n",
|
||
|
" # Move to the new state based on the chosen action\n",
|
||
|
" row_index, column_index = get_next_location(row_index, column_index, action_index)\n",
|
||
|
"\n",
|
||
|
" # Get the reward for the new state the agent has moved to\n",
|
||
|
" reward = rewards[row_index, column_index]\n",
|
||
|
" \n",
|
||
|
" # Retrieve the Q-value of the old state-action pair\n",
|
||
|
" old_q_value = q_values[old_row_index, old_column_index, action_index]\n",
|
||
|
"\n",
|
||
|
" # Calculate the temporal difference: \n",
|
||
|
" # TD = Reward + Discount * (Max Q-value for the next state) - Old Q-value\n",
|
||
|
" temporal_difference = reward + (discount_factor * np.max(q_values[row_index, column_index])) - old_q_value\n",
|
||
|
"\n",
|
||
|
" # Update the Q-value for the previous state-action pair using the learning rate\n",
|
||
|
" new_q_value = old_q_value + (learning_rate * temporal_difference)\n",
|
||
|
" q_values[old_row_index, old_column_index, action_index] = new_q_value # Assign updated value\n",
|
||
|
"\n",
|
||
|
" # After all episodes, print a message indicating the training is complete\n",
|
||
|
" print('Training complete!')\n",
|
||
|
"\n",
|
||
|
" # Return the Q-values for further use (e.g., evaluation or exploitation phase)\n",
|
||
|
" return q_values\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# single episode\n",
|
||
|
"\n",
|
||
|
"def eposode(epsilon, discount_factor, learning_rate, epochs):\n",
|
||
|
" for episode in range(epochs):\n",
|
||
|
" row_index, column_index = get_starting_location()\n",
|
||
|
" \n",
|
||
|
" while not is_terminal_state(row_index, column_index):\n",
|
||
|
" # choose which action to take (i.e., where to move next)\n",
|
||
|
" action_index = get_next_action(row_index, column_index, epsilon)\n",
|
||
|
" \n",
|
||
|
" # perform the chosen action, and transition to the next state / next location\n",
|
||
|
" old_row_index, old_column_index = row_index, column_index # store the old row and column indexes\n",
|
||
|
" row_index, column_index = get_next_location(row_index, column_index, action_index)\n",
|
||
|
" \n",
|
||
|
" # receive the reward for moving to the new state, and calculate the temporal difference\n",
|
||
|
" reward = rewards[row_index, column_index]\n",
|
||
|
" old_q_value = q_values[old_row_index, old_column_index, action_index]\n",
|
||
|
" temporal_difference = reward + (discount_factor * np.max(q_values[row_index, column_index])) - old_q_value\n",
|
||
|
" \n",
|
||
|
" # update the Q-value for the previous state and action pair\n",
|
||
|
" new_q_value = old_q_value + (learning_rate * temporal_difference)\n",
|
||
|
" q_values[old_row_index, old_column_index, action_index] = new_q_value\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def q_learn_multi(epsilon=0.9, discount_factor=0.9, learning_rate=0.9, epochs=250, threads = 4):\n",
|
||
|
" \n",
|
||
|
" thread_array = []\n",
|
||
|
"\n",
|
||
|
" \n",
|
||
|
" for num in range(threads):\n",
|
||
|
" thread = threading.Thread(target=eposode, args=(epsilon, discount_factor, learning_rate, epochs))\n",
|
||
|
" thread_array.append(thread)\n",
|
||
|
" thread.start()\n",
|
||
|
"\n",
|
||
|
" for thread in thread_array:\n",
|
||
|
" thread.join()\n",
|
||
|
" print('Training complete!')\n",
|
||
|
"\n",
|
||
|
" return q_values\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Q-Learning Multi-threaded\n",
|
||
|
"<br>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Training complete!\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"q_values = np.zeros((environment_rows, environment_columns, 4))\n",
|
||
|
"\n",
|
||
|
"q_values = q_learn_multi(0.7, 0.6, 0.1, 500, 12)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAa0AAAGxCAYAAADRQunXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABxN0lEQVR4nO3deXhU5dn48e9kkkwWkgBZSEKWyTbDLrK8FBABURQQtVrrvtRKacWF8qu11iroW6Xq29YFl2JbxVpcqqKCK5VFNgWRRQFnErKQEAgJ2fdM5vz+GGbMZIEwmTknk9yf65oLzpk5cz9n8pxzP89zNp2iKApCCCGEHwjQugBCCCFEd0nSEkII4TckaQkhhPAbkrSEEEL4DUlaQggh/IYkLSGEEH5DkpYQQgi/IUlLCCGE35CkJYQQwm/0yaR16623YjQa3eY99thjvPfeex0+u2nTJnQ6HZs2bTrj927fvp1ly5ZRWVnZ4T2j0cill17qWYFV9sorr6DT6cjPz9e6KN2Wn5+PTqfjlVdecc3zh/V49tlnGTZsGAaDgbS0NB5++GFaWlrOuNyyZcvQ6XRdvt544w2vlbGz39YZv695/vnn3dbzbDQ3N/PLX/6ShIQE9Ho9Y8eO9WrZAD766COWLVvm9e/1hjfeeIOxY8cSEhJCYmIiixcvpra2ttvLFxQUcNttt5GYmIjBYGDo0KH8+Mc/PvuCKH3QLbfcoqSmprrNCw8PV2655ZYOn62qqlJ27NihVFVVnfF7n3zySQVQ8vLyOryXmpqqzJs3z8MSq+vll1/ucj16q8bGRmXHjh3KiRMnXPN6+3r88Y9/VHQ6nXL//fcrGzduVJ544gklODhYWbBgwRmXLSwsVHbs2NHhNWrUKCU0NFSpqKjwWjnz8vIUQHn55Zc7xO9rRo4cqUyfPt2jZZ966ikFUJ599lll+/btyv79+71bOEVRFi1apPTG3fJrr72mAMrtt9+ubNiwQXnxxReVqKgo5aKLLurW8t9++60SHR2tTJw4Ufn3v/+tbN68WXnjjTeUn/3sZ2ddlsCzT3N9S2RkJD/60Y80LUNDQwMhISF9smXrLQaDQfO/09k4efIkf/zjH1mwYAGPPfYYADNmzKClpYU//OEPLF68mBEjRnS5fFJSEklJSW7z8vPzOXDgADfccAMDBw70ZfE7jd/ffffdd4SGhnLnnXdqXZSz1tDQQGhoqEfLtra2cu+99zJ79mxeeuklAGbOnElERAQ33HADH3/8MXPmzOlyeUVRuOmmm0hOTmbLli0YDAbXe9dcc81Zl0e14UHncMP+/fu5+uqriYqKYvDgwSxZsgSbzYbFYuGSSy4hIiICo9HIE0884bZ8V0NB3Rne0+l01NXVsWrVKtfwyowZM7q9vLP89957LwBpaWmu72m/3CeffMK4ceMIDQ1l2LBh/POf/+x0PT777DNuu+02YmNjCQsLo6mpCYA333yTyZMnEx4ezoABA7j44ovZs2eP23d8/fXXXHvttRiNRkJDQzEajVx33XUUFBR0KPeXX37J1KlTXV36+++/v1vDU2fSnXLeeuutDBgwgAMHDjBr1izCw8OJjY3lzjvvpL6+3u2z//nPf5g0aRJRUVGEhYWRnp7Obbfd5nq/syGsrvzzn//knHPOISQkhMGDB/PjH/+YQ4cOdVq2nJwc5s6dy4ABA0hOTub//b//5/pb9MQnn3xCY2MjP/vZz9zm/+xnP0NRlE6Hqs/kn//8J4qicPvtt3tcruLiYn76058SERFBVFQU11xzDcePH+/wuc6GBzds2MCMGTOIjo4mNDSUlJQUrrrqKre/ZVNTE4888gjDhw8nJCSE6OhoZs6cyfbt212faWxs5P777yctLY3g4GCGDh3KokWLOgy763S6TofKjEYjt956q2vauU1t3LiRX/3qV8TExBAdHc2VV15JcXGx23IHDhxg8+bNru23/WGEruh0Ov7+97/T0NDgWtZZF5977jnOP/984uLiCA8PZ/To0TzxxBOdbmeffPIJs2bNctXz4cOHs3z5csBRJ5977jlXPOfLuc/r7u/mPFTx7rvvcu655xISEsLDDz/crfXszJdffsmxY8c61OWrr76aAQMGsGbNmtMu/8UXX7B3714WL17slrA8pfoxrZ/+9Kecc845vPPOOyxYsIC//vWv/PrXv+aKK65g3rx5rFmzhgsuuID77ruPd9991ysxd+zYQWhoKHPnzmXHjh3s2LGD559//qy+4/bbb+euu+4C4N1333V9z7hx41yf2bdvH//v//0/fv3rX/P+++8zZswYfv7zn/PFF190+L7bbruNoKAg/vWvf/H2228TFBTEY489xnXXXceIESN46623+Ne//kVNTQ3Tpk3j4MGDrmXz8/Mxm8089dRTfPrppzz++OMcO3aMiRMnUlZW5vrcwYMHmTVrFpWVlbzyyiu8+OKL7Nmzhz/+8Y8dyuPc8LuTFLpbToCWlhbmzp3LrFmzeO+997jzzjv529/+5tbC2rFjB9dccw3p6em88cYbfPjhhzz00EPYbLYzlqW95cuX8/Of/5yRI0fy7rvv8vTTT7N//34mT55MdnZ2h7JddtllzJo1i/fff5/bbruNv/71rzz++ONun2ttbcVms53xZbfbXct89913AIwePdrtuxISEoiJiXG93112u51XXnmFzMxMpk+fflbLOjU0NHDhhRfy2WefsXz5cv7zn/8QHx/frdZufn4+8+bNIzg4mH/+85988skn/OlPfyI8PJzm5mYAbDYbc+bM4X//93+59NJLWbNmDa+88gpTpkzhyJEjgKPVfcUVV/B///d/3HTTTXz44YcsWbKEVatWccEFF/SowXD77bcTFBTE6tWreeKJJ9i0aRM33nij6/01a9aQnp7Oueee69p+z7TDddqxYwdz584lNDTUtey8efMAOHz4MNdffz3/+te/WLduHT//+c958sknWbhwodt3/OMf/2Du3LnY7XZefPFF1q5dy913301RUREADz74ID/5yU9c8ZyvhISEs/7dvvnmG+69917uvvtuPvnkE6666iqgZ3V5zJgxbjGCgoIYNmzYGeuyc/8XERHB3LlzCQkJYcCAAVx66aV8//333fr93Zz1gKKHli5dqgDKn//8Z7f5Y8eOVQDl3Xffdc1raWlRYmNjlSuvvNI1r6vjFxs3blQAZePGja55Z3NMq7Plu3KmY1ohISFKQUGBa15DQ4MyePBgZeHChR3W4+abb3Zb/siRI0pgYKBy1113uc2vqalR4uPjlZ/+9Kddlstmsym1tbVKeHi48vTTT7vmX3PNNUpoaKhy/Phxt88OGzasw3qsWrVK0ev1yqpVq077G5xNOW+55RYFcCuToijKo48+qgDK1q1bFUVRlP/7v/9TAKWysrLLuJ0dd2lfJyoqKpTQ0FBl7ty5HcpsMBiU66+/vkPZ3nrrLbfPzp07VzGbzW7zUlNTFeCMr6VLl7qWWbBggWIwGDpdF5PJpMyePbvLde3Mxx9/rADK8uXLz2q5tl544QUFUN5//323+QsWLOjw2zq3V6e3335bAZS9e/d2+f2vvvqqAigvvfRSl5/55JNPFEB54okn3Oa/+eabCqCsXLnSNa/9b+qUmprqti0768Edd9zh9rknnnhCAZRjx4655vXkmNYtt9yihIeHn/Yzra2tSktLi/Lqq68qer1eKS8vVxTFsX1ERkYq5513nmK327tcvqtjWmfzu6Wmpip6vV6xWCwdvmf69Ondqsttf1/n9tr2d3SaPXu2YjKZTvubLFy4UAGUyMhI5ec//7ny3//+V/nXv/6lpKamKjExMUpxcfFpl29P9WNa7c+wGz58OPv27XMbEw0MDCQzM7PT4S5fUxSF1tZWt3mBgd37mcaOHUtKSoprOiQkBJPJ1Ol6OFs+Tp9++ik2m42bb77ZrYcREhLC9OnT2bhxo2tebW0t//u//8s777xDfn6+W3nbDoNt3LiRWbNmMWTIENc8vV7PNddc02G44Oabb+bmm28+4zqeTTmdbrjhBrfp66+/ngceeICNGzcydepUJk6cCDh64T/
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"graph(q_values, save=True, title=\"multi-thread: epsilon=0.7, discount_factor=0.6\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Q-Learning Single Threaded\n",
|
||
|
"<br>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Training Progress: 100%|████████████████████████████████████| 1000/1000 [00:37<00:00, 26.43epochs/s]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Training complete!\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"q_values = np.zeros((environment_rows, environment_columns, 4))\n",
|
||
|
"\n",
|
||
|
"q_values = q_learn_single(0.9, 0.7, 0.1, 1000)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGxCAYAAAADEuOPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB0aUlEQVR4nO3dd3xUZdr4/8/MJJkUkhDSG+kz1IAUWUAELLCsYt191hXb6mN5xMKyxbar6K5i2fWHimXxeb6urgvqqqigqFkQRERBFAXEmYQUEgIppPdM5vz+mOSYSYEwmTmTIdf79ZoXnDbXfSb3Ode579N0iqIoCCGEED5G7+0CCCGEEK6QBCaEEMInSQITQgjhkySBCSGE8EmSwIQQQvgkSWBCCCF8kiQwIYQQPkkSmBBCCJ8kCUwIIYRP8noCu+6660hNTfV4nNTUVK677jq3fudzzz3HP/7xj17jt27dik6n480333RrPE+ZN28e8+bN83YxTsmKFSvQ6XRO44b6epSXl3PdddcRFRVFcHAwM2fOZPPmzQNe/l//+hdnnHEGgYGBREVFceWVV1JcXOz2cvb123pi+/G20tJSVqxYwd69e11a/ptvvmHu3LmEh4ej0+lYtWqVW8sH8Mgjj/DOO++4/XsHq6GhgWXLlpGQkEBgYCCTJ0/mtddeO6XvePfdd5k7dy5hYWGEhIQwfvx41qxZc0rf4fUE9qc//Yn169d7uxgu6S+BCc/77//+b3bu3OntYgxYa2sr5557Lps3b+app57i3XffJTY2lp/+9Kds27btpMs/88wzXHXVVUybNo13332Xxx57jK1btzJnzhyqq6s9Xv7169fzpz/9yeNxtFRaWsqDDz7ocgK7/vrrOXr0KK+99ho7d+7kiiuucG8BGboJ7LLLLuPll1/mgQceYNOmTUyfPp1f/epXrF27dkDLP/roo1x22WVMmDCBN954g/fee49bb72Vtra2UyuIMkykpKQo1157rVu/c/z48crcuXN7jf/kk08UQPn3v//t0vc2NjYOsmSnZu7cuX2uh68Zyuvx7LPPKoDy+eefq+Pa29uVcePGKWeeeeYJl21paVHCw8OVxYsXO43//PPPFUC599573VrWBx54QBkOu4bdu3crgPLSSy+5tLyfn5/yP//zP+4tVA8hISFu32+1tbUp7e3tLi///vvvK4Cydu1ap/Hnn3++kpCQoNhsthMu/9VXXyl6vV557LHHXC5DF4+2wCoqKrjppptITk7GaDQSHR3N7Nmz+c9//qPO01cXok6n47bbbuOf//wnY8eOJTg4mEmTJrFx48ZeMd59912ys7MxGo2kp6fz1FNP9dkF0pe6ujp+97vfkZaWRkBAAImJiSxbtozGxsaTLpuamsqBAwfYtm0bOp0OnU7Xaz3a29u57777SEhIICwsjPPOOw+LxeI0z7x585gwYQKffvops2bNIjg4mOuvv/6Uyvfss89y9tlnExMTQ0hICBMnTuTxxx+nvb3daT5FUXj88cdJSUkhMDCQKVOmsGnTppOu68kMtJxdf9e///3vmEwmjEYj48aN69X10NTUpH5fYGAgo0aNYtq0aaxbt06dZ6B/46qqKm699VYSExMJCAggPT2d++67j9bW1j7LNpA654r169djNpuZOXOmOs7Pz4+rrrqKXbt2ceTIkX6X3b9/P7W1tfzsZz9zGj9z5kxGjRrFW2+95XK53n//fSZPnozRaCQtLY2//vWvfc7XswvRbrfzl7/8BbPZTFBQECNHjiQ7O5unnnrKabkffviBX/3qV8TGxmI0Ghk9ejTXXHON0++/f/9+Lr74YiIiItTuqJdfftnpe/7xj3+g0+koLCx0Gt/VXb9161Z1XNc2tXv3bubMmUNwcDDp6ek8+uij2O12dbnp06cD8Otf/1rdhlesWHHS36yrLDabjeeff15dFhz7vFtvvZVx48YxYsQIYmJiOOecc9i+fXuv72ltbeWhhx5i7NixBAYGEhkZyfz58/n8888BR51sbGzk5ZdfVmN07yIfyO/W9fv885//5Le//S2JiYkYjUby8vJOup79Wb9+PSNGjOAXv/iF0/hf//rXlJaW8uWXX55w+dWrV2M0Grn99ttdLkMXv0F/wwlcffXVfP311zz88MOYTCZqamr4+uuvOX78+EmXff/999m9ezcPPfQQI0aM4PHHH+fSSy/FYrGQnp4OwIcffshll13G2Wefzeuvv47NZuOvf/0rZWVlJ/3+pqYm5s6dS0lJCffeey/Z2dkcOHCA+++/n3379vGf//znhDvI9evX8/Of/5zw8HCee+45AIxGo9M89957L7Nnz+Z///d/qaur46677mLx4sUcPHgQg8Ggznf06FGuuuoq/vCHP/DII4+g1+tPqXyHDh3iyiuvVBPIt99+y8MPP8wPP/zA//t//0+N8+CDD/Lggw9yww038POf/5zi4mJuvPFGOjo6MJvNTmWfN28e27ZtQznJywpO9Xd87733+OSTT3jooYcICQnhueee41e/+hV+fn78/Oc/B2D58uX885//5C9/+QtnnHEGjY2N7N+/f0D1pruWlhbmz5/PoUOHePDBB8nOzmb79u2sXLmSvXv38v777zvNP5A6pygKHR0dA4rv5/fj5rV//37mzJnTa57s7GwADhw4QGJiYp/f09Wt0rN+dY3Lzc2lpaWFwMDAAZWry+bNm7n44ouZOXMmr732Gh0dHTz++OMD2n4ef/xxVqxYwR//+EfOPvts2tvb+eGHH6ipqVHn+fbbbznrrLOIiorioYceIisri6NHj/Lee+/R1taG0WjEYrEwa9YsYmJiePrpp4mMjOTVV1/luuuuo6ysjD/84Q+ntE5djh07xpIlS/jtb3/LAw88wPr167nnnntISEjgmmuuYcqUKbz00kv8+te/5o9//CMXXHABAElJSSf97gsuuICdO3cyc+ZMfv7zn/Pb3/5WnVZVVQXAAw88QFxcHA0NDaxfv5558+axefNmNQHZbDYWLVrE9u3bWbZsGeeccw42m40vvviCw4cPM2vWLHbu3Mk555zD/Pnz1e7bsLAwgFP+3e655x5mzpzJCy+8gF6vJyYmRi3HQBgMBnU73r9/P2PHjnWq3/BjXd6/fz+zZs3q97s+/fRTxo4dy1tvvcWf//xn8vLyiI+P56qrruKhhx4iICBgQGUCPNtPMGLECGXZsmUnnOfaa69VUlJSnMYBSmxsrFJXV6eOO3bsmKLX65WVK1eq46ZPn64kJycrra2t6rj6+nolMjKyVxdIzy7ElStXKnq9Xtm9e7fTfG+++aYCKB988MFJ1+9kXYg/+9nPnMa/8cYbCqDs3LlTHTd37lwFUDZv3uw0r6vl6+joUNrb25VXXnlFMRgMSlVVlaIoilJdXa0EBgYql156qdP8O3bsUIBe63HOOecoBoPhhOt/quUElKCgIOXYsWPqOJvNpowZM0bJzMxUx02YMEG55JJLThi3r26unl2IL7zwggIob7zxhtN8jz32mAIoH3/8sVPZBlLnuv62A/kUFBSoy/n7+ys333xzr/Xo6gbs2R3T3fHjxxW9Xq/ccMMNTuPz8vLUWKWlpf0u358ZM2YoCQkJSnNzszqurq5OGTVq1Em3nwsvvFCZPHnyCb//nHPOUUaOHKmUl5f3O88VV1yhGI1G5fDhw07jFy1apAQHBys1NTWKoijKSy+91Os3VZQf/x6ffPKJOq5rm/ryyy+d5h03bpyycOFCdXiwXYiAsnTp0hPOY7PZlPb2duXcc8912vZeeeUVBVBefPHFEy7fXxfiQH+3rt/n7LPP7vUdBQUFA67L3X/frKwsp9+xS2lpqQIojzzyyAnXyWg0KqGhoUpERISyevVqZcuWLcp9992nGAwG5corrzzhsj15tAvxzDPP5B//+Ad/+ctf+OKLL3p1aZ3I/PnzCQ0NVYdjY2OJiYmhqKgIgMbGRr766isuueQSp4w9YsQIFi9efNLv37hxIxMmTGDy5MnYbDb1s3DhQqcuCbvd7jR9oEffABdddJHTcNcRStc6dImIiOCcc85xqXzguBrqoos
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"graph(q_values, save=True, title=\"single-thread: epsilon=0.9, discount_factor=0.6\")"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|