DeepSeek-R1-Distill-Qwen-1.5B 模型部署文档(gpu部署)
DeepSeek-R1-Distill-Qwen-1.5B 是一个基于 Qwen-1.5B 模型的知识蒸馏版本,适用于多种自然语言处理任务,如文本生成、问答、对话系统等。本文档将指导您如何部署该模型。
·
DeepSeek-R1-Distill-Qwen-1.5B 模型部署文档
1. 简介
DeepSeek-R1-Distill-Qwen-1.5B 是一个基于 Qwen-1.5B 模型的知识蒸馏版本,适用于多种自然语言处理任务,如文本生成、问答、对话系统等。本文档将指导您如何部署该模型。
2. 环境准备
2.1 硬件要求
- GPU: 至少 16GB 显存
- CPU: 8核以上
- 内存: 32GB 以上
- 存储: 至少 10GB 可用空间
2.2 软件要求
- 操作系统: Linux (推荐 Ubuntu 20.04)
- Python: 3.8 或以上
- CUDA: 11.1 或以上
- cuDNN: 8.0 或以上
2.3 依赖库安装
pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install transformers==4.12.0
pip install sentencepiece
pip install flask
3. 模型下载
3.1 从 Hugging Face 下载模型
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "DeepSeek/R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.save_pretrained("./deepseek-r1-distill-qwen-1.5b")
tokenizer.save_pretrained("./deepseek-r1-distill-qwen-1.5b")
3.2 从本地加载模型
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("./deepseek-r1-distill-qwen-1.5b")
tokenizer = AutoTokenizer.from_pretrained("./deepseek-r1-distill-qwen-1.5b")
4. 模型部署(重要流程,步骤3可以忽略)
4.1 使用 Flask 部署 API
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
app = Flask(__name__)
model = AutoModelForCausalLM.from_pretrained("../model/DeepSeek-R1-Distill-Qwen-1.5B")
tokenizer = AutoTokenizer.from_pretrained("../model/DeepSeek-R1-Distill-Qwen-1.5B")
# 将模型移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
@app.route('/generate', methods=['POST'])
def generate():
try:
data = request.json
input_text = data.get('input_text', '')
max_length = data.get('max_length', 50)
# 将输入数据移动到 GPU
inputs = tokenizer(input_text, return_tensors="pt").to(device)
# 模型推理
outputs = model.generate(inputs['input_ids'], max_length=max_length)
# 解码输出
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({'generated_text': generated_text})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
4.2 启动服务
python app.py
4.3 测试 API
curl -H "Content-Type: application/json" -d '{"input_text": "你好", "max_length": 50}' -X POST http://localhost:5000/generate
python调用
import requests
import json
import argparse
def generate_text(input_text, max_length):
# 定义请求数据
url = "http://localhost:5000/generate"
headers = {"Content-Type": "application/json"}
data = {"input_text": input_text, "max_length": max_length}
# 将数据转换为 JSON 字符串
data_str = json.dumps(data, ensure_ascii=False) # 确保 JSON 中的中文字符不被转义
# 发送 POST 请求并流式接收响应
response = requests.post(url, headers=headers, data=data_str.encode('utf-8'), stream=True)
# 检查请求是否成功
if response.status_code == 200:
# 流式处理响应
for line in response.iter_lines(decode_unicode=True):
if line:
try:
# 将字符串解析为 JSON 对象
json_data = json.loads(line)
# 提取 generated_text 并打印
generated_text = json_data.get("generated_text", "")
print(generated_text)
except json.JSONDecodeError:
print(f"无法解析JSON对象: {line}")
else:
print(f"请求失败,状态码: {response.status_code}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="处理输入文本和最大长度参数")
parser.add_argument("--input_text", type=str, required=True, help="输入的文本")
parser.add_argument("--max_length", type=int, required=True, help="最大长度")
args = parser.parse_args()
generate_text(args.input_text, args.max_length)
python app.py --input_text "介绍一下你自己" --max_length 50
5. 参考文档
更多推荐
所有评论(0)