在Tensorflow的DNN分类器估计器函数中,如何使用加权交叉熵损失函数?

我目前正在处理一个二项分类算法,数据极度倾斜(90%负样本/10%正样本),使用tf.estimator.DNNClassifier。由于我训练的所有模型都倾向于将所有样本标记为负样本,我需要实现一个加权损失函数。

我查看了许多不同的问题,其中许多都很有启发性。然而,我无法得到一个实际的端到端答案,关于如何真正实现这些函数。这个这个讨论是最好的。

我的问题是:我想使用tf.nn.weighted_cross_entropy_with_logits(),但我不知道应该在代码的哪里插入它。

我有一个构建特征列的函数:

def construct_feature_columns(input_features):  return set([tf.feature_column.numeric_column(my_feature)              for my_feature in input_features])

一个定义tf.estimator.DNNClassifier以及其他参数的函数,如优化器和输入函数:

def train_nn_classifier_model(    learning_rate,    steps,    batch_size,    hidden_units,    training_examples,    training_targets,    validation_examples,    validation_targets):    dnn_classifier = tf.estimator.DNNClassifier(        feature_columns=construct_feature_columns(training_examples),        hidden_units=hidden_units,        optimizer=my_optimizer)

训练函数:

dnn_classifier.train(input_fn=training_input_fn, steps=steps_per_period)

预测函数,用于在训练时计算误差:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

优化器:

  my_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)  my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0)

输入函数(用于训练输入、预测训练输入和验证输入):

  training_input_fn = lambda: my_input_fn(      training_examples,       training_targets['True/False'],       batch_size=batch_size)

我应该在哪里插入tf.nn.weighted_cross_entropy_with_logits,以便我的模型使用这个函数计算损失?

另外,如何在交叉熵函数中调用targets (A Tensor of the same type and shape as logits)?它是training_targets数据框,还是input function的输出,其中包含training_targets作为输入?

具体来说,logits是什么?因为对我来说,它们应该是来自函数的预测:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

但这对我来说似乎不太合理。我尝试了许多不同的方法来实现它,但没有一个有效。


回答:

我很遗憾地告诉您一个坏消息,DNN分类器不支持自定义损失函数:

损失是通过使用softmax交叉熵计算的。

这是文档中唯一提到损失(函数)的部分,我找不到任何帖子讨论直接更改DNNClassifier的有效解决方案。相反,看起来您需要构建自己的自定义估计器

Related Posts

为什么我们在K-means聚类方法中使用kmeans.fit函数?

我在一个视频中使用K-means聚类技术,但我不明白为…

如何获取Keras中ImageDataGenerator的.flow_from_directory函数扫描的类名?

我想制作一个用户友好的GUI图像分类器,用户只需指向数…

如何查看每个词的tf-idf得分

我试图了解文档中每个词的tf-idf得分。然而,它只返…

如何修复 ‘ValueError: Found input variables with inconsistent numbers of samples: [32979, 21602]’?

我在制作一个用于情感分析的逻辑回归模型时遇到了这个问题…

如何向神经网络输入两个不同大小的输入?

我想向神经网络输入两个数据集。第一个数据集(元素)具有…

逻辑回归与机器学习有何关联

我们正在开会讨论聘请一位我们信任的顾问来做机器学习。一…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注