Tensorflow 2.0是否线程安全?
更具体地说,在Tensorflow 2.0中(使用Keras API),从不同线程调用同一个模型的fit
/predict
或其他方法是否安全?
我在文档中或在线查找时没有找到明确的答案。
我看到这个问题是2017年提出的,说Keras(虽然问题提到的是Theano后端)是线程安全的,但你必须在调用predict()
之前调用私有方法model._make_predict_function()
(我认为这已经废弃)。然而,我阅读了这篇博客文章,这是2019年的,说它不是线程安全的。
我还找到了这个问题,是2018年提出的,说Tensorflow(Keras之前的版本)是线程安全的,但你必须确保显式使用默认图(我认为这对Tensorflow 2.*无关紧要,因为有急切执行)。当我查找急切执行中的线程安全性时,我看到了这篇文章,文档中提到了急切执行的线程安全性,但它是关于Java的。
为了让事情更加混乱,我在Github上看到今年(2020年)的一个使用Keras的A3C实现,在训练共享的策略/值网络之前使用了锁,暗示Keras不是线程安全的,你必须在训练共享模型之前获取锁。然而,在我看来,这个实现是有缺陷的,因为每个工作者都在创建和使用它自己的独特锁,这违背了拥有锁的目的。我的结论是,要么他的代码不顾“锁”而成功运行,因为Keras是线程安全的,要么他有一个错误。
我做了自己的最终测试,我运行了两个线程,让同一个模型适应不同的输出(对于相同的常量输入),并在训练期间尝试调用predict,看起来是工作的,但我问这个问题是因为我想确保。Tensorflow 2.0/Keras是否有任何不线程安全的情况?
回答:
根据Keras贡献者在这个 GitHub问题中的说法:
Keras模型不能保证是线程安全的。考虑在每个线程中拥有模型的独立副本以进行CPU推理。