Keras是否忽略掩码值的标签?

我正在使用Keras实现一个LSTM模型。我对序列进行了填充以达到一定长度,以便正确地将数据集输入模型。

目前,我的模型如下:

model = tf.keras.Sequential()model.add(Masking(mask_value=0., input_shape=(timesteps, features)))model.add(LSTM(units=100, return_sequences=True, input_shape=(timesteps, features)))model.add(Dropout(0.5))model.add(Dense(1, activation='sigmoid'))

Keras是否会自动跳过损失函数中掩码值的标签?


回答:

是的,如果你的模型使用了掩码,那么目标函数(即损失函数)会自动调整以支持掩码,从而在计算损失时忽略被掩码的样本/时间步。实际上,weighted_masked_objective 是实现这一功能的后台函数:

def weighted_masked_objective(fn):    """为目标函数添加掩码和样本加权支持。    它将一个目标函数 `fn(y_true, y_pred)`    转换为带样本加权和成本掩码的目标函数    `fn(y_true, y_pred, weights, mask)`。    # 参数        fn: 要包装的目标函数,            其签名为 `fn(y_true, y_pred)`。    # 返回        签名为 `fn(y_true, y_pred, weights, mask)` 的函数。    """    if fn is None:        return None    def weighted(y_true, y_pred, weights, mask=None):        """包装函数。        # 参数            y_true: `fn` 的 `y_true` 参数。            y_pred: `fn` 的 `y_pred` 参数。            weights: 权重张量。            mask: 掩码张量。        # 返回            标量张量。        """        # score_array 的维度 >= 2        score_array = fn(y_true, y_pred)        if mask is not None:            # 将掩码转换为 floatX 以避免在 Theano 中进行 float64 上行转换            mask = K.cast(mask, K.floatx())            # 掩码应与 score_array 形状相同            score_array *= mask            # 每批次的损失应与未掩码样本的数量成比例            score_array /= K.mean(mask) + K.epsilon()        # 应用样本加权        if weights is not None:            # 将 score_array 降维至与权重数组相同的维度            ndim = K.ndim(score_array)            weight_ndim = K.ndim(weights)            score_array = K.mean(score_array,                                 axis=list(range(weight_ndim, ndim)))            score_array *= weights            score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))        return K.mean(score_array)    return weighted

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注