我有两个类别:正类(1)和负类(0)。
数据集非常不平衡,因此目前我的小批次主要包含0。实际上,许多批次只会包含0。我想尝试对正例和负例设置不同的成本;请看下面的代码。
我的代码问题在于,由于bound_index列表为空,我得到了很多nan
。有什么优雅的方法来解决这个问题吗?
def calc_loss_debug(logits, labels): logits = tf.reshape(logits, [-1]) labels = tf.reshape(labels, [-1]) index_bound = tf.where(tf.equal(labels, tf.constant(1, dtype=tf.float32))) index_unbound = tf.where(tf.equal(labels, tf.constant(0, dtype=tf.float32))) entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels) entropies_bound = tf.gather(entropies, index_bound) entropies_unbound = tf.gather(entropies, index_unbound) loss_bound = tf.reduce_mean(entropies_bound) loss_unbound = tf.reduce_mean(entropies_unbound)
回答:
由于你有0和1的标签,你可以轻松避免使用tf.where
,使用如下结构:
labels = ...entropies = ...labels_complement = tf.constant(1.0, dtype=tf.float32) - labelsentropy_ones = tf.reduce_sum(tf.mul(labels, entropies))entropy_zeros = tf.reduce_sum(tf.mul(labels_complement, entropies))
要获得平均损失,你需要除以批次中0和1的数量,这可以很容易地计算为:
num_ones = tf.reduce_sum(labels)num_zeros = tf.reduce_sum(labels_complement)
当然,当批次中没有1时,你仍然需要防止除以0的情况。我建议使用tf.cond(tf.equal(num_ones, 0), ...)
来处理。