ValueError: 无效的十进制整数字面值在Alexnet上

嘿,我在运行Alexnet特征提取代码时遇到了一个错误。我使用这个GitHub链接创建了alexnet.pb文件。我使用Tensorboard进行了检查,图形显示正常。

我想使用这个模型从fc7/relu层提取特征,并将其输入到另一个模型中。我使用以下代码创建图形:

data = 0model_dir = 'model'images_dir = 'images_alexnet/train/' + str(data) + '/'list_images = [images_dir+f for f in os.listdir(images_dir) if re.search('jpeg|JPEG', f)]list_images.sort()def create_graph():    with gfile.FastGFile(os.path.join(model_dir, 'alexnet.pb'), 'rb') as f:        graph_def = tf.GraphDef()        graph_def.ParseFromString(f.read())        _ = tf.import_graph_def(graph_def, name='')create_graph()

然后我使用以下代码输入input并提取fc7/relu层:

def extract_features(image_paths, verbose=False):            feature_dimension = 4096    features = np.empty((len(image_paths), feature_dimension))    with tf.Session() as sess:        flattened_tensor = sess.graph.get_tensor_by_name('fc7/relu:0')        for i, image_path in enumerate(image_paths):            if verbose:                print('正在处理 %s...' % (image_path))            if not gfile.Exists(image_path):                tf.logging.fatal('文件不存在 %s', image)            image_data = gfile.FastGFile(image_path, 'rb').read()            feature = sess.run(flattened_tensor, {'input:0': image_data})            features[i, :] = np.squeeze(feature)    return features

但我得到了这个错误:

ValueError: invalid literal for int() with base 10: b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\

看起来我在输入图形时做错了。我使用Tensorboard查看图形,似乎占位符的dtypeuint8。我该如何解决这个问题?

完整错误信息:

  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile    execfile(filename, namespace)  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile    exec(compile(f.read(), filename, 'exec'), namespace)  File "C:/Users/Hermon Jay/Documents/Python/diabetic_retinopathy_temp6_transfer_learning/feature_extraction_alexnet.py", line 49, in <module>    features = extract_features(list_images)  File "C:/Users/Hermon Jay/Documents/Python/diabetic_retinopathy_temp6_transfer_learning/feature_extraction_alexnet.py", line 44, in extract_features    feature = sess.run(flattened_tensor, {'input:0': image_data})  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 889, in run    run_metadata_ptr)  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1089, in _run    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)  File "C:\ProgramData\Anaconda3\lib\site-packages\numpy\core\numeric.py", line 531, in asarray    return array(a, dtype, copy=False, order=order)ValueError: invalid literal for int() with base 10: b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\

回答:

这一行代码:

image_data = gfile.FastGFile(image_path, 'rb').read()

正在将image_path处的文件读取为字节数组。然而,input占位符期望的是一个类型为uint8的四维数组。例如,你可以参考你提供的链接中的下一个教程,10 AlexNet Transfer Learning;其中get_batch函数使用额外的图形和操作如tf.image.decode_jpeg来生成批次;然后将该图形的结果作为输入输入到主网络图形中。

例如,如果你的所有图像都能装入内存,你可以这样做(否则你需要像教程中那样进行批处理):

def read_images(image_paths):    with tf.Graph().as_default(), tf.Session() as sess:        file_name = tf.placeholder(tf.string)        jpeg_data = tf.read_file(jpeg_name)        decoded_image = tf.image.decode_jpeg(jpeg_data, channels=3)        images = []        for path in image_paths:            images.append(sess.run(decoded_image, feed_dict={file_name: path}))        return imagesdef extract_features(image_paths):    images = read_images(image_paths)    with tf.Session() as sess:        flattened_tensor = sess.graph.get_tensor_by_name('fc7/relu:0')        return sess.run(flattened_tensor, {'input:0': images})

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注