Tensorflow tf.squared_difference 输出意外形状

我是新手,如果这个问题看起来很傻,请原谅我,但我实在是在网上找不到相关信息。我在使用 tf.squared_difference 时,输出的形状与预期不符。我期望得到一个形状为 shape=(100, ?) 的张量作为损失,如下代码片段所示

[print("Logits",logits,"#Labels",labels,"LOSS",tf.squared_difference(labels,logits)) for logits, labels in zip(logits_series,labels_series)] 

然而,它产生了形状为 (100,100) 的损失

Logits Tensor(“add_185:0”, shape=(100, 1), dtype=float32) #Labels Tensor(“unstack_29:0”, shape=(100,), dtype=float32) LOSS Tensor(“SquaredDifference_94:0”, shape=(100, 100), dtype=float32) Logits Tensor(“add_186:0”, shape=(100, 1), dtype=float32) #Labels Tensor(“unstack_29:1”, shape=(100,), dtype=float32) LOSS Tensor(“SquaredDifference_95:0”, shape=(100, 100), dtype=float32)

我测试了另一个例子,代码如下,输出的形状符合预期。

myTESTX = tf.placeholder(tf.float32, [100, None])myTESTY = tf.placeholder(tf.float32, [100, 1])print("Test diff X-Y",tf.squared_difference(myTESTX,myTESTY) )print("Test diff Y-X",tf.squared_difference(myTESTY,myTESTX) )

Test diff X-Y Tensor(“SquaredDifference_92:0”, shape=(100, ?), dtype=float32) Test diff Y-X Tensor(“SquaredDifference_93:0”, shape=(100, ?), dtype=float32)

我无法理解为什么这两个代码片段会产生不同的输出形状


回答:

你的第一个例子(使用 logitslabels)与第二个例子(使用 myTESTXmyTESTY)之间存在细微的差异。logits 的形状与 myTESTY 相同,为 (100, 1)。然而,labels 的形状为 (100,)(这不是动态形状),而 myTESTX 的形状为 (100, ?)

在第一个例子中(logitslabels),输入形状分别为 (100,)(100,1),TensorFlow 使用广播。由于两个输入形状都不是动态的,因此你的输出形状是静态的:(100, 100),这是由于广播造成的。

在第二个例子中(myTESTXmyTESTY),输入形状分别为 (100, ?)(100, 1)。第一个输入形状是动态的,因此你的输出形状也是动态的:(100, ?)

作为一个更简单的、说明性的例子,在使用相同广播机制的 numpy 中,可以考虑以下代码:

import numpy as npx = np.arange(10)                # Shape: (10,)y = np.arange(10).reshape(10,1)  # Shape: (10, 1)difference = x-y                 # Shape: (10, 10)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注