我正在构建一个LSTM模型,并用我在kaggle上找到的TSLA数据集进行训练。所以我的问题是,当我调用model.predict时,这个预测是否给出了下一天的股票价格?这是不是一个一步预测?当我打印model.predict时,我得到一个巨大的列表,所以我使用numpy的argmax函数来得到一个数字。以下是代码:
import tensorflow as tf from tensorflow.keras.layers import LSTM, Dense, Dropout, Input, GlobalMaxPooling1Dimport numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScalerfrom tensorflow.keras.optimizers import Adamdf = pd.read_csv('TSLA.csv')series = df['Close'].values.reshape(-1, 1)scaler = StandardScaler()scaler.fit(series[:len(series)//2])series = scaler.transform(series).flatten()X = []Y = []T = 10D = 1for t in range(len(series) - T): X.append(series[t:t+T]) Y.append(series[t+T])X = np.array(X).reshape(-1, T, D)Y = np.array(Y)N = len(X)print(X.shape, Y.shape)model = tf.keras.Sequential([ Input(shape=(T, D)), LSTM(50), Dense(100, activation='relu'), Dropout(0.25), Dense(1)])model.compile(optimizer=Adam(lr=0.01), loss='mse')r = model.fit(X[:-N//2], Y[:-N//2], validation_data=(X[-N//2:], Y[-N//2:]), epochs=200)plt.plot(r.history['loss'])plt.plot(r.history['val_loss'])plt.show()preds = model.predict(X)outs = preds[:,0]print(outs)print(np.argmax(outs))
回答:
在这里使用argmax是不合适的。这90个值是训练集中下一天的90个预测。当你运行以下代码时:
preds = model.predict(X)
它会为你训练集中的所有90个数据点提供下一天的值。这行代码:
print(np.argmax(outs))
是没有意义的。
顺便说一下,你可以用Python获取股票价格,不需要CSV文件。
pip install pandas-datareader
from pandas_datareader import data as wbticker=wb.DataReader('TSLA',start='2015-1-1',data_source='yahoo')print(ticker)
High Low ... Volume Adj CloseDate ... 2015-01-02 44.650002 42.652000 ... 23822000.0 43.8620002015-01-05 43.299999 41.431999 ... 26842500.0 42.0180022015-01-06 42.840000 40.841999 ... 31309500.0 42.2560012015-01-07 42.956001 41.956001 ... 14842000.0 42.1899992015-01-08 42.759998 42.001999 ... 17212500.0 42.124001 ... ... ... ... ...2020-10-13 448.890015 436.600006 ... 34463700.0 446.6499942020-10-14 465.899994 447.350006 ... 48045400.0 461.2999882020-10-15 456.570007 442.500000 ... 35672400.0 448.8800052020-10-16 455.950012 438.850006 ... 32620000.0 439.6700132020-10-19 447.000000 437.649994 ... 9422697.0 442.840607