码迷,mamicode.com
首页 > 其他好文 > 详细

第六讲 循环神经网络--LSTM--stock

时间:2020-05-13 23:15:12      阅读:106      评论:0      收藏:0      [点我收藏+]

标签:plot   pyplot   check   time   ESS   repr   sha   sum   class   

  1 !pip install tushare
  2 import tushare as ts
  3 import numpy as np
  4 import tensorflow as tf
  5 from tensorflow.keras.layers import Dropout, Dense, LSTM
  6 import matplotlib.pyplot as plt
  7 import os
  8 import pandas as pd
  9 from sklearn.preprocessing import MinMaxScaler
 10 from sklearn.metrics import mean_squared_error, mean_absolute_error
 11 import math
 12 
 13 
 14 df1 = ts.get_k_data(600519, ktype=D, start=2004-01-01, end=2020-05-12)
 15 
 16 datapath1 = "./SH600519.csv"
 17 df1.to_csv(datapath1)
 18 
 19 
 20 maotai = pd.read_csv("./SH600519.csv")
 21 
 22 maotai.head()
 23 
 24 
 25 maotai.tail()
 26 
 27 
 28 training_set = maotai.iloc[0:3000, 2:3].values
 29 test_set = maotai.iloc[3000:, 2:3].values
 30 
 31 #归一化
 32 sc = MinMaxScaler(feature_range = (0, 1))
 33 training_set_scaled = sc.fit_transform(training_set)
 34 test_set = sc.transform(test_set)
 35 
 36 training_set_scaled.shape
 37 
 38 test_set.shape
 39 
 40 x_train = []
 41 y_train = []
 42 
 43 x_test = []
 44 y_test = []
 45 
 46 
 47 for i in range(60, len(training_set_scaled)):
 48   x_train.append(training_set_scaled[i - 60:i, 0])
 49   y_train.append(training_set_scaled[i, 0])
 50 
 51 np.random.seed(7)
 52 np.random.shuffle(x_train)
 53 np.random.seed(7)
 54 np.random.shuffle(y_train)
 55 tf.random.set_seed(7)
 56 
 57 
 58 x_train, y_train = np.array(x_train), np.array(y_train)
 59 
 60 x_train.shape
 61 y_train.shape
 62 
 63 
 64 x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
 65 for i in range(60, len(test_set)):
 66   x_test.append(test_set[i-60:i, 0])
 67   y_test.append(test_set[i, 0])
 68 
 69 x_test, y_test = np.array(x_test), np.array(y_test)
 70 x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))
 71 
 72 
 73 model = tf.keras.Sequential([
 74         LSTM(80, return_sequences=True),
 75         Dropout(0.2),
 76         LSTM(100),
 77         Dropout(0.2),
 78         Dense(1)
 79 ])
 80 
 81 model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
 82               loss=mean_squared_error)
 83 
 84 checkpoint_save_path = "./checkpoint/LSTM_stock.ckpt"
 85 
 86 if os.path.exists(checkpoint_save_path + .index):
 87   print(-------------load the model-------------)
 88   model.load_weights(checkpoint_save_path)
 89 
 90 cp_callback = tf.keras.callbacks.ModelCheckpoint(
 91     filepath=checkpoint_save_path,
 92     save_weights_only=True,
 93     save_best_only=True,
 94     monitor=val_loss)
 95 
 96 history = model.fit(x_train, y_train, batch_size=64, epochs=24, 
 97                     validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])
 98 
 99 model.summary()
100 
101 
102 
103 with open("./weights.txt", "w") as f:
104   for v in model.trainable_variables:
105     f.write(str(v.name) + \n)
106     f.write(str(v.shape) + \n)
107     f.write(str(v.numpy()) + \n)
108 
109 
110 loss = history.history[loss]
111 val_loss = history.history[val_loss]
112 
113 plt.plot(loss, label=Training Loss)
114 plt.plot(val_loss, label=Validation Loss)
115 plt.title(Training and Validation Loss)
116 plt.legend()
117 plt.show()
118 
119 
120 predicted_stock_price = model.predict(x_test)
121 predicted_stock_price = sc.inverse_transform(predicted_stock_price)
122 real_stock_price = sc.inverse_transform(test_set[60:])
123 
124 plt.plot(real_stock_price, color=red, label=real_stock_price)
125 plt.plot(predicted_stock_price, color=blue, label=predicted_stock_price)
126 plt.title(Maotai Stock Price Prediction)
127 plt.xlabel(Time)
128 plt.ylabel(Maotai Stock Price)
129 plt.legend()
130 plt.show()
131 
132 
133 mse = mean_squared_error(predicted_stock_price, real_stock_price)
134 rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price))
135 mae = mean_absolute_error(predicted_stock_price, real_stock_price)
136 print(均方误差: %.6f%mse)
137 print(均方根误差: %.6f%rmse)
138 print(平均绝对误差: %.6f%mae)

 

第六讲 循环神经网络--LSTM--stock

标签:plot   pyplot   check   time   ESS   repr   sha   sum   class   

原文地址:https://www.cnblogs.com/wbloger/p/12885496.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!