我有这个函数:
def sampling(x): zeros = x*0 samples = tf.random.categorical(tf.math.log(x), 1) samples = tf.squeeze(tf.one_hot(samples, depth=2), axis=1) return zeros+samples
我从这个层调用它:
x = layers.Lambda(sampling, name="lambda")(x)
但我需要更改sampling函数中的depth变量,所以我需要这样的东西:
def sampling(x, depth):
但是,如何让它与Lambda层一起工作呢?
非常感谢
回答:
在Lambda层中使用lambda函数…
def sampling(x, depth): zeros = x*0 samples = tf.random.categorical(tf.math.log(x), 1) samples = tf.squeeze(tf.one_hot(samples, depth=depth), axis=1) return zeros+samples
使用方法:
Lambda(lambda t: sampling(t, depth=3), name="lambda")(x)