如何在tensorflow.js上加载/重新训练/保存ssd_inception_v2_coco模型?

机器学习/张量流初学者。

这些已经训练好的模型是否可以加载到tfjs上并在那里重新训练,然后导出到下载文件夹中,还是必须使用Tensorflow Python才能实现?

我看到这个过程在Tensorflow Python的教程中描述得很详细且文档齐全,但遗憾的是,我找不到任何关于在浏览器中使用tfjs重新训练对象检测模型的文档/教程(图像分类有,对象检测没有)。

我知道如何使用npm加载coco-ssd模型,然后可能还能触发将其保存到下载文件夹,但以下这些怎么办:

  • 配置文件(需要修改,因为我只想要一个类,而不是90个)
  • 带注释的图像(包括.jpg、.xml和.csv)
  • labels.pbtxt
  • .record文件

是否有任何方法可以重新训练像ssd_inception_v2_coco这样的ssd模型,或者我只是没有找到正确的谷歌关键词,或者在当前框架状态下这是不可能的?


回答:

您可以通过使用coco-ssd模型作为特征提取器来进行迁移学习。一个迁移学习的例子可以在这里看到。

这里有一个模型,使用特征提取器作为新顺序模型的输入来提取特征。

const loadModel = async () => {  const loadedModel = await tf.loadModel(MODEL_URL)  console.log(loadedModel)  // take whatever layer except last output  loadedModel.layers.forEach(layer => console.log(layer.name))  const layer = loadedModel.getLayer(LAYER_NAME)  return tf.model({ inputs: loadedModel.inputs, outputs: layer.output });}loadModel().then(featureExtractor => {  model = tf.sequential({    layers: [      // Flattens the input to a vector so we can use it in a dense layer. While      // technically a layer, this only performs a reshape (and has no training      // parameters).      // slice so as not to take the batch size      tf.layers.flatten(        { inputShape: featureExtractor.outputs[0].shape.slice(1) }),      // add all the layers of the model to train      tf.layers.dense({        units: UNITS,        activation: 'relu',        kernelInitializer: 'varianceScaling',        useBias: true      }),      // Last Layer. The number of units of the last layer should correspond      // to the number of classes to predict.      tf.layers.dense({        units: NUM_CLASSES,        kernelInitializer: 'varianceScaling',        useBias: false,        activation: 'softmax'      })    ]  });})

为了从coco-ssd的90个类别中检测单个对象,可以简单地对coco-ssd的预测结果进行条件测试。

const image = document.getElementById(id)cocoSsd.load()  .then(model => model.detect(image))  .then(prediction => {if (prediction.class === OBJECT_DETECTED) {  // display it the bbox to the user}})

如果类别在coco-ssd中不存在,则需要构建一个检测器。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注