Tensorflow神经网络用于二元分类;如何使用占位符

这是我的代码:

我的目标是一个形状为(N,)的向量,仅包含二进制数字

然而,我遇到了编译错误

/Library/Frameworks/Python.framework/Versions/3.6/bin/python3.6 /Users/Lai/Dropbox/PersonalProject/MachineLearningForSports/models/NeuralNetwork.py/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.  "This module will be removed in 0.20.", DeprecationWarning)Traceback (most recent call last):  File "/Users/Lai/Dropbox/PersonalProject/MachineLearningForSports/models/NeuralNetwork.py", line 102, in <module>    _, c = sess.run([optimizer,cost],feed_dict = {x:batch_x,y:batch_y})  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 766, in run    run_metadata_ptr)  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 943, in _run    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))ValueError: Cannot feed value of shape (100,) for Tensor 'Placeholder_1:0', which has shape '(?, 2)'

由于我的批量大小是100,我认为错误发生在比较我的目标与预测时。tf.placable似乎以N*2的形式进行预测,虽然我并不确定。有什么帮助吗?谢谢


回答:

每当你执行计算图中的动态节点时——这几乎是任何非输入节点——你都需要指定所有依赖变量。可以这样想:如果你有一个数学函数的形式y = f(x) = Ax + b(例如),并且你想评估这个函数,你需要指定x。然而,如果你想评估(即读取)A的值,你不需要指定x,因为A是已知的(至少在这个上下文中)。

因此,你可以评估(通过传递给tf.Session.run(...))你的网络参数,而无需指定输入(如上例中的A)。然而,你不能在不指定输入的情况下评估函数的输出(在例子中,你需要指定x)。

至于你的代码,以下这行因此将不会工作:print(sess.run(pred)),因为你要求会话评估一个函数而未指定其输入。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注