pytorch “尝试第二次通过图进行反向传播”错误与字符级RNN

我正在使用pytorch训练一个字符级的GRU,并将文本分成一定长度的批次。这是训练循环:

for e in range(self.epochs):  self.model.train()  h = self.get_init_state(self.batch_size)    for batch_num in range(self.num_batch_runs):    batch = self.generate_batch(batch_num).to(device)        inp_batch = batch[:-1,:]    tar_batch = batch[1:,:]            self.model.zero_grad()    loss = 0    for i in range(inp_batch.shape[0]):      out, h = self.model(inp_batch[i:i+1,:],h)      loss += loss_fn(out[0],tar_batch[i].view(-1))              loss.backward()    nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)    optimizer.step()        if not (batch_num % 5):      print("epoch: {}, loss: {}".format(e,loss.data.item()/inp_batch.shape[0]))

尽管如此,在第一个批次之后,我仍然遇到了这个错误:

Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

提前感谢..


回答:

我自己找到了答案,GRU的隐藏状态仍然附着在最后一次批次运行上,因此需要使用以下方法断开连接:

h.detach_()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注