max_length无法解决问答模型问题

我的问题:如何让我的“问答”模型运行,处理一个大(>512字节)的.txt文件?

背景:我正在使用谷歌的BERT词嵌入模型创建一个问答模型。当我导入一个包含几句话的.txt文件时,模型运行良好,但当.txt文件超出模型学习的512字节词语限制时,模型就无法回答我的问题。

我尝试解决问题的方法:我在编码部分设置了max_length,但这似乎没有解决问题(我的尝试代码如下)。

from transformers import AutoTokenizer, AutoModelForQuestionAnsweringimport torchmax_seq_length = 512tokenizer = AutoTokenizer.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")model = AutoModelForQuestionAnswering.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")f = open("test.txt", "r")text = str(f.read())questions = [    "Wat is de hoofdstad van Nederland?",    "Van welk automerk is een Cayenne?",    "In welk jaar is pindakaas geproduceerd?",]for question in questions:    inputs = tokenizer.encode_plus(question,                                    text,                                    add_special_tokens=True,                                    max_length=max_seq_length,                                   truncation=True,                                   return_tensors="pt")    input_ids = inputs["input_ids"].tolist()[0]    text_tokens = tokenizer.convert_ids_to_tokens(input_ids)    answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)    answer_start = torch.argmax(        answer_start_scores    )  # Get the most likely beginning of answer with the argmax of the score    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))    print(f"Question: {question}")    print(f"Answer: {answer}\n")

代码结果:

> Question: Wat is de hoofdstad van Nederland?> Answer: [CLS]>> Question: Van welk automerk is een Cayenne?> Answer: [CLS]>> Question: In welk jaar is pindakaas geproduceerd?> Answer: [CLS]

可以看到,模型只返回了[CLS]标记,这是发生在标记器编码部分的。

编辑:我发现解决这个问题的办法是遍历.txt文件,让模型可以通过遍历找到答案。


回答:

编辑:我发现解决这个问题的办法是遍历.txt文件,让模型可以通过遍历找到答案。模型之所以回答[CLS],是因为它在512字节的上下文中找不到答案,它需要进一步查看上下文。

通过创建这样的循环:

with open("sample.txt", "r") as a_file:  for line in a_file:    text = line.strip()    print(text)

可以将遍历后的文本应用到encode_plus中。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注