使用TensorFlow对实时视频进行分类

我正在使用这个教程来开始学习TensorFlow – TensorFlow for poets

在使用retrain.py脚本训练模型后,我希望使用retrained_graph.pb来对视频进行分类,并在视频运行时实时查看结果。

我所做的是使用opencv来读取我想分类的视频,逐帧读取。即读取一帧,保存它,打开它,分类它,并使用cv2.imshow()在屏幕上显示它和分类结果。

这样确实可以工作,但由于从磁盘读取和写入帧,导致视频出现延迟。

我能否使用训练过程中获得的图形,直接对视频进行分类,而不需要逐帧读取和保存?

这是我使用的代码 –

with tf.Session(graph=graph) as sess:video_capture = cv2.VideoCapture(video_path)i = 0while True:    frame = video_capture.read()[1] # get current frame    frameId = video_capture.get(1) #current frame number    i = i + 1    cv2.imwrite(filename="C:\\video_images\\"+ str(i) +".jpg", img=frame) # write frame image to file    image_data = "C:\\video_images\\" + str(i) + ".jpg"    t = read_tensor_from_image_file(image_data,                                    input_height=input_height,                                    input_width=input_width,                                    input_mean=input_mean,                                    input_std=input_std)    predictions = sess.run(output_operation.outputs[0], {input_operation.outputs[0]: t})    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]    scores = []    for node_id in top_k:        human_string = label_lines[node_id]        score = predictions[0][node_id]        scores.append([score, human_string])        #print('%s (score = %.5f)' % (human_string, score))    #print("\n\n")    font = cv2.FONT_HERSHEY_SIMPLEX    cv2.putText(frame, scores[0][1] + " - " + repr(round(scores[0][0], 2)), (10, 50), font, 1, (0, 0, 255), 2, cv2.LINE_AA)    cv2.putText(frame, scores[1][1] + " - " + repr(round(scores[1][0], 2)), (10, 100), font, 1, (0, 0, 255), 2, cv2.LINE_AA)    cv2.imshow("image", frame)    cv2.waitKey(1)    os.remove("C:\\video_images\\" + str(i) + ".jpg")video_capture.release()cv2.destroyAllWindows()

谢谢。


回答:

已经解决了这个问题。

我编辑了read_tensor_from_image_file函数如下,并直接用帧代替image_data作为输入。

def read_tensor_from_image_file(file_name,                            input_height=299,                            input_width=299,                            input_mean=0,                            input_std=255):input_name = "file_reader"output_name = "normalized"if type(file_name) is str:    file_reader = tf.read_file(file_name, input_name)    if file_name.endswith(".png"):        image_reader = tf.image.decode_png(file_reader, channels = 3,                                           name='png_reader')    elif file_name.endswith(".gif"):        image_reader = tf.squeeze(tf.image.decode_gif(file_reader,                                                      name='gif_reader'))    elif file_name.endswith(".bmp"):        image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')    else:        image_reader = tf.image.decode_jpeg(file_reader, channels = 3,                                            name='jpeg_reader')    float_caster = tf.cast(image_reader, tf.float32)    dims_expander = tf.expand_dims(float_caster, 0);    resized = tf.image.resize_bilinear(dims_expander, [input_height,                                                       input_width])    normalized = tf.divide(tf.subtract(resized, [input_mean]),                            [input_std])    sess = tf.Session()    result = sess.run(normalized)elif type(file_name) is np.ndarray:    resized = cv2.resize(file_name, (input_width, input_height),                            interpolation=cv2.INTER_LINEAR)    normalized = (resized - input_mean) / input_std    result = normalized    result = array(result).reshape(1, 224, 224, 3)return result

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注