值错误:无法为形状为(5, 15)的值提供给张量’one_hot:0’,其形状为'(5, 15, 2)’

这是代码

num_epochs = 100total_series_length = 50000truncated_backprop_length = 15state_size = 4num_classes = 2echo_step = 3batch_size = 5num_batches = total_series_length//batch_size//truncated_backprop_length

生成数据{…}

batchX_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])#以及一个用于RNN状态的,5,4 init_state = tf.placeholder(tf.float32, [batch_size, state_size])batchX_placeholder = tf.one_hot(batchX_placeholder, num_classes)inputs_series = tf.unstack(batchX_placeholder, axis=1)cell = tf.contrib.rnn.BasicRNNCell(state_size)rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state=init_state)

一些优化代码{….}然后创建图形

#第3步 训练网络with tf.Session() as sess:    #我们每次都必须这样做,它应该只是知道    #我们已经初始化了这些变量。v2版的家伙们,v2..    sess.run(tf.initialize_all_variables())    #交互模式    plt.ion()    #初始化图形    plt.figure()    #显示图形    plt.show()    #显示损失减少    loss_list = []    for epoch_idx in range(num_epochs):        #在每个epoch生成数据,批次在epoch中运行        x,y = generateData()        #初始化一个空的隐藏状态        _current_state = np.zeros((batch_size, state_size))        print("新数据,epoch", epoch_idx)        #每个批次        for batch_idx in range(num_batches):            #每个批次的开始和结束点            #因为权重在每个时间层重复            #这些层不会被展开到时间的开始,            #那将是太高的计算成本,因此在有限的时间步内被截断            start_idx = batch_idx * truncated_backprop_length            end_idx = start_idx + truncated_backprop_length            batchX = x[:,start_idx:end_idx]            batchY = y[:,start_idx:end_idx]            #运行计算图,给它我们之前计算的值            _total_loss, _train_step, _final_state, _predictions_series = sess.run(                [total_loss, train_step, final_state, predictions],                feed_dict={                    batchX_placeholder:batchX,                    batchY_placeholder:batchY,                    init_state:_current_state                })            loss_list.append(_total_loss)            if batch_idx%100 == 0:                print("步骤",batch_idx, "损失", _total_loss)                plot(loss_list, _predictions_series, batchX, batchY)plt.ioff()plt.show()

这是错误信息:

ValueError                                Traceback (most recent call last)<ipython-input-9-7c3d1289d16b> in <module>()     40                     batchX_placeholder:batchX,     41                     batchY_placeholder:batchY,---> 42                     init_state:_current_state     43                 })     44 /home/pranshu_44/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)    765     try:    766       result = self._run(None, fetches, feed_dict, options_ptr,--> 767                          run_metadata_ptr)    768       if run_metadata:    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)/home/pranshu_44/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)    942                 'Cannot feed value of shape %r for Tensor %r, '    943                 'which has shape %r'--> 944                 % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))    945           if not self.graph.is_feedable(subfeed_t):    946             raise ValueError('Tensor %s may not be fed.' % subfeed_t)ValueError: Cannot feed value of shape (5, 15) for Tensor 'one_hot:0', which has shape '(5, 15, 2)'

我查看了文档,但那一点也不helpful,如果有其他简单的方法也会很有帮助


回答:

您将占位符变量转换为one-hot表示,但没有将您在训练期间实际输入网络的数据转换。尝试在输入之前将batchX转换为one-hot表示。这段代码将矩阵转换为其one-hot表示:

# 假设y包含从0到N-1的值,其中N是类别数batchX = (np.arange(max(batchX.flatten())+1) == batchX[:,:,None]).astype(int)

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注