### ValueError: Tensor转换请求的数据类型为int32,但Tensor的数据类型为float32 – LSTM实现(tensorflow 2.0.0)

我在尝试测试不同类型的LSTM实现时,在预测代码中遇到了这个错误。

Tensorflow版本 – ‘2.0.0’


我没有删除这个问题,因为我仍然需要知道哪里出了问题。我是否总是需要担心在输入模型时使用float32作为数据类型?


示例代码

X = list()Y = list()X = [x+1 for x in range(20)]Y = [y * 15 for y in X]X = np.array(X,dtype=int)Y=  np.array(Y,dtype=int)X=array(X).reshape(20, 1, 1)model = Sequential()model.add(LSTM(50, activation='relu', input_shape=(1, 1)))model.add(Dense(1))model.compile(optimizer='adam', loss='mse')print(model.summary())model.fit(X, Y, epochs=2, validation_split=0.2, batch_size=5)test_input = np.array(30,dtype=int)test_input = test_input.reshape((1, 1, 1))test_output = model.predict(test_input)   <---- 这一行有错误

错误:

ValueError Traceback (most recent call last)~\AppData\Local\Continuum\anaconda3\envs\PythonCPU\lib\site-packages\tensorflow_core\python\framework\op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)    526                 as_ref=input_arg.is_ref,--> 527                 preferred_dtype=default_dtype)    528           except TypeError as err:~\AppData\Local\Continuum\anaconda3\envs\PythonCPU\lib\site-packages\tensorflow_core\python\framework\ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx, accept_composite_tensors)   1270           "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %-> 1271           (dtype.name, value.dtype.name, value))   1272     return valueValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32: <tf.Tensor 'sequential/lstm/MatMul/ReadVariableOp:0' shape=(1, 200) dtype=float32>During handling of the above exception, another exception occurred:...............TypeError: Input 'b' of 'MatMul' Op has type float32 that does not match type int32 of argument 'a'.

回答:

我尝试了不同数据类型的排列组合,结果发现只需要将所有数组的数据类型全部改为’float32’就能解决这个问题。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注