如何向双向LSTM模型中添加注意力层

我正在开发一个双向LSTM模型,并希望向其中添加一个注意力层。但我不知道如何添加它。

我当前的模型代码是

model = Sequential()model.add(Embedding(max_words, 1152, input_length=max_len, weights=[embeddings]))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(Dropout(0.5))model.add(Bidirectional(LSTM(32)))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(Dropout(0.5))model.add(Dense(1, activation='sigmoid'))model.summary()

模型摘要是

Model: "sequential_1"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================embedding_1 (Embedding)      (None, 1152, 1152)        278396928 _________________________________________________________________batch_normalization_1 (Batch (None, 1152, 1152)        4608      _________________________________________________________________activation_1 (Activation)    (None, 1152, 1152)        0         _________________________________________________________________dropout_1 (Dropout)          (None, 1152, 1152)        0         _________________________________________________________________bidirectional_1 (Bidirection (None, 64)                303360    _________________________________________________________________batch_normalization_2 (Batch (None, 64)                256       _________________________________________________________________activation_2 (Activation)    (None, 64)                0         _________________________________________________________________dropout_2 (Dropout)          (None, 64)                0         _________________________________________________________________dense_1 (Dense)              (None, 1)                 65        =================================================================Total params: 278,705,217Trainable params: 278,702,785Non-trainable params: 2,432

回答:

这可以是一个自定义解决方案,通过自定义层在位置/时间维度上计算注意力

from tensorflow.keras.layers import Layerfrom tensorflow.keras import backend as Kclass Attention(Layer):        def __init__(self, return_sequences=True):        self.return_sequences = return_sequences        super(Attention,self).__init__()            def build(self, input_shape):                self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),                               initializer="normal")        self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),                               initializer="zeros")                super(Attention,self).build(input_shape)            def call(self, x):                e = K.tanh(K.dot(x,self.W)+self.b)        a = K.softmax(e, axis=1)        output = x*a                if self.return_sequences:            return output                return K.sum(output, axis=1)

它被设计为接收3D张量并输出3D张量(return_sequences=True)或2D张量(return_sequences=False)。下面是一个示例

# 生成虚拟数据max_len = 100max_words = 333emb_dim = 126n_sample = 5X = np.random.randint(0,max_words, (n_sample,max_len))Y = np.random.randint(0,2, n_sample)

使用return_sequences=True

model = Sequential()model.add(Embedding(max_words, emb_dim, input_length=max_len))model.add(Bidirectional(LSTM(32, return_sequences=True)))model.add(Attention(return_sequences=True)) # 接收3D并输出3Dmodel.add(LSTM(32))model.add(Dense(1, activation='sigmoid'))model.summary()model.compile('adam', 'binary_crossentropy')model.fit(X,Y, epochs=3)

使用return_sequences=False

model = Sequential()model.add(Embedding(max_words, emb_dim, input_length=max_len))model.add(Bidirectional(LSTM(32, return_sequences=True)))model.add(Attention(return_sequences=False)) # 接收3D并输出2Dmodel.add(Dense(1, activation='sigmoid'))model.summary()model.compile('adam', 'binary_crossentropy')model.fit(X,Y, epochs=3)

您可以轻松地将其集成到您的网络中

这里是运行的笔记本

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注