嘿,我在运行Alexnet特征提取代码时遇到了一个错误。我使用这个GitHub链接创建了alexnet.pb
文件。我使用Tensorboard进行了检查,图形显示正常。
我想使用这个模型从fc7/relu
层提取特征,并将其输入到另一个模型中。我使用以下代码创建图形:
data = 0model_dir = 'model'images_dir = 'images_alexnet/train/' + str(data) + '/'list_images = [images_dir+f for f in os.listdir(images_dir) if re.search('jpeg|JPEG', f)]list_images.sort()def create_graph(): with gfile.FastGFile(os.path.join(model_dir, 'alexnet.pb'), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='')create_graph()
然后我使用以下代码输入input
并提取fc7/relu
层:
def extract_features(image_paths, verbose=False): feature_dimension = 4096 features = np.empty((len(image_paths), feature_dimension)) with tf.Session() as sess: flattened_tensor = sess.graph.get_tensor_by_name('fc7/relu:0') for i, image_path in enumerate(image_paths): if verbose: print('正在处理 %s...' % (image_path)) if not gfile.Exists(image_path): tf.logging.fatal('文件不存在 %s', image) image_data = gfile.FastGFile(image_path, 'rb').read() feature = sess.run(flattened_tensor, {'input:0': image_data}) features[i, :] = np.squeeze(feature) return features
但我得到了这个错误:
ValueError: invalid literal for int() with base 10: b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\
看起来我在输入图形时做错了。我使用Tensorboard查看图形,似乎占位符的dtype
是uint8
。我该如何解决这个问题?
完整错误信息:
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile execfile(filename, namespace) File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile exec(compile(f.read(), filename, 'exec'), namespace) File "C:/Users/Hermon Jay/Documents/Python/diabetic_retinopathy_temp6_transfer_learning/feature_extraction_alexnet.py", line 49, in <module> features = extract_features(list_images) File "C:/Users/Hermon Jay/Documents/Python/diabetic_retinopathy_temp6_transfer_learning/feature_extraction_alexnet.py", line 44, in extract_features feature = sess.run(flattened_tensor, {'input:0': image_data}) File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 889, in run run_metadata_ptr) File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1089, in _run np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) File "C:\ProgramData\Anaconda3\lib\site-packages\numpy\core\numeric.py", line 531, in asarray return array(a, dtype, copy=False, order=order)ValueError: invalid literal for int() with base 10: b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\
回答:
这一行代码:
image_data = gfile.FastGFile(image_path, 'rb').read()
正在将image_path
处的文件读取为字节数组。然而,input
占位符期望的是一个类型为uint8
的四维数组。例如,你可以参考你提供的链接中的下一个教程,10 AlexNet Transfer Learning;其中get_batch
函数使用额外的图形和操作如tf.image.decode_jpeg
来生成批次;然后将该图形的结果作为输入输入到主网络图形中。
例如,如果你的所有图像都能装入内存,你可以这样做(否则你需要像教程中那样进行批处理):
def read_images(image_paths): with tf.Graph().as_default(), tf.Session() as sess: file_name = tf.placeholder(tf.string) jpeg_data = tf.read_file(jpeg_name) decoded_image = tf.image.decode_jpeg(jpeg_data, channels=3) images = [] for path in image_paths: images.append(sess.run(decoded_image, feed_dict={file_name: path})) return imagesdef extract_features(image_paths): images = read_images(image_paths) with tf.Session() as sess: flattened_tensor = sess.graph.get_tensor_by_name('fc7/relu:0') return sess.run(flattened_tensor, {'input:0': images})