### 在Keras中根据阈值将中间层的输出设置为0或1

我有一个模型,其中包含“分类”和“回归”类似的部分。我使用乘法层将它们合并。在执行乘法之前,我想根据阈值将分类部分的输出设置为0或1。我尝试使用带有自定义函数的Lambda层,如下所示,但遇到了各种错误,我对这些错误一无所知。逐一解决它们并不能增加我的理解。谁能解释如何定义一个自定义的Lambda层函数来修改这些值?

我当前的Lambda层函数:(由于FailedPreconditionError: Attempting to use uninitialized value lstm_32/bias而无法工作)

def func(x):        a = x.eval(session=tf.Session())    a[x < 0.5] = 0    a[x >= 0.5] = 1    return K.variable(a)

回归部分:

input1 = Input(shape=(1, ))model = Sequential()model.add(Embedding(vocab_size + 1, embedding, input_length=1))model.add(LSTM(hidden, recurrent_dropout=0.1, return_sequences=True))model.add(LSTM(6))model.add(Reshape((3,2)))model.add(Activation('linear'))

分类部分:

input2 = Input(shape=(1, ))model2 = Sequential()model2.add(Embedding(vocab_size + 1, embedding, input_length=1))model2.add(LSTM(hidden, recurrent_dropout=0.1, return_sequences=True))model2.add(LSTM(1))model2.add(Activation('sigmoid'))model2.add(???)  # 需要在这里添加0-1阈值处理

合并两个部分:

reg_head = model(input1)clf_head = model2(input2)    merge_head = multiply(inputs=[clf_head, reg_head])m2 = Model(inputs=[input1, input2], outputs=merge_head)

回答:

func中,你不能对张量进行eval操作。

使用张量的理念是它们在整个模型中从头到尾保持“连接”(他们称之为图)。这种连接允许模型计算梯度。如果你对张量进行评估并尝试使用这些值,你会破坏这种连接。

此外,要获取张量的实际值,你需要输入数据。只有在你调用fitpredict等类似方法时,输入数据才存在。在构建阶段没有数据,只有表示和连接。

一个仅使用张量的可能函数是:

def func(x):    greater = K.greater_equal(x, 0.5) #将返回布尔值    greater = K.cast(greater, dtype=K.floatx()) #将布尔值转换为0和1        return greater 

但要小心!这将不可微分。从现在开始,这些值将在模型中被视为常量。这意味着在此之前的权重在训练期间不会被更新(你将无法通过m2训练分类模型,但你仍然可以通过model2训练它)。如果需要的话,有一些高级的解决方法,请在评论中说明。

Lambda层中使用这个函数:

model.add(Lambda(func, output_shape=yourOutputShapeIfUsingTheano))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注