假设我们有一个使用顺序API编写的模型:
config = { 'learning_rate': 0.001, 'lstm_neurons':32, 'lstm_activation':'tanh', 'dropout_rate': 0.08, 'batch_size': 128, 'dense_layers':[ {'neurons': 32, 'activation': 'relu'}, {'neurons': 32, 'activation': 'relu'}, ]}def get_model(num_features, output_size): opt = Adam(learning_rate=0.001) model = Sequential() model.add(Input(shape=[None,num_features], dtype=tf.float32, ragged=True)) model.add(LSTM(config['lstm_neurons'], activation=config['lstm_activation'])) model.add(BatchNormalization()) if 'dropout_rate' in config: model.add(Dropout(config['dropout_rate'])) for layer in config['dense_layers']: model.add(Dense(layer['neurons'], activation=layer['activation'])) model.add(BatchNormalization()) if 'dropout_rate' in layer: model.add(Dropout(layer['dropout_rate'])) model.add(Dense(output_size, activation='sigmoid')) model.compile(loss='mse', optimizer=opt, metrics=['mse']) print(model.summary()) return model
当使用分布式训练框架时,我需要将语法转换为使用模型子类化。我查看了文档,但无法弄清楚如何操作。
回答:
这是一个等价的子类化实现。虽然我没有测试过。
# Subclassed API Model class MySubClassed(tf.keras.Model): def __init__(self, output_size): super(MySubClassed, self).__init__() self.lstm = tf.keras.layers.LSTM(config['lstm_neurons'], activation=config['lstm_activation']) self.bn = tf.keras.layers.BatchNormalization() if 'dropout_rate' in config: self.dp1 = tf.keras.layers.Dropout(config['dropout_rate']) self.dp2 = tf.keras.layers.Dropout(config['dropout_rate']) self.dp3 = tf.keras.layers.Dropout(config['dropout_rate']) for layer in config['dense_layers']: self.dense1 = tf.keras.layers.Dense(layer['neurons'], activation=layer['activation']) self.bn1 = tf.keras.layers.BatchNormalization() self.dense2 = tf.keras.layers.Dense(layer['neurons'], activation=layer['activation']) self.bn2 = tf.keras.layers.BatchNormalization() self.out = tf.keras.layers.Dense(output_size, activation='sigmoid') def call(self, inputs, training=True, **kwargs): x = self.lstm(inputs) x = self.bn(x) if 'dropout_rate' in config: x = self.dp1(x) x = self.dense1(x) x = self.bn1(x) if 'dropout_rate' in config: x = self.dp2(x) x = self.dense2(x) x = self.bn2(x) if 'dropout_rate' in config: x = self.dp3(x) return self.out(x) # A convenient way to get model summary # and plot in subclassed api def build_graph(self, raw_shape): x = tf.keras.layers.Input(shape=(None, raw_shape), ragged=True) return tf.keras.Model(inputs=[x], outputs=self.call(x))
构建并编译模型
s = MySubClassed(output_size=1) s.compile( loss = 'mse', metrics = ['mse'], optimizer = tf.keras.optimizers.Adam(learning_rate=0.001))
传递一些张量以创建权重(检查)。
raw_input = (16, 16, 16)y = s(tf.ones(shape=(raw_input))) print("weights:", len(s.weights))print("trainable weights:", len(s.trainable_weights))weights: 21trainable weights: 15
总结和图示
总结并可视化模型图形。
s.build_graph(16).summary()Model: "model"_________________________________________________________________Layer (type) Output Shape Param # =================================================================input_1 (InputLayer) [(None, None, 16)] 0 _________________________________________________________________lstm (LSTM) (None, 32) 6272 _________________________________________________________________batch_normalization (BatchNo (None, 32) 128 _________________________________________________________________dropout (Dropout) (None, 32) 0 _________________________________________________________________dense_2 (Dense) (None, 32) 1056 _________________________________________________________________batch_normalization_3 (Batch (None, 32) 128 _________________________________________________________________dropout_1 (Dropout) (None, 32) 0 _________________________________________________________________dense_3 (Dense) (None, 32) 1056 _________________________________________________________________batch_normalization_4 (Batch (None, 32) 128 _________________________________________________________________dropout_2 (Dropout) (None, 32) 0 _________________________________________________________________dense_4 (Dense) (None, 1) 33 =================================================================Total params: 8,801Trainable params: 8,609Non-trainable params: 192