ValueError: 无法为形状为(4,)的张量’Placeholder_36:0’提供值,该张量形状为'(?, 4)’

我在尝试实现一个TensorFlow回归模型,我的训练数据形状为train_X=(200,4)和train_Y=(200,)。我遇到了形状错误,这是我的一部分代码,请问有人能指出我哪里做错了么?

df=pd.read_csv(‘all.csv’)

df=df.drop(‘Time’,axis=1)

print(df.describe()) #了解数据集

train_Y=df[“power”]

train_X=df.drop(‘power’,axis=1)

train_X=numpy.asarray(train_X)

train_Y=numpy.asarray(train_Y)

n_samples = train_X.shape[0]

tf图输入

X = tf.placeholder(‘float’,[None,len(train_X[0])])

Y = tf.placeholder(“float”)

设置模型权重

W = tf.Variable(rng.randn(), name=”weight”)

b = tf.Variable(rng.randn(), name=”bias”)

构建线性模型

pred = tf.add(tf.multiply(X, W), b)

均方误差

cost = tf.reduce_sum(tf.pow(pred-Y, 2))/(2*n_samples)

梯度下降

注意,minimize()知道要修改W和b,因为Variable对象默认是

trainable=True

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

初始化变量(即分配它们的默认值)

init = tf.global_variables_initializer()

开始训练

with tf.Session() as sess:

# 运行初始化
sess.run(init)
# 拟合所有训练数据
for epoch in range(training_epochs):
    for (x, y) in zip(train_X, train_Y):
        sess.run(optimizer, feed_dict={X: x, Y: y})
    # 每轮显示日志
    if (epoch+1) % display_step == 0:
        c = sess.run(cost, feed_dict={X: train_X, Y:train_Y})
        print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c), \
            "W=", sess.run(W), "b=", sess.run(b))
print("优化完成!")
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("训练成本=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')
# 图形显示
plt.plot(train_X, train_Y, 'ro', label='原始数据')
plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='拟合线')
plt.legend()
plt.show()
enter code here

回答:

我修改了形状,问题解决了

train_y = np.reshape(train_y, (-1, 1))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注