PyTorch – 多元正态分布参数的最大似然估计问题

为了了解PyTorch的工作原理,我尝试对多元正态分布中的一些参数进行最大似然估计。然而,对于与协方差相关的参数,这种方法似乎不起作用。

所以我的问题是:为什么这段代码不起作用?

输出:

mu: tensor([0.4168, 0.1580]), mu_hat: tensor([0.4127, 0.1454], requires_grad=True)sigma: tensor([1.1917, 1.7290]), sigma_hat: tensor([1., 1.], grad_fn=<ExpBackward>)rho: tensor([0.3589]), rho_hat: tensor([0.], grad_fn=<TanhBackward>)>>> torch.__version__'1.0.0.dev20181127'

换句话说,为什么log_sigmaatanh_rho的估计值没有从初始值移动?


回答:

您创建协方差矩阵的方式不是可反向传播的

def make_covariance_matrix(sigma, rho):    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],                         [rho * torch.prod(sigma), sigma[1]**2]])

当从(多个)张量创建一个新张量时,只有输入张量的值会被保留。输入张量的所有附加信息都被去除,因此从这一点开始,所有与您的参数的图连接都被切断,因此反向传播无法通过。

这里有一个简短的例子来说明这一点:

输出:

Original parameter 1:tensor([ 0.8913]) TrueOriginal parameter 2:tensor([ 0.4785]) TrueNew tensor form params:tensor([ 0.8913,  0.4785]) False

如您所见,从参数param1param2创建的张量不跟踪param1param2的梯度。

所以您可以使用以下保持图连接可反向传播的代码:

def make_covariance_matrix(sigma, rho):    conv = torch.cat([(sigma[0]**2).view(-1), rho * torch.prod(sigma), rho * torch.prod(sigma), (sigma[1]**2).view(-1)])    return conv.view(2, 2)

使用torch.cat将值连接到一个扁平张量。然后使用view()将它们调整到正确的形状。
这会产生与您的函数相同的矩阵输出,但它保持了与您的参数log_sigmaatanh_rho的连接。

这是更改make_covariance_matrix前后的输出。如您所见,现在您可以优化参数,并且值确实发生了变化:

Before:mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.,  0.])sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.,  1.])rho: tensor([ 0.2558]), rho_hat: tensor([ 0.])After:mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.0712,  0.7781])sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.4410,  1.0807])rho: tensor([ 0.2558]), rho_hat: tensor([ 0.2235])

希望这对您有帮助!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注