例如,模型A接收一张图像作为输入并对其进行分类。假设它的输出维度为3,例如(0, 1, 0)。根据输出结果,图像会被传递到模型X、Y或Z中。实际上,我在寻找一个能够执行这种条件分支的层。
回答:
在使用Tensorflow时,tf.case()
可能是你需要的,查看文档获取更多信息。
这里有一个例子,假设默认使用模型Z,并且假设这些条件是互斥的:
condition_for_X = conditionX(output_A) # e.g. tf.less(tf.reduce_sum(output_A), -1.) condition_for_Y = conditionX(output_A) # e.g. tf.greater_equal(tf.reduce_sum(output_A), 1.) def use_X(): return model_X(output_A)def use_Y(): return model_Y(output_A)def use_Z(): return model_Z(output_A)result = tf.case({condition_for_X: use_X, condition_for_Y : use_Y}, default=use_Z, exclusive=True)