我最近开始使用Torch框架和Lua脚本语言玩弄神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但足够简单的东西:
我的想法是,我有3个输入,我必须选择前两个,进行除法运算,并将结果传递给线性模块。因此,我编写了这个小脚本:
require "nn";require "optim";local N = 3;local input = torch.Tensor{ {1, 2, 3}, {9, 20, 20}, {9, 300, 1},};local output = torch.Tensor(N);for i=1, N do output[i] = 1;endlocal ratioPerceptron = nn.Sequential();ratioPerceptron:add(nn.Narrow(1, 1, 2));ratioPerceptron:add(nn.CDivTable());ratioPerceptron:add(nn.Reshape(N, 1));ratioPerceptron:add(nn.Linear(1, 1));ratioPerceptron:add(nn.Sigmoid());local criterion = nn.BCECriterion();local params, gradParams = ratioPerceptron:getParameters();local optimState = {learningRate = 0.01};local maxIteration = 100000;for i=1, maxIteration do local function f(params) gradParams:zero(); local outputs = ratioPerceptron:forward(input); local loss = criterion:forward(outputs, output); local dloss_doutputs = criterion:backward(outputs, output); ratioPerceptron:backward(input, dloss_doutputs); return loss, gradParams; end optim.sgd(f, params, optimState);end
在训练过程中调用backward时会失败,并显示以下错误:
CDivTable.lua:21: both torch.LongStorage and (null) have no addition operator
但是,如果我从顺序模块中移除CDivTable,并将nn.Reshape和nn.Linear更改为二维输入(因为我们移除了CDivTable,它将二维输入除以产生一维输出),如下所示:
local ratioPerceptron = nn.Sequential();ratioPerceptron:add(nn.Narrow(1, 1, 2));ratioPerceptron:add(nn.Reshape(N, 2));ratioPerceptron:add(nn.Linear(2, 1));ratioPerceptron:add(nn.Sigmoid());
训练可以无错误地完成…是否有其他方法可以对两个选定的输入进行除法运算,并将结果传递给线性模块?
回答:
模块CDivTable
接受一个表作为输入,并将第一个表的元素除以第二个表的元素。在这里,你用单个输入而不是两个输入的表来馈送你的网络。这就是为什么你会遇到带有null
的错误,我认为。Torch无法理解你的输入(由两个向量组成)应该被视为两个向量的表。它只能看到一个大小为2x3
的张量!因此,你必须告诉Torch从输入中创建一个表。因此,你可以使用SplitTable(dim)
模块,它将沿维度dim
将输入拆分为表。
在狭窄模块之后插入这行ratioPerceptron:add(nn.SplitTable(1))
:
local ratioPerceptron = nn.Sequential();ratioPerceptron:add(nn.Narrow(1, 1, 2));ratioPerceptron:add(nn.SplitTable(1))ratioPerceptron:add(nn.CDivTable());ratioPerceptron:add(nn.Reshape(N, 1));ratioPerceptron:add(nn.Linear(1, 1));ratioPerceptron:add(nn.Sigmoid());
此外,当你遇到这样的错误时,我建议你通过添加print
语句来查看你的网络计算了什么:在你添加导致错误的模块之前的那一行插入print(ratioPerceptron:forward(input))
。