在Tensorflow的DNN分类器估计器函数中,如何使用加权交叉熵损失函数?

我目前正在处理一个二项分类算法,数据极度倾斜(90%负样本/10%正样本),使用tf.estimator.DNNClassifier。由于我训练的所有模型都倾向于将所有样本标记为负样本,我需要实现一个加权损失函数。

我查看了许多不同的问题,其中许多都很有启发性。然而,我无法得到一个实际的端到端答案,关于如何真正实现这些函数。这个这个讨论是最好的。

我的问题是:我想使用tf.nn.weighted_cross_entropy_with_logits(),但我不知道应该在代码的哪里插入它。

我有一个构建特征列的函数:

def construct_feature_columns(input_features):  return set([tf.feature_column.numeric_column(my_feature)              for my_feature in input_features])

一个定义tf.estimator.DNNClassifier以及其他参数的函数,如优化器和输入函数:

def train_nn_classifier_model(    learning_rate,    steps,    batch_size,    hidden_units,    training_examples,    training_targets,    validation_examples,    validation_targets):    dnn_classifier = tf.estimator.DNNClassifier(        feature_columns=construct_feature_columns(training_examples),        hidden_units=hidden_units,        optimizer=my_optimizer)

训练函数:

dnn_classifier.train(input_fn=training_input_fn, steps=steps_per_period)

预测函数,用于在训练时计算误差:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

优化器:

  my_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)  my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0)

输入函数(用于训练输入、预测训练输入和验证输入):

  training_input_fn = lambda: my_input_fn(      training_examples,       training_targets['True/False'],       batch_size=batch_size)

我应该在哪里插入tf.nn.weighted_cross_entropy_with_logits,以便我的模型使用这个函数计算损失?

另外,如何在交叉熵函数中调用targets (A Tensor of the same type and shape as logits)?它是training_targets数据框,还是input function的输出,其中包含training_targets作为输入?

具体来说,logits是什么?因为对我来说,它们应该是来自函数的预测:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

但这对我来说似乎不太合理。我尝试了许多不同的方法来实现它,但没有一个有效。


回答:

我很遗憾地告诉您一个坏消息,DNN分类器不支持自定义损失函数:

损失是通过使用softmax交叉熵计算的。

这是文档中唯一提到损失(函数)的部分,我找不到任何帖子讨论直接更改DNNClassifier的有效解决方案。相反,看起来您需要构建自己的自定义估计器

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注