Python + Theano:逻辑回归权重不更新

我已经广泛对比了现有的教程,但还是无法找出为什么我的权重不更新。以下是返回更新列表的函数:

def get_updates(cost, params, learning_rate):updates = []for param in params:    updates.append((param, param - learning_rate * T.grad(cost, param)))return updates

这个函数是在顶层定义的,不在任何类中。这是每个参数的标准梯度下降。这里的’params’参数是以mlp.params的形式输入的,它只是各层参数列表的拼接列表。为了找出我的成本函数为什么不减少,我移除了除了逻辑回归层之外的所有层。以下是MLP构造函数中mlp.params的定义。它遵循每个层的定义及其各自的参数列表。

self.params = []for layer in self.layers:    self.params += layer.params

以下是训练函数,我在每个epoch中对每个小批量数据调用它:

train = theano.function([minibatch_index], cost,                    updates=get_updates(cost, mlp.params, learning_rate),                    givens= {                        x: train_set_x[minibatch_index * batch_size : (minibatch_index + 1) * batch_size],                        y: train_set_y[minibatch_index * batch_size : (minibatch_index + 1) * batch_size]                    })

如果您需要更多细节,整个文件在这里可用:http://pastebin.com/EeNmXfGD

我不知道有多少人使用Theano(似乎并不多);如果你读到这里,谢谢你。

已修复:我确定不能使用平均平方误差作为成本函数。将其替换为负对数似然后,它照常工作。


回答:

这种行为是由几个原因引起的,但归根结底是成本函数没有正确计算。在你的实现中,逻辑回归层的输出是每个输入数字的预测类别(通过argmax操作获得),你计算它与预期预测之间的平方差。

这会导致相对于模型中任何参数的梯度为0,因为argmax的输出(预测类别)相对于argmax的输入(类别概率)的梯度将为0。

相反,逻辑回归应该输出类别的概率:

def output(self, input):    input = input.flatten(2)    self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)    return self.p_y_given_x

然后在MLP类中,你计算成本。你可以使用期望类别概率与模型计算的概率之间的均方误差,但人们倾向于使用预期类的负对数似然,你可以在MLP类中这样实现它:

def neg_log_likelihood(self, x, y):    p_y_given_x = self.output(x)    return -T.mean(T.log(p_y_given_x)[T.arange(y.shape[0]), y])

然后你可以使用这个函数来计算你的成本,模型就会训练:

cost = mlp.neg_log_likelihood(x_, y)

另外几点:

  • 在第215行,当你打印你的成本时,你将其格式化为整数值,但它是一个浮点值;这会在监控中丢失精度。
  • 如你在逻辑回归类中所做的那样,将所有权重初始化为0通常是不推荐的。权重应该在初始值上有所不同,以帮助打破对称性

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注