手写最基础的大模型推理
·
代码样例
手写一个 最基础的大模型推理示例,但是为了方便测试和节省显存,我们先用一个 小模型,比如 HuggingFace 上的 distilgpt2(比 GPT-2 小很多,适合 CPU/GPU 本地跑)。
下面我给你一个 完整 Python 示例,不用任何复杂框架,也能做简单的文本生成。
1️⃣ 安装依赖
pip install torch transformers
2️⃣ 基础推理代码
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# 选择小模型
model_name = "distilgpt2"
# 加载 tokenizer 和模型
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 如果有 GPU,可切换到 CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# 输入 prompt
prompt = "Hello, I am a small language model. I can"
# 编码输入
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 推理生成
# max_length 控制输出长度
outputs = model.generate(**inputs, max_length=50, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, no_repeat_ngram_size=2)
# 解码输出
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("生成结果:")
print(text)
3️⃣ 代码说明
-
Tokenizer
- 把文字转换成模型可以理解的 token。
-
Model
AutoModelForCausalLM用于自回归文本生成。- 小模型
distilgpt2显存需求很低,几百 MB 就够。
-
生成参数
max_length=50:生成长度上限。do_sample=True:随机采样生成多样文本。temperature=0.7:温度越高,生成越随机。
| 参数 | 作用 | 效果 |
|---|---|---|
max_length |
限制生成长度 | 防止无限生成 |
do_sample |
是否随机采样 | True → 多样化,False → 贪心固定 |
temperature |
调整采样随机性 | <1 保守,>1 随机 |
top_k |
只采样前 K 个 token | 限制输出范围,降低低概率 token 出现 |
top_p |
只采样累积概率 ≥ p 的 token | 动态选择高概率集合 |
no_repeat_ngram_size |
避免 n-gram 重复 | 防止生成重复短语 |
-
GPU/CPU 切换
- 用
torch.cuda.is_available()自动检测是否有 GPU。
- 用
✅ 这个示例可以在 本地 CPU 或 GPU 上运行,快速实现 大模型推理的最基础流程。
运行效果
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 76/76 [00:00<00:00, 1691.30it/s, Materializing param=transformer.wte.weight]
GPT2LMHeadModel LOAD REPORT from: distilgpt2
Key | Status | |
-------------------------------------------+------------+--+-
transformer.h.{0, 1, 2, 3, 4, 5}.attn.bias | UNEXPECTED | |
Notes:
- UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
生成结果:
Hello, I am a small language model. I can write code with a few different things like:
1. Create an object that can be modified
2. Use a custom function
3. Write an anonymous function in the body
4. Add an optional parameter to the function (or an argument)
5. Make a class that accepts a value and a function that returns a string
6. Get a method that uses a single argument
7. Return a new value
8. Convert an empty function into a .hf file
9. Delete the empty method
10. Replace the .hsf with the following code:

更多推荐


所有评论(0)