在从文本中提取特征时,如何检查一个向量化器(例如TfIdfVectorizer或CountVectorizer)是否已经在训练数据上拟合?
特别是,我希望代码能自动判断一个向量化器是否已经拟合。
from sklearn.feature_extraction.text import TfidfVectorizervectorizer = TfidfVectorizer()def vectorize_data(texts): # 如果向量化器尚未拟合 vectorizer.fit_transform(texts) # 否则 vectorizer.transform(texts)
回答:
您可以使用check_is_fitted
,它正是为此目的设计的。
在TfidfVectorizer.transform()
的源码中,您可以查看其使用方式:
def transform(self, raw_documents, copy=True): # 这就是您需要的。 check_is_fitted(self, '_tfidf', 'The tfidf vector is not fitted') X = super(TfidfVectorizer, self).transform(raw_documents) return self._tfidf.transform(X, copy=False)
因此,在您的案例中,您可以这样做:
from sklearn.utils.validation import check_is_fitteddef vectorize_data(texts): try: check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted') except NotFittedError: vectorizer.fit(texts) # 在所有情况下向量化器都已拟合,因此只需调用transform() vectorizer.transform(texts)