我在使用sentence_transformers库编码大量文档(超过一百万)时遇到了问题。
给定一个非常相似的语料库字符串列表。当我执行以下操作时:
from sentence_transformers import SentenceTransformerembedder = SentenceTransformer('msmarco-distilbert-base-v2')corpus_embeddings = embedder.encode(corpus, convert_to_tensor=False)
几个小时后,进程似乎卡住了,因为它永远不会完成,并且在检查进程查看器时没有任何运行的进程。
由于我怀疑这是内存问题(GPU板没有足够的内存在一单步中处理所有数据),我尝试将语料库分成批次,将它们转换为NumPy数组,然后将它们连接成一个单一的矩阵,如下所示:
from itertools import zip_longestfrom sentence_transformers import SentenceTransformer, utilimport torchfrom loguru import loggerimport globfrom natsort import natsorteddef grouper(iterable, n, fillvalue=np.nan): args = [iter(iterable)] * n return zip_longest(*args, fillvalue=fillvalue)embedder = SentenceTransformer('msmarco-distilbert-base-v2')for j, e in enumerate(list(grouper(corpus, 3))): try:# print('------------------') for i in filter(lambda v: v==v, e): corpus_embeddings=embedder.encode(i, convert_to_tensor=False) torch.save(corpus_embeddings, f'/Users/user/Downloads/embeddings_part_{j}.npy') except TypeError: print(j, e) logger.debug("TypeError in batch {batch_num}", batch_num=j)l = []for e in natsorted(glob.glob("/Users/user/Downloads/*.npy")): l.append(torch.load(e)) corpus_embeddings = np.vstack(l)corpus_embeddings
然而,上述程序似乎不起作用。原因是当我尝试使用语料库的一个小样本时,无论是否使用批处理方法,所得到的矩阵都是不同的,例如:
不使用批处理方法:
array([[-0.6828216 , -0.26541945, 0.31026787, ..., 0.19941986, 0.02366139, 0.4489861 ], [-0.45781 , -0.02955275, 1.0897563 , ..., -0.20077021, -0.37821707, 0.2248317 ], [ 0.8532193 , -0.13642257, -0.8872398 , ..., -0.57482916, 0.12760726, -0.66986346], ..., [-0.04036704, 0.06745373, -0.6010259 , ..., -0.08174597, -0.18513843, -0.64744204], [-0.30782765, -0.04935509, -0.11624689, ..., 0.10423593, -0.14073376, -0.09206307], [-0.77139395, -0.08119706, 0.43753916, ..., 0.1653319 , 0.06861683, -0.16276269]], dtype=float32)
使用批处理方法:
array([[ 0.8532191 , -0.13642241, -0.8872397 , ..., -0.5748289 , 0.12760736, -0.6698637 ], [ 0.3679317 , -0.21968201, 0.9932826 , ..., -0.86282325, -0.04683857, 0.18995859], [ 0.23026675, 0.69587034, -0.8116473 , ..., 0.23903558, 0.413471 , -0.23438476], ..., [ 0.923319 , 0.4152724 , -0.3153545 , ..., -0.6863369 , 0.01149149, -0.51300013], [-0.30782777, -0.04935484, -0.11624689, ..., 0.10423636, -0.1407339 , -0.09206269], [-0.77139413, -0.08119693, 0.43753892, ..., 0.16533189, 0.06861652, -0.16276267]], dtype=float32)
上述批处理程序的正确方法是什么?
更新
在检查上述批处理程序后,我发现当我将上述代码的批处理大小设置为1
时((enumerate(list(grouper(corpus, 1))))
),我能够得到相同的矩阵输出,无论是否使用批处理。因此,我的疑问是,如何正确地将编码器应用于大量文档?
回答:
此行的代码在进行编码之前按文本长度对输入进行排序。我不知道为什么这样做。
所以,要么注释掉这些行,要么像这样将它们复制到你的代码中:
length_sorted_idx = np.argsort([-embedder._text_length(sen) for sen in corpus])corpus_sorted = [corpus[idx] for idx in length_sorted_idx]
然后使用corpus_sorted
进行编码,并使用length_sorted_idx
将输出映射回原位置。
或者,你可以一次一个地进行编码,这样就不需要关心哪个输出来自哪个文本了。