PyTorch中LSTM的参数和函数调用

能否请有人解释以下代码:

import torchimport torch.nn as nninput = torch.randn(5, 3, 10)h0 = torch.randn(2, 3, 20)c0 = torch.randn(2, 3, 20)rnn = nn.LSTM(10,20,2)output, (hn, cn) = rnn(input, (h0, c0))print(input)

在调用rnn rnn(input, (h0, c0))时,我们在括号中传递了参数h0和c0。这意味着什么?如果(h0, c0)代表一个单一的值,那么这个值是什么?这里传递的第三个参数又是什么?然而,在rnn = nn.LSTM(10,20,2)这行中,我们在LSTM函数中没有使用括号就传递了参数。谁能解释一下这个函数调用是如何工作的?


回答:

赋值rnn = nn.LSTM(10, 20, 2)创建了一个新的nn.Module实例,使用了nn.LSTM类。其前三个参数分别是input_size(这里是10),hidden_size(这里是20)和num_layers(这里是2)。

另一方面,rnn(input, (h0, c0))实际上是在调用该类实例,即运行__call__方法,这大致相当于该模块的forward函数。nn.LSTM__call__方法接受两个参数:input(形状为(sequnce_length, batch_size, input_size))和一个包含两个张量的元组(h_0, c_0)(在nn.LSTM的基本用例中,这两个张量的形状都是(num_layers, batch_size, hidden_size)

在使用内置函数时,请随时参考PyTorch文档,您将找到参数列表的精确定义(用于初始化类实例的参数)以及输入/输出规格(在使用该模块推理时)。


您可能对这种符号表示法感到困惑,以下是一个小例子,可能会有所帮助:

  • 元组作为输入:

    def fn1(x, p):    a, b = p # 解包输入    return a*x + b>>> fn1(2, (3, 1))>>> 7
  • 元组作为输出

    def fn2(x):    return x, (3*x, x**2) # 实际上输出是一个整数和一个元组的元组 >>> x, (a, b) = fn2(2) # 解包(2, (6, 4))>>> x, a, b(2, 6, 4)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注