在TensorFlow.js中编写自定义的InstantLayerNormalization

我正在尝试在浏览器中实现一个深度学习模型,这需要移植一些自定义层,其中一个是即时层归一化。下面是一段应该能工作的代码,但它有点旧。我遇到了这个错误:

Uncaught (in promise) ReferenceError: initializer is not definedat InstantLayerNormalization.build

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script><script>class InstantLayerNormalization extends tf.layers.Layer{    static className = 'InstantLayerNormalization';    epsilon = 1e-7     gamma;    beta;    constructor(config)     {        super(config);    }    getConfig()     {        const config = super.getConfig();        return config;    }        build(input_shape)    {        let shape = tf.tensor(input_shape);        // initialize gamma        self.gamma = self.add_weight(shape=shape,                                      initializer='ones',                                      trainable=true,                                      name='gamma')        // initialize beta        self.beta = self.add_weight(shape=shape,                            initializer='zeros',                            trainable=true,                            name='beta')    }            call(inputs){        mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)        variance = tf.math.reduce_mean(tf.math.square(inputs - mean), axis=[-1], keepdims=True)        std = tf.math.sqrt(variance + self.epsilon)        outputs = (inputs - mean) / std        outputs = outputs * self.gamma        outputs = outputs + self.beta        return outputs    }    static get className() {        console.log(className);       return className;    }}tf.serialization.registerClass(InstantLayerNormalization);</script>

回答:

继承类tf.layers.Layer的方法没有被正确调用。

  • Python中的self在JavaScript中是this
  • add_weight在JavaScript中应为addWeight
  • 这里addWeight方法的签名。请注意,在JavaScript中,函数参数的解构赋值没有variable=value的格式
// 而不是这样self.gamma = self.add_weight(shape=shape, initializer='ones', trainable=true, name='gamma')// 应该是这样this.gamma = this.addWeight('gamma', shape, undefined, 'ones', undefined, true)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注