Supervised fine-tuning:修订间差异

来自WHY42
Riguz留言 | 贡献
Riguz留言 | 贡献
第35行: 第35行:
cd LLaMA-Factory
cd LLaMA-Factory
pip install -e ".[torch,metrics]"
pip install -e ".[torch,metrics]"
</syntaxhighlight>
== Web UI ==
<syntaxhighlight lang="bash">
GRADIO_SERVER_PORT=5050 llamafactory-cli webui
</syntaxhighlight>
</syntaxhighlight>



2024年7月24日 (三) 06:27的版本

有监督微调(SFT)是指采用预先训练好的神经网络模型,并针对你自己的专门任务在少量的监督数据上对其进行重新训练的技术。

SFT在大语言模型中的应用有以下重要原因:

  • 任务特定性能提升:预训练语言模型通过大规模的无监督训练学习了语言的统计模式和语义表示。然而它在特定任务下的效果可能并不令人满意。通过在任务特定的有标签数据上进行微调,模型可以进一步学习任务相关的特征和模式,从而提高性能。
  • 领域适应性:预训练语言模型可能在不同领域的数据上表现不一致。通过在特定领域的有标签数据上进行微调,可以使模型更好地适应该领域的特殊术语、结构和语义,提高在该领域任务上的效果。
  • 数据稀缺性:某些任务可能受制于数据的稀缺性,很难获得大规模的标签数据。监督微调可以通过使用有限的标签数据来训练模型,从而在数据有限的情况下取得较好的性能。
  • 防止过拟合:在监督微调过程中,通过使用有标签数据进行有监督训练,可以减少模型在特定任务上的过拟合风险。这是因为监督微调过程中的有标签数据可以提供更具体的任务信号,有助于约束模型的学习,避免过多地拟合预训练过程中的无监督信号。[1]

LLM大语言模型所需SFT数据

为每个示例准备文本输入和标签,以问答形式呈现,如下所示:

问题: 维珍澳大利亚何时开始运营?背景: 维珍澳大利亚,是维珍澳大利亚航空有限公司的交易名称,是一家总部设在澳大利亚的航空公司。它是使用维珍品牌的最大机队规模的航空公司。它于2000年8月31日作为维珍蓝航空公司开始运营,在一条航线上有两架飞机。在2001年9月安捷澳大利亚公司倒闭后,它突然发现自己成为澳大利亚国内市场的一家主要航空公司。此后,该航空公司发展到直接服务于澳大利亚的32个城市,从布里斯班到墨尔本和悉尼的枢纽。
回应: 维珍澳大利亚于2000年8月31日以维珍蓝的名义开始提供服务,在一条航线上使用两架飞机。

问答格式可以处理成多种文件格式, 例如JSONL, Excel File, CSV; 核心是要保持两个独立的字段, 即问题和答案。 可以从公开网络下载指令数据模板, 并尝试替换内容: https://huggingface.co/datasets/BAAI/COIG

选择SFT的超参数

EPOCH 影响比 LR 大,可以根据数据规模适当调整EPOCH大小,例如小数据量可以适当增大epoch,让模型充分收敛。

例如:EPOCH:100条数据时, Epoch为15,1000条数据时, Epoch为10,10000条数据时, Epoch为2。 过高的epoch可能会带来通用NLP能力的遗忘,这里需要您根据实际需求核定,若您只需要下游能力提升,则通用NLP能力的略微下降影响不大。若您非常在乎通用NLP能力,平台侧也提供过来种子数据来尽可能保证通用NLP能力不降低太多。 适当增加global batch_size :如增加accumulate step 32 64,当分布式节点增多时可以进一步增加batch_size,提高吞吐。 学习率(LR, learning Rate): 对于ptuing/lora等peft训练方式,同时可以适当增大LR。

Qwen2 Fine tuning

LLaMA-Factory

Install LLaMA-Factory

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]"

Web UI

GRADIO_SERVER_PORT=5050 llamafactory-cli webui

Train

Put https://github.com/FreedomIntelligence/GrammarGPT/blob/main/pseudo_data/instruction.json into /data/

Update data/dataset_info.json: add

{
   "grammar": { "file_name": "instruction.json" }
   ...
}

qwen-sft.yaml:

### model
model_name_or_path: Qwen/Qwen2-1.5B-Instruct

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
dataset: grammar
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/Qwen2-1.5B-Instruct/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

Train:

llamafactory-cli train examples/train_lora/qwen-sft.yaml:

[INFO|tokenization_utils_base.py:2583] 2024-07-23 19:06:55,938 >> Special tokens file saved in saves/Qwen2-1.5B-Instruct/lora/sft/special_tokens_map.json
***** train metrics *****
  epoch                    =     2.9867
  total_flos               =  1675066GF
  train_loss               =     0.5714
  train_runtime            = 0:23:01.86
  train_samples_per_second =      1.954
  train_steps_per_second   =      0.243
Figure saved at: saves/Qwen2-1.5B-Instruct/lora/sft/training_loss.png
07/23/2024 19:06:56 - WARNING - llamafactory.extras.ploting - No metric eval_loss to plot.
07/23/2024 19:06:56 - WARNING - llamafactory.extras.ploting - No metric eval_accuracy to plot.
[INFO|trainer.py:3788] 2024-07-23 19:06:56,956 >>
***** Running Evaluation *****
[INFO|trainer.py:3790] 2024-07-23 19:06:56,956 >>   Num examples = 100
[INFO|trainer.py:3793] 2024-07-23 19:06:56,956 >>   Batch size = 1
100%|████████████████████████████████████████████████████████████████████| 100/100 [00:15<00:00,  6.35it/s]
***** eval metrics *****
  epoch                   =     2.9867
  eval_loss               =     0.5796
  eval_runtime            = 0:00:15.98
  eval_samples_per_second =      6.255
  eval_steps_per_second   =      6.255
[INFO|modelcard.py:449] 2024-07-23 19:07:12,965 >> Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}}

Inference