在TensorFlow中关闭服务器

当我们想要使用分布式TensorFlow时,我们会创建一个参数服务器,方法如下:

tf.train.Server.join()

然而,除了终止进程之外,我找不到任何关闭服务器的方法。TensorFlow文档中对join()的描述是

Blocks until the server has shut down.This method currently blocks forever.

这对我来说非常令人困扰,因为我希望创建多个服务器进行计算,并在一切完成后关闭它们。

有没有可能的解决方案?

谢谢


回答:

您可以通过使用session.run(dequeue_op)代替server.join(),并在希望该进程终止时让另一个进程向该队列中添加内容,来按需终止参数服务器进程。

因此,对于k个参数服务器分片,您可以创建k个队列,每个队列具有唯一的shared_name属性,并尝试从该队列中dequeue。当您想要关闭服务器时,您可以遍历所有队列,并向每个队列中enqueue一个标记。这将导致session.run解除阻塞,Python进程将运行到末尾并退出,从而关闭服务器。

下面是一个包含2个分片的自包含示例,摘自:https://gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e

(有关多工作者/多分片的示例,请参见 https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca

import subprocessimport tensorflow as tfimport timeimport sysflags = tf.flagsflags.DEFINE_string("port1", "12222", "port of worker1")flags.DEFINE_string("port2", "12223", "port of worker2")flags.DEFINE_string("task", "", "internal use")FLAGS = flags.FLAGS# setup local cluster from flagshost = "127.0.0.1:"cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]}clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()if __name__=='__main__':  if not FLAGS.task:  # start servers and run client      # launch distributed service      def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT)      runcmd("python %s --task=0"%(sys.argv[0]))      runcmd("python %s --task=1"%(sys.argv[0]))      time.sleep(1)      # bring down distributed service      sess = tf.Session("grpc://"+host+FLAGS.port1)      queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0")      queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1")      with tf.device("/job:worker/task:0"):          add_op0 = tf.add(tf.ones(()), tf.ones(()))      with tf.device("/job:worker/task:1"):          add_op1 = tf.add(tf.ones(()), tf.ones(()))      print("Running computation on server 0")      print(sess.run(add_op0))      print("Running computation on server 1")      print(sess.run(add_op1))      print("Bringing down server 0")      sess.run(queue0.enqueue(1))      print("Bringing down server 1")      sess.run(queue1.enqueue(1))  else: # Launch TensorFlow server    server = tf.train.Server(clusterspec, config=None,                             job_name="worker",                             task_index=int(FLAGS.task))    print("Starting server "+FLAGS.task)    sess = tf.Session(server.target)    queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task)    sess.run(queue.dequeue())    print("Terminating server"+FLAGS.task)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注