我正在尝试在 PyTorch 中实现一个简单的线性模型,该模型可以接受 x 数据和 y 数据,然后训练以识别方程 y = mx + b。然而,每当我尝试在训练后测试我的模型时,它认为方程是 y = mx + 2b。我会展示我的代码,希望有人能发现问题。提前感谢任何帮助。
import torchD_in = 500D_out = 500batch=200model=torch.nn.Sequential( torch.nn.Linear(D_in,D_out),)
接下来我创建一些数据并设定一个规则。我们用 3x+4 作为例子。
x_data=torch.rand(batch,D_in)y_data=torch.randn(batch,D_out)for i in range(batch): for j in range(D_in): y_data[i][j]=3*x_data[i][j]+5 # 模型认为 y=mx+c -> y=mx+2c?loss_fn=torch.nn.MSELoss(size_average=False)optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
现在开始训练…
for epoch in range(500): y_pred=model(x_data) loss=loss_fn(y_pred,y_data) optimizer.zero_grad() loss.backward() optimizer.step()
然后我用一个全是 1 的张量/矩阵来测试我的模型。
test_data=torch.ones(batch,D_in) y_pred=model(test_data)
现在,我期望得到 3*1 + 4 = 7,但我的模型却认为是 11。
[[ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387, 10.7516], [ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387, 10.7516], [ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387, 10.7516], ..., [ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387, 10.7516], [ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387, 10.7516], [ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387, 10.7516]])
类似地,如果我将规则改为 y=3x+8,我的模型猜测是 19。所以,我不确定发生了什么。为什么常数项被加了两次?顺便说一下,如果我只将规则设为 y=3x,我的模型正确推断出 3,对于一般情况下的 y=mx,我的模型也正确推断出 m。不知为何,常数项让模型出现了偏差。任何帮助解决这个问题的建议都非常感谢。谢谢!
回答:
你的网络训练时间不够长。它使用一个包含500个特征的向量来描述单个数据点。
你的网络需要将包含500个特征的大输入映射到包含500个值的输出。你的训练数据是随机生成的,不像你的简单示例,所以我认为你只需要训练更长时间来调整权重,以近似从 R^500 到 R^500 的这个函数。
如果我减小输入和输出的维度,并增加批量大小、学习率和训练步骤,我得到了预期的结果:
import torchD_in = 100D_out = 100batch = 512model=torch.nn.Sequential( torch.nn.Linear(D_in,D_out),)x_data=torch.rand(batch,D_in)y_data=torch.randn(batch,D_out)for i in range(batch): for j in range(D_in): y_data[i][j]=3*x_data[i][j]+4 # 模型认为 y=mx+c -> y=mx+2c?loss_fn=torch.nn.MSELoss(size_average=False)optimizer=torch.optim.Adam(model.parameters(),lr=0.01)for epoch in range(10000): y_pred=model(x_data) loss=loss_fn(y_pred,y_data) optimizer.zero_grad() loss.backward() optimizer.step()test_data=torch.ones(batch,D_in)y_pred=model(test_data)print(y_pred)
如果你只想近似 f(x) = 3x + 4
且只有一个输入,你也可以将 D_in
和 D_out
设置为 1。