当我们想要使用分布式TensorFlow时,我们会创建一个参数服务器,方法如下:
tf.train.Server.join()
然而,除了终止进程之外,我找不到任何关闭服务器的方法。TensorFlow文档中对join()的描述是
Blocks until the server has shut down.This method currently blocks forever.
这对我来说非常令人困扰,因为我希望创建多个服务器进行计算,并在一切完成后关闭它们。
有没有可能的解决方案?
谢谢
回答:
您可以通过使用session.run(dequeue_op)
代替server.join()
,并在希望该进程终止时让另一个进程向该队列中添加内容,来按需终止参数服务器进程。
因此,对于k
个参数服务器分片,您可以创建k
个队列,每个队列具有唯一的shared_name
属性,并尝试从该队列中dequeue
。当您想要关闭服务器时,您可以遍历所有队列,并向每个队列中enqueue
一个标记。这将导致session.run
解除阻塞,Python进程将运行到末尾并退出,从而关闭服务器。
下面是一个包含2个分片的自包含示例,摘自:https://gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e
(有关多工作者/多分片的示例,请参见 https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca)
import subprocessimport tensorflow as tfimport timeimport sysflags = tf.flagsflags.DEFINE_string("port1", "12222", "port of worker1")flags.DEFINE_string("port2", "12223", "port of worker2")flags.DEFINE_string("task", "", "internal use")FLAGS = flags.FLAGS# setup local cluster from flagshost = "127.0.0.1:"cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]}clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()if __name__=='__main__': if not FLAGS.task: # start servers and run client # launch distributed service def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT) runcmd("python %s --task=0"%(sys.argv[0])) runcmd("python %s --task=1"%(sys.argv[0])) time.sleep(1) # bring down distributed service sess = tf.Session("grpc://"+host+FLAGS.port1) queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0") queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1") with tf.device("/job:worker/task:0"): add_op0 = tf.add(tf.ones(()), tf.ones(())) with tf.device("/job:worker/task:1"): add_op1 = tf.add(tf.ones(()), tf.ones(())) print("Running computation on server 0") print(sess.run(add_op0)) print("Running computation on server 1") print(sess.run(add_op1)) print("Bringing down server 0") sess.run(queue0.enqueue(1)) print("Bringing down server 1") sess.run(queue1.enqueue(1)) else: # Launch TensorFlow server server = tf.train.Server(clusterspec, config=None, job_name="worker", task_index=int(FLAGS.task)) print("Starting server "+FLAGS.task) sess = tf.Session(server.target) queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task) sess.run(queue.dequeue()) print("Terminating server"+FLAGS.task)