微调模型时内存不足

我目前正在尝试微调来自https://huggingface.co/dima806/bird_sounds_classificationWav2Vec2模型。但我在Google Colab的免费版本上使用的内存超过了限制。

以下是我的代码:

from transformers import TrainingArguments, Trainer# Load model with ignore_mismatched_sizes=Truemodel = Wav2Vec2ForSequenceClassification.from_pretrained(    "dima806/bird_sounds_classification",    num_labels=len(label2id),    ignore_mismatched_sizes=True)# Set up training with gradient accumulationbatch_size = 1  # Reduce batch size to manage memoryaccumulation_steps = 4  # Accumulate gradients over 4 stepstraining_args = TrainingArguments(    output_dir="./results",    evaluation_strategy="epoch",    learning_rate=2e-5,    per_device_train_batch_size=batch_size,    per_device_eval_batch_size=batch_size,    gradient_accumulation_steps=accumulation_steps,  # Gradient accumulation    num_train_epochs=3,    weight_decay=0.01,    fp16=True,  # Enable mixed precision training)trainer = Trainer(    model=model,    args=training_args,    train_dataset=train_dataset,    eval_dataset=val_dataset,    tokenizer=feature_extractor,)# Train the modeltrainer.train()

为什么内存使用量会超过12.7GB?我的数据集只有20个项目。我该如何解决这个问题?


回答:

声音输入过长,重新采样音频成小段后,问题得到了解决。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注