TensorFlow模型在大数据集上耗时呈指数增长

我正在使用TensorFlow for Poets来检测服装图片中的特征。我已经训练了4个不同的模型(袖子、形状、长度和下摆)。现在我将图片URL传递给每个模型并存储结果。由于我的数据量很大(10万张图片),因此使用Spark来广播4个模型一次,并将图片RDD传递给模型以检测特征。这耗时呈指数增长。从每张图片3秒开始,执行时间不断增加。当脚本已经处理了10000张图片时,执行时间达到了每张图片8秒。我对TensorFlow还不熟,如果能得到任何让执行时间变为线性的建议,将不胜感激。

def getLabelDresses(file_name):    resultDict = {}    t = read_tensor_from_image_file(file_name,                              input_height=input_height,                              input_width=input_width,                              input_mean=input_mean,                              input_std=input_std)    input_name = "import/" + input_layer    output_name = "import/" + output_layer    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_hemline.value)        tf.import_graph_def(graph_def)        input_operation_hemline = g.get_operation_by_name(input_name);        output_operation_hemline = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_hemline.outputs[0],{input_operation_hemline.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_hemline)        resultDict['hemline'] = labels[top_k[0]]    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_shape.value)        tf.import_graph_def(graph_def)        input_operation_shape = g.get_operation_by_name(input_name);        output_operation_shape = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_shape.outputs[0],{input_operation_shape.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_shape)        resultDict['shape'] = labels[top_k[0]]    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_length.value)        tf.import_graph_def(graph_def)        input_operation_length = g.get_operation_by_name(input_name);        output_operation_length = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_length.outputs[0],{input_operation_length.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_length)        resultDict['length'] = labels[top_k[0]]    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_sleeve.value)        tf.import_graph_def(graph_def)        input_operation_sleeve = g.get_operation_by_name(input_name);        output_operation_sleeve = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_sleeve.outputs[0],{input_operation_sleeve.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_sleeve)        resultDict['sleeve'] = labels[top_k[0]]         return resultDict;model_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_graph_hemline.pb"label_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_labels_hemline.txt"model_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_graph_length.pb"label_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_labels_length.txt"model_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_graph_shape.pb"label_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_labels_shape.txt"model_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_graph_sleeve.pb"label_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_labels_sleeve.txt"with gfile.FastGFile(model_file_hemline, "rb") as f:    model_data = f.read()    model_data_hemline = sc.broadcast(model_data)with gfile.FastGFile(model_file_length, "rb") as f:    model_data = f.read()    model_data_length = sc.broadcast(model_data)with gfile.FastGFile(model_file_shape, "rb") as f:    model_data = f.read()    model_data_shape = sc.broadcast(model_data)with gfile.FastGFile(model_file_sleeve, "rb") as f:    model_data = f.read()    model_data_sleeve = sc.broadcast(model_data)def calculate(row):    path = "/tmp/"+row.guid    url = row.modelno    print(path, url)    if(url is not None):        import urllib.request        urllib.request.urlretrieve(url, path)        t1=time.time()         result = getLabelDresses(path)        print(time.time()-t1)        print(result)        return row    return rowproduct2.rdd.map(calculate).collect()

回答:

您的代码中每次调用getLabelDresses都会向图中添加操作。

将您的代码拆分为设置(模型加载)部分和执行部分,前者只执行一次,后者对每张图片执行。后者应只包含对Session.run的调用。

另一个选项是在处理下一张图片之前使用tf.reset_default_graph清除图。但这不是首选方案。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注