使用训练好的Keras模型对新的CSV数据进行预测

我正在做一个项目,目的是预测房价是否高于或低于其中位价。为此,我使用了来自Kaggle的数据集(https://drive.google.com/file/d/1GfvKA0qznNVknghV4botnNxyH-KvODOC/view)。其中1表示“高于中位价”,0表示“低于中位价”。我编写了以下代码来训练神经网络并将其保存为.h5文件:

import pandas as pdfrom sklearn import preprocessingfrom sklearn.model_selection import train_test_splitfrom keras.models import Sequentialfrom keras.layers import Denseimport h5pydf = pd.read_csv('housepricedata.csv')dataset = df.valuesX = dataset[:,0:10]Y = dataset[:,10]min_max_scaler = preprocessing.MinMaxScaler()X_scale = min_max_scaler.fit_transform(X)X_train, X_val_and_test, Y_train, Y_val_and_test = train_test_split(X_scale, Y, test_size=0.3)X_val, X_test, Y_val, Y_test = train_test_split(X_val_and_test, Y_val_and_test, test_size=0.5)model = Sequential([    Dense(32, activation='relu', input_shape=(10,)),    Dense(32, activation='relu'),    Dense(1, activation='sigmoid'),])model.compile(optimizer='sgd',              loss='binary_crossentropy',              metrics=['accuracy'])hist = model.fit(X_train, Y_train,          batch_size=32, epochs=100,          validation_data=(X_val, Y_val))model.save("house_price.h5")

运行后,.h5文件成功保存到了我的目录中。现在我想使用我训练好的模型对一个新的.csv文件进行预测,判断这些数据是否高于或低于中位价。以下是我在VSCode中要进行预测的csv文件的图像:csv文件图像 如您所见,这个文件不包含1(高于中位价)或0(低于中位价),因为这是我希望它预测的内容。我编写了以下代码来实现这个目的:

它的输出是[[0.00101464]] 我不知道这是什么意思,为什么它只返回一个值,尽管csv文件有4行。有人知道我如何修复这个问题并能够为csv文件中的每一行预测一个1或0吗?谢谢!


回答:

我尽我所能理解你想要的!让我们试试看!这个代码对我来说是有效的

 import tensorflow model = tensorflow.keras.models.load_model("house_price.h5") y_pred=model.predict(X_test)

如果你仍然无法解决问题,请访问以下网站 1:答案1 2:答案2

两种y_pred对我来说产生了相同的输出

这里有一点你需要注意,y_pred不包含0和1,因为你使用了sigmoid函数,它以概率形式确定预测结果,所以如果(y_pred>0.5)表示值为1

  #True表示1  #false表示0  #你可以使用pandas的replace函数或map函数将true转换为1

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注