我在尝试用PyTorch创建一个AI时遇到了这个错误:
RuntimeError: gather_out_cpu(): Expected dtype int64 for index
这是我的函数:
def learn(self, batch_state, batch_next_state, batch_reward, batch_action): outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1) next_outputs = self.model(batch_next_state).detach().max(1)[0] target = self.gamma * next_outputs + batch_reward td_loss = F.smooth_l1_loss(outputs, target) self.optimizer.zero_grad() td_loss.backward(retain_variables = True) self.optimizer.step()
回答:
您需要在将batch_action
张量传递给torch.gather
之前更改其数据类型。
def learn(...): batch_action = batch_action.type(torch.int64) outputs = ... ...# oroutputs = self.model(batch_state).gather(1, batch_action.type(torch.int64).unsqueeze(1)).squeeze(1)