我在Python中使用模型训练了一个模型
reg = 0.000001model = Sequential()model.add(Dense(24, activation='tanh', name='input_dense', input_shape=input_shape))model.add(GRU(24, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), reset_after=False))model.add(Flatten())model.add(Dense(2, activation='softmax'))
但是当我使用“tensorflowjs_converter –input_format keras”转换这个模型并在浏览器中加载时,出现了错误
未处理的拒绝(错误):未知正则化器:L2。这可能是由于以下原因之一:
- 正则化器是在Python中定义的,在这种情况下,需要将其移植到TensorFlow.js或您的JavaScript代码中。
- 自定义正则化器是在JavaScript中定义的,但未使用tf.serialization.registerClass()正确注册。
model.json文件的内容是
{ "format": "layers-model", "generatedBy": "keras v2.4.0", "convertedBy": "TensorFlow.js Converter v2.3.0", "modelTopology": { "keras_version": "2.4.0", "backend": "tensorflow", "model_config": { "class_name": "Sequential", "config": { "name": "sequential", "layers": [ { "class_name": "InputLayer", "config": { "batch_input_shape": [null, 22, 13], "dtype": "float32", "sparse": false, "ragged": false, "name": "input_dense_input" } }, { "class_name": "Dense", "config": { "name": "input_dense", "trainable": true, "batch_input_shape": [null, 22, 13], "dtype": "float32", "units": 24, "activation": "tanh", "use_bias": true, "kernel_initializer": { "class_name": "GlorotUniform", "config": { "seed": null } }, "bias_initializer": { "class_name": "Zeros", "config": {} }, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null } }, { "class_name": "GRU", "config": { "name": "gru", "trainable": true, "dtype": "float32", "return_sequences": true, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "time_major": false, "units": 24, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": { "class_name": "GlorotUniform", "config": { "seed": null } }, "recurrent_initializer": { "class_name": "Orthogonal", "config": { "gain": 1.0, "seed": null } }, "bias_initializer": { "class_name": "Zeros", "config": {} }, "kernel_regularizer": { "class_name": "L2", "config": { "l2": 9.999999974752427e-7 } }, "recurrent_regularizer": { "class_name": "L2", "config": { "l2": 9.999999974752427e-7 } }, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2, "reset_after": false } }, { "class_name": "Flatten", "config": { "name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last" } }, { "class_name": "Dense", "config": { "name": "dense", "trainable": true, "dtype": "float32", "units": 2, "activation": "softmax", "use_bias": true, "kernel_initializer": { "class_name": "GlorotUniform", "config": { "seed": null } }, "bias_initializer": { "class_name": "Zeros", "config": {} }, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null } } ] } }, "training_config": { "loss": "categorical_crossentropy", "metrics": ["accuracy"], "weighted_metrics": null, "loss_weights": null, "optimizer_config": { "class_name": "Nadam", "config": { "name": "Nadam", "learning_rate": 0.0020000000949949026, "decay": 0.004000000189989805, "beta_1": 0.8999999761581421, "beta_2": 0.9990000128746033, "epsilon": 1e-7 } } } }, "weightsManifest": [ { "paths": ["group1-shard1of1.bin"], "weights": [ { "name": "dense/kernel", "shape": [528, 2], "dtype": "float32" }, { "name": "dense/bias", "shape": [2], "dtype": "float32" }, { "name": "gru/gru_cell/kernel", "shape": [24, 72], "dtype": "float32" }, { "name": "gru/gru_cell/recurrent_kernel", "shape": [24, 72], "dtype": "float32" }, { "name": "gru/gru_cell/bias", "shape": [72], "dtype": "float32" }, { "name": "input_dense/kernel", "shape": [13, 24], "dtype": "float32" }, { "name": "input_dense/bias", "shape": [24], "dtype": "float32" } ] } ]}
回答:
选项1
没有L1
和L2
类;它们只是接口(更多信息在此)
有一个L1L2
类,它将接收配置并返回正确的正则化器。您可以手动将所有L2
替换为L1L2
。
选项2
注册一个L2类
class L2 { static className = 'L2'; constructor(config) { return tf.regularizers.l1l2(config) }}tf.serialization.registerClass(L2);// 现在加载模型