为什么这个tf.placeholder必须是float类型?

为什么x需要是float类型?为什么不能是int类型呢?我传递的是一个int类型的列表啊?

代码:

x = tf.placeholder(tf.float32, shape=[None, 1])  # 为什么必须是float类型?
y = tf.placeholder(tf.int32, shape=[None, 2])
with tf.name_scope("network"):
    layer1 = tf.layers.dense(x, 100, activation=tf.nn.relu, name="hidden_layer")
    output = tf.layers.dense(layer1, 2, name="output_layer")
with tf.name_scope("loss"):
    xentropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output)
    loss = tf.reduce_mean(xentropy, name="loss")
with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
    with tf.Session() as sess:
        for i in range(1, 50):
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            sess.run(training_op, feed_dict={x: np.array(train_data).reshape([-1, 1]), y: label})
            if i % 10 == 0:
                saver.save(sess, "saved_models/testing")
                print "Saved"

当我将其改为tf.int32时,会出现以下错误:

TypeError: Value passed to parameter 'features' has DataType int32 not in list of allowed values: float16, float32, float64

如果需要,我可以提供更多代码。


回答:

这是因为tf.nn.softmax_cross_entropy_with_logits的要求:

logits和labels必须具有相同的形状[batch_size, num_classes]以及相同的dtype(可以是float16, float32, 或float64)。

我认为你可以用整数输入计算损失。然而,大多数情况下,这个损失是通过梯度下降来最小化的——就像你所做的那样——这意味着输入需要编码实数以获得任意更新。

问题在于tf.layers.dense不会改变你的输入类型。所以如果它的输入是整数,它将产生一个整数输出。(至少如果激活函数与整数兼容,比如relu——sigmoid会引发错误)。

你可能想要做的是提供整数输入,然后tf.float32中进行所有计算。要做到这一点,首先将你的输入转换为float类型,然后再提供给dense

layer1 = tf.layers.dense(tf.to_float(x), 100, activation=tf.nn.relu, name="hidden_layer")

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注