大模型LLMs基于Langchain+FAISS+Ollama/Deepseek/Qwen/OpenAI的RAG检索方法以及优化
·
写在前文
基础框架环境:Langchain、Ollama(可以替换为Qwen、ChatGLM、Deepseek)、
向量库:FAISS(可以直接替换为Chroma、Qdrant、FAISS、Milvus、LanceDB....之类)
Chat模型:llm = ChatOllama(model='deepseek-r1:1.5b') # 简化版,需要使用比如OpenAI、DK云端版直接初始化即可...
向量化:embedding=OllamaEmbeddings(model='nomic-embed-text:latest')
所有环境全部默认已经初始化...。
本文主要提供检索方法内容如下
0、初始化环境
0.1、初始化日志以及依赖
from operator import itemgetter
from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever, ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker, \
DocumentCompressorPipeline
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.document_compressors import FlashrankRerank
from langchain_community.retrievers import BM25Retriever
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.tracers import ConsoleCallbackHandler
import utils, uuid
import numpy as np
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import SequentialChain
import logging
import os
from langchain_community.document_loaders import TextLoader
from flashrank import Ranker, RerankRequest
# 配置日志
logging.basicConfig(
filename='app.log',
level=logging.INFO,
encoding='utf-8',
format='%(asctime)s - %(message)s'
)
logging.getLogger("langchain").setLevel(logging.DEBUG)
0.2、初始化llm和embedding
"""
初始化LLM和Embedding
"""
# 本文所有参数都是最基础的,其余的比如top_p、top_k....之类采用默认即可。
llm = ChatOpenAI(
base_url='http://localhost:11434/v1',
api_key='ollama',
model='deepseek-r1:1.5b', # deepseek-r1:1.5b、qwen2:0.5b
)
#llm = ChatOllama(model='deepseek-r1:1.5b')
# 需要下载nomic-embed-text,除了这个也可以采用HaggingFace的库或者baichuan、OpenAI库。
embedding = OllamaEmbeddings(model='nomic-embed-text:latest')
#os.environ['BAICHUAN_API_KEY'] = 'sk-xxxxxxx'
#embedding = BaichuanTextEmbeddings()
#embedding = HuggingFaceEmbeddings(model_name='xxxxxx')
#embedding = OpenAIEmbeddings()
0.3、文本处理
数据模板(xxx.txt)
什么是java?
xxxxxxxxxxxxxxxxxxx
什么是java?
xxxxxxxxxxxxxxxxxxx
什么是java?
xxxxxxxxxxxxxxxxxxx
什么是java?
xxxxxxxxxxxxxxxxxxx
文本处理
# ================================文本处理===================================
"""
Step1、加载.txt文档
"""
def load_txt(file_path: str = './data/loader.txt') -> list[Document]:
return TextLoader(file_path=file_path, encoding='utf-8').load()
"""
Step2、文本预处理
只做展示,这里可以去去除不必要的标签,比如HTML标签、MD标签、去除停用词、多余的符号....之类
具体情况根据业务调整
"""
def preprocess_text(text: str):
return text.replace("\n\n\n\n", "[SPLIT]").replace("\n\n\n", "[SPLIT]").replace("\\", "\"").replace("\n", "")
"""
Step3、使用RecursiveCharacterTextSplitter切分文档
"""
def splitter_text(txt_data: list[Document], processed_text: str) -> list[Document]:
# 初始化文本分割器(实际使用时会根据内容自动选择分割粒度)
text_splitter = RecursiveCharacterTextSplitter(
separators=["[SPLIT]"], # 正则表达式匹配连续换行
chunk_size=0, # 足够大的值以保证按分隔符分割。 如果设置过大,系统会自动合并分割后小文本-------这儿可能有问题....
chunk_overlap=0, # 重复字符
keep_separator=False, # 是否保留分隔符
length_function=len,
is_separator_regex=False # 启用正则模式
)
# 先分割文本,再动态生成元数据 --- 主要是为了设置不同的uid 如果直接设置到text_splitter.create_documents中可能会导致uid相同
split_docs = text_splitter.split_text(processed_text)
# 为每个文档块生成独立的 metadata(包括 doc_id)
metadata_list = [{
**txt_data[0].metadata,
"doc_id": str(uuid.uuid4()).replace('-', ''),
} for _ in split_docs]
# 分割原始文档
# 直接构建Document列表
return [
Document(page_content=chunk, metadata=meta)
for chunk, meta in zip(split_docs, metadata_list)
]
0.4、创建/加载/添加向量库
# ================================创建或加载向量库===================================
"""
创建向量库
"""
def create_embed(
faiss_persist_directory: str,
faiss_index_name: str,
documents: list[Document],
embedding,
batch_size: int = 30
) -> FAISS:
"""
向量数据库构建...
"""
print('开始向量数据库构建...')
# 将文档分批次处理
batches = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
# 用第一个批次初始化索引
print(f'用第一批数据({len(batches[0])}条)初始化向量库...')
faiss_vector_store = FAISS.from_documents(
documents=batches[0],
embedding=embedding
)
# 添加剩余批次
for idx, batch in enumerate(batches[1:], start=2):
try:
print(f'添加第{idx}批数据({len(batch)}条)...')
faiss_vector_store.add_documents(batch)
faiss_vector_store.save_local(
folder_path=faiss_persist_directory,
index_name=faiss_index_name
)
store_list()
except Exception as e:
print(f"第{idx}批处理失败: {str(e)}")
raise
return faiss_vector_store
"""
向量库已经存在,则直接添加数据到向量库
"""
def add_embed(
faiss_persist_directory: str,
faiss_index_name: str,
documents: list[Document],
embedding
) -> FAISS:
"""
添加向量数据...
"""
print('向量库已经构建,添加向量数据...')
faiss_vector_store = FAISS.load_local(
faiss_persist_directory,
embeddings=embedding,
index_name=faiss_index_name, # 需与保存时一致
allow_dangerous_deserialization=True
)
# 分批添加文档
count = 1
batch_size = 30
for i in range(0, len(documents), batch_size):
try:
print(f'添加第[{count}]批次向量数据...:{faiss_vector_store.add_documents(documents[i:i + batch_size])}')
faiss_vector_store.save_local(
folder_path=faiss_persist_directory,
index_name=faiss_index_name
)
store_list()
count += 1
except Exception as e:
print(f"第{count}批处理失败: {str(e)}")
raise
return faiss_vector_store
"""
加载已经存在的向量库
"""
def load_embed_store(faiss_persist_directory, faiss_index_name, embedding) -> FAISS:
print("加载向量库...")
return FAISS.load_local(
faiss_persist_directory,
embeddings=embedding,
index_name=faiss_index_name, # 需与保存时一致
allow_dangerous_deserialization=True
)
def add_create_embed(faiss_persist_directory: str, txt_path: str, faiss_index_name: str, embedding) -> FAISS:
txt_data = load_txt(txt_path)
processed_text = preprocess_text(txt_data[0].page_content)
documents = splitter_text(txt_data, processed_text)
if not os.path.exists(faiss_persist_directory):
return create_embed(faiss_persist_directory, faiss_index_name, documents, embedding)
else:
return add_embed(faiss_persist_directory, faiss_index_name, documents, embedding)
# ================================额外操作-方便调试===================================
print(f"当前库有多少数据:{len(list(faiss_vector_store.docstore._dict.values()))}")
# 查看所有文档
def list_store(faiss_vector_store: FAISS):
for idx, doc in enumerate(faiss_vector_store.docstore._dict.values()):
print(f"文档 {idx + 1}:")
print(f"内容: {doc.page_content}")
print(f"元数据: {doc.metadata}\n---\n")
# 通过文档 ID 获取向量
def id_embed(faiss_vector_store: FAISS):
doc_ids = list(faiss_vector_store.docstore._dict.keys()) # 获取所有文档的 ID 列表
vector = faiss_vector_store.index.reconstruct(int(doc_ids[0])) # 获取第一个文档的向量
print(f"文档向量(维度 {len(vector)}):\n{np.round(vector, 4)}") # 保留4位小数
# 批量导出所有向量
def list_embed():
all_vectors = faiss_vector_store.index.reconstruct_n(0, faiss_vector_store.index.ntotal)
print(f"总向量数: {len(all_vectors)}")
print(f"示例向量:\n{all_vectors[0]}")
1、使用search检索
# search和as_retriever是一样的。但是FAISS有这个功能,就写在这里
# 默认是使用similarity"也有"mmr", or "similarity_score_threshold"
# 底层都是调用的是similarity_search、similarity_search_with_relevance_scores、max_marginal_relevance_search方法
def search_retriver(
faiss_vector_store: FAISS,
faiss_query: str,
search_type: str = 'mmr',
k: int = 3,
lambda_mult: int = 0
):
print(
f"使用search检索:{faiss_vector_store.search(query=faiss_query, search_type=search_type, k=k, lambda_mult=lambda_mult)}")
2、无任何优化:采用默认的similarity
# 2、无任何优化:采用默认的similarity --- 和上面search方法类似
def default_retriver(
faiss_query,
search_type: str = 'similarity',
k: int = 3
):
retriever = faiss_vector_store.as_retriever(search_type=search_type, search_kwargs={"k": k})
# print(f"无任何优化检索:{retriever.get_relevant_documents(query=faiss_query)}")
documents = retriever.invoke(faiss_query)
print(f"无任何优化检索:{documents}")
3、最大边际相似性(MMR)
# 3、最大边际相似性(MMR)
def mmr_retriever(
faiss_vector_store,
faiss_query,
k: int = 3,
lambda_mult: int = 0.5
):
mmr_retriever = faiss_vector_store.as_retriever(
search_type='mmr',
search_kwargs={
"k": k,
"lambda_mult": lambda_mult
}
)
# print(f"最大边际相似性检索:{mmr_retriever.get_relevant_documents(query=faiss_query)}")
print(f"最大边际相似性检索:{mmr_retriever.invoke(faiss_query)}")
4、相似性得分similarity_score_threshold
# 4、相似性打分检索来提高精度
def sst_retriever(
faiss_vector_store,
faiss_query,
score_threshold: float = 0.8,
k: int = 3
):
sst_retriever = faiss_vector_store.as_retriever(
search_type='similarity_score_threshold',
search_kwargs={
"score_threshold": score_threshold, # 分数大于0.8;
"k": k # k 返回一条;
},
)
# print(f"相似性打分:{sst_retriever.get_relevant_documents(query=faiss_query)}")
print(f"相似性得分检索:{sst_retriever.invoke(faiss_query)}")
5、BM25混合检索策略:结合BM25关键词检索与向量检索
# 5、混合检索策略:结合关键词与向量检索
# BM25:是一种基于词频和逆文档频率(TF-IDF)的传统检索算法,非常适合关键词匹配
# ### pip install rank_bm25
def bm25_ensemble_retriever(
faiss_query,
faiss_vector_store
):
# (使用公共方法获取文档)
documents = list(faiss_vector_store.docstore._dict.values())
# bm25_retriever = BM25Retriever.from_texts(texts, metadatas=metadatas)
# ensemble_retriever = EnsembleRetriever(
# retrievers=[
# faiss_vector_store.as_retriever(search_kwargs={"k": 3}),
# bm25_retriever
# ],
# weights=[0.6, 0.4]
# )
# 不会返回id
bm25_retriever = BM25Retriever.from_documents(
documents,
k=10, # 返回数量
k1=1.5, # 默认1.2,增大使高频词贡献更高
b=0.8 # 默认0.75,减小以降低文档长度影响
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_vector_store.as_retriever(search_kwargs={"k": 3})],
weights=[0.4, 0.6],
)
print(f"混合检索:{ensemble_retriever.invoke(faiss_query)}")
##### 测试BM25检索
def bm25_test(faiss_vector_store, faiss_query):
from langchain.retrievers import BM25Retriever, EnsembleRetriever
documents = list(faiss_vector_store.docstore._dict.values())
# bm25_retriever1 = BM25Retriever.from_texts(
# texts = [txt_data[0].page_content],
# metadatas=[txt_data[0].metadata]
# )
# print(f"BM25Retriever1:{bm25_retriever1.invoke("我只出生了一次,为什么每年都要庆生")}")
bm25_retriever2 = BM25Retriever.from_documents(
documents,
k=20, # 返回数量
k1=1.5, # 默认1.2,增大使高频词贡献更高
b=0.8 # 默认0.75,减小以降低文档长度影响
)
# bm25_retriever2.k = 2 # 设置 BM25 检索器返回的文档数量
print(f"BM25Retriever2:{bm25_retriever2.invoke(faiss_query)}")
6、多重文档检索
# 6、多重文档检索:就是将用户的问题,提前交给LLMs,让LLMs帮我们丰富问题、进行多角度的扩展,然后分别将问题发给LLMs(使用MultiQueryRetriever),取重复的部分即可。
def multi_query_retriever(llm, as_retriever, faiss_query):
prompt = """
你是一个AI语言模型助手。
你的任务是生成给定用户问题的3个不同版本,以便从向量数据库中检索相关文档。
通过生成用户问题的多个视角,你的目标是帮助用户克服基于距离的相似性搜索的一些局限性。
提供这些由换行符分隔的替代问题。
原始问题:{question}
"""
prompt_template = PromptTemplate(input_variables=['question'], template=prompt)
multi_query_retriever = MultiQueryRetriever.from_llm(
prompt=prompt_template,
retriever=as_retriever,
llm=llm,
include_original=True, # 是否包含原始问题
)
# print(f"多重检索:{multi_query_retriever.get_relevant_documents(query=faiss_query)}")
print(f"多重文档检索:{multi_query_retriever.invoke(
faiss_query,
# verbose=True, # 好像没效果...
config={"callbacks": [ConsoleCallbackHandler()]}
)}")
7、上下文压缩方法
上下文压缩方法:先把问题基本检索,然后把提的问题与基础检索出来的数据一起压缩---比如使用文本摘要简化,再次检索,然后通过LLM把其中不相干的内容删掉...
- 核心机制:使用LLMChainExtractor对文档内容进行语义压缩(如摘要提取)
- 工作流程:原始检索 → 用LLM精简文档内容 → 返回压缩后的文本
- 特点:保持原始结果数量,但缩短每个文档的长度(保留关键信息)
- 适用场景:需要减少上下文长度时(如处理长文档)
# 7、上下文压缩方法:先把问题基本检索,然后把提的问题与基础检索出来的数据一起压缩,再次检索,然后通过LLM把其中不相干的内容删掉...
def compression_retriever(llm, faiss_query, k: int = 20):
compressor = LLMChainExtractor.from_llm(llm=llm)
compression_retriever = ContextualCompressionRetriever(
base_retriever=faiss_vector_store.as_retriever(search_kwargs={"k": k}),
base_compressor=compressor
)
# print(f"上下文压缩:{compression_retriever.get_relevant_documents(query=faiss_query)}")
print(f"上下文压缩检索:{compression_retriever.invoke(faiss_query)}")
8、重排序优化
8.1、重排序优化:使用BGE模型重排序模型
# 8.1、重排序优化:使用BGE模型重排序模型
def reorder_retriever(faiss_query, top_n: int = 3, k: int = 20):
"""
核心机制:使用BgeReranker模型对检索结果进行质量重排序
工作流程:原始检索 → 用专用模型重新打分 → 仅保留top_n结果
特点:不改变文档内容,但改变排序和结果数量(从k到top_n)
适用场景:需要提升Top结果准确性的场景
:param top_n: 最终结果数
"""
model_path = 'D://A4Project//LLM//bge-reranker-base'
model = HuggingFaceCrossEncoder(
model_name=model_path,
model_kwargs={'device': 'cpu'}
)
compressor = CrossEncoderReranker(model=model, top_n=top_n)
# compressor = BgeReranker(model=model_path, top_n=top_n) # 这个好像不行
compression_retriever = ContextualCompressionRetriever(
base_retriever=faiss_vector_store.as_retriever(search_kwargs={"k": k}),
base_compressor=compressor
)
print(f"重排序优化:{compression_retriever.invoke(faiss_query)}")
8.2、重排序优化:使用RankLLMRerank重排序模型
# 8.2、重排序优化:使用RankLLMRerank重排序模型
### 应该版本冲突了,官网是这样的,但是实际使用新版本会报错...
### 留作参考。
def reorder_rankllm_retriever(
faiss_query,
top_n: int = 3,
k: int = 3
):
from langchain_community.document_compressors.rankllm_rerank import RankLLMRerank
"""
model:"Unsupported model type. Please use 'vicuna', 'zephyr', or 'gpt'."
gpt_model:
"""
model_path = 'D://A4Project//LLM//bge-reranker-base'
compressor = RankLLMRerank(model=model_path, top_n=top_n)
compression_retriever = ContextualCompressionRetriever(
base_retriever=faiss_vector_store.as_retriever(search_kwargs={"k": k}),
base_compressor=compressor
)
print(f"RankLLMRerank重排序优化:{compression_retriever.invoke(faiss_query)}")
8.3、重排序优化:使用FlashrankRerank重排序模型
# 8.3、重排序优化:使用FlashrankRerank重排序模型
def reorder_flash_rankRerank_retriever(
faiss_query,
llm,
top_n: int = 3,
k: int = 20
):
"""
默认模型 ms-marco-TinyBERT-L-2-v2 (约4MB)
最佳交叉编码器重排序器 ms-marco-MiniLM-L-12-v2 (约34MB)
最佳非交叉编码器重排序器 rank-T5-flan (约110MB)
支持100多种语言的多语言模型 ms-marco-MultiBERT-L-12 (约150MB)
微调的 ce-esci-MiniLM-L12-v2
大型上下文窗口和较快性能的 rank_zephyr_7b_v1_full (约4GB,4比特量化 --- 需要C/C++编译.)
专用阿拉伯语重排序器 miniReranker_arabic_v1
----- 需要手动下载hf模型,放在本地目录; ---- 需要文件的可以后台私信...
解决不能直接下载HF模型文件:
手动设置client = Ranker(cache_dir='本地模型目录', max_length=128, model_name='模型名称')
---- 这是最简单的,官方也给了这个解决,但是网上很多人都是去修改FlashrankRerank或者Ranker.Config的源码,但是实际不需要。可以直接手动设置client即Ranker...
"""
base_retriever = faiss_vector_store.as_retriever(search_kwargs={"k": k})
print(f"没有排序时:{base_retriever.invoke(faiss_query)}")
ranker = Ranker(cache_dir='xxxxx/本地模型目录/LLM//flash_rankRerank//')
compressor = FlashrankRerank(client=ranker, top_n=top_n)
compression_retriever = ContextualCompressionRetriever(
base_retriever=base_retriever,
base_compressor=compressor
)
# 仅重排序
print(f"FlashrankRerank重排序优化:{compression_retriever.invoke(faiss_query)}")
# 使用 FlashRank 进行 QA 重排序
from langchain.chains import RetrievalQA
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=compression_retriever,
chain_type_kwargs={
"prompt": PromptTemplate(
template="""综合以下多个来源的信息回答:
来源文档:{context}
问题:{question}
整合回答:""",
input_variables=["context", "question"]
)
},
return_source_documents=True
)
print(f"使用 FlashRank 进行 QA 重排序:{chain.invoke(faiss_query)}")
# 测试flashranker重排序方法
def flash_ranker_test():
ranker = Ranker(cache_dir='xxxxx本地模型目录//LLM//flash_rankRerank//', max_length=128, model_name='rank-T5-flan')
query = "如何加速大语言模型?"
passages = [
{
"id": 1,
"text": "量化技术通过将32位浮点权重转换为8位整数等低精度格式,可减少30-50%的模型体积。这种方法不仅能降低显存占用,还能在边缘设备上提升2-4倍的推理速度。",
"meta": {"additional": "info1"}
},
{
"id": 2,
"text": "基于CUDA和TensorRT的GPU加速方案,通过内核融合和显存优化,可将transformer层的计算效率提升3倍以上。特别是使用FP16混合精度时,能实现计算与显存带宽的平衡。",
"meta": {"additional": "info1"}
},
{
"id": 3,
"text": "分布式计算中的流水线并行(Pipeline Parallelism)可将模型层拆分到不同设备,结合张量并行(Tensor Parallelism),能实现千亿参数模型的分布式训练,吞吐量提升可达线性扩展。",
"meta": {"additional": "info1"}
},
{
"id": 4,
"text": "FlashAttention算法通过分块计算和IO感知优化,将注意力机制的计算复杂度从O(n²)降至O(n),在4096长度序列上可实现7.6倍的加速效果。",
"meta": {"additional": "info1"}
},
{
"id": 5,
"text": "知识蒸馏技术通过让小型学生模型学习教师模型的输出分布,可在保持90%性能的前提下,将175B参数的GPT-3压缩到1/100规模,推理速度提升50倍。",
"meta": {"additional": "info1"}
},
{
"id": 6,
"text": "KV缓存机制通过存储历史键值对,可避免transformer解码时的重复计算。结合动态批处理(Dynamic Batching),在生成任务中能将吞吐量提高3-5倍。",
"meta": {"additional": "info1"}
},
{
"id": 7,
"text": "混合专家模型(MoE)通过条件式计算,每个输入仅激活部分专家网络。如Switch Transformer在1.6万亿参数时,计算量仅相当于稠密模型的250亿参数。",
"meta": {"additional": "info1"}
}
]
rerankrequest = RerankRequest(query=query, passages=passages)
results = ranker.rerank(rerankrequest)
print(results)
9、层次化检索
先粗检索,再精检索;
- 粗检索:范围大、返回数据量大; - 精检索:再利用LLM对粗检索结果精练
9.1、两阶段检索架构
# 9.1、层次化检索:两阶段检索架构
def fine_two_retriever(
faiss_query,
top_n: int = 3,
k: int = 20
):
"""
核心机制:分阶段处理(粗检索 + 精炼)
工作流程:
第一阶段:扩大检索范围(k=20)
第二阶段:对初筛结果进行压缩/重排序
特点:先保证召回率再提升精度,组合更灵活
适用场景:需要平衡召回率和准确率的复杂场景
:param top_n: 最终结果数
:param k: 初筛结果数
"""
model_path = 'xxxxx本地模型位置//LLM//bge-reranker-base'
model = HuggingFaceCrossEncoder(
model_name=model_path,
model_kwargs={'device': 'cpu'}
)
compressor = CrossEncoderReranker(model=model, top_n=top_n)
# compressor = BgeReranker(model=model_path, top_n=top_n)
coarse_retriever = faiss_vector_store.as_retriever(search_kwargs={"k": k})
fine_retriever = ContextualCompressionRetriever(
base_retriever=coarse_retriever,
base_compressor=compressor # 使用上述重排序器
)
print(f"层次化检索:{fine_retriever.invoke(faiss_query)}")
9.2、层次化检索:三阶段检索架构
# 9.2、层次化检索:三阶段检索架构
def fine_three_retriever(
llm,
faiss_query,
top_n: int = 10,
k: int = 20
):
# 组合三者的示例
coarse_retriever = faiss_vector_store.as_retriever(search_kwargs={"k": k})
model_path = 'xxxxx本地模型位置//LLM//bge-reranker-base'
model = HuggingFaceCrossEncoder(
model_name=model_path,
model_kwargs={'device': 'cpu'}
)
reranker = CrossEncoderReranker(model=model, top_n=top_n)
# reranker = BgeReranker(model=model_path, top_n=top_n)
compressor = LLMChainExtractor.from_llm(llm=llm)
# 三阶段处理:
# 1. 粗检索 2. 重排序 3. 内容压缩
pipeline_retriever = ContextualCompressionRetriever(
base_retriever=ContextualCompressionRetriever(
base_retriever=coarse_retriever,
base_compressor=reranker
),
base_compressor=compressor
)
print(f"三阶段检索:{pipeline_retriever.invoke(faiss_query)}")
10、检索综合使用
# 这种用法在工作中没有用过,留作记录,可以这样使用...
# 但是实际上效果只需要检索前使用向量检索即相似性得分即可完成大部分工作检索,顶多再加上BM25混合检索,其中的多查询检索,,,我感觉效果挺一般.。然后为了提高精确度和减少发给LLMs的Tokens的使用,在检索完后对内容进行FlashRank重排序+内容压缩去掉多余额外的内容即可。
def comprehensive_retriever(faiss_query, faiss_vector_store, llm):
"""
向量检索:相似性得分 + 扩展问题/多文档检索 + BM25混合检索 + FlashRankRerank重排序优化 + 内容压缩 + LLM回答
我先获取到检索器,然后再用FlashRank进行重排序优化,返回相似度最高的前3个。然后再将这个数据传递给链。
:return:
"""
# 基础的 向量embedding检索
embedding_retriever = faiss_vector_store.as_retriever(
search_type='similarity_score_threshold',
search_kwargs={"score_threshold": 0.5, "k": 20}
)
# 多查询检索|多查询检索
multi_prompt = """
你是一个AI语言模型助手。
你的任务是生成给定用户问题的3个不同版本,以便从向量数据库中检索相关文档。
通过生成用户问题的多个视角,你的目标是帮助用户克服基于距离的相似性搜索的一些局限性。
提供这些由换行符分隔的替代问题。
原始问题:{question}
"""
multi_prompt_template = PromptTemplate(input_variables=['question'], template=multi_prompt)
multi_query_retriever = MultiQueryRetriever.from_llm(
prompt=multi_prompt_template,
retriever=embedding_retriever,
llm=llm,
include_original=True, # 是否包含原始问题
)
# 初始化BM25检索
# (使用公共方法获取文档)
documents = list(faiss_vector_store.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
documents,
k=20, # 返回数量
k1=1.5, # 默认1.2,增大使高频词贡献更高
b=0.8 # 默认0.75,减小以降低文档长度影响
)
# 混合检索:BM25+embedding的
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, multi_query_retriever],
weights=[0.4, 0.6]
)
# 混合检索后 重排序
# 构建压缩管道:重排序 + 内容提取
ranker = Ranker(cache_dir='xxxx本地模型目录//LLM//flash_rankRerank//')
flashrank_rerank = FlashrankRerank(client=ranker, top_n=2)
# 重排序后 压缩上下文
compressor_prompt = """
鉴于以下问题和内容,提取与回答问题相关的背景*按原样*的任何部分。如果上下文都不相关,则返回{no_output_str}。
记住,*不要*编辑提取上下文的部分。
问题: {{question}}
内容: {{context}}
提取相关部分:
"""
compressor_prompt_template = PromptTemplate(
input_variables=['question', 'context'],
template=compressor_prompt.format(no_output_str='NO_OUTPUT'))
compressor = LLMChainExtractor.from_llm(prompt=compressor_prompt_template, llm=llm)
# ### 其实可以使用两次ContextualCompressionRetriever像9.2一样,但是为了记录多种方法,就使用DocumentCompressorPipeline记录----效果一样
pipeline = DocumentCompressorPipeline(transformers=[flashrank_rerank, compressor])
# 最终压缩检索器
compression_retriever = ContextualCompressionRetriever(
base_retriever=ensemble_retriever,
base_compressor=pipeline
)
print(f"综合使用:{compression_retriever.invoke(faiss_query, config={"callbacks": [ConsoleCallbackHandler()]})}")
11、检索+LLM联合使用回复问题
-----为了简化流程,直接使用第二步中的“1、无任何优化:采用默认的similarity”的检索。
其余的,就是将similarity的retriever检索器换成其余的即可.
# 自定义日志回调
class CustomLogger(BaseCallbackHandler):
def on_chain_start(self, serialized, inputs, **kwargs):
print(f"🔍 输入参数: {inputs}")
def on_chain_end(self, outputs, **kwargs):
print(f"✨ 输出结果: {outputs}")
def on_retriever_end(self, documents, **kwargs):
print(f"📄 检索到 {len(documents)} 篇文档:")
for doc in documents:
print(f" - {doc.page_content[:50]}...")
# 向量检索--即基础检索,使用相似性检索similarity方法
def default_retriver(
faiss_query,
search_type: str = 'similarity',
k: int = 3
):
retriever = faiss_vector_store.as_retriever(search_type=search_type, search_kwargs={"k": k})
# print(f"无任何优化检索:{retriever.get_relevant_documents(query=faiss_query)}")
documents = retriever.invoke(faiss_query)
print(f"无任何优化检索:{documents}")
# 对检索出来的文档进行预处理 ---- 只提取Documents列表中的page_content中的内容
context_text = "\n".join([doc.page_content for doc in documents])
prompt = """
请根据以下内容回答问题,内容中如果没有的那就回答“请咨询人工...”,内容中如果有其他不相干的内容,直接删除即可。
内容:{content}
问题:{query}
回答:
"""
prompt_template = ChatPromptTemplate.from_template(prompt)
# 注意:使用RunnablePassthrough()时,会将invoke()中的所有参数传递到目标参数...。
# ### 所以我们在使用RunnablePassthrough时,最好是只传递一个参数即可,如果我们要传递多个参数时,使用 lambda或者itemgetter提取指定的key的值才行。
# """
# 当你使用{"content": RunnablePassthrough(), "query": RunnablePassthrough()}传递参数时,提示词模板实际得到的参数如下:
# 请根据以下内容回答问题,内容中如果没有的那就回答“请咨询人工...”,内容中如果有其他不相干的内容,直接删除即可。
# 内容:{"content": context_text, "query": faiss_query},
# 问题:{"content": context_text, "query": faiss_query},
# 回答:
# """
chain1 = (
{"content": lambda x: x['content'], "query": lambda x: x['query']}
# {"content": RunnablePassthrough(), "query": RunnablePassthrough()} # 这是错误示范...
# {"content": itemgetter("content"), "query": itemgetter("query")}
# {"content": RunnableLambda(lambda x:x['content']) , "query": RunnableLambda(lambda x:x['query']) }
| prompt_template
| llm
)
logging.info(f"无优化的检索+LLM1:{chain1.invoke(
{"content": context_text, "query": faiss_query},
)}")
logging.info(f"{"+" * 100}") # # config={"callbacks": [StdOutCallbackHandler()]} # 添加标准输出回调
# 只检索了,但是没有预处理 --- 所以会携带Document、metadata等标签
chain2 = (
{"content": retriever, "query": RunnablePassthrough()}
| prompt_template
| llm
)
logging.info(f"无优化的检索+LLM2:{chain2.invoke(faiss_query)}")
logging.info(f"{"+" * 100}")
# 只检索了,但是没有预处理 --- 所以会携带Document、metadata等标签
chain3 = (
{
"content": (lambda x: retriever.invoke(x["query"])),
"query": (lambda x: x["query"])
}
| prompt_template
| llm
)
logging.info(f"无优化的检索+LLM3:{chain3.invoke({"query": faiss_query})}")
logging.info(f"{"+" * 100}")
# 处理检索结果的函数(将文档列表转换为字符串)
process_docs = RunnableLambda(
lambda docs: "\n".join([doc.page_content for doc in docs])
)
# 检索后进行预处理,会去掉Document、metadata之类标签
chain4 = (
{
"content": retriever | process_docs, # 先检索再处理文档
"query": RunnablePassthrough() # 直接传递用户原始问题
}
| prompt_template # 组合成完整 prompt
| llm # 传给大模型生成回答
)
logging.info(f"无优化的检索+LLM4:{chain4.invoke(faiss_query)}")
logging.info(f"{"+" * 100}")
# 使用自定义回调
logging.info(f"自定义日志回调:{chain4.invoke(faiss_query, config={"callbacks": [CustomLogger()]})}")
logging.info(f"{"+" * 100}")
12、加餐
如何打印日志?老版本使用LLMChains创建链时,可以使用参数“verbose=True”打印日志,但是使用LCEL(“|”管道连接)时,没法设置“verbose=True”,所以本文展示了两种调试:
12.1、自定义日志回调函数;
# 自定义日志回调
class CustomLogger(BaseCallbackHandler):
def on_chain_start(self, serialized, inputs, **kwargs):
print(f"🔍 输入参数: {inputs}")
def on_chain_end(self, outputs, **kwargs):
print(f"✨ 输出结果: {outputs}")
def on_retriever_end(self, documents, **kwargs):
print(f"📄 检索到 {len(documents)} 篇文档:")
for doc in documents:
print(f" - {doc.page_content[:50]}...")
# 主要时再invoke中添加config参数即可
print(f"自定义日志回调:{chains.invoke(faiss_query, config={"callbacks": [CustomLogger()]})}")
12.2、使用ConsoleCallbackHandler()的回调函数
# 方法一:可以设置chain.invoke(xxx,config={'callbacks': [ConsoleCallbackHandler()]})
retriever.invoke("查询内容",
# verbose=True, # 好像没效果...
config={"callbacks": [ConsoleCallbackHandler()]})
# 方法二:可以设置chain.with_config({"callbacks": [ConsoleCallbackHandler()]})
chain_console_callback = (
llm
| StrOutputParser()
).with_config({"callbacks": [ConsoleCallbackHandler()]}) # 在控制台打印执行日志
chain_console_callback.invoke("你是谁?")
13、直接使用
txt_path = '本地数据位置'
faiss_persist_directory = 'faiss要存储的地址'
faiss_index_name = '数据库的索引名称 --- 最终位置为persist_dir+index_name'
if __name__ == '__main__':
# 添加或者创建 --- 文本数据在add_create_embed方法里面处理。
# faiss_vector_store = add_create_embed(faiss_persist_directory, txt_path, faiss_index_name, embedding)
# 加载
faiss_vector_store = load_embed_store(faiss_persist_directory, faiss_index_name, embedding)
faiss_query = '为什么没人说ABCD型的成语?'
comprehensive_retriever(faiss_query, faiss_vector_store, llm)
search_retriver(faiss_vector_store, faiss_query)
default_retriver(faiss_query)
mmr_retriever(faiss_vector_store, faiss_query)
sst_retriever(faiss_vector_store, faiss_query)
ensemble_as_retriever = faiss_vector_store.as_retriever(search_kwargs={"k": 3})
bm25_ensemble_retriever(faiss_query, faiss_vector_store)
bm25_test(faiss_vector_store, faiss_query)
multi_as_retriever = faiss_vector_store.as_retriever()
multi_query_retriever(llm, multi_as_retriever, faiss_query)
compression_retriever(llm, faiss_query)
reorder_retriever(faiss_query)
reorder_rankllm_retriever(faiss_query, llm)
reorder_flash_rankRerank_retriever(faiss_query, llm)
fine_two_retriever(faiss_query)
fine_three_retriever(llm, faiss_query)
更多推荐
所有评论(0)