在TensorFlow.js中加载保存的自定义模型后预测错误

在编译和训练我的自定义模型后,我保存了它并得到了两个文件,例如.bin和.json文件。之后,我在另一个页面加载了该自定义模型,在该页面上,我输入了用于训练该模型的图像,并根据加载的自定义模型对这些图像进行预测。

虽然对一些图像的预测效果很好,但对其他图像的预测结果却错误。

这是我的代码:

        $("#predict-button").click(async function(){        let image= $('#selected-image').get(0);        let image1 = $('#selected-image1').get(0);        console.log('image:::',image);        console.log('image1:::',image1);        let tensorarr = [];        let tensor1 = preprocessImage(image,$("#model-selector").val());        tensorarr.push(tensor1);        let tensor2 = preprocessImage(image1,$("#model-selector").val());        tensorarr.push(tensor2);        let resize_image = [];        let resize;        for(var i=0; i<tensorarr.length; i++)        {            resize = tf.reshape(tensorarr[i], [1, 224, 224, 3],'resize');            console.log('resize:::',resize);            resize_image.push(resize);        }        // Labels        const label = ['Shelf','Rack'];        const setLabel = Array.from(new Set(label));        let ysarr =[];        const ys = tf.oneHot(tf.tensor1d(label.map((a) => setLabel.findIndex(e => e === a)), 'int32'), 10)        console.log('ys:::'+ys);        const y = tf.reshape(ys, [-1]);        y.print();        const d = y.slice([0], [10]);        d.print();        ysarr.push(d);        const e = y.slice([10], [10]);        e.print();        ysarr.push(e);        console.log('ysarr',ysarr);        model.add(tf.layers.conv2d({            inputShape: [224, 224 , 3],            kernelSize: 5,            filters: 8,            strides: 1,            activation: 'relu',            kernelInitializer: 'VarianceScaling'        }));        model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));        model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));        model.add(tf.layers.flatten({}));        model.add(tf.layers.dense({units: 64, activation: 'relu'}));        model.add(tf.layers.dense({units: 10, activation: 'softmax'}));        model.compile({            loss: 'meanSquaredError',            optimizer : 'sgd'        })        console.log('model:::'+model);        // Train the model using the data.        let tesnor_dim =[];        let tensr;        for(var j=0; j<2; j++){            console.log('resize_image',resize_image);            tensr = tf.expandDims(ysarr[j], 0);            tesnor_dim.push(tensr);            console.log('tesnor_dim',tesnor_dim);            console.log('before resize_image[j]',resize_image[j]);            console.log('before tesnor_dim[j]',tesnor_dim[j]);            await model.fit(resize_image[j], tesnor_dim[j], {epochs: 100}).then((loss) => {                console.log('resize_image.get[j]',resize_image[j]);                console.log('tesnor_dim[j]',tesnor_dim[j]);                console.log('loss',loss);                const t = model.predict(resize_image[j]);                console.log('Prediction:::'+t);                pred = t.argMax(1).dataSync(); // get the class of highest probability                const labelsPred = Array.from(pred).map(e => setLabel[e]);                console.log('labelsPred:::'+labelsPred);            }).catch((e) => {                console.log(e.message);            })            }                     const saveResults = model.save('downloads://my-model-1');                console.log(saveResults);               });

回答:

模型给出了错误的预测。该怎么办?

  • 检查模型的准确性。模型的准确性非常低,表明该模型要么不适合解决这个问题,要么需要更改一些参数。

  • 即使准确性很好,模型在预测特定类别时也可能出错。在这种情况下,混淆矩阵将有助于识别被错误预测的类别。一旦这些类别被识别出来,就可以使用更多这些类别的训练数据来提高训练后的准确性。


从问题的模型来看,很明显这是一个分类模型,即给定一张图像,模型将预测该图像所属的类别。

'meanSquaredError' 损失函数并不是分类问题的理想选择。categoricalCrossEntropy 能够获得最佳准确性。即使更改了损失函数,准确性可能仍然达不到预期。此时需要添加更多层,改变模型的其他参数。然后进行训练并比较准确性,这个过程会持续进行…

Related Posts

为什么我们在K-means聚类方法中使用kmeans.fit函数?

我在一个视频中使用K-means聚类技术,但我不明白为…

如何获取Keras中ImageDataGenerator的.flow_from_directory函数扫描的类名?

我想制作一个用户友好的GUI图像分类器,用户只需指向数…

如何查看每个词的tf-idf得分

我试图了解文档中每个词的tf-idf得分。然而,它只返…

如何修复 ‘ValueError: Found input variables with inconsistent numbers of samples: [32979, 21602]’?

我在制作一个用于情感分析的逻辑回归模型时遇到了这个问题…

如何向神经网络输入两个不同大小的输入?

我想向神经网络输入两个数据集。第一个数据集(元素)具有…

逻辑回归与机器学习有何关联

我们正在开会讨论聘请一位我们信任的顾问来做机器学习。一…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注