我正在尝试使用 tensorflowjs 在 JavaScript 中导入一个 CNN 分类器并进行一些预测。有没有类似于 Keras 的 predict_classes 方法,可以直接给我一个数字来表示模型预测的类别?我可以使用普通的 model.predict 方法,但是返回的是一个张量,我很难遍历它来找到最高的值。
回答:
没有像 predict_classes
这样的方法。但是你可以按照以下方式操作。
const predictClasses = model.predict(input);// 假设你的 predictClasses 看起来像这样 [1,2,3]const yourClass = predictClasses.argMax(-1).dataSync()[0]
API 文档请参考。