如何优化SpaCy训练?

我目前正在训练一个用于多标签文本分类的SpaCy模型。有6个标签:愤怒、期待、厌恶、恐惧、喜悦、悲伤、惊讶和信任。数据集超过20万条。然而,每个epoch需要4个小时。我想知道是否有办法优化训练过程,使其更快进行,可能是我遗漏了某些可以改进模型的步骤。


训练数据

TRAIN_DATA = list(zip(train_texts, [{"cats": cats} for cats in final_train_cats]))[...  {'cats': {'anger': 1,    'anticipation': 0,    'disgust': 0,    'fear': 0,    'joy': 0,    'sadness': 0,    'surprise': 0,    'trust': 0}}), ('mausoleum',  {'cats': {'anger': 1,    'anticipation': 0,    'disgust': 0,    'fear': 0,    'joy': 0,    'sadness': 0,    'surprise': 0,    'trust': 0}}), ...]

训练过程

nlp = spacy.load("en_core_web_sm")category = nlp.create_pipe("textcat", config={"exclusive_classes": True})nlp.add_pipe(category)# add label to text classifiercategory.add_label("trust")category.add_label("fear")category.add_label("disgust")category.add_label("surprise")category.add_label("anticipation")category.add_label("anger")category.add_label("joy")optimizer = nlp.begin_training()losses = {}for i in range(100):    random.shuffle(TRAIN_DATA)    print('...')    for batch in minibatch(TRAIN_DATA, size=8):        texts = [nlp(text) for text, entities in batch]        annotations = [{"cats": entities} for text, entities in batch]        nlp.update(texts, annotations, sgd=optimizer, losses=losses)    print(i, losses)...0 {'parser': 0.0, 'tagger': 27.018985521040854, 'textcat': 0.0, 'ner': 0.0}...1 {'parser': 0.0, 'tagger': 27.01898552104131, 'textcat': 0.0, 'ner': 0.0}...

回答:

“20万条记录的数据集每个epoch需要4个小时”并不能告诉我们太多信息:

  1. 确保你没有耗尽内存(你有吗?)它占用了多少RAM?
  2. 你可能是在单线程下运行的,由于GIL的影响。例如,参考这篇文章,了解如何关闭GIL以多核运行训练。你有多少个核心?
  • texts = [nlp(text) ...]放在内循环for batch in minibatch(TRAIN_DATA, size=8):中看起来有问题,因为你的代码将始终持有GIL,尽管你只需要在处理输入文本时使用它,即parser阶段,而不是在训练时。
  • 重构你的代码,使你首先对所有输入运行nlp()管道,然后保存一些中间表示(数组或其他)。将这段代码与你的训练循环分开,以便训练可以多线程进行。
  1. 我无法评论你对minibatch()参数的选择,但8看起来非常小,这些参数似乎对性能有影响,所以尝试调整它们(/网格搜索几个值)。
  2. 最后,一旦你检查了以上所有内容,找到你能找到的最快的单核/多核设备,并确保有足够的RAM。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注