PyTorch – 多元正态分布参数的最大似然估计问题

为了了解PyTorch的工作原理,我尝试对多元正态分布中的一些参数进行最大似然估计。然而,对于与协方差相关的参数,这种方法似乎不起作用。

所以我的问题是:为什么这段代码不起作用?

输出:

mu: tensor([0.4168, 0.1580]), mu_hat: tensor([0.4127, 0.1454], requires_grad=True)sigma: tensor([1.1917, 1.7290]), sigma_hat: tensor([1., 1.], grad_fn=<ExpBackward>)rho: tensor([0.3589]), rho_hat: tensor([0.], grad_fn=<TanhBackward>)>>> torch.__version__'1.0.0.dev20181127'

换句话说,为什么log_sigmaatanh_rho的估计值没有从初始值移动?


回答:

您创建协方差矩阵的方式不是可反向传播的

def make_covariance_matrix(sigma, rho):    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],                         [rho * torch.prod(sigma), sigma[1]**2]])

当从(多个)张量创建一个新张量时,只有输入张量的值会被保留。输入张量的所有附加信息都被去除,因此从这一点开始,所有与您的参数的图连接都被切断,因此反向传播无法通过。

这里有一个简短的例子来说明这一点:

输出:

Original parameter 1:tensor([ 0.8913]) TrueOriginal parameter 2:tensor([ 0.4785]) TrueNew tensor form params:tensor([ 0.8913,  0.4785]) False

如您所见,从参数param1param2创建的张量不跟踪param1param2的梯度。

所以您可以使用以下保持图连接可反向传播的代码:

def make_covariance_matrix(sigma, rho):    conv = torch.cat([(sigma[0]**2).view(-1), rho * torch.prod(sigma), rho * torch.prod(sigma), (sigma[1]**2).view(-1)])    return conv.view(2, 2)

使用torch.cat将值连接到一个扁平张量。然后使用view()将它们调整到正确的形状。
这会产生与您的函数相同的矩阵输出,但它保持了与您的参数log_sigmaatanh_rho的连接。

这是更改make_covariance_matrix前后的输出。如您所见,现在您可以优化参数,并且值确实发生了变化:

Before:mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.,  0.])sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.,  1.])rho: tensor([ 0.2558]), rho_hat: tensor([ 0.])After:mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.0712,  0.7781])sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.4410,  1.0807])rho: tensor([ 0.2558]), rho_hat: tensor([ 0.2235])

希望这对您有帮助!

Related Posts

在使用k近邻算法时,有没有办法获取被使用的“邻居”?

我想找到一种方法来确定在我的knn算法中实际使用了哪些…

Theano在Google Colab上无法启用GPU支持

我在尝试使用Theano库训练一个模型。由于我的电脑内…

准确性评分似乎有误

这里是代码: from sklearn.metrics…

Keras Functional API: “错误检查输入时:期望input_1具有4个维度,但得到形状为(X, Y)的数组”

我在尝试使用Keras的fit_generator来训…

如何使用sklearn.datasets.make_classification在指定范围内生成合成数据?

我想为分类问题创建合成数据。我使用了sklearn.d…

如何处理预测时不在训练集中的标签

已关闭。 此问题与编程或软件开发无关。目前不接受回答。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注