我将一个用于时间序列分析的Tensorflow代码转换成了Pytorch,结果发现性能差异非常大,实际上Pytorch的层完全无法处理季节性问题。我感觉肯定是遗漏了某些重要的东西。
请帮助找出Pytorch代码中的不足之处,导致学习效果不理想。我注意到当遇到季节变化时,损失值会出现大幅波动,并且无法学习到这些变化。使用相同的层、节点和其他所有设置,我本以为性能应该相近。
# tensorflow codewindow_size = 20batch_size = 32shuffle_buffer_size = 1000def windowed_dataset(series, window_size, batch_size, shuffle_buffer): dataset = tf.data.Dataset.from_tensor_slices(series) dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True) dataset = dataset.flat_map(lambda window: window.batch(window_size + 1)) dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1], window[-1])) dataset = dataset.batch(batch_size).prefetch(1) return datasetdataset = windowed_dataset(x_train, window_size, batch_size, shuffle_buffer_size)model = tf.keras.models.Sequential([ tf.keras.layers.Dense(100, input_shape=[window_size], activation="relu"), tf.keras.layers.Dense(10, activation="relu"), tf.keras.layers.Dense(1)])model.compile(loss="mse", optimizer=tf.keras.optimizers.SGD(lr=1e-6, momentum=0.9))model.fit(dataset,epochs=100,verbose=0)forecast = []for time in range(len(series) - window_size): forecast.append(model.predict(series[time:time + window_size][np.newaxis]))forecast = forecast[split_time-window_size:]results = np.array(forecast)[:, 0, 0]plt.figure(figsize=(10, 6))plot_series(time_valid, x_valid)plot_series(time_valid, results)tf.keras.metrics.mean_absolute_error(x_valid, results).numpy()
# pytorch codewindow_size = 20batch_size = 32shuffle_buffer_size = 1000class tsdataset(Dataset): def __init__(self, series, window_size): self.series = series self.window_size = window_size self.dataset, self.labels = self.preprocess() def preprocess(self): series = self.series final, labels = [], [] for i in range(len(series)-self.window_size): final.append(np.array(series[i:i+window_size])) labels.append(np.array(series[i+window_size])) return torch.from_numpy(np.array(final)), torch.from_numpy(np.array(labels)) def __getitem__(self,index): # print(self.dataset[index], self.labels[index], index) return self.dataset[index], self.labels[index] def __len__(self): return len(self.dataset)train_dataset = tsdataset(x_train, window_size)train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)class tspredictor(nn.Module): def __init__(self, window_size, out1, out2, out3): super(tspredictor, self).__init__() self.l1 = nn.Linear(window_size, out1) self.l2 = nn.Linear(out1, out2) self.l3 = nn.Linear(out2, out3) def forward(self,seq): l1 = F.relu(self.l1(seq)) l2 = F.relu(self.l2(l1)) l3 = self.l3(l2) return l3model = tspredictor(20, 100,10,1)loss_function = nn.MSELoss()optimizer = optim.SGD(model.parameters(), lr=1e-6, momentum=0.9)for epoch in range(100): for t,l in train_dataloader: model.zero_grad() tag_scores = model(t) loss = loss_function(tag_scores, l) loss.backward() optimizer.step() # print("Epoch is {}, loss is {}".format(epoch, loss.data))forecast = []for time in range(len(series) - window_size): prediction = model(torch.from_numpy(series[time:time + window_size][np.newaxis])) forecast.append(prediction)forecast = forecast[split_time-window_size:]results = np.array(forecast)plt.figure(figsize=(10, 6))plot_series(time_valid, x_valid)plot_series(time_valid, results)
要生成数据,可以使用以下代码:
def plot_series(time, series, format="-", start=0, end=None): plt.plot(time[start:end], series[start:end], format) plt.xlabel("Time") plt.ylabel("Value") plt.grid(False)def trend(time, slope=0): return slope * timedef seasonal_pattern(season_time): """Just an arbitrary pattern, you can change it if you wish""" return np.where(season_time < 0.1, np.cos(season_time * 6 * np.pi), 2 / np.exp(9 * season_time))def seasonality(time, period, amplitude=1, phase=0): """Repeats the same pattern at each period""" season_time = ((time + phase) % period) / period return amplitude * seasonal_pattern(season_time)def noise(time, noise_level=1, seed=None): rnd = np.random.RandomState(seed) return rnd.randn(len(time)) * noise_leveltime = np.arange(10 * 365 + 1, dtype="float32")baseline = 10series = trend(time, 0.1) baseline = 10amplitude = 40slope = 0.005noise_level = 3# Create the seriesseries = baseline + trend(time, slope) + seasonality(time, period=365, amplitude=amplitude)# Update with noiseseries += noise(time, noise_level, seed=51)split_time = 3000time_train = time[:split_time]x_train = series[:split_time]time_valid = time[split_time:]x_valid = series[split_time:]
回答:
损失函数中存在广播问题。将损失函数更改为以下内容可以解决这个问题:
loss = loss_function(tag_scores, l.view(-1,1))