我在编写一个图像分类器,并且已经定义了加载器,但遇到了这个错误,我完全不知道原因是什么。
我已经定义了训练加载器,为了更好地解释,我尝试了以下代码:
for ina,lab in train_loader: print(type(ina)) print(type(lab))
结果我得到了:
<class 'torch.Tensor'><class 'tuple'>
现在,为了训练模型,我做了以下操作:
def train_model(model,optimizer,n_epochs,criterion): start_time = time.time() for epoch in range(1,n_epochs-1): epoch_time = time.time() epoch_loss = 0 correct = 0 total = 0 print( "Epoch {}/{}".format(epoch,n_epochs)) model.train() for inputs,labels in train_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() output = model(inputs) loss = criterion(output,labels) loss.backward() optimizer.step() epoch_loss +=loss.item() _,pred =torch.max(output,1) correct += (pred.cpu()==label.cpu()).sum().item() total +=labels.shape[0] acc = correct/total
然后我得到了以下错误:
Epoch 1/15---------------------------------------------------------------------------AttributeError Traceback (most recent call last)<ipython-input-36-fea243b3636a> in <module>----> 1 train_model(model=arch, optimizer=optim, n_epochs=15, criterion=criterion)<ipython-input-34-b53149a4bac0> in train_model(model, optimizer, n_epochs, criterion) 12 for inputs,labels in train_loader: 13 inputs = inputs.to(device)---> 14 labels = labels.to(device) 15 optimizer.zero_grad() 16 output = model(inputs)AttributeError: 'tuple' object has no attribute 'to'
如果你需要更多信息,请告诉我!谢谢
编辑:标签看起来是这样的。这是一个蜜蜂和黄蜂的图像分类。它还包含了昆虫和非昆虫
(‘wasp’, ‘wasp’, ‘insect’, ‘insect’, ‘wasp’, ‘insect’, ‘insect’, ‘wasp’, ‘wasp’, ‘bee’, ‘insect’, ‘insect’, ‘other’, ‘bee’, ‘other’, ‘wasp’, ‘other’, ‘wasp’, ‘bee’, ‘bee’, ‘wasp’, ‘wasp’, ‘wasp’, ‘wasp’, ‘bee’, ‘wasp’, ‘wasp’, ‘other’, ‘bee’, ‘wasp’, ‘bee’, ‘bee’)(‘wasp’, ‘wasp’, ‘insect’, ‘bee’, ‘other’, ‘wasp’, ‘insect’, ‘wasp’, ‘insect’, ‘insect’, ‘insect’, ‘wasp’, ‘wasp’, ‘insect’, ‘wasp’, ‘wasp’, ‘wasp’, ‘bee’, ‘wasp’, ‘wasp’, ‘insect’, ‘insect’, ‘wasp’, ‘wasp’, ‘bee’, ‘wasp’, ‘insect’, ‘bee’, ‘bee’, ‘insect’, ‘insect’, ‘other’)
回答:
这字面上意味着Python中的tuple类没有名为to
的方法。因为你试图将标签放到设备上,只需使用labels = torch.tensor(labels).to(device)
即可。
如果你不想这样做,你可以通过让DataLoader返回PyTorch张量而不是元组来改变它的工作方式。
编辑
由于标签似乎是字符串,我建议先将它们转换为独热编码向量:
>>> import torch>>> labels_unique = set(labels)>>> keys = {key: value for key, value in zip(labels_unique, range(len(labels_unique)))}>>> labels_onehot = torch.zeros(size=(len(labels), len(keys)))>>> for idx, label in enumerate(labels_onehot):... labels_onehot[idx][keys[label]] = 1...>>> labels_onehot = labels.to(device)
我在这里有点盲目尝试,因为我并不完全了解细节,但确实字符串无法与张量一起使用。