使用pytorch进行句子分类的多类别分类(使用nn.LSTM)

我有一个来自这个教程的网络,我希望以句子作为输入(这已经完成),并得到一个单行张量作为结果。

根据教程,这句话“John’s dog likes food”返回了一个一列的张量:

tensor([[-3.0462, -4.0106, -0.6096],[-4.8205, -0.0286, -3.9045],[-3.7876, -4.1355, -0.0394],[-0.0185, -4.7874, -4.6013]])

…以及类别列表:

tag_list[ "name", "verb", "noun"]

每行都有与词相关的标签的概率。(第一个词有[-3.0462, -4.0106, -0.6096]向量,其中最后一个元素对应于得分最高的标签,“noun”)

教程中的数据集看起来像这样:

training_data = [    ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),    ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])]

而我希望我的数据集格式如下:

training_data = [    ("Hello world".split(), ["ONE"]),    ("I am dog".split(), ["TWO"]),    ("It's Britney glitch".split(), ["THREE"])]

参数定义如下:

class LSTMTagger(nn.Module):    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):        super(LSTMTagger, self).__init__()        self.hidden_dim = hidden_dim        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)        self.lstm = nn.LSTM(embedding_dim, hidden_dim)        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)    def forward(self, sentence):        embeds      = self.word_embeddings(sentence)        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))        tag_space   = self.hidden2tag(lstm_out.view(len(sentence), -1))        tag_scores  = F.log_softmax(tag_space, dim=1)        return tag_scores

目前,输入和输出的尺寸不匹配,我得到的错误是:ValueError: Expected input batch_size (2) to match target batch_size (1)。

由于尺寸不匹配,似乎损失函数无法接受输入:

loss        = criterion(tag_scores, targets)

我读到最后一层可以定义为nn.Linear来压缩输出,但我似乎无法得到任何结果。尝试了其他损失函数

如何修改模型,使其对整个句子进行分类,而不是像原始教程中那样对每个词进行分类?


回答:

我通过简单地获取最后的隐藏状态解决了这个问题

tag_space   = self.hidden2tag(lstm_out[-1])

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注