使用LIME进行BERT变换器可视化结果时出现内存错误

情况:我目前正在使用LIME包,根据这个教程,可视化我构建的huggingface变换器机器学习模型的结果。

问题:我的代码设置好并运行良好,直到我创建LIME的解释器对象。此时我遇到了内存错误。

提问:我做错了什么?为什么会遇到内存错误?

代码:这是我的代码(您应该可以直接复制粘贴到google colab并运行整个代码)

########################## LOAD PACKAGES ####################### Install new packages in our environment!pip install lime!pip install wget!pip install transformers# Import general librariesimport sklearnimport sklearn.ensembleimport sklearn.metricsimport numpy as npimport pandas as pd# Import libraries specific to this notebookimport limeimport wgetimport osfrom __future__ import print_functionfrom transformers import FeatureExtractionPipeline, BertModel, BertTokenizer, BertConfigfrom lime.lime_text import LimeTextExplainer# Let the notebook know to plot inline%matplotlib inline########################## LOAD DATA ########################### Get URLurl = 'https://nyu-mll.github.io/CoLA/cola_public_1.1.zip'# Download the file (if we haven't already)if not os.path.exists('./cola_public_1.1.zip'):    wget.download(url, './cola_public_1.1.zip')# Unzip the dataset (if we haven't already)if not os.path.exists('./cola_public/'):    !unzip cola_public_1.1.zip# Load the dataset into a pandas dataframe.df_cola = pd.read_csv("./cola_public/raw/in_domain_train.tsv", delimiter='\t',                       header=None, names=['sentence_source', 'label',                                           'label_notes', 'sentence'])# Only look at the first 50 observations for debuggingdf_cola = df_cola.head(50)###################### TRAIN TEST SPLIT ####################### Apply the train test splitx_train, x_test, y_train, y_test = sklearn.model_selection.train_test_split(    df_cola.sentence, df_cola.label, test_size=0.2, random_state=42)###################### CREATE LIME CLASSIFIER ####################### Create a function to extract vectors from a single sentencedef vector_extractor(sentence):    # Create a basic BERT model, config and tokenizer for the pipeline    configuration = BertConfig()    configuration.max_len = 64    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',                                              do_lower_case=True,                                               max_length=64,                                              pad_to_max_length=True)    model = BertModel.from_pretrained('bert-base-uncased',config=configuration)    # Create the pipeline    vector_extractor = FeatureExtractionPipeline(model=model,                                                 tokenizer=tokenizer,                                                 device=0)    # The pipeline gives us all tokens in the final layer - we want the CLS token    vector = vector_extractor(sentence)    vector = vector[0][0]    # Return the vector    return vector# Adjust the format of our sentences (from pandas series to python list)x_train = x_train.values.tolist()x_test = x_test.values.tolist()# First we vectorize our train features for the classifierx_train_vectorized = [vector_extractor(x) for x in x_train]# Create and fit the random forest classifierrf = sklearn.ensemble.RandomForestClassifier(n_estimators=100)rf.fit(x_train_vectorized, y_train)# Define the lime_classifier functiondef lime_classifier(sentences):     # Turn all the sentences into vectors    vectors = [vector_extractor(x) for x in sentences]    # Get predictions for all     predictions = rf.predict_proba(vectors)    # Return the probabilies as a 2D-array    return predictions  ########################### APPLY LIME ########################### Create the general explainer objectexplainer = LimeTextExplainer()# "Fit" the explainer object to a specific observationexp = explainer.explain_instance(x_test[1],                                  lime_classifier,                                  num_features=6)

回答:

最终通过重新实现类似于这个GitHub帖子中的方法解决了这个问题:https://github.com/marcotcr/lime/issues/409

我的代码现在与上述代码大不相同 – 如果您遇到类似问题,可能需要参考GitHub帖子进行指导。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注