AI-Stock-Predictor/Train Stock Model LSTM Neural Network.ipynb

1243 lines
156 KiB
Plaintext
Raw Normal View History

2024-10-08 01:24:55 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b29811c3",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>Date</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>4608</th>\n",
" <td>39.80</td>\n",
" <td>40.22</td>\n",
" <td>39.77</td>\n",
" <td>40.09</td>\n",
" <td>40.09</td>\n",
" <td>27456000</td>\n",
" <td>24-09-13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4609</th>\n",
" <td>40.00</td>\n",
" <td>40.23</td>\n",
" <td>39.58</td>\n",
" <td>39.89</td>\n",
" <td>39.89</td>\n",
" <td>12162500</td>\n",
" <td>24-09-16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4610</th>\n",
" <td>39.70</td>\n",
" <td>40.08</td>\n",
" <td>39.31</td>\n",
" <td>39.49</td>\n",
" <td>39.49</td>\n",
" <td>17906900</td>\n",
" <td>24-09-17</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4611</th>\n",
" <td>39.76</td>\n",
" <td>40.99</td>\n",
" <td>39.02</td>\n",
" <td>39.06</td>\n",
" <td>39.06</td>\n",
" <td>41241400</td>\n",
" <td>24-09-18</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4612</th>\n",
" <td>40.10</td>\n",
" <td>40.25</td>\n",
" <td>39.26</td>\n",
" <td>39.72</td>\n",
" <td>39.72</td>\n",
" <td>22277500</td>\n",
" <td>24-09-19</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Open High Low Close Adj Close Volume Date\n",
"4608 39.80 40.22 39.77 40.09 40.09 27456000 24-09-13\n",
"4609 40.00 40.23 39.58 39.89 39.89 12162500 24-09-16\n",
"4610 39.70 40.08 39.31 39.49 39.49 17906900 24-09-17\n",
"4611 39.76 40.99 39.02 39.06 39.06 41241400 24-09-18\n",
"4612 40.10 40.25 39.26 39.72 39.72 22277500 24-09-19"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import pandas_ta as ta\n",
"import IPython\n",
"\n",
"# Uncomment for Interactive Graphs\n",
"#%matplotlib widget\n",
"\n",
"name = \"GDX\"\n",
"data = pd.read_csv(\"data/\"+name + \".csv\")\n",
"data.tail(5)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "68e700e1",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>index</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>Date</th>\n",
" <th>Diff_Close</th>\n",
" <th>RSI</th>\n",
" <th>EMAF</th>\n",
" <th>EMAM</th>\n",
" <th>EMAS</th>\n",
" <th>BBL_5_2.0</th>\n",
" <th>BBM_5_2.0</th>\n",
" <th>BBU_5_2.0</th>\n",
" <th>BBB_5_2.0</th>\n",
" <th>BBP_5_2.0</th>\n",
" <th>STOCHk_14_3_3</th>\n",
" <th>STOCHd_14_3_3</th>\n",
" <th>Target1</th>\n",
" <th>Target2</th>\n",
" <th>Target3</th>\n",
" <th>Target4</th>\n",
" <th>Target5</th>\n",
" <th>Target6</th>\n",
" <th>Target7</th>\n",
" <th>Target8</th>\n",
" <th>Target9</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>4403</th>\n",
" <td>4602</td>\n",
" <td>37.66</td>\n",
" <td>37.83</td>\n",
" <td>37.28</td>\n",
" <td>37.33</td>\n",
" <td>37.33</td>\n",
" <td>16260000</td>\n",
" <td>24-09-05</td>\n",
" <td>-0.33</td>\n",
" <td>47.133189</td>\n",
" <td>37.965111</td>\n",
" <td>37.071605</td>\n",
" <td>33.759355</td>\n",
" <td>36.218919</td>\n",
" <td>37.802</td>\n",
" <td>39.385081</td>\n",
" <td>8.375644</td>\n",
" <td>0.350924</td>\n",
" <td>15.729597</td>\n",
" <td>32.100843</td>\n",
" <td>-0.85</td>\n",
" <td>0.15</td>\n",
" <td>0.46</td>\n",
" <td>0.49</td>\n",
" <td>1.15</td>\n",
" <td>0.29</td>\n",
" <td>-0.11</td>\n",
" <td>-0.21</td>\n",
" <td>-0.70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4404</th>\n",
" <td>4603</td>\n",
" <td>37.17</td>\n",
" <td>37.43</td>\n",
" <td>36.22</td>\n",
" <td>36.32</td>\n",
" <td>36.32</td>\n",
" <td>19939200</td>\n",
" <td>24-09-06</td>\n",
" <td>-0.85</td>\n",
" <td>41.326931</td>\n",
" <td>37.791941</td>\n",
" <td>37.042130</td>\n",
" <td>33.784834</td>\n",
" <td>35.778392</td>\n",
" <td>37.290</td>\n",
" <td>38.801608</td>\n",
" <td>8.107312</td>\n",
" <td>0.179150</td>\n",
" <td>9.755741</td>\n",
" <td>18.352188</td>\n",
" <td>0.15</td>\n",
" <td>0.46</td>\n",
" <td>0.49</td>\n",
" <td>1.15</td>\n",
" <td>0.29</td>\n",
" <td>-0.11</td>\n",
" <td>-0.21</td>\n",
" <td>-0.70</td>\n",
" <td>-0.38</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" index Open High Low Close Adj Close Volume Date \\\n",
"4403 4602 37.66 37.83 37.28 37.33 37.33 16260000 24-09-05 \n",
"4404 4603 37.17 37.43 36.22 36.32 36.32 19939200 24-09-06 \n",
"\n",
" Diff_Close RSI EMAF EMAM EMAS BBL_5_2.0 \\\n",
"4403 -0.33 47.133189 37.965111 37.071605 33.759355 36.218919 \n",
"4404 -0.85 41.326931 37.791941 37.042130 33.784834 35.778392 \n",
"\n",
" BBM_5_2.0 BBU_5_2.0 BBB_5_2.0 BBP_5_2.0 STOCHk_14_3_3 \\\n",
"4403 37.802 39.385081 8.375644 0.350924 15.729597 \n",
"4404 37.290 38.801608 8.107312 0.179150 9.755741 \n",
"\n",
" STOCHd_14_3_3 Target1 Target2 Target3 Target4 Target5 Target6 \\\n",
"4403 32.100843 -0.85 0.15 0.46 0.49 1.15 0.29 \n",
"4404 18.352188 0.15 0.46 0.49 1.15 0.29 -0.11 \n",
"\n",
" Target7 Target8 Target9 \n",
"4403 -0.11 -0.21 -0.70 \n",
"4404 -0.21 -0.70 -0.38 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Adding indicators\n",
"data['Diff_Close'] = data['Adj Close']-data.Open\n",
"data['RSI']=ta.rsi(data.Close, length=14)\n",
"data['EMAF']=ta.ema(data.Close, length=18)\n",
"data['EMAM']=ta.ema(data.Close, length=50)\n",
"data['EMAS']=ta.ema(data.Close, length=200)\n",
"data.ta.bbands(append=True)\n",
"data.ta.stoch(append=True)\n",
"\n",
"\n",
"data['Target1'] = data['Diff_Close'].shift(-1)\n",
"data['Target2'] = data['Diff_Close'].shift(-2)\n",
"data['Target3'] = data['Diff_Close'].shift(-3)\n",
"data['Target4'] = data['Diff_Close'].shift(-4)\n",
"data['Target5'] = data['Diff_Close'].shift(-5)\n",
"data['Target6'] = data['Diff_Close'].shift(-6)\n",
"data['Target7'] = data['Diff_Close'].shift(-7)\n",
"data['Target8'] = data['Diff_Close'].shift(-8)\n",
"data['Target9'] = data['Diff_Close'].shift(-9)\n",
"\n",
"\n",
"data.dropna(inplace=True)\n",
"data.reset_index(inplace = True)\n",
"pd.set_option('display.max_columns', None)\n",
"\n",
"data.tail(2)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a2b0e972",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>index</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Adj Close</th>\n",
" <th>Diff_Close</th>\n",
" <th>RSI</th>\n",
" <th>EMAF</th>\n",
" <th>EMAM</th>\n",
" <th>EMAS</th>\n",
" <th>BBL_5_2.0</th>\n",
" <th>BBM_5_2.0</th>\n",
" <th>BBU_5_2.0</th>\n",
" <th>BBB_5_2.0</th>\n",
" <th>BBP_5_2.0</th>\n",
" <th>STOCHk_14_3_3</th>\n",
" <th>STOCHd_14_3_3</th>\n",
" <th>Target1</th>\n",
" <th>Target2</th>\n",
" <th>Target3</th>\n",
" <th>Target4</th>\n",
" <th>Target5</th>\n",
" <th>Target6</th>\n",
" <th>Target7</th>\n",
" <th>Target8</th>\n",
" <th>Target9</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>199</td>\n",
" <td>38.70</td>\n",
" <td>38.70</td>\n",
" <td>37.97</td>\n",
" <td>32.99</td>\n",
" <td>-5.71</td>\n",
" <td>42.823849</td>\n",
" <td>39.130266</td>\n",
" <td>39.242898</td>\n",
" <td>38.151400</td>\n",
" <td>36.894752</td>\n",
" <td>37.894</td>\n",
" <td>38.893248</td>\n",
" <td>5.273910</td>\n",
" <td>0.568051</td>\n",
" <td>26.968411</td>\n",
" <td>24.878202</td>\n",
" <td>-4.50</td>\n",
" <td>-6.22</td>\n",
" <td>-4.61</td>\n",
" <td>-4.51</td>\n",
" <td>-5.38</td>\n",
" <td>-4.97</td>\n",
" <td>-5.43</td>\n",
" <td>-4.69</td>\n",
" <td>-5.66</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>200</td>\n",
" <td>38.02</td>\n",
" <td>38.73</td>\n",
" <td>37.93</td>\n",
" <td>33.52</td>\n",
" <td>-4.50</td>\n",
" <td>46.633750</td>\n",
" <td>39.079712</td>\n",
" <td>39.219647</td>\n",
" <td>38.156361</td>\n",
" <td>37.773181</td>\n",
" <td>38.240</td>\n",
" <td>38.706819</td>\n",
" <td>2.441522</td>\n",
" <td>0.939142</td>\n",
" <td>29.608675</td>\n",
" <td>27.927078</td>\n",
" <td>-6.22</td>\n",
" <td>-4.61</td>\n",
" <td>-4.51</td>\n",
" <td>-5.38</td>\n",
" <td>-4.97</td>\n",
" <td>-5.43</td>\n",
" <td>-4.69</td>\n",
" <td>-5.66</td>\n",
" <td>-5.36</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>201</td>\n",
" <td>38.44</td>\n",
" <td>38.75</td>\n",
" <td>37.06</td>\n",
" <td>32.22</td>\n",
" <td>-6.22</td>\n",
" <td>39.735192</td>\n",
" <td>38.876584</td>\n",
" <td>39.138484</td>\n",
" <td>38.146348</td>\n",
" <td>37.049243</td>\n",
" <td>38.054</td>\n",
" <td>39.058757</td>\n",
" <td>5.280689</td>\n",
" <td>0.050140</td>\n",
" <td>23.950967</td>\n",
" <td>26.842684</td>\n",
" <td>-4.61</td>\n",
" <td>-4.51</td>\n",
" <td>-5.38</td>\n",
" <td>-4.97</td>\n",
" <td>-5.43</td>\n",
" <td>-4.69</td>\n",
" <td>-5.66</td>\n",
" <td>-5.36</td>\n",
" <td>-5.46</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>202</td>\n",
" <td>37.01</td>\n",
" <td>37.48</td>\n",
" <td>36.46</td>\n",
" <td>32.40</td>\n",
" <td>-4.61</td>\n",
" <td>41.049977</td>\n",
" <td>38.716944</td>\n",
" <td>39.068740</td>\n",
" <td>38.138523</td>\n",
" <td>36.764084</td>\n",
" <td>37.908</td>\n",
" <td>39.051916</td>\n",
" <td>6.035223</td>\n",
" <td>0.260472</td>\n",
" <td>20.792079</td>\n",
" <td>24.783907</td>\n",
" <td>-4.51</td>\n",
" <td>-5.38</td>\n",
" <td>-4.97</td>\n",
" <td>-5.43</td>\n",
" <td>-4.69</td>\n",
" <td>-5.66</td>\n",
" <td>-5.36</td>\n",
" <td>-5.46</td>\n",
" <td>-5.66</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>203</td>\n",
" <td>37.68</td>\n",
" <td>38.44</td>\n",
" <td>37.68</td>\n",
" <td>33.17</td>\n",
" <td>-4.51</td>\n",
" <td>46.333706</td>\n",
" <td>38.666739</td>\n",
" <td>39.036240</td>\n",
" <td>38.139533</td>\n",
" <td>36.773123</td>\n",
" <td>37.886</td>\n",
" <td>38.998877</td>\n",
" <td>5.874874</td>\n",
" <td>0.659047</td>\n",
" <td>20.107004</td>\n",
" <td>21.616683</td>\n",
" <td>-5.38</td>\n",
" <td>-4.97</td>\n",
" <td>-5.43</td>\n",
" <td>-4.69</td>\n",
" <td>-5.66</td>\n",
" <td>-5.36</td>\n",
" <td>-5.46</td>\n",
" <td>-5.66</td>\n",
" <td>-5.34</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" index Open High Low Adj Close Diff_Close RSI EMAF \\\n",
"0 199 38.70 38.70 37.97 32.99 -5.71 42.823849 39.130266 \n",
"1 200 38.02 38.73 37.93 33.52 -4.50 46.633750 39.079712 \n",
"2 201 38.44 38.75 37.06 32.22 -6.22 39.735192 38.876584 \n",
"3 202 37.01 37.48 36.46 32.40 -4.61 41.049977 38.716944 \n",
"4 203 37.68 38.44 37.68 33.17 -4.51 46.333706 38.666739 \n",
"\n",
" EMAM EMAS BBL_5_2.0 BBM_5_2.0 BBU_5_2.0 BBB_5_2.0 \\\n",
"0 39.242898 38.151400 36.894752 37.894 38.893248 5.273910 \n",
"1 39.219647 38.156361 37.773181 38.240 38.706819 2.441522 \n",
"2 39.138484 38.146348 37.049243 38.054 39.058757 5.280689 \n",
"3 39.068740 38.138523 36.764084 37.908 39.051916 6.035223 \n",
"4 39.036240 38.139533 36.773123 37.886 38.998877 5.874874 \n",
"\n",
" BBP_5_2.0 STOCHk_14_3_3 STOCHd_14_3_3 Target1 Target2 Target3 \\\n",
"0 0.568051 26.968411 24.878202 -4.50 -6.22 -4.61 \n",
"1 0.939142 29.608675 27.927078 -6.22 -4.61 -4.51 \n",
"2 0.050140 23.950967 26.842684 -4.61 -4.51 -5.38 \n",
"3 0.260472 20.792079 24.783907 -4.51 -5.38 -4.97 \n",
"4 0.659047 20.107004 21.616683 -5.38 -4.97 -5.43 \n",
"\n",
" Target4 Target5 Target6 Target7 Target8 Target9 \n",
"0 -4.51 -5.38 -4.97 -5.43 -4.69 -5.66 \n",
"1 -5.38 -4.97 -5.43 -4.69 -5.66 -5.36 \n",
"2 -4.97 -5.43 -4.69 -5.66 -5.36 -5.46 \n",
"3 -5.43 -4.69 -5.66 -5.36 -5.46 -5.66 \n",
"4 -4.69 -5.66 -5.36 -5.46 -5.66 -5.34 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.drop(['Volume', 'Close', 'Date'], axis=1, inplace=True)\n",
"data_set = data\n",
"data_set.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b9d38e4c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>index</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Adj Close</th>\n",
" <th>Diff_Close</th>\n",
" <th>RSI</th>\n",
" <th>EMAF</th>\n",
" <th>EMAM</th>\n",
" <th>EMAS</th>\n",
" <th>BBL_5_2.0</th>\n",
" <th>BBM_5_2.0</th>\n",
" <th>BBU_5_2.0</th>\n",
" <th>BBB_5_2.0</th>\n",
" <th>BBP_5_2.0</th>\n",
" <th>STOCHk_14_3_3</th>\n",
" <th>STOCHd_14_3_3</th>\n",
" <th>Target1</th>\n",
" <th>Target2</th>\n",
" <th>Target3</th>\n",
" <th>Target4</th>\n",
" <th>Target5</th>\n",
" <th>Target6</th>\n",
" <th>Target7</th>\n",
" <th>Target8</th>\n",
" <th>Target9</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>4400</th>\n",
" <td>4599</td>\n",
" <td>-0.032629</td>\n",
" <td>-0.035886</td>\n",
" <td>-0.032090</td>\n",
" <td>0.134366</td>\n",
" <td>0.357534</td>\n",
" <td>0.135110</td>\n",
" <td>-0.011064</td>\n",
" <td>-0.023456</td>\n",
" <td>-0.152467</td>\n",
" <td>-0.007310</td>\n",
" <td>-0.014559</td>\n",
" <td>-0.050959</td>\n",
" <td>-0.884847</td>\n",
" <td>-0.358144</td>\n",
" <td>0.330323</td>\n",
" <td>0.488259</td>\n",
" <td>0.279452</td>\n",
" <td>0.383562</td>\n",
" <td>0.335616</td>\n",
" <td>0.264384</td>\n",
" <td>0.401370</td>\n",
" <td>0.443836</td>\n",
" <td>0.447945</td>\n",
" <td>0.538356</td>\n",
" <td>0.420548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4401</th>\n",
" <td>4600</td>\n",
" <td>-0.061550</td>\n",
" <td>-0.068073</td>\n",
" <td>-0.085821</td>\n",
" <td>0.077857</td>\n",
" <td>0.279452</td>\n",
" <td>-0.115170</td>\n",
" <td>-0.015294</td>\n",
" <td>-0.023108</td>\n",
" <td>-0.150777</td>\n",
" <td>-0.047258</td>\n",
" <td>-0.029421</td>\n",
" <td>-0.041646</td>\n",
" <td>-0.762153</td>\n",
" <td>-0.905362</td>\n",
" <td>0.020945</td>\n",
" <td>0.291154</td>\n",
" <td>0.383562</td>\n",
" <td>0.335616</td>\n",
" <td>0.264384</td>\n",
" <td>0.401370</td>\n",
" <td>0.443836</td>\n",
" <td>0.447945</td>\n",
" <td>0.538356</td>\n",
" <td>0.420548</td>\n",
" <td>0.365753</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4402</th>\n",
" <td>4601</td>\n",
" <td>-0.103077</td>\n",
" <td>-0.095819</td>\n",
" <td>-0.092537</td>\n",
" <td>0.062788</td>\n",
" <td>0.383562</td>\n",
" <td>-0.172337</td>\n",
" <td>-0.020589</td>\n",
" <td>-0.023368</td>\n",
" <td>-0.149273</td>\n",
" <td>-0.073525</td>\n",
" <td>-0.046937</td>\n",
" <td>-0.050252</td>\n",
" <td>-0.716935</td>\n",
" <td>-0.709322</td>\n",
" <td>-0.415359</td>\n",
" <td>-0.022089</td>\n",
" <td>0.335616</td>\n",
" <td>0.264384</td>\n",
" <td>0.401370</td>\n",
" <td>0.443836</td>\n",
" <td>0.447945</td>\n",
" <td>0.538356</td>\n",
" <td>0.420548</td>\n",
" <td>0.365753</td>\n",
" <td>0.352055</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4403</th>\n",
" <td>4602</td>\n",
" <td>-0.074527</td>\n",
" <td>-0.078431</td>\n",
" <td>-0.071642</td>\n",
" <td>0.080368</td>\n",
" <td>0.335616</td>\n",
" <td>-0.090165</td>\n",
" <td>-0.023565</td>\n",
" <td>-0.022924</td>\n",
" <td>-0.147587</td>\n",
" <td>-0.082629</td>\n",
" <td>-0.054747</td>\n",
" <td>-0.056542</td>\n",
" <td>-0.709344</td>\n",
" <td>-0.299248</td>\n",
" <td>-0.697161</td>\n",
" <td>-0.372758</td>\n",
" <td>0.264384</td>\n",
" <td>0.401370</td>\n",
" <td>0.443836</td>\n",
" <td>0.447945</td>\n",
" <td>0.538356</td>\n",
" <td>0.420548</td>\n",
" <td>0.365753</td>\n",
" <td>0.352055</td>\n",
" <td>0.284932</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4404</th>\n",
" <td>4603</td>\n",
" <td>-0.092696</td>\n",
" <td>-0.093230</td>\n",
" <td>-0.111194</td>\n",
" <td>0.038091</td>\n",
" <td>0.264384</td>\n",
" <td>-0.247492</td>\n",
" <td>-0.030463</td>\n",
" <td>-0.024164</td>\n",
" <td>-0.146390</td>\n",
" <td>-0.099513</td>\n",
" <td>-0.074158</td>\n",
" <td>-0.077581</td>\n",
" <td>-0.718980</td>\n",
" <td>-0.643250</td>\n",
" <td>-0.818783</td>\n",
" <td>-0.659348</td>\n",
" <td>0.401370</td>\n",
" <td>0.443836</td>\n",
" <td>0.447945</td>\n",
" <td>0.538356</td>\n",
" <td>0.420548</td>\n",
" <td>0.365753</td>\n",
" <td>0.352055</td>\n",
" <td>0.284932</td>\n",
" <td>0.328767</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" index Open High Low Adj Close Diff_Close RSI \\\n",
"4400 4599 -0.032629 -0.035886 -0.032090 0.134366 0.357534 0.135110 \n",
"4401 4600 -0.061550 -0.068073 -0.085821 0.077857 0.279452 -0.115170 \n",
"4402 4601 -0.103077 -0.095819 -0.092537 0.062788 0.383562 -0.172337 \n",
"4403 4602 -0.074527 -0.078431 -0.071642 0.080368 0.335616 -0.090165 \n",
"4404 4603 -0.092696 -0.093230 -0.111194 0.038091 0.264384 -0.247492 \n",
"\n",
" EMAF EMAM EMAS BBL_5_2.0 BBM_5_2.0 BBU_5_2.0 \\\n",
"4400 -0.011064 -0.023456 -0.152467 -0.007310 -0.014559 -0.050959 \n",
"4401 -0.015294 -0.023108 -0.150777 -0.047258 -0.029421 -0.041646 \n",
"4402 -0.020589 -0.023368 -0.149273 -0.073525 -0.046937 -0.050252 \n",
"4403 -0.023565 -0.022924 -0.147587 -0.082629 -0.054747 -0.056542 \n",
"4404 -0.030463 -0.024164 -0.146390 -0.099513 -0.074158 -0.077581 \n",
"\n",
" BBB_5_2.0 BBP_5_2.0 STOCHk_14_3_3 STOCHd_14_3_3 Target1 Target2 \\\n",
"4400 -0.884847 -0.358144 0.330323 0.488259 0.279452 0.383562 \n",
"4401 -0.762153 -0.905362 0.020945 0.291154 0.383562 0.335616 \n",
"4402 -0.716935 -0.709322 -0.415359 -0.022089 0.335616 0.264384 \n",
"4403 -0.709344 -0.299248 -0.697161 -0.372758 0.264384 0.401370 \n",
"4404 -0.718980 -0.643250 -0.818783 -0.659348 0.401370 0.443836 \n",
"\n",
" Target3 Target4 Target5 Target6 Target7 Target8 Target9 \n",
"4400 0.335616 0.264384 0.401370 0.443836 0.447945 0.538356 0.420548 \n",
"4401 0.264384 0.401370 0.443836 0.447945 0.538356 0.420548 0.365753 \n",
"4402 0.401370 0.443836 0.447945 0.538356 0.420548 0.365753 0.352055 \n",
"4403 0.443836 0.447945 0.538356 0.420548 0.365753 0.352055 0.284932 \n",
"4404 0.447945 0.538356 0.420548 0.365753 0.352055 0.284932 0.328767 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.preprocessing import MinMaxScaler\n",
"sc = MinMaxScaler(feature_range=(-1,1))\n",
"\n",
"df_scaled = sc.fit_transform(data_set.to_numpy())\n",
"data_set_scaled_pd = pd.DataFrame(df_scaled, columns=data_set.columns.tolist())\n",
"\n",
"\n",
"\n",
"data_set_scaled_pd['index'] = data_set['index']\n",
"\n",
"data_set_scaled_pd.tail()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "99ca74fd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Length of Data 4405\n",
"X Shape: (4385, 20, 16)\n",
"Y Shape: (4385, 9)\n"
]
}
],
"source": [
"X = []\n",
"backcandles = 20\n",
"\n",
"\n",
"data_set_scaled = data_set_scaled_pd.to_numpy()\n",
"\n",
"print(\"Length of Data\", data_set_scaled.shape[0])\n",
"\n",
"features = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]\n",
"#features = [5]\n",
"feature_count = len(features)\n",
"\n",
"it = 0\n",
"for j in features:\n",
" X.append([])\n",
" for i in range(backcandles, data_set_scaled.shape[0]):\n",
" X[it].append(data_set_scaled[i-backcandles:i, j])\n",
" it += 1\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"#move axis from 0 to position 2\n",
"X=np.moveaxis(X, [0], [2])\n",
"\n",
"X = np.array(X)\n",
"\n",
"yi = np.array(data_set_scaled[backcandles:, -9:])\n",
"y=yi\n",
"\n",
"print(\"X Shape:\", X.shape)\n",
"print(\"Y Shape:\", y.shape)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a2a87918",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4385\n",
"(4385, 20, 16)\n",
"(0, 20, 16)\n",
"(4385, 9)\n",
"(0, 9)\n"
]
}
],
"source": [
"# split data into train test sets\n",
"splitlimit = int(len(X)*1)\n",
"print(splitlimit)\n",
"X_train, X_test = X[:splitlimit], X[splitlimit:]\n",
"y_train, y_test = y[:splitlimit], y[splitlimit:]\n",
"print(X_train.shape)\n",
"print(X_test.shape)\n",
"print(y_train.shape)\n",
"print(y_test.shape)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9867161a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-10-07 20:29:09.720660: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-10-07 20:29:09.737019: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-10-07 20:29:09.741871: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-10-07 20:29:09.754590: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-10-07 20:29:10.373903: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(4385, 20, 16)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1728347350.828378 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347350.874002 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347350.878920 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347350.883919 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347350.887319 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347350.891213 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347351.030925 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347351.032497 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1728347351.033977 332220 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"2024-10-07 20:29:11.035404: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5325 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3050, pci bus id: 0000:2d:00.0, compute capability: 8.6\n",
"/home/brickman/miniconda3/envs/stock/lib/python3.10/site-packages/keras/src/layers/rnn/rnn.py:204: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
" super().__init__(**kwargs)\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">150</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">100,200</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">4,832</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_layer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">297</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m150\u001b[0m) │ \u001b[38;5;34m100,200\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m4,832\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_layer (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m297\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">105,329</span> (411.44 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m105,329\u001b[0m (411.44 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">105,329</span> (411.44 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m105,329\u001b[0m (411.44 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-10-07 20:29:12.885047: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n"
]
}
],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import LSTM\n",
"from keras.layers import Dropout\n",
"from keras.layers import Dense\n",
"from keras.layers import TimeDistributed\n",
"\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import optimizers\n",
"from keras.callbacks import History\n",
"from keras.models import Model\n",
"from keras.layers import Dense, Dropout, LSTM, Input, Activation, concatenate\n",
"import numpy as np\n",
"\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras import layers\n",
"\n",
"np.random.seed(10)\n",
"print(X_train.shape)\n",
"\n",
"model = Sequential([layers.LSTM(150, input_shape=(backcandles, feature_count), activation='tanh'),\n",
" layers.Dense(32, activation='relu'),\n",
" layers.Dense(9, name='dense_layer', activation='tanh')])\n",
"\n",
"\n",
"model.compile(loss='mse', \n",
" optimizer=Adam(learning_rate=0.001),\n",
" metrics=['mean_absolute_error'])\n",
"\n",
"model.summary()\n",
"\n",
"epochs=2400\n",
"history = model.fit(x=X_train, y=y_train, batch_size=15, epochs=epochs, shuffle=True, validation_split = 0.2, verbose=0)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7e2b61f1-032f-43b5-a457-d6835076b7f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<contextlib.ExitStack at 0x7cb97803ba30>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6AAAAHBCAYAAAB38tZnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADH+klEQVR4nOzdeVhU1RsH8O+wDjsIyKIguOOuYIZGahrm0qplZi65lFGZmr/S1LQ0TTMjyyVNM9PUSisrMtHULMncK0WzRHGBEFQQkf3+/jjMvjAzDMwA38/zzONw59x73wHkznvPOe+RSZIkgYiIiIiIiKiaOdg6ACIiIiIiIqofmIASERERERFRjWACSkRERERERDWCCSgRERERERHVCCagREREREREVCOYgBIREREREVGNYAJKRERERERENYIJKBEREREREdUIJqBERERERERUI5iA1kIymcykx969e6t0njlz5kAmk1m07969e60Sg70bPXo0IiIiDL5+9epVuLi44PHHHzfYJi8vD+7u7njggQdMPu+6desgk8lw/vx5k2NRJ5PJMGfOHJPPp3DlyhXMmTMHx48f13mtKr8v1lJSUoLg4GDIZDJ8+eWXNo2FiOo+Xo/tB6/HKra8HkdERGDQoEE2OTfVHk62DoDMl5KSovH13LlzsWfPHvz0008a29u0aVOl84wbNw733XefRft26dIFKSkpVY6htgsMDMQDDzyAr7/+GtevX4efn59Om82bN+P27dsYO3Zslc41a9YsvPjii1U6RmWuXLmC119/HREREejUqZPGa1X5fbGW7777Dv/99x8AYM2aNRgyZIhN4yGiuo3X49qD12Mi+8EEtBa68847Nb4ODAyEg4ODznZtBQUFcHd3N/k8jRs3RuPGjS2K0dvbu9J46ouxY8di69at2LhxI55//nmd19euXYugoCAMHDiwSudp1qxZlfavqqr8vljLmjVr4OLigp49e2Lnzp24dOmSzWPSp6ysDKWlpXB1dbV1KERUBbwe1y68HhPZBw7BraN69eqFdu3a4eeff0b37t3h7u6OMWPGAAC2bNmC+Ph4hISEwM3NDVFRUZg2bRpu3bqlcQx9QzgUQyt27NiBLl26wM3NDa1bt8batWs12ukb8jN69Gh4enrin3/+wYABA+Dp6YmwsDC89NJLKCoq0tj/0qVLGDJkCLy8vODr64vhw4fj0KFDkMlkWLdundH3fvXqVSQkJKBNmzbw9PREw4YNcc8992D//v0a7c6fPw+ZTIbFixdjyZIliIyMhKenJ2JjY/Hbb7/pHHfdunVo1aoVXF1dERUVhfXr1xuNQ6Ffv35o3LgxPv74Y53XUlNTcfDgQYwcORJOTk5ITk7Ggw8+iMaNG0Mul6N58+Z45plnkJ2dXel59A35ycvLw/jx4+Hv7w9PT0/cd999+Pvvv3X2/eeff/DUU0+hRYsWcHd3R6NGjXD//ffjzz//VLbZu3cvunbtCgB46qmnlEPLFEOH9P2+lJeXY9GiRWjdujVcXV3RsGFDjBw5EpcuXdJop/h9PXToEOLi4uDu7o6mTZvirbfeQnl5eaXvHRB3g3fs2IH7778f//vf/1BeXm7wd+Wzzz5DbGwsPD094enpiU6dOmHNmjUabXbs2IE+ffrAx8cH7u7uiIqKwoIFCzRi7tWrl86xtX8Oit+zRYsWYd68eYiMjISrqyv27NmDwsJCvPTSS+jUqRN8fHzQoEEDxMbG4ptvvtE5bnl5Od5//3106tQJbm5u8PX1xZ133ont27cDEB+sGjRogIKCAp1977nnHrRt29aE7yIRWRuvx7weA/XrelyZwsJCTJ8+HZGRkXBxcUGjRo3w3HPP4caNGxrtfvrpJ/Tq1Qv+/v5wc3NDeHg4Bg8erHGdW7FiBTp27AhPT094eXmhdevWePXVV60SJ1UfJqB1WEZGBp588kk88cQTSEpKQkJCAgDg7NmzGDBgANasWYMdO3Zg0qRJ+Pzzz3H//febdNwTJ07gpZdewuTJk/HNN9+gQ4cOGDt2LH7++edK9y0pKcEDDzyAPn364JtvvsGYMWPw7rvvYuHChco2t27dQu/evbFnzx4sXLgQn3/+OYKCgjB06FCT4rt27RoAYPbs2fj+++/x8ccfo2nTpujVq5feOTDLli1DcnIyEhMTsXHjRty6dQsDBgxAbm6uss26devw1FNPISoqClu3bsXMmTMxd+5cnWFW+jg4OGD06NE4evQoTpw4ofGa4iKo+DDy77//IjY2FitWrMDOnTvx2muv4eDBg7jrrrtQUlJi0vtXkCQJDz30ED799FO89NJL+Oqrr3DnnXeif//+Om2vXLkCf39/vPXWW9ixYweWLVsGJycndOvWDWfOnAEghnEp4p05cyZSUlKQkpKCcePGGYzh2WefxSuvvIJ7770X27dvx9y5c7Fjxw50795d5yKemZmJ4cOH48knn8T27dvRv39/TJ8+HRs2bDDp/a5btw5lZWUYM2YM+vbtiyZNmmDt2rWQJEmj3WuvvYbhw4cjNDQU69atw1dffYVRo0bhwoULyjZr1qzBgAEDUF5ejpUrV+Lbb7/FxIkTdS7U5li6dCl++uknLF68GD/88ANat26NoqIiXLt2DVOnTsXXX3+NTZs24a677sIjjzyi84Fq9OjRePHFF9G1a1ds2bIFmzdvxgMPPKCcd/Tiiy/i+vXr+OyzzzT2O3XqFPbs2YPnnnvO4tiJqGp4Peb1uD5dj035XixevBgjRozA999/jylTpuCTTz7BPffco7wBcv78eQwcOBAuLi5Yu3YtduzYgbfeegseHh4oLi4GIIZMJyQkoGfPnvjqq6/w9ddfY/LkyTo3cMgOSVTrjRo1SvLw8NDY1rNnTwmAtHv3bqP7lpeXSyUlJdK+ffskANKJEyeUr82ePVvS/hVp0qSJJJfLpQsXLii33b59W2rQoIH0zDPPKLft2bNHAiDt2bNHI04A0ueff65xzAEDBkitWrVSfr1s2TIJgPTDDz9otHvmmWckANLHH39s9D1pKy0tlUpKSqQ+ffpIDz/8sHJ7WlqaBEBq3769VFpaqtz++++/SwCkTZs2SZIkSWVlZVJoaKjUpUsXqby8XNnu/PnzkrOzs9SkSZNKYzh37pwkk8mkiRMnKreVlJRIwcHBUo8ePfTuo/jZXLhwQQIgffPNN8rXPv74YwmAlJaWptw2atQojVh++OEHCYD03nvvaRz3zTfflABIs2fPNhhvaWmpVFxcLLVo0UKaPHmycvuhQ4cM/gy0f19SU1MlAFJCQoJGu4MHD0oApFdffVW5TfH7evDgQY22bdq0kfr162cwToXy8nKpefPmUqNGjZQ/S0U86v8Hzp07Jzk6OkrDhw83eKybN29K3t7e0l133aXx89bWs2dPqWfPnjrbtX8Oit+zZs2aScXFxUbfh+J3dezYsVLnzp2V23/++WcJgDRjxgyj+/fs2VPq1KmTxrZnn31W8vb2lm7evGl0XyKqOl6PjeP1uO5fj5s0aSINHDjQ4Os7duyQAEiLFi3S2L5lyxYJgLRq1SpJkiTpyy+/lABIx48fN3is559/XvL19a00JrI/7AGtw/z8/HDPPffobD937hyeeOIJBAcHw9HREc7OzujZsycAMQSlMp06dUJ4eLjya7lcjpYtW2r0IBkik8l07ux26NBBY999+/bBy8tLZwL9sGHDKj2+wsqVK9GlSxfI5XI4OTnB2dkZu3fv1vv+Bg4cCEdHR414AChjOnPmDK5cuYInnnhCY0hLkyZN0L17d5PiiYyMRO/evbFx40blnbsffvgBmZmZyrutAJCVlYUJEyYgLCxMGXeTJk0AmPazUbdnzx4AwPDhwzW2P/HEEzptS0tLMX/+fLRp0wYuLi5wcnKCi4sLzp49a/Z5tc8/evRoje133HEHoqKisHv3bo3twcHBuOOOOzS2af9uGLJv3z78888/GDVqlPJnqRiWpD4cLTk5GWVlZUZ7Aw8cOIC8vDwkJCRYtYrgAw88AGdnZ53tX3zxBXr06AFPT0/lz3zNmjUa3/cffvgBACrtxXzxxRd
"text/plain": [
"<Figure size 1100x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"acc = history.history['mean_absolute_error']\n",
"val_acc = history.history['val_mean_absolute_error']\n",
"\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"\n",
"epochs_range = range(epochs)\n",
"\n",
"plt.figure(figsize=(11, 5))\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(epochs_range, acc, label='Training Accuracy')\n",
"plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n",
"plt.legend(loc='lower right')\n",
"plt.title('Training and Validation Accuracy')\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(epochs_range, loss, label='Training Loss')\n",
"plt.plot(epochs_range, val_loss, label='Validation Loss')\n",
"plt.legend(loc='upper right')\n",
"plt.title('Training and Validation Loss')\n",
"plt.ion()\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "08324ede",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m138/138\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step\n",
"Pridicted: [-0.4036709 -0.46461192 -0.3590175 -0.28324017 -0.34823442 -0.43334588\n",
" -0.42882568 -0.41436017 -0.403928 ] Real value: [-0.40136986 -0.46164384 -0.35342466 -0.27945205 -0.34931507 -0.43013699\n",
" -0.42739726 -0.41369863 -0.39726027]\n",
"Pridicted: [-0.46557218 -0.35770795 -0.2813072 -0.3510212 -0.42975435 -0.43653348\n",
" -0.40678245 -0.4031321 -0.41870612] Real value: [-0.46164384 -0.35342466 -0.27945205 -0.34931507 -0.43013699 -0.42739726\n",
" -0.41369863 -0.39726027 -0.41780822]\n",
"Pridicted: [-0.35209465 -0.28051785 -0.34810475 -0.43110624 -0.43116555 -0.41247454\n",
" -0.39174688 -0.42331287 -0.42454508] Real value: [-0.35342466 -0.27945205 -0.34931507 -0.43013699 -0.42739726 -0.41369863\n",
" -0.39726027 -0.41780822 -0.42328767]\n",
"Pridicted: [-0.28172114 -0.34396863 -0.4297893 -0.4320941 -0.41340354 -0.39519808\n",
" -0.41816622 -0.42309865 -0.3611399 ] Real value: [-0.27945205 -0.34931507 -0.43013699 -0.42739726 -0.41369863 -0.39726027\n",
" -0.41780822 -0.42328767 -0.36027397]\n",
"Pridicted: [-0.34929985 -0.42562777 -0.433488 -0.41463438 -0.39713734 -0.42165747\n",
" -0.42167172 -0.3633375 -0.41996527] Real value: [-0.34931507 -0.43013699 -0.42739726 -0.41369863 -0.39726027 -0.41780822\n",
" -0.42328767 -0.36027397 -0.42054795]\n",
"Pridicted: [-0.42858496 -0.4267884 -0.40962404 -0.39073083 -0.42235476 -0.4238721\n",
" -0.3564185 -0.42203316 -0.3270994 ] Real value: [-0.43013699 -0.42739726 -0.41369863 -0.39726027 -0.41780822 -0.42328767\n",
" -0.36027397 -0.42054795 -0.32739726]\n",
"Pridicted: [-0.43434742 -0.40459284 -0.38885665 -0.4140891 -0.42647222 -0.36434662\n",
" -0.4164863 -0.32749146 -0.4512253 ] Real value: [-0.42739726 -0.41369863 -0.39726027 -0.41780822 -0.42328767 -0.36027397\n",
" -0.42054795 -0.32739726 -0.45205479]\n",
"Pridicted: [-0.41408408 -0.3860175 -0.4147621 -0.42086872 -0.36401656 -0.41890422\n",
" -0.32657167 -0.45092446 -0.3550458 ] Real value: [-0.41369863 -0.39726027 -0.41780822 -0.42328767 -0.36027397 -0.42054795\n",
" -0.32739726 -0.45205479 -0.35342466]\n",
"Pridicted: [-0.39207783 -0.4143812 -0.42082974 -0.36102867 -0.42305246 -0.32819805\n",
" -0.44679096 -0.35819703 -0.18476132] Real value: [-0.39726027 -0.41780822 -0.42328767 -0.36027397 -0.42054795 -0.32739726\n",
" -0.45205479 -0.35342466 -0.18767123]\n",
"Pridicted: [-0.41665655 -0.41635478 -0.36310774 -0.41913277 -0.33059776 -0.4508693\n",
" -0.35423303 -0.18578593 -0.30750462] Real value: [-0.41780822 -0.42328767 -0.36027397 -0.42054795 -0.32739726 -0.45205479\n",
" -0.35342466 -0.18767123 -0.31643836]\n",
"\n",
"[0.40136986 0.44383562 0.44794521 0.53835616 0.42054795 0.36575342\n",
" 0.35205479 0.28493151 0.32876712]\n"
]
}
],
"source": [
"y_pred = model.predict(X_train)\n",
"for i in range(10):\n",
" print(\"Pridicted:\", y_pred[i], \"Real value:\", y_train[i])\n",
"print()\n",
"print(f\"{y_train[-1]}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "67f5e31a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<contextlib.ExitStack at 0x7cb970beb0d0>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAzoAAAGsCAYAAAAVEdLDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB3FElEQVR4nO3dd3hUVeLG8e9MeiGhhwAJvYTeey+hSQ09AiqIKKDIrj/FiopiWduiICiiSGgauvTeBOm91wChQwLpZX5/XAmygCaQ5CaT9/M88+DcO+XNLEvy5px7jsVms9kQERERERGxI1azA4iIiIiIiKQ3FR0REREREbE7KjoiIiIiImJ3VHRERERERMTuqOiIiIiIiIjdUdERERERERG7o6IjIiIiIiJ2x9HsAKmRnJzMhQsXyJUrFxaLxew4IiIiIiJiEpvNxq1btyhcuDBW68PHbbJF0blw4QJ+fn5mxxARERERkSwiLCyMokWLPvR8tig6uXLlAowvxsvLy+Q0IiIiIiJilsjISPz8/FI6wsNki6JzZ7qal5eXio6IiIiIiPzjJS1ajEBEREREROyOio6IiIiIiNgdFR0REREREbE72eIaHRERERGR7CYpKYmEhASzY2Q7Tk5OODg4PPbrqOiIiIiIiKQjm83GxYsXuXnzptlRsq3cuXNTqFChx9pDU0VHRERERCQd3Sk5BQsWxN3dXRvep4HNZiM6OprLly8D4Ovr+8ivpaIjIiIiIpJOkpKSUkpOvnz5zI6TLbm5uQFw+fJlChYs+MjT2LQYgYiIiIhIOrlzTY67u7vJSbK3O5/f41zjpKIjIiIiIpLONF3t8aTH56eiIyIiIiIidkdFR0RERERE7I6KjoiIiIiIZIjTp09jsVjYvXt3pr+3io6IiIiIiNgdFR0RkfSSGAURh81OISIiki7i4+PNjvBYVHRERNLDpbWwqDz8FgB73gSbzexEIiKSVdhsxi/DMvuWxu9FzZo1Y9iwYYwcOZL8+fPTunVrDh48SPv27fH09MTHx4d+/fpx9erVlOcsXbqURo0akTt3bvLly8cTTzzBiRMn0vsTfCTaMFRE5HEkxcO+d+Dgx8Cf31AOfABx16DW12B9tE3ORETEjiRFw2zPzH/fnrfB0SNNT/npp594/vnn2bRpE9evX6dp06Y8++yzfP7558TExPDqq6/Ss2dPVq9eDUBUVBQjR46kcuXKREVF8fbbb9O1a1d2796N1WrumIqKjojIo4o8ApuD4foO436pQeBdCXa+DMe/hfgbUH8qODibm1NERCSVSpcuzSeffALA22+/TY0aNfjwww9Tzv/www/4+flx9OhRypYtS1BQ0D3Pnzx5MgULFuTgwYNUqlQpU7P/LxUdEZG0stngxPewY4TxWzrnvFD3O/DrZpx3KwSbn4SzsyAhAhr/mubfqImIiB1xcDdGV8x43zSqVatWyn/v2LGDNWvW4Ol5/2jUiRMnKFu2LCdOnOCtt95iy5YtXL16leTkZADOnj2roiMikq3EXoU/noVz84z7Pi2h/k/gXuTuY4r1Aidv2NANwpfC6kBotgic85gSWURETGaxZJtfeHl43M2ZnJxMx44d+fjjj+97nK+vLwAdO3bEz8+P7777jsKFC5OcnEylSpWyxEIGKjoiIqkVvgK2DICYcLA6QdWxUP5lsDxgDnLhttBiJaztAFc3w8qm0HwZuPlmfm4REZFHUKNGDUJDQylevDiOjvfXhmvXrnHo0CEmTpxI48aNAdi4cWNmx3worbomIvJPkmJhx0hYE2iUHK8AaPMHBPzrwSXnjgINoNU6cC0EN/fBikZw+2Tm5RYREXkMQ4cO5fr16/Tp04c//viDkydPsnz5cp555hmSkpLIkycP+fLlY9KkSRw/fpzVq1czcuRIs2OnUNEREfk7Nw/Asrpw5Avjfpmh0HY75KmWuufnqQKtN4JnSaPkLG9olB4REZEsrnDhwmzatImkpCTatGlDpUqVeOmll/D29sZqtWK1Wpk5cyY7duygUqVKvPzyy3z66admx05hsdmy/mYPkZGReHt7ExERgZeXl9lxRCQnsNng6Dew+xVjRMelANT7AYo88WivFxNuXKsTsR+cckOz34wRHxERsSuxsbGcOnWKEiVK4OrqanacbOvvPsfUdgON6IiI/K+YS7DuCdgx3Cg5vu2g/b5HLzlgXJvTej3krw8JN2F1a7iwLN0ii4iIyL1UdERE/ur8b7C4MlxYDFYXqDnOGH1x83n813bOAy1WgG8bY1nq9R3hzKzHf10RERG5j4qOiAhAYgxsG2aM5MRdgdxVjGtxyg0zlgVNL44e0GQB+PeC5ATY1AeOfZt+ry8iIiKAlpcWEYEbu2FTX4g8ZNwv9zJU+xAcMmhutYMzNAgxRniOfwvbnof461BhVPqWKhERkRxMRUdEci5bMhz+EvaMguR4Yxno+j+Bb2DGv7fVAWqPB5d8cOAD2PMGxF2D6p/+/ZLVIiIikioqOiKSM0VfMDb/vLjSuF+0M9T5HlzzZ14GiwWqjjHKzs6RcPhzY2Snzndg1T/PIiIij0PfSUUk5wmbC1sHGaXCwR1qfgGlnjVv2lj5l41pbFsHwckfIf4mNJyRcVPnREREcgDNjxCRnCPhNmx9FjZ0M0pO3prQbieUHmz+tTEln4LGocZKb+fmwdr2kBBpbiYREZFsTEVHRHKGa9tgaQ048T1ggQqvQevN4FXO7GR3Fe0MzZeAYy64tAZWtYDYK2anEhERSXfFixfnyy+/zND3UNEREfuWnAQHxsLyBnDrGLgXhZarodpYY/WzrManObRaAy754foOWNkYosLMTiUiIpLtqOiIiP2KOgurW8Ce18GWCP49oP1e8GlmdrK/l7cmtNpglLLII7CiofGniIhIFhIfH292hL+loiMi9unMLFhcBS6vB0dPqPcjNJxlXPSfHXiXh9abjKl10WGwopExwiMiIpJBmjVrxrBhwxg2bBi5c+cmX758vPnmm9hsNsCYbjZmzBieeuopvL29efbZZwHYvHkzTZo0wc3NDT8/P1588UWioqJSXvfy5ct07NgRNzc3SpQoQUhISKZ8PVp1TUTsS0IkbB8Op6Ya9/PVNTbnzFXK3FyPwsPfGNlZ0xZu7ISVzaHpgqw/IiUiIvew2WxEJ0Rn+vu6O7ljSeNiOz/99BMDBw5k69atbN++ncGDB1OsWLGUUvPpp5/y1ltv8eabbwKwb98+2rRpw/vvv8/kyZO5cuVKSlmaMmUKAE899RRhYWGsXr0aZ2dnXnzxRS5fvpy+X+wDWGx3KloWFhkZibe3NxEREXh5eZkdR0Syqiu/w+ZgiDplbLpZ8U2o9CZYncxO9ngSImFdZ7i81liVrdEsY+ECERHJcmJjYzl16hQlSpTA1dXYJiAqPgrPsZ6ZnuX2qNt4OHuk+vHNmjXj8uXLHDhwIKUgvfbaayxYsICDBw9SvHhxqlevzty5c1Oe079/f9zc3Jg4cWLKsY0bN9K0aVOioqI4e/Ys5cqVY8uWLdStWxeAw4cPExAQwBdffMGIESMemOVBn+Mdqe0GmromItlfciLsHf3nhfunwKM4tFoPVd7N/iUHwMnLWI2tSCdIjoMNQXDyJ7NTiYiIHapXr949o0D169fn2LFjJCUlAVCrVq17Hr9jxw5+/PFHPD09U25t2rQhOTmZU6dOcejQIRwdHe95Xvny5cmdO3eGfy2PNHVt/PjxfPrpp4SHh1OxYkW+/PJLGjdu/MDHrl27lubNm993/NChQ5QvX/5R3l5E5K7bJ2Hzk3D1d+N+8X5Qaxw4e5ubK705uBr77GwdBKd+gi1PQfwNKD/C7GQiIvIP3J3cuT3qtinvm948PO4dIUpOTua5557jxRdfvO+x/v7+HDliLKaT1il06SHNRWfWrFmMGDGC8ePH07BhQyZOnEi7du04ePAg/v7+D33ekSNH7hlaKlCgwKMlFhEBsNng1M+wfRgk3gInb6g9AYr3MTtZxrE6Qr0fwDkvHPkCdr4Mcdegynvmb3gqIiIPZbFY0jSFzExbtmy5736ZMmVwcHB44ONr1Kj
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,5))\n",
"plt.plot(y_train[-1], color = 'orange', label = 'real')\n",
"plt.plot(y_pred[-1], color = 'green', label = 'pred')\n",
"plt.legend()\n",
"plt.ion()"
]
},
{
"cell_type": "markdown",
"id": "b76d089d-356a-4f9c-9a94-3d00671e8f4d",
"metadata": {},
"source": [
"# Save Model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "71f7fd19-5a06-40ae-bc25-fab6e2aa08bc",
"metadata": {},
"outputs": [],
"source": [
"model.save(\"models/\" + name + \".keras\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}