我在尝试使用hugging face的blip2模型在colab上创建一个图像描述模型。我的代码在上周(11月8日)之前运行正常,但现在会抛出一个异常。
我使用以下命令安装包:
!pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets
我使用以下代码加载blip2处理器和模型:
model_name = "Salesforce/blip2-opt-2.7b"processor = AutoProcessor.from_pretrained(model_name)model = Blip2ForConditionalGeneration.from_pretrained(model_name,device_map="auto",load_in_8bit=False)
我使用以下代码生成描述:
def generate_caption(processor, model, image_path): image = PILImage.open(image_path).convert("RGB") print("image shape:" + image.size) device = "cuda" if torch.cuda.is_available() else "cpu" # Preprocess the image inputs = processor(images=image, return_tensors="pt").to(device) print("Input shape:", inputs['pixel_values'].shape) print("Device:", device) # Additional debugging for key, value in inputs.items(): print(f"Key: {key}, Shape: {value.shape}") # Generate caption with torch.no_grad(): generated_ids = model.generate(**inputs) caption = processor.decode(generated_ids[0], skip_special_tokens=True) return caption
以下是使用此方法生成描述的代码:
image_path = "my_image_path.jpg" caption = generate_caption(processor, model, image_path) print(f"{image_path}: {caption}"
最后,这是运行上述代码的输出和错误:
image shape: (320, 240)Input shape: torch.Size([1, 3, 224, 224])Device: cudaKey: pixel_values, Shape: torch.Size([1, 3, 224, 224])--------------------------------------------------------------------------- .../usr/local/lib/python3.10/dist-packages/transformers/models/blip_2/modeling_blip_2.py in generate(self, pixel_values, input_ids, attention_mask, interpolate_pos_encoding, **generate_kwargs) 2314 if getattr(self.config, "image_token_index", None) is not None: 2315 special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)-> 2316 inputs_embeds[special_image_mask] = language_model_inputs.flatten() 2317 else: 2318 logger.warning_once(RuntimeError: shape mismatch: value tensor of shape [81920] cannot be broadcast to indexing result of shape [0]
我已经在网上搜索并使用了各种AI模型寻求帮助,但没有结果。我猜测这是由于包更新的问题,因为我的代码上周没有问题。(我尝试将代码恢复到11月8日的版本,但它仍然抛出异常。)此外,我不明白错误消息中81920是如何计算的。
回答:
我遇到了同样的问题。你需要在处理器中添加一个提示:
prompt = " "inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)
希望这对你有帮助。