在Theano中如何获取共享变量的形状信息?

似乎使用variable.shape会提示我:

AttributeError: 'SharedVariable' object has no attribute 'shape'

theano.tensor.shape(variable)则会返回shape.0

我非常困惑为什么无法获取到形状信息?当我想获取符号变量的形状信息时,也遇到了同样的问题,实在是太奇怪了。

x = T.matrix('x')   # 数据以光栅化图像的形式呈现
y = T.ivector('y')  # 标签以一维整数向量的形式呈现
layer0_input = x.reshape((batch_size, 1, 28, 28))

在上面的例子中,x(符号变量)已经被重塑为某种形状,如果我无法检索其形状信息而仍然可以为其分配新形状,这对我来说是没有意义的。


回答:

第一个错误可能是由于你试图在数据类型SharedVariable上而不是在实际的共享变量上评估shape属性造成的。

否则,获得shape.0是完全正常的:这是一个表示形状的符号表达式,其在先验上是未知的。一旦你用数据进行评估,你就会看到形状:

import theano
import theano.tensor as T
import numpy as np
s = theano.shared(np.arange(2 * 3 * 5).reshape(2, 3, 5))
print(s.shape)  # 给你shape.0
print(s.shape.eval())  # 给你包含2, 3, 5的数组
a = T.tensor3()
print(a.shape)  # 给你shape.0
print(a.shape.eval({a: np.arange(2 * 3 * 5).reshape(2, 3, 5).astype(theano.config.floatX)}))  # 给你2, 3, 5

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注