我目前正在Tensorflow中实现http://www.aclweb.org/anthology/P15-1061中的方法。
我按照论文的第2.5节实现了成对排名损失函数,如下所示:
s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index)s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size])s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])
我不得不使用tf.gather而不是tf.gather_nd,因为后者尚未实现梯度下降。我还必须将所有索引转换为与扁平矩阵匹配的正确索引。
如果tf.gather_nd实现了梯度下降,我的代码将如下所示:
s_theta_y = tf.gather_nd(s_theta, y_t_index)s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index)s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])
s_theta是论文中为每个类别标签计算的得分。y_true_index包含真实类别的索引,用于计算s_theta_y。y_neg_index是所有负类别的索引,其维度为#class-1或#class,这取决于关系是否被分类为其他类别。
然而,一些句子被分类为其他类别,因此s_theta_y不存在,我们不应在计算中考虑它。为了处理这种情况,我使用了一个常数因子0来取消该项,并且为了保持负类别的相同维度向量,我只是复制了一个随机的索引值,因为最终我们只关心所有负类别中的最大值(而不是索引)。
有没有更有效的方法来计算损失函数中的这些项?我觉得使用tf.gather并进行多次reshape操作非常慢。
回答:
显然,gather_nd是你想要的,但在那之前没有实现梯度的情况下,我不会犹豫使用你的reshape()解决方案,因为reshape()几乎是免费的。
reshape()操作的C++实现看起来像是在做很多工作,但实际上它只是在对形状信息进行快速错误检查。第90行的CopyFrom似乎可能很昂贵,但实际上它只是一个指针复制(CopyFrom调用了CopyFromInternal,它复制了指针)。
这完全合理:底层缓冲区只是一个以行主序排列的数字平面数组,这种排序不依赖于形状信息。出于同样的原因,像tf.transpose()这样的操作通常确实需要复制数据。