TensorFlow模型在大数据集上耗时呈指数增长

我正在使用TensorFlow for Poets来检测服装图片中的特征。我已经训练了4个不同的模型(袖子、形状、长度和下摆)。现在我将图片URL传递给每个模型并存储结果。由于我的数据量很大(10万张图片),因此使用Spark来广播4个模型一次,并将图片RDD传递给模型以检测特征。这耗时呈指数增长。从每张图片3秒开始,执行时间不断增加。当脚本已经处理了10000张图片时,执行时间达到了每张图片8秒。我对TensorFlow还不熟,如果能得到任何让执行时间变为线性的建议,将不胜感激。

def getLabelDresses(file_name):    resultDict = {}    t = read_tensor_from_image_file(file_name,                              input_height=input_height,                              input_width=input_width,                              input_mean=input_mean,                              input_std=input_std)    input_name = "import/" + input_layer    output_name = "import/" + output_layer    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_hemline.value)        tf.import_graph_def(graph_def)        input_operation_hemline = g.get_operation_by_name(input_name);        output_operation_hemline = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_hemline.outputs[0],{input_operation_hemline.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_hemline)        resultDict['hemline'] = labels[top_k[0]]    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_shape.value)        tf.import_graph_def(graph_def)        input_operation_shape = g.get_operation_by_name(input_name);        output_operation_shape = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_shape.outputs[0],{input_operation_shape.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_shape)        resultDict['shape'] = labels[top_k[0]]    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_length.value)        tf.import_graph_def(graph_def)        input_operation_length = g.get_operation_by_name(input_name);        output_operation_length = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_length.outputs[0],{input_operation_length.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_length)        resultDict['length'] = labels[top_k[0]]    with tf.Graph().as_default() as g:        graph_def = tf.GraphDef()        graph_def.ParseFromString(model_data_sleeve.value)        tf.import_graph_def(graph_def)        input_operation_sleeve = g.get_operation_by_name(input_name);        output_operation_sleeve = g.get_operation_by_name(output_name);        with tf.Session() as sess:            results = sess.run(output_operation_sleeve.outputs[0],{input_operation_sleeve.outputs[0]: t})        results = np.squeeze(results)        top_k = results.argsort()[-1:][::-1]        labels = load_labels(label_file_sleeve)        resultDict['sleeve'] = labels[top_k[0]]         return resultDict;model_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_graph_hemline.pb"label_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_labels_hemline.txt"model_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_graph_length.pb"label_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_labels_length.txt"model_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_graph_shape.pb"label_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_labels_shape.txt"model_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_graph_sleeve.pb"label_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_labels_sleeve.txt"with gfile.FastGFile(model_file_hemline, "rb") as f:    model_data = f.read()    model_data_hemline = sc.broadcast(model_data)with gfile.FastGFile(model_file_length, "rb") as f:    model_data = f.read()    model_data_length = sc.broadcast(model_data)with gfile.FastGFile(model_file_shape, "rb") as f:    model_data = f.read()    model_data_shape = sc.broadcast(model_data)with gfile.FastGFile(model_file_sleeve, "rb") as f:    model_data = f.read()    model_data_sleeve = sc.broadcast(model_data)def calculate(row):    path = "/tmp/"+row.guid    url = row.modelno    print(path, url)    if(url is not None):        import urllib.request        urllib.request.urlretrieve(url, path)        t1=time.time()         result = getLabelDresses(path)        print(time.time()-t1)        print(result)        return row    return rowproduct2.rdd.map(calculate).collect()

回答:

您的代码中每次调用getLabelDresses都会向图中添加操作。

将您的代码拆分为设置(模型加载)部分和执行部分,前者只执行一次,后者对每张图片执行。后者应只包含对Session.run的调用。

另一个选项是在处理下一张图片之前使用tf.reset_default_graph清除图。但这不是首选方案。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注