我试图用一个封装在 WebAssembly 二进制文件中的自定义函数来覆盖 Keras/TF2.0 的损失函数。以下是相关代码。
@tf.functiondef custom_loss(y_true, y_pred): return tf.constant(instance.exports.loss(y_true, y_pred))
我使用的方式如下
model.compile(optimizer_obj, loss=custom_loss)# 然后 model.fit(...)
我对 TF2.0 的即时执行(eager execution)并不是完全了解,因此关于这方面的任何见解都会很有帮助。
我认为 instance.exports.loss 函数与此错误无关,但是如果您确定其他一切都正常,请告诉我,我会提供更多细节。
这是堆栈跟踪和实际错误:https://pastebin.com/6YtH75ay
回答:
首先,您无需使用 @tf.function
来定义自定义损失函数。
我们可以愉快地(虽然有点无意义地)做如下操作:
def custom_loss(y_true, y_pred): return tf.reduce_mean(y_pred)model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
只要我们在 custom_loss
内部使用的所有操作 对 TensorFlow 是可微分的
因此,您可以移除 @tf.function
装饰器,但之后我猜您会遇到类似如下的错误信息:
[一些跟踪信息] - 请确保您的所有操作都有定义的梯度(即是可微分的)。常见的没有梯度的操作:K.argmax, K.round, K.eval.
因为 TensorFlow 无法找到 WebAssembly 二进制文件中函数的梯度。损失函数内部的所有内容必须是 TensorFlow 可以理解并计算梯度的,否则它将无法优化以获得更低的损失值。
也许最好的前进方式是使用 TensorFlow 可以计算梯度的操作来复制 instance.exports.loss
内的功能,而不是直接引用它?
“`