使用人类输入增强预训练模型以更好地确定图像相似性

假设我想构建一个工具,能够在0到1的范围内告诉我两张图像之间的相似度。

enter image description here

我可以下载预训练的图像模型,比如vgg16,移除其最后一层(分类器),使用预训练模型将图像转换为特征向量,最后使用余弦相似度等方法计算向量之间的相似度。

enter image description here

该方法在这里有更详细的解释:https://mc.ai/find-more-like-this-product-using-transfer-learning/

… 但如果我想使用自己的数据集来增强这个工具呢?更具体地说,我有一个数据集,其中人类判断两张图片是相似的(1)还是不相似的(0):

enter image description here

我想知道,您是否有办法将此整合进来以增强相似度评分工具?非常感谢。:)


回答:

您可以使用基于VGG-16的孪生网络。该模型将使用您标记的图像进行训练,相似的图像将以1的相似度进行训练,不同的图像将以0的相似度进行训练。

可以使用三元组损失来最小化相同对象之间的差异性。

这里有一个在keras中的实现:https://medium.com/@prabhnoor0212/siamese-network-keras-31a3a8f37d04

使用您的数据集进行训练的过程:

假设您有一个简单的分类数据集。因此,您有一个苹果类,一个猫类,等等。

  1. 首先,您需要确保同一类别中的图像非常相似。

(如果您将汽车和卡车放在同一类别中,并期望它们有1的相似度分数,这将使您的网络表现不佳)

  1. 一旦您有了数据集,您只需生成图像对,如果两张图像来自同一类别,标签将为1。如果不是,标签将为0。

  2. 现在,只需使用三元组损失和VGG孪生网络训练模型即可。

您可以修改代码示例,甚至可以找到许多共享代码的资源。

enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注