GPflow分类:后验方差的解释

GPflow网站上的多类分类教程中,使用了一个稀疏变分高斯过程(SVGP)模型在一个一维玩具示例上。就像所有其他GPflow模型一样,SVGP模型有一个方法predict_y(self, Xnew),它返回在Xnew点上的未知数据的均值和方差。

从教程中可以清楚地看出,从predict_y解包的第一个参数是三个类别中每一个的后验预测概率(单元格[7][8]),在下图的第二个面板中以彩色线条显示。然而,作者并没有详细解释从predict_y解包的第二个参数,即预测的方差。在回归设置中,我能理解它的解释,因为在这种情况下后验预测分布将是一个高斯分布。


但我无法理解这里的解释是什么。特别是,我想知道如何使用这个度量来构建表示新数据点类别预测不确定性的误差条。


我稍微修改了教程中的代码,以在下图中添加一个额外的面板:第三个面板以黑色显示了从predict_y获得的方差的最大标准差(方差的平方根)。显然,这是一个很好的不确定性度量,而且最高可能值为0.5可能也不是巧合,但我找不到它是如何计算的以及它代表什么。

黑色线条:三个中的最高方差

包含所有代码的完整笔记本在这里

def plot(m):f = plt.figure(figsize=(12,8))a1 = f.add_axes([0.05, 0.05, 0.9, 0.5])av = f.add_axes([0.05, 0.6, 0.9, 0.1])a2 = f.add_axes([0.05, 0.75, 0.9, 0.1])a3 = f.add_axes([0.05, 0.9, 0.9, 0.1])xx = np.linspace(m.X.read_value().min()-0.3, m.X.read_value().max()+0.3, 200).reshape(-1,1)mu, var = m.predict_f(xx)mu, var = mu.copy(), var.copy()p, v = m.predict_y(xx)a3.set_xticks([])a3.set_yticks([])av.set_xticks([])lty = ['-', '--', ':']for i in range(m.likelihood.num_classes):    x = m.X.read_value()[m.Y.read_value().flatten()==i]    points, = a3.plot(x, x*0, '.')    color=points.get_color()    a1.fill_between(xx[:,0], mu[:,i] + 2*np.sqrt(var[:,i]), mu[:,i] - 2*np.sqrt(var[:,i]), alpha = 0.2)    a1.plot(xx, mu[:,i], color=color, lw=2)    a2.plot(xx, p[:,i], '-', color=color, lw=2)av.plot(xx, np.sqrt(np.max(v[:,:], axis = 1)), c = "black", lw=2)for ax in [a1, av, a2, a3]:  ax.set_xlim(xx.min(), xx.max())a2.set_ylim(-0.1, 1.1)a2.set_yticks([0, 1])a2.set_xticks([])plot(m)

回答:

Model.predict_y()调用Likelihood.predict_mean_and_var()。如果你查看后者的函数文档[1],你会发现它所做的只是计算预测分布的均值和方差。即,我们首先计算边缘预测分布q(y) = \int p(y|f) q(f) df,然后我们计算q(y)的均值和方差。

对于高斯分布,均值和方差可以独立指定,并且它们分别作为点预测和不确定性的解释。对于伯努利似然分布,均值和方差完全由单一参数p决定。分布的均值是事件的概率,这已经告诉我们不确定性!方差没有提供更多信息。

然而,你是对的,方差是一个很好的不确定性度量,其中更高意味着更多的不确定性。作为p函数的熵看起来非常相似(尽管两者在边缘附近的行为有所不同):

p = np.linspace(0.001, 1 - 0.001, 1000)[:, None]q = 1 - pplt.plot(p, -p * np.log(p) - q * np.log(q), label='entropy')plt.plot(p, p * q, label='variance')plt.legend()plt.xlabel('probability')

[1] https://github.com/GPflow/GPflow/blob/b8ed8332549a375da8658a1117470ac86d823e7f/gpflow/likelihoods.py#L76

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注