在Tensorflow中高效计算成对排名损失函数

我目前正在Tensorflow中实现http://www.aclweb.org/anthology/P15-1061中的方法。

我按照论文的第2.5节实现了成对排名损失函数,如下所示:

s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index)s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size])s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])

我不得不使用tf.gather而不是tf.gather_nd,因为后者尚未实现梯度下降。我还必须将所有索引转换为与扁平矩阵匹配的正确索引。

如果tf.gather_nd实现了梯度下降,我的代码将如下所示:

s_theta_y = tf.gather_nd(s_theta, y_t_index)s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index)s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])

s_theta是论文中为每个类别标签计算的得分。y_true_index包含真实类别的索引,用于计算s_theta_y。y_neg_index是所有负类别的索引,其维度为#class-1或#class,这取决于关系是否被分类为其他类别。

然而,一些句子被分类为其他类别,因此s_theta_y不存在,我们不应在计算中考虑它。为了处理这种情况,我使用了一个常数因子0来取消该项,并且为了保持负类别的相同维度向量,我只是复制了一个随机的索引值,因为最终我们只关心所有负类别中的最大值(而不是索引)。

有没有更有效的方法来计算损失函数中的这些项?我觉得使用tf.gather并进行多次reshape操作非常慢。


回答:

显然,gather_nd是你想要的,但在那之前没有实现梯度的情况下,我不会犹豫使用你的reshape()解决方案,因为reshape()几乎是免费的。

reshape()操作的C++实现看起来像是在做很多工作,但实际上它只是在对形状信息进行快速错误检查。第90行的CopyFrom似乎可能很昂贵,但实际上它只是一个指针复制(CopyFrom调用了CopyFromInternal,它复制了指针)。

这完全合理:底层缓冲区只是一个以行主序排列的数字平面数组,这种排序不依赖于形状信息。出于同样的原因,像tf.transpose()这样的操作通常确实需要复制数据。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注