目录

1. configs.py

功能概述

关键代码与细节

2. evaluate.py

功能概述

关键代码与细节

3. generate.py

功能概述

关键代码与细节

4. grpo.py

功能概述

关键代码与细节

5. rewards.py

功能概述

关键代码与细节

6. sft.py

功能概述

关键代码与细节

安装

训练模型

评估模型

复现DeepSeek的评估结果

MATH-500

GPQA Diamond

数据生成流程


技术实现与细节

以下是对提供的代码文件的详细剖析,结合代码内容和项目背景,分析其功能、实现细节和应用场景。

1. configs.py

功能概述

configs.py 文件定义了两种配置类:GRPOConfigSFTConfig,分别用于 GRPO(Group Relative Policy Optimization)训练和 SFT(Supervised Fine-Tuning)训练。这些配置类继承自 trl(Transformers Reinforcement Learning)库中的基础配置类,并添加了一些额外的参数。

关键代码与细节
  • GRPOConfig 和 SFTConfig

    @dataclass
    class GRPOConfig(trl.GRPOConfig):
        benchmarks: list[str] = field(
            default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
        )
        callbacks: list[str] = field(
            default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
        )
        system_prompt: Optional[str] = field(
            default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
        )
        hub_model_revision: Optional[str] = field(
            default="main", metadata={"help": "The Hub model branch to push the model to."}
        )
        overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
        push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    • 继承关系GRPOConfigSFTConfig 继承自 trl.GRPOConfigtrl.SFTConfig,扩展了这些类的功能。

    • 新增参数

      • benchmarks:训练后运行的基准测试列表。

      • callbacks:训练过程中运行的回调函数列表。

      • system_prompt:用于基准测试的系统提示。

      • hub_model_revision:推送模型到 Hugging Face Hub 的分支。

      • overwrite_hub_revisionpush_to_hub_revision:控制是否覆盖或推送模型版本。

  • 应用场景

    • 这些配置类用于定义训练和评估的参数,支持用户自定义训练流程中的各种设置,如基准测试、回调函数和模型版本管理。

2. evaluate.py

功能概述

evaluate.py 文件定义了自定义的评估任务,用于在 LightEval 框架中评估模型的性能。这些任务包括数学推理、问答等。

关键代码与细节
  • 评估指标

    latex_gold_metric = multilingual_extractive_match_metric(
        language=Language.ENGLISH,
        fallback_mode="first_match",
        precision=5,
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
        aggregation_function=max,
    )
    • multilingual_extractive_match_metric:一个多语言的提取匹配指标,用于评估模型生成的内容是否与参考答案匹配。

    • gold_extraction_targetpred_extraction_target:定义了从参考答案和模型生成内容中提取信息的配置。

  • 提示函数

    def prompt_fn(line, task_name: str = None):
        return Doc(
            task_name=task_name,
            query=line["problem"],
            choices=[line["solution"]],
            gold_index=0,
        )
    • prompt_fn:生成评估任务的提示,用于数学推理任务。

    • aime_prompt_fngpqa_prompt_fn:分别为 AIME 和 GPQA 任务生成提示。

  • 任务定义

    aime24 = LightevalTaskConfig(
        name="aime24",
        suite=["custom"],
        prompt_function=aime_prompt_fn,
        hf_repo="HuggingFaceH4/aime_2024",
        hf_subset="default",
        hf_avail_splits=["train"],
        evaluation_splits=["train"],
        few_shots_split=None,
        few_shots_select=None,
        generation_size=32768,
        metric=[expr_gold_metric],
        version=1,
    )
    • LightevalTaskConfig:定义了一个评估任务的配置,包括任务名称、提示函数、数据集、评估指标等。

    • TASKS_TABLE:将所有定义的任务存储在一个列表中,便于管理和运行。

  • 应用场景

    • 该文件用于定义和运行模型的评估任务,支持多种数学推理和问答任务,帮助用户评估模型在不同领域的性能。

3. generate.py

功能概述

generate.py 文件定义了一个用于生成数据的管道,使用 distilabel 工具从模型中生成合成数据。

