
R1-Zero与R1的复现之路——从Open-Reasoner-Zero到Open R1:先后涉及规则奖励下的PPO迭代,及SFT+GRPO的复现
前言
根据R1的GitHub可知
类别 | 开源内容 | 未开源内容 |
---|---|---|
模型权重 | R1、R1-Zero 及蒸馏模型权重(MIT 协议) | 原始训练数据 未公开冷启动数据、RL 训练数据集或合成数据的具体内容,仅提供依赖的公开数据集名称(如 AI-MO、NuminaMath-TIR) |
技术文档 | GRPO 算法、奖励系统设计、冷启动流程等技术报告 | 训练代码,比如分布式训练代码细节 |
训练工具 | 合成数据生成脚本、评估基准代码 | 完整 RL 训练框架 |
推理支持 | API 接口、本地部署方案、框架适配指南 | 生产级优化内核 即动态显存管理、生产级批处理等企业级部署工具未开源 |
可以看到,R1并未开源关键的训练数据和训练代码,好在如此文《复现DeepSeek V3——在V3官方代码库对MoE、MLA的推理代码之外,补充我对多token预测MTP训练代码的实现》所说,有个Open R1的开源项目,后来又发现了Open-Reasoner-Zero项目
本文便基于这两个开源项目,从零复现R1-Zero以及正式版的R1,以下是本文的编写过程
- 2.13,完成本文的基本框架,包括Open R1项目的介绍(含OpenR1-Math-220k数据集),以及安装/训练/评估等
- 2.14,要想复现R1,算法层面得先有GRPO的实现,好在TRL库实现了,故在此文的第三部分 好好解读了下TRL包中实现的GRPO源码
- 2.22,在花了很大的心思、力气,把deepseek的GRPO、MLA算法的代码解析通透,比如GRPO与PPO的详细对比,再比如MLA中,图片 公式 代码的一一对应,详见此专栏《火爆全球的DeepSeek系列模型》
开始更新本文的:第一部分 对R1-Zero的复现:详解Open-Reasoner-Zero——纯RL训练Qwen2.5 - ..
第一部分 对R1-Zero的复现:详解Open-Reasoner-Zero——纯RL训练Qwen2.5
他们拿Qwen2.5-{7B, 32B} 作为基础模型[5] 进行实验,并直接开始大规模RL 训练,而未进行任何微调(例如蒸馏或SFT)[6, 7]
- 扩展了标准PPO 算法[4]
- 训练数据包括数万对精心挑选的问题和答案,涵盖STEM、数学和推理任务,专为增强模型在多样化和复杂问题解决场景中的能力而设计
- 受DeepSeek-R1 [2] 的启发,他们设计了类似此文《一文速览火爆全球的推理模型DeepSeek R1:如何通过纯RL训练以比肩甚至超越OpenAI o1(含Kimi K1.5的解读)》此节「1.2.3 训练模板:通过prompt让Zero启动深度思考的推理模式」的提示模板
以引导模型利用推理计算,逐步掌握复杂任务的推理能力,如表1 所示 - 此外,他们基于OpenRLHF [8] 开发了一个高效且易于使用的大规模RL训练框架,通过引入更灵活的训练器,实现GPU 协同生成,并支持卸载和回载训练
1.1 Reasoner-Zero的数据集、奖励函数、RL训练
1.1.1 数据集
他们
- 从各种来源收集公开数据,包括AIME(截至2023年)、MATH、Numina-Math集合[9]、Tulu3 MATH[10]以及其他开源数据集
- 根据来源和问题难度,提取AMC、AIME、Math、奥林匹克竞赛和AoPS论坛的组件作为我们的难度级别提示,以确保适当的难度级别
- 使用程序化的方法合成额外的推理任务以扩充数据集
- 排除了使用基于规则的奖励函数难以评估的问题,例如多项选择和以证明为导向的问题,以确保在训练期间奖励计算的准确性和一致性
- 基于启发式评估问题难度实现了一种基于模型的过滤策略。具体而言,我们使用LLM评估每个问题的通过率,移除通过率过高或为零的样本
- 应用基于N-gram和嵌入相似性的过滤方法来去重样本,并保持数据的多样性
最终整理的数据集包含约57k个样本,涵盖STEM、数学和推理领域
1.1.2 规则奖励:且仅包含准确性奖励,不包括格式奖励
与 DeepSeek-R1-Zero[2,详见此文《1.2 DeepSeek-R1-Zero:直接规则驱动的大规模RL训练,去掉SFT》节中的《1.2.2 规则奖励建模(准确性奖励 + 格式奖励):不用训练专门的偏好奖励模型》] 不同
- 他们采用了一个简单的基于规则的奖励函数,该函数仅检查答案的正确性,而不包含任何额外的格式奖励——个人认为他们这么做的原因是 由于训练时已经提供了带有明确格式的prompt模板,且对模型遵循模板的能力比较自信,故不怕模型回答时 不遵循prompt模板中的格式,马上你会看到最终效果 是否真如此
- 具体来说,该奖励函数设计为在训练期间提取‘‘和‘‘标签之间的内容,并将其与参考答案进行比较
为了在扩展强化学习中保持清晰和简洁,他们实现了一个二元奖励方案:对与参考答案完全匹配的情况给予奖励 1,其他所有情况为 0
且为了在评估中确保严格和一致的评估,他们采用了广泛使用的 Math-Verify1 库,其使用方式如图 3 所示
令人惊讶的是,他们发现,通过他们设计的提示,即使是未对齐的基础模型也能以很高的概率生成格式良好的响应。在早期的训练阶段,基础模型可以快速学习并强化正确的推理和回答格式
且如下图所示,即使是基础模型仅使用简单的基于规则的奖励函数,也能快速采用结构化推理模式
从而他们认为,他们的研究表明 对于训练Reasoner-Zero模型而言,复杂的奖励函数并不是必要的
1.1.3 RL训练:GAE下的PPO迭代
不同于DeepSeek-R1-Zero 中使用的GRPO,他们采用了PPO(Proximal Policy Optimization)算法[4] 作为扩展训练的RL算法
- 具体来说,对于每个问题q(即提示),模型生成一组response
,并根据基于规则的奖励函数接收相应的奖励
,其中n 表示采样轨迹的数量(即每个提示的展开大小)
- 对于时间步t(即token t)的每个response
,令
表示时间t 的状态,该状态由问题和所有先前生成的token组成,
表示该步生成的token
- 使用广义优势估计(GAE)[11] 为每个token计算优势估计
通常,GAE 通过结合多个n 步优势估计,并通过参数λ 控制的指数加权平均,在优势估计中提供了偏差和方差之间的权衡。优势计算为
其中是TD(时间差分)残差,γ 是折扣因子,决定了相对于即时奖励对未来奖励的重视程度
- PPO 算法通过优化以下目标函数更新策略模型参数θ 以最大化期望奖励
并更新价值模型参数ϕ 以最小化价值损失
其中ϵ 是截断参数,πθ 是当前策略,πθold 是更新前的旧策略,Vϕ 是价值函数,是折扣回报
且他们使用精心调整的超参数实例化了PPO 算法:GAE 参数λ = 1.0,折扣因子γ = 1.0,以及截断参数ϵ = 0.2
如果对以上GAE和PPO的任何一切细节、公式、符号不太理解,可以参看此文《ChatGPT技术原理解析:从RL之PPO算法、RLHF到GPT4、instructGPT》
1.1.4 对Reasoner-Zero小结及与ChatGPT-PPO、R1-GRPO的对比
他们认为
- RL 算法关键实现:他们的实证研究表明,原始PPO 在不同模型规模和训练时长下提供了一个非常稳定且稳健的训练过程,而无需额外的修改
且通过广泛的实验,他们发现GAE 参数在PPO 的推理任务中起到了关键作用
具体来说,设置λ = 1.0 和γ = 1.0,虽然在传统RL 场景中通常被认为是次优的,但在扩展RL 训练中达到了理想的平衡 - 最小化奖励函数设计:他们认为他们表明,一个简单的基于规则的奖励函数不仅是足够的,而且是最优的,因为最小化设计不会留下潜在奖励漏洞的空间
且值得注意的是,即使是未对齐的基础模型也能快速适应所需格式,这表明这是一个无需复杂奖励工程的简单任务 - 损失函数:他们在不依赖任何基于KL 的正则化技术(例如,KL shaped rewards and loss)的情况下实现了稳定的训练,这与事实上的RLHF 社区[12] 和Reasoner 模型[13, 2] 不同
这也为进一步的大规模RL 提供了有希望的潜力
这里其实对比的是ChatGPT的RLHF训练方式「更多详看此文《ChatGPT技术原理解析:从RL之PPO算法、RLHF到GPT4、instructGPT》」,即其有对RM函数加KL正则化约束
具体而言,的作用是通过KL散度对比RL在最大化RM的目标下学到的策略
和基线策略
的差距,一开始时,
的初始化值就是
,最终希望
最终迭代结束后,它俩之间的差距不至于太大
- 扩大训练数据:他们发现,扩大数据的数量和多样性对于Reasoner-Zero 的训练至关重要。在有限的学术数据集(如MATH)上训练会导致性能快速达到瓶颈,而他们精心设计的大规模多样化数据集使得在训练集和测试集上都能持续扩展且没有饱和迹象
下面,我再把Reasoner-Zero的RL训练模式,与ChatGPT-PPO、R1-GRPO做下对比,如下表格所示
通用模型 ChatGPT-PPO、微软deepspeed chat |
推理模型 R1-Zero-GRPO |
推理模型 Reasoner-Zero-PPO |
|
奖励函数 可以通过r(x,y)真实获取对应的奖励值 |
排序数据训练偏好奖励模型 |
规则奖励建模(准确性奖励 + 格式奖励):不用训练专门的偏好奖励模型 对于数学问题:“ 7 + 3*7 = ?”,系统通过包含这个问题在内的一系列数学问题-答案集,知道该问题的答案为28,那么只需要检查模型的输出<answer>是否为28 就行了 |
规则奖励:且仅包含准确性奖励,不包括格式奖励 二元奖励方案:对与参考答案完全匹配的情况给予奖励 1,其他所有情况为 0 |
优势函数 | ![]() |
比如 标准差(近似计算) = 0.5 则有 |
|
回报 是需要计算的 |
![]() |
无 |
相当于 |
价值估计 | ![]() ![]() |
无 | |
目标函数 |
|
|
第二部分 Open R1:以Qwen2.5-1.5B为基础,封装各种开源框架
2.1 Open R1分别对外开源的内容
2.1.1 GRPO的实现、数据生成器
Open R1复现了R1正式版完整训练流程的前两个阶段「以Qwen2.5-1.5B为基础,以deepseek-R1的训练过程打造」,并把代码开源了,其GitHub仓库主要包括GRPO的实现、训练与评估代码、用于合成数据的生成器
具体而言,涉及如下
- src/open_r1:自身实现的4个独立脚本,用于训练和评估模型以及生成合成数据的脚本,这4个独立脚本具体如下所示
1) grpo.py:在给定的数据集上使用 GRPO 训练模型
2) sft.py:在数据集上执行模型的简单 SFT
3) evaluate.py:在 R1 基准上评估模型
4) generate.py:使用Distilabel从模型生成合成数据 - 封装了Transformer框架,和RL框架TRL
TRL这个框架我曾在我这篇文章里介绍过,其支持SFT、PPO、GRPO等训练方法
换言之,Open R1并没有再去实现一遍GRPO——也没必要,而是直接用的TRL框架中对GRPO的实现 - 封装了计算图distilabel框架,和MegFlow类似,内部用 networkx 实现 DAG
比如open-r1 用 distilabel 加载目标 LLM、造 QA 数据。 例如用 qwen-7B时:python3 src/open_r1/generate.py --hf-dataset /data/share/NuminaMath-TIR --model Qwen2.5-7B-Instruct --prompt-column problem
- 封装了评测方法 lighteval
evaluation 框架基本模式,都是加载 dataset、运行模型、打满吞吐、打印精度表。推理期间包装不同的 inference repo - 封装了底层推理框架vLLM
2.1.2 对R1训练流程前两个阶段的复现(SFT和GRPO训练)
如下图所示,Open R1分别实现了
- 从 DeepSeek-R1 中提取高质量语料库来复现 R1-Distill 模型
这里有个很重要的问题是,到底如何从R1中提取高质量语料库
其实如Open R1的GitHub所说,从 DeepSeek-R1 提炼出的具有推理轨迹的数据集(例如Bespoke-Stratos-17k)上运行 SFT - 基于DeepSeek V3 创建 R1-Zero 的纯 RL 管道
- 复现R1正式版完整训练流程的前两个阶段(SFT + 规则奖励下的RL)——毕竟完整的R1正式版训练流程有4个阶段呢
阶段一 冷启动SFT 阶段二 规则奖励下的RL R1-Zero模型生成的冷启动数据:微调V3 面向推理的RL:结合三个规则奖励——准确率奖励、格式奖励、语言一致性奖励 阶段三 增强SFT 阶段四 规则+偏好奖励下的RL 来自阶段二模型的60w推理数据
和V3模型的20w非推理数据:微调V3
全场景RL
规则奖励、偏好奖励
我司也会在这个课程《DeepSeek原理与项目实战营》里讲一下这个Open R1的复现思路,及深入解读其源码,以帮助更多人可以更好的用好该Open R1
2.2 Open R1对外开源的OpenR1-Math-220k数据集
2.2.1 Math-220k与现存推理数据集的比较
如此文《一文速览火爆全球的推理模型DeepSeek R1:如何通过纯RL训练以比肩甚至超越OpenAI o1(含Kimi K1.5的解读)》所说
作者还实验了蒸馏——赋予小模型推理能力
- 方法是直接使用 DeepSeek-R1 阶段三中精心挑选的 80 万个样本对开源模型如 Qwen(Qwen, 2024b)和 Llama(AI@Meta,2024)进行了微调
- 这80万样本中包含来自R1 4阶段训练中阶段二模型的60w推理数据
最终,通过这「60w推理数据+20w非推理数据」对小模型做微调,也能让小模型即便不经过专门的RL训练也能获得不俗的推理能力,比如DeepSeek-R1-32B和DeepSeek-R1-70B在大多数基准上明显优于o1-mini
遗憾的是,R1并未开源这60万条推理数据
OpenR1-Math-220k 数据集就是来补上这块空缺的。具体而言,Open R1 团队使用 DeepSeek R1 生成了 80 万条推理轨迹,经过筛选和验证后得到了 22 万条高质量数据
虽然在此之前,目前市面上开源的推理数据集包括:OpenThoughts-114k、Bespoke-Stratos-17k、Dolphin-R1 和 LIMO 等多个推理数据集
那与现有数据集相比,OpenR1-Math-220k数据集有什么新的特点呢
- 80 万条 R1 推理轨迹
使用 DeepSeek R1 为 40 万个问题各生成了两个答案,最终经过筛选后保留了 22 万个具有正确推理轨迹的问题 - 本地运行 512 个 H100
没有依赖 API,而是在计算集群上利用 vLLM 和 SGLang 本地运行生成任务,每天可以生成 18 万条推理轨迹 - 基于 NuminaMath 1.5
专注于数学推理公式,为 NuminaMath 1.5(NuminaMath-CoT 数据集的改进版本)中的问题生成答案 - 自动过滤
Open R1 团队通过数学验证,只保留至少有一个正确答案的问题,还让 Llama3.3-70B-Instruct 作为「判官」,以筛选出更多正确的样本,特别是那些因格式错误而无法通过基于规则的解析器验证的答案 - 性能追平
在 OpenR1-Math-220k 训练出来的 Qwen-7B-Math-Instruct,达到了与 DeepSeek-Distill-Qwen-7B 相当的性能
总之,OpenR1-Math-220k数据集分为如下两个部分
- default(94k 问题):这部分数据在经过监督微调(SFT)后表现最佳
- extended(131k 问题):这部分数据包含额外的 NuminaMath 1.5 数据源,例如 cn_k12,提供了更多的推理公式
研究发现这个子集在经过监督微调后的性能低于默认数据集,作者认为可能是因为 cn_k12 中的问题相对简单
2.2.2 OpenR1-Math-220k数据集的创造过程:生成、过滤、评估
首先,对于数据生成
他们为了构建数据集,OpenR1 团队让 DeepSeek R1 为来自 NuminaMath 1.5 的 40 万个问题生成答案。他们遵循了 DeepSeek 技术报告中推荐的参数设置,并在提示词前添加了以下指令:
Please reason step by step, and put your final answer within \boxed{}.
且为了确保生成过程的高效性,团队将每次生成的 tokens 限制设置为 16k。经过分析发现,只有 75% 的问题能够在 8k tokens 内解决,而大多数剩余问题需要完整的 16k tokens
- 最初,他们使用 vLLM 进行推理,每个 H100 节点每秒可以生成 15 个答案(相关生成脚本已分享在 OpenR1 仓库中)
- 最近,他们又开始尝试使用 SGLang,每个 H100 节点每秒可以生成 25 个答案(速度提升了近两倍),这使得 512 个 H100 节点上每天能生成 30 万个问题的答案
- 为了在后续的过滤和优化过程中提供更大的灵活性,团队为每个问题生成了2个答案 —— 有时甚至生成4个
这样一来,不仅复刻出了类似于 DeepSeek R1 允许进行拒绝采样的方法,还能使数据集能够适用于如 DPO 等偏好优化方法
对应的数据生成脚本在此:huggingface/open-r1/tree/main/slurm
其次,对于数据过滤
即为了确保数据集中只包含高质量且正确的推理结果,Open R1 团队设计了一套数学验证系统,用于自动比对 LLM 生成的复杂数学表达式答案与数据集中的标准答案
- 在这个过程中,OpenR1 团队发现大约 55% 的问题至少有一个正确答案。然而,NuminaMath 1.5 数据集中有很多答案是空的,或者格式不符合验证标准,这都给自动验证带来了困难
- 为了解决这些问题,Open R1 团队先是对 Math-Verify 工具进行了改进,使其能够处理更多不常见的答案格式,再使用 Llama-3.3-70B-Instruct 模型进行二次评估
具体来说,对于那些被 Math-Verify 判定为错误的答案,使用 Llama-3.3-70B-Instruct 模型重新评估,识别实际上正确但因格式问题被错判的答案
最终,他们找回了 2.5 万条被「误判」的数据 - 优化 Math-Verify 工具:对 Math-Verify 工具进行了改进,使其能够处理更多不常见的答案格式
让 Llama-3.3-70B-Instruct 「作判官」的提示词如下:You are a mathematical answer validator. You will be provided with a mathematical problem and you need to compare the answer in the reference solution, and the final answer in a model's solution to determine if they are equivalent, even if formatted differently. PROBLEM: {problem} REFERENCE SOLUTION: {answer} MODEL'S SOLUTION: {generation} Focus ONLY on comparing the final mathematical answer provided by the model while ignoring differences in: - Formatting (e.g., \\boxed{{}} vs plain text) - Multiple choice formatting (e.g., "A" vs full solution) - Order of coordinate pairs or solutions - Equivalent mathematical expressions or notation variations - If the model's answer is nonsense, return "Verdict: AMBIGUOUS" Start with a brief explanation of your comparison (2-3 sentences). Then output your final answer in one of the following formats: - "Verdict: EQUIVALENT" - "Verdict: DIFFERENT" - "Verdict: AMBIGUOUS"
- 对于那些包含多个正确答案的数据行,团队尝试使用奖励模型(RM)作为最终筛选器来选择最佳答案。具体操作如下:
首先,从每个包含多个正确答案的数据行中,去掉(<think>…</think>),提取最终答案;
第二,将问题和提取的答案输入到配置了 vLLM 的 Qwen/Qwen2.5-Math-RM-72B 模型中,获取每个答案的评分
接着,根据模型评分,对每个包含多个正确答案的数据行排名,选择排名最高的答案纳入训练数据集
遗憾的是,消融实验表明,这种方法并没有比随机选择一个正确答案带来更好的模型性能。Open R1 团队的判断是,可能在使用奖励模型评分时,不仅要考虑最终答案,还要包括推理过程
最后,对于效果评估上
Open R1 在 OpenR1-Math-220k 的基础上,对 Qwen2.5-Math-Instruct 进行了 3 轮微调,学习率为 5e-5
- 为了将上下文长度从 4k 扩展到 32k,他们将 RoPE 频率提高到 300k。训练遵循线性学习率调度,其中包含 10% 的预热阶段
- 下表展示了在 lighteval 上 OpenR1-Qwen-7B、DeepSeek-Distill-Qwen-7B 和 OpenThinker-7B 的性能对比,可以看出在数学成绩上,OpenR1-Qwen-7B 和 DeepSeek-Distill-Qwen-7B 差距不是非常明显
第三部分 Open R1的安装、训练、评估
3.1 环境搭建与依赖安装
Open-R1 项目依赖 CUDA 12.1
- 创建虚拟环境并升级 pip
可使用uv
工具创建 Python 虚拟环境,安装uv
后执行如下命令(uv安装文档 https://docs.astral.sh/uv/getting-started/installation/)uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip
- 安装 vLLM:CUDA 12.1 环境下,运行下面的命令
uv pip install vllm==0.6.6.post1`` ``# For CUDA 12.1``pip install vllm==0.6.6.post1 --extra-index-url https://download.pytorch.org/whl/cu121``export LD_LIBRARY_PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/nvjitlink/lib')"):$LD_LIBRARY_PATH
- 安装其余依赖:根据具体使用场景,执行
pip install -e.[LIST OF MODES]
对于多数开发者,建议使用
pip install -e ".[dev]" - 登录账号与检查 Git LFS:登录 Hugging Face 和 Weights and Biases 账号,分别执行
检查系统是否安装 Git LFS,若未安装,使用sudo apt-get install git-lfs进行安装huggingface-cli login``wandb login
3.2 模型训练
3.2.1 SFT阶段:可通过Bespoke-Stratos-17k微调Qwen2.5-Math-1.5B-Instruct
Open-R1 支持 DDP 和 DeepSpeed(ZeRO-2、ZeRO-3)两种训练方式,切换时只需调整configs文件夹中加速器 YAML 配置文件路径。以配备 8 块 H100(80GB)显卡的节点为例,训练命令如下:
SFT 的代码在下述代码文件里
accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py
简言之
- 可以先下载HuggingFaceH4/Bespoke-Stratos-17k数据集
「数据集地址:https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k」 - 然后再下载Qwen/Qwen2.5-Math-1.5B-Instruct模型
- 之后,基于上面的数据集做对该模型做SFT
具体SFT时,一般涉及以下步骤
- 加载数据集和tokenizer
# Load datasets dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) tokenizer.pad_token = tokenizer.eos_token
- 配置模型参数
- 设置训练器
# Initialize the SFT Trainer trainer = SFTTrainer( model=model_args.model_name_or_path, # 指定模型路径 args=training_args, # 指定训练参数 train_dataset=dataset[script_args.dataset_train_split], # 指定训练数据集 eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, # 指定测试数据集 processing_class=tokenizer, # 指定tokenizer peft_config=get_peft_config(model_args), callbacks=get_callbacks(training_args, model_args), )
- 训练与保存
过程中,若使用 Slurm 调度系统,可运行
sbatch --output=/path/to/logs/%x-%j.out --err=/path/to/logs/%x-%j.err slurm/sft.slurm {model} {dataset} {accelerator}
3.2.2 GRPO阶段:可通过NuminaMath-TIR数据集对R1-Distill-Qwen-7B做RL训练
首先是数据集、模型权重的下载
- 可以先下载AI-MO/NuminaMath-TIR数据集
数据集地址:https://huggingface.co/datasets/AI-MO/NuminaMath-TIR - 然后再下载deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
初始模型权重地址:https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/tree/main
下载完成后注意修改recipes/qwen/Qwen2.5-1.5B-Instruct/grpo/confg_full.yaml中的model_name_or_path和dataset_name,以匹配模型和数据集的位置
然后修改num_processes,如果你有8块GPU,需设置为7(因为vllm需要占用一块GPU),以此类推。(如果只有一块GPU,建议直接将use_vllm改为false,然后将num_processes改为1) - 之后使用上面的数据集对该模型做RL训练——RL算法用GRPO
在Open R1中,整个GRPO的流程被封装在TRL库中,用户可以直接调用
accelerate launch --config_file configs/zero3.yaml src/open_r1/grpo.py
接下来,咱们便来具体看下GRPO的训练过程
- 首先是,奖励函数,包括两个:准确率奖励(accuracy_reward)和格式奖励(format_reward)。准确率奖励意味着解题越准确,分数越高;格式奖励意味着输出的格式越标准,分数越高
// 待更 - 加载数据集、设置奖励函数、设置模型
# Load the dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Get reward functions reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, ) training_args.model_init_kwargs = model_kwargs
- 设置trainer
# Initialize the GRPO trainer trainer = GRPOTrainer( model=model_args.model_name_or_path, # 指定模型路径 reward_funcs=reward_funcs, # 指定奖励函数 args=training_args, train_dataset=dataset[script_args.dataset_train_split], # 指定训练集 eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), # 指定PEFT配置 callbacks=get_callbacks(training_args, model_args), # 指定回调函数 )
- 训练和保存
// 待更
3.3 模型评估
// 待更
第四部分 Open R1的进一步更新
详见此文《Open R1: Update #3》
// 待更
更多推荐
所有评论(0)