总是提示期望dense_Dense1_input的形状为[null,8],但得到的数组形状为[8,1]

我在尝试使用CSV文件中的数据来训练我的神经网络,使用的是tensorflow.js,但始终没有结果。总是出现相同的错误信息(检查时出错:期望dense_Dense1_input的形状为[null,8],但得到的数组形状为[8,1]。)。我知道有类似的问题被问过,但没有找到任何关于数据存储在CSV文件中的解答。

这是我的代码:

const dataLine = tf.tensor([0.352941,0.482412,0,0,0,0.353204,0.047822,0.116667]);columnConfigs = {outcome: {isLabel: true}};const dataset = tf.data.csv('data.csv', {columnConfigs}).map(({xs, ys}) => {return {xs:Object.values(xs), ys:Object.values(ys)}});const model = tf.sequential();model.add(tf.layers.dense({units: 12, inputShape: [8]}));model.add(tf.layers.dense({units: 1, inputShape: [12]}));model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});model.fitDataset(dataset, {    epochs: 100,  });const prediction = model.predict(dataLine);prediction.print();

我还附上了一小部分我正在使用的数据:

pregnancies,glucose,blood_pressure,skin_thickness,insulin,BMI,diabetes_pedigree_function,age,outcome0.058824,0.507538,0.409836,0.151515,0.042553,0.360656,0.191289,0.083333,00.294118,0.442211,0.540984,0.212121,0.027187,0.363636,0.112724,0.15,00.470588,0.884422,0.737705,0.343434,0.35461,0.502235,0.166097,0.616667,10.411765,0.753769,0.540984,0.424242,0.404255,0.517139,0.273271,0.35,00.058824,0.366834,0.409836,0.10101,0,0.342772,0.072588,0,00.411765,0.939698,0.557377,0.393939,0.359338,0.561848,0.075149,0.333333,10,0.502513,0.721311,0.606061,0.130024,0.697466,0.377455,0.166667,00,0.733668,0.672131,0,0,0.603577,0.727156,0.383333,00,0.527638,0.52459,0.414141,0.167849,0.61848,0.040564,0.016667,00.117647,0.422111,0,0,0,0,0.096499,0,0

对此问题提供的任何帮助都将不胜感激,谢谢


回答:

尝试对数据集进行批处理:

const dataset = tf.data.csv('data.csv', {columnConfigs})    .map(({xs, ys}) => {return {xs:Object.values(xs), ys:Object.values(ys)}})    .batch(100)

然后在预测时扩展维度:

const dataLine = tf.tensor([0.352941,0.482412,0,0,0,0.353204,0.047822,0.116667])    .expandDims();...const prediction = model.predict(dataLine);

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注