TensorFlow.js中用于颜色预测的最佳模型类型?

我在创建一个颜色预测器时遇到了一个问题。我成功地让模型工作了,但预测总是落在2.5到5.5的中位数范围内。模型应该输出0到8的数值,对应每种颜色,并且我为每种颜色准备了等量的训练数据点。是否有更好的模型可以使用,使其能预测0或7?我认为它不会,因为它认为这些是某种异常值。

这是我的模型

const model = tf.sequential();const hidden = tf.layers.dense({  units: 3,  inputShape: [3] //每个输入有3个值r, g, 和 b});const output = tf.layers.dense({  units: 1 //只有一个输出(对应rgb值的颜色)    });model.add(hidden);model.add(output);model.compile({  activation: 'sigmoid',  loss: "meanSquaredError",  optimizer: tf.train.sgd(0.005)});

这个模型适合我的问题吗?


回答:

这个模型缺乏非线性,因为没有激活函数。给定一个rgb输入,模型应该预测8个可能值中最可能的颜色。这是一个分类问题。问题中定义的模型在进行回归,即试图根据输入预测一个数值。

对于分类问题,最后一层应该预测概率。在这种情况下,最后一层通常使用softmax激活函数。损失函数应为categoricalCrossentropybinaryCrossEntropy(如果只有两种颜色需要预测)。

考虑以下预测三种颜色类别的模型:红色,绿色和蓝色

const model = tf.sequential();model.add(tf.layers.dense({units: 10, inputShape: [3], activation: 'sigmoid' }));model.add(tf.layers.dense({units: 10, activation: 'sigmoid' }));model.add(tf.layers.dense({units: 3, activation: 'softmax' }));model.compile({ loss: 'categoricalCrossentropy', optimizer: 'adam' });const xs = tf.tensor([  [255, 23, 34],  [255, 23, 43],  [12, 255, 56],  [13, 255, 56],  [12, 23, 255],  [12, 56, 255]]);// Labelsconst label = ['red', 'red', 'green', 'green', 'blue', 'blue']const setLabel = Array.from(new Set(label))const ys = tf.oneHot(tf.tensor1d(label.map((a) => setLabel.findIndex(e => e === a)), 'int32'), 3)// Train the model using the data.  model.fit(xs, ys, {epochs: 100}).then((loss) => {  const t = model.predict(xs);  pred = t.argMax(1).dataSync(); // get the class of highest probability  labelsPred = Array.from(pred).map(e => setLabel[e])  console.log(labelsPred)}).catch((e) => {  console.log(e.message);})
<html>  <head>    <!-- Load TensorFlow.js -->    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"> </script>  </head>  <body>  </body></html>

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注