我正在使用TensorFlow for Poets来检测服装图片中的特征。我已经训练了4个不同的模型(袖子、形状、长度和下摆)。现在我将图片URL传递给每个模型并存储结果。由于我的数据量很大(10万张图片),因此使用Spark来广播4个模型一次,并将图片RDD传递给模型以检测特征。这耗时呈指数增长。从每张图片3秒开始,执行时间不断增加。当脚本已经处理了10000张图片时,执行时间达到了每张图片8秒。我对TensorFlow还不熟,如果能得到任何让执行时间变为线性的建议,将不胜感激。
def getLabelDresses(file_name): resultDict = {} t = read_tensor_from_image_file(file_name, input_height=input_height, input_width=input_width, input_mean=input_mean, input_std=input_std) input_name = "import/" + input_layer output_name = "import/" + output_layer with tf.Graph().as_default() as g: graph_def = tf.GraphDef() graph_def.ParseFromString(model_data_hemline.value) tf.import_graph_def(graph_def) input_operation_hemline = g.get_operation_by_name(input_name); output_operation_hemline = g.get_operation_by_name(output_name); with tf.Session() as sess: results = sess.run(output_operation_hemline.outputs[0],{input_operation_hemline.outputs[0]: t}) results = np.squeeze(results) top_k = results.argsort()[-1:][::-1] labels = load_labels(label_file_hemline) resultDict['hemline'] = labels[top_k[0]] with tf.Graph().as_default() as g: graph_def = tf.GraphDef() graph_def.ParseFromString(model_data_shape.value) tf.import_graph_def(graph_def) input_operation_shape = g.get_operation_by_name(input_name); output_operation_shape = g.get_operation_by_name(output_name); with tf.Session() as sess: results = sess.run(output_operation_shape.outputs[0],{input_operation_shape.outputs[0]: t}) results = np.squeeze(results) top_k = results.argsort()[-1:][::-1] labels = load_labels(label_file_shape) resultDict['shape'] = labels[top_k[0]] with tf.Graph().as_default() as g: graph_def = tf.GraphDef() graph_def.ParseFromString(model_data_length.value) tf.import_graph_def(graph_def) input_operation_length = g.get_operation_by_name(input_name); output_operation_length = g.get_operation_by_name(output_name); with tf.Session() as sess: results = sess.run(output_operation_length.outputs[0],{input_operation_length.outputs[0]: t}) results = np.squeeze(results) top_k = results.argsort()[-1:][::-1] labels = load_labels(label_file_length) resultDict['length'] = labels[top_k[0]] with tf.Graph().as_default() as g: graph_def = tf.GraphDef() graph_def.ParseFromString(model_data_sleeve.value) tf.import_graph_def(graph_def) input_operation_sleeve = g.get_operation_by_name(input_name); output_operation_sleeve = g.get_operation_by_name(output_name); with tf.Session() as sess: results = sess.run(output_operation_sleeve.outputs[0],{input_operation_sleeve.outputs[0]: t}) results = np.squeeze(results) top_k = results.argsort()[-1:][::-1] labels = load_labels(label_file_sleeve) resultDict['sleeve'] = labels[top_k[0]] return resultDict;model_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_graph_hemline.pb"label_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_labels_hemline.txt"model_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_graph_length.pb"label_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_labels_length.txt"model_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_graph_shape.pb"label_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_labels_shape.txt"model_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_graph_sleeve.pb"label_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_labels_sleeve.txt"with gfile.FastGFile(model_file_hemline, "rb") as f: model_data = f.read() model_data_hemline = sc.broadcast(model_data)with gfile.FastGFile(model_file_length, "rb") as f: model_data = f.read() model_data_length = sc.broadcast(model_data)with gfile.FastGFile(model_file_shape, "rb") as f: model_data = f.read() model_data_shape = sc.broadcast(model_data)with gfile.FastGFile(model_file_sleeve, "rb") as f: model_data = f.read() model_data_sleeve = sc.broadcast(model_data)def calculate(row): path = "/tmp/"+row.guid url = row.modelno print(path, url) if(url is not None): import urllib.request urllib.request.urlretrieve(url, path) t1=time.time() result = getLabelDresses(path) print(time.time()-t1) print(result) return row return rowproduct2.rdd.map(calculate).collect()
回答:
您的代码中每次调用getLabelDresses
都会向图中添加操作。
将您的代码拆分为设置(模型加载)部分和执行部分,前者只执行一次,后者对每张图片执行。后者应只包含对Session.run
的调用。
另一个选项是在处理下一张图片之前使用tf.reset_default_graph
清除图。但这不是首选方案。