TensorFlow TypeError: 传递给参数 input 的值的数据类型为 uint8,不在允许的值列表中:float16, float32

我在过去三天里一直尝试训练一个简单的卷积神经网络(CNN)。

首先,我设置了一个输入管道/队列配置,用于从目录树中读取图像并准备批次。

我从这个链接获取了这段代码。因此,我现在有了train_image_batchtrain_label_batch,我需要将它们输入到我的CNN中。

train_image_batch, train_label_batch = tf.train.batch(        [train_image, train_label],        batch_size=BATCH_SIZE        # ,num_threads=1    )

但我不知道该如何操作。我使用了这个链接中提供的CNN代码。

# 输入层input_layer = tf.reshape(train_image_batch, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS])# 卷积层 #1conv1 = new_conv_layer(input_layer, NUM_CHANNELS, 5, 32, 2) # 池化层 #1pool1 = new_pooling_layer(conv1, 2, 2)

打印input_layer时显示如下

Tensor(“Reshape:0”, shape=(5, 120, 120, 3), dtype=uint8)

下一行代码在执行时崩溃,报出TypeError错误;conv1 = new_conv_layer(…)。new_conv_layer函数的具体内容如下

def new_conv_layer(input,              # 前一层。               num_input_channels, # 前一层中的通道数。               filter_size,        # 每个滤波器的宽度和高度。               num_filters,        # 滤波器的数量。               stride):# 卷积的滤波器权重的形状。# 此格式由TensorFlow API决定。shape = [filter_size, filter_size, num_input_channels, num_filters]# 使用给定形状创建新的权重,即滤波器。weights = tf.Variable(tf.truncated_normal(shape, stddev=0.05))# 为每个滤波器创建新的偏置。biases = tf.Variable(tf.constant(0.05, shape=[num_filters]))# 创建卷积的TensorFlow操作。# 注意所有维度的步长都设置为1。# 第一个和最后一个步长必须始终为1,# 因为第一个是针对图像编号的,# 最后一个是针对输入通道的。# 例如,strides=[1, 2, 2, 1]表示滤波器# 在图像的x轴和y轴上移动2个像素。# padding设置为'SAME',这意味着输入图像# 用零填充,因此输出的大小与输入相同。layer = tf.nn.conv2d(input=input,                     filter=weights,                     strides=[1, stride, stride, 1],                     padding='SAME')# 将偏置值添加到卷积结果中。# 每个滤波器通道都添加一个偏置值。layer += biases# 整流线性单元(ReLU)。# 它对每个输入像素x计算max(x, 0)。# 这为公式增加了一些非线性,允许我们# 学习更复杂的函数。layer = tf.nn.relu(layer)# 注意ReLU通常在池化之前执行,# 但由于relu(max_pool(x)) == max_pool(relu(x)),我们可以# 通过先进行最大池化来节省75%的ReLU操作。# 我们返回结果层和滤波器权重,# 因为我们稍后会绘制权重。return layer, weights

具体来说,它在tf.nn.conv2d处崩溃,报出以下错误

TypeError: 传递给参数 ‘input’ 的值的数据类型为 uint8,不在允许的值列表中:float16, float32


回答:

你的输入管道中的图像类型为 ‘uint8’,你需要将其类型转换为 ‘float32’,你可以在图像jpeg解码器之后执行此操作:

image = tf.image.decode_jpeg(...image = tf.cast(image, tf.float32)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注