Tensorflow 2.0与Keras API是否线程安全?

Tensorflow 2.0是否线程安全?

更具体地说,在Tensorflow 2.0中(使用Keras API),从不同线程调用同一个模型的fit/predict或其他方法是否安全?

我在文档中或在线查找时没有找到明确的答案。

我看到这个问题是2017年提出的,说Keras(虽然问题提到的是Theano后端)是线程安全的,但你必须在调用predict()之前调用私有方法model._make_predict_function()(我认为这已经废弃)。然而,我阅读了这篇博客文章,这是2019年的,说它不是线程安全的。

我还找到了这个问题,是2018年提出的,说Tensorflow(Keras之前的版本)是线程安全的,但你必须确保显式使用默认图(我认为这对Tensorflow 2.*无关紧要,因为有急切执行)。当我查找急切执行中的线程安全性时,我看到了这篇文章,文档中提到了急切执行的线程安全性,但它是关于Java的。

为了让事情更加混乱,我在Github上看到今年(2020年)的一个使用Keras的A3C实现,在训练共享的策略/值网络之前使用了锁,暗示Keras不是线程安全的,你必须在训练共享模型之前获取锁。然而,在我看来,这个实现是有缺陷的,因为每个工作者都在创建和使用它自己的独特锁,这违背了拥有锁的目的。我的结论是,要么他的代码不顾“锁”而成功运行,因为Keras是线程安全的,要么他有一个错误。

我做了自己的最终测试,我运行了两个线程,让同一个模型适应不同的输出(对于相同的常量输入),并在训练期间尝试调用predict,看起来是工作的,但我问这个问题是因为我想确保。Tensorflow 2.0/Keras是否有任何不线程安全的情况?


回答:

根据Keras贡献者在这个 GitHub问题中的说法:

Keras模型不能保证是线程安全的。考虑在每个线程中拥有模型的独立副本以进行CPU推理。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注