关键代码与细节
  • 构建管道

    def build_distilabel_pipeline(
        model: str,
        base_url: str = "http://localhost:8000/v1",
        prompt_column: Optional[str] = None,
        prompt_template: str = "{{ instruction }}",
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        max_new_tokens: int = 8192,
        num_generations: int = 1,
        input_batch_size: int = 64,
        client_replicas: int = 1,
        timeout: int = 900,
        retries: int = 0,
    ) -> Pipeline:
        ...
    • build_distilabel_pipeline:构建一个 distilabel 管道,用于生成数据。

    • 参数

      • model:用于生成数据的模型名称。

      • base_url:模型服务器的 URL。

      • prompt_columnprompt_template:定义提示的列和模板。

      • temperaturetop_p:生成的温度和核采样参数。

      • max_new_tokens:生成的最大新 token 数量。

      • num_generations:每个输入生成的样本数量。

      • input_batch_size:输入的批量大小。

      • client_replicas:客户端副本数量,用于并行处理。

      • timeoutretries:请求超时和重试次数。

  • 主函数

    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
        ...
        args = parser.parse_args()
        ...
        pipeline = build_distilabel_pipeline(
            model=args.model,
            base_url=args.vllm_server_url,
            prompt_template=args.prompt_template,
            prompt_column=args.prompt_column,
            temperature=args.temperature,
            top_p=args.top_p,
            max_new_tokens=args.max_new_tokens,
            num_generations=args.num_generations,
            input_batch_size=args.input_batch_size,
            client_replicas=args.client_replicas,
            timeout=args.timeout,
            retries=args.retries,
        )
        ...
        distiset = pipeline.run(
            dataset=dataset,
            dataset_batch_size=args.input_batch_size * 1000,
            use_cache=False,
        )
        ...
    • 命令行参数:通过 argparse 解析命令行参数,支持用户自定义生成数据的配置。

    • 数据加载:使用 datasets 加载数据集。

    • 管道运行:运行生成管道,生成合成数据并保存到 Hugging Face Hub。

  • 应用场景

    • 该文件用于生成合成数据,支持用户自定义生成配置,适用于模型训练和数据增强。

4. grpo.py

功能概述

grpo.py 文件实现了 GRPO(Group Relative Policy Optimization)训练流程,用于优化模型的策略。

关键代码与细节
  • GRPOScriptArguments

    @dataclass
    class GRPOScriptArguments(ScriptArguments):
        reward_funcs: list[str] = field(
            default_factory=lambda: ["accuracy", "format"],
            metadata={
                "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
            },
        )
        cosine_min_value_wrong: float = field(
            default=0.0,
            metadata={"help": "Minimum reward for wrong answers"},
        )
        ...
    • reward_funcs:定义奖励函数列表,支持多种奖励函数,如准确率、格式、推理步骤、余弦缩放和重复惩罚。

    • cosine_min_value_wrong 等参数:定义余弦缩放奖励的参数。

  • 主函数

    def main(script_args, training_args, model_args):
        ...
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
        ...
        reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
        ...
        trainer = GRPOTrainer(
            model=model_args.model_name_or_path,
            reward_funcs=reward_funcs,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
            peft_config=get_peft_config(model_args),
            callbacks=get_callbacks(training_args, model_args),
        )
        ...
    • 数据加载:使用 datasets 加载训练和评估数据集。

    • 奖励函数:根据用户指定的奖励函数,加载相应的函数。

    • GRPOTrainer:初始化 GRPO 训练器,设置模型、奖励函数、训练参数等。

    • 训练循环:运行训练循环,支持从断点恢复训练。

  • 应用场景

    • 该文件用于 GRPO 训练,支持多种奖励函数和训练配置,适用于优化模型的策略。

5. rewards.py

功能概述

rewards.py 文件定义了多种奖励函数,用于在 GRPO 训练中评估模型生成的内容。

关键代码与细节
  • 奖励函数

    def accuracy_reward(completions, solution, **kwargs):
        ...
        reward = float(verify(answer_parsed, gold_parsed))
        ...
    • accuracy_reward:检查模型生成的内容是否与参考答案一致,返回 1 或 0。

    • format_reward:检查生成内容是否符合特定格式。

    • reasoning_steps_reward:检查生成内容是否包含清晰的推理步骤。

    • cosine_scaled_reward:基于生成内容长度的余弦缩放奖励。

    • repetition_penalty_reward:基于重复 n-gram 的惩罚奖励。

  • 应用场景

    • 这些奖励函数用于 GRPO 训练,帮助模型生成更准确、更符合格式、更具推理性和更少重复的内容。

6. sft.py

功能概述

sft.py 文件实现了 SFT(Supervised Fine-Tuning)训练流程,用于对模型进行有监督微调。

关键代码与细节
  • 主函数

    def main(script_args, training_args, model_args):
        ...
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
        ...
        trainer = SFTTrainer(
            model=model_args.model_name_or_path,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
            processing_class=tokenizer,
            peft_config=get_peft_config(model_args),
            callbacks=get_callbacks(training_args, model_args),
        )
        ...
    • 数据加载:使用 datasets 加载训练和评估数据集。

    • SFTTrainer:初始化 SFT 训练器,设置模型、训练参数、分词器等。

    • 训练循环:运行训练循环,支持从断点恢复训练。

  • 应用场景

    • 该文件用于 SFT 训练,支持多种训练配置和回调函数,适用于对模型进行有监督微调。

功能模块

  • 模型训练

    • SFT(Supervised Fine-Tuning):对预训练模型进行微调,使其更好地适应特定任务。例如,在指令微调中,将小样本数据集用于微调,使模型生成更符合人类常识的对话内容。

    • GRPO(Group-Relative Policy Optimization):使用 GRPO 方法对模型进行 RL(强化学习)培训。该方法基于代理与环境之间的交互,通过最大化累积奖励信号来训练策略模型。

  • 模型评估

    • 使用 lighteval 对模型进行评估,lighteval 是一种轻量级的评估工具,支持多种评估任务。例如,在 AIME 2024、MATH-500 和 GPQA Diamond 等任务上对模型进行测试,得到准确率等评估指标,以评估模型的性能。

  • 数据生成

    • 从 smol 蒸馏 R1 模型生成数据:使用轻量级的蒸馏 R1 模型生成数据。该模块通过 Distilabel 来生成合成数据,为模型训练提供更多样化的数据。

    • 从 DeepSeek-R1 生成数据:使用更大的 DeepSeek-R1 模型生成数据。这需要更多的计算资源,但可以生成更高质量的合成数据,以支持更复杂的模型训练和测试。

安装

[!CAUTION]
相关库依赖于CUDA 12.4。如果您看到与段错误相关的错误,请使用nvcc --version仔细检查您的系统正在运行的CUDA版本。

要运行这个项目中的代码,首先,使用例如uv创建一个Python虚拟环境。
要安装uv,请参考UV安装指南

uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip --link-mode=copy

接下来,安装vLLM:

uv pip install vllm==0.7.1 --link-mode=copy

这也会安装PyTorch v2.5.1,使用这个版本非常重要,因为vLLM的二进制文件是针对该版本编译的。然后,您可以通过pip install -e .[LIST OF MODES]安装特定用例的其余依赖项。对于大多数贡献者,我们建议:

GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" --link-mode=copy

接下来,按如下方式登录您的Hugging Face和Weights and Biases账户:

训练模型

我们支持使用数据并行分布式训练(DDP)或DeepSpeed(ZeRO-2和ZeRO-3)来训练模型。例如,要在从DeepSeek-R1提炼的带有推理痕迹的数据集(如Bespoke-Stratos-17k)上运行监督微调(SFT),请运行以下命令:

# 通过命令行进行训练
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
    --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
    --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
    --learning_rate 2.0e-5 \
    --num_train_epochs 1 \
    --packing \
    --max_seq_length 4096 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --bf16 \
    --output_dir data/Qwen2.5-1.5B-Open-R1-Distill

# 通过YAML配置文件进行训练
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
    --config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml

目前,支持以下任务:

评估模型

make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=data NUM_GPUS=8

要使用张量并行:

make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=tensor NUM_GPUS=8

复现DeepSeek的评估结果

MATH-500

我们能够在约1 - 3个标准差范围内复现DeepSeek在MATH-500基准测试上报告的结果:

模型 MATH-500(🤗 LightEval) MATH-500(DeepSeek报告值)
DeepSeek-R1-Distill-Qwen-1.5B 81.2 83.9
DeepSeek-R1-Distill-Qwen-7B 91.8 92.8
DeepSeek-R1-Distill-Qwen-14B 94.2 93.9
DeepSeek-R1-Distill-Qwen-32B 95.0 94.3
DeepSeek-R1-Distill-Llama-8B 85.4 89.1
DeepSeek-R1-Distill-Llama-70B 93.4 94.5

要复现这些结果,请使用以下命令:

NUM_GPUS=1 # 对于32B和70B模型,设置为8
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS"
OUTPUT_DIR=data/evals/$MODEL

lighteval vllm $MODEL_ARGS "custom|math_500|0|0" \
    --custom-tasks src/open_r1/evaluate.py \
    --use-chat-template \
    --output-dir $OUTPUT_DIR

GPQA Diamond

lighteval vllm $MODEL_ARGS "custom|gpqa:diamond|0|0" \
    --custom-tasks src/open_r1/evaluate.py \
    --use-chat-template \
    --output-dir $OUTPUT_DIR
python scripts/run_benchmarks.py --model-id={model_id}  --benchmarks gpqa
数据生成流程
  • 小模型蒸馏数据生成

    • 使用轻量级蒸馏 R1 模型生成数据。通过 Distilabel 工具,从预定义的提示模板和数据集出发,生成合成数据。

    • 例如,使用 DeepSeek-R1 的蒸馏 Qwen-7B 模型生成数学推理数据,将数据保存到远程数据集中,并可通过华为 MindSpore 加载该数据集以用于训练。

  • 大模型数据生成

    • 使用更大的 DeepSeek-R1 模型生成数据,需要更多的计算资源。通过 Slurm 脚本(如 slurm/generate.slurm)在集群上运行生成任务,可以高效地生成大规模合成数据。

    • 生成过程中,可以通过设置温度(如 0.6)、提示列(如 “problem”)等参数来控制生成数据的质量和多样性。

更多可参照GitHub - huggingface/open-r1: Fully open reproduction of DeepSeek-R1

点击阅读全文
Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