检索增强生成(RAG)
RAG(Retrieval-Augmented Generation)是一种强大的技术,它让大语言模型能够基于你的私有数据回答问题,而不仅仅依赖训练时学到的知识。
什么是 RAG
RAG 的核心思想:
- 检索(Retrieval) - 从知识库中找到相关文档
- 增强(Augmented) - 将检索到的文档作为上下文
- 生成(Generation) - LLM 基于上下文生成回答
用户问题 → 检索相关文档 → 构建上下文 → LLM 生成回答
为什么需要 RAG
| 问题 | 纯 LLM | RAG |
|---|---|---|
| 知识时效性 | 只知道训练数据 | 可接入最新数据 |
| 私有数据 | 无法访问 | 可以查询 |
| 幻觉问题 | 可能胡说 | 有据可查 |
| 可追溯性 | 无法引用来源 | 可引用原文 |
RAG 基础架构
┌─────────────┐
│ 文档库 │
└──────┬──────┘
│
┌──────▼──────┐
│ 文档加载器 │
└──────┬──────┘
│
┌──────▼──────┐
│ 文本分割器 │
└──────┬──────┘
│
┌──────▼──────┐
│ 向量化存储 │
└──────┬──────┘
│
┌──────▼──────┐
用户问题 ──────────►│ 检索器 │
└──────┬──────┘
│
┌──────▼──────┐
│ LLM │
└──────┬──────┘
│
┌──────▼──────┐
│ 回答 │
└─────────────┘
文档加载
加载文本文件
from langchain_community.document_loaders import TextLoader
loader = TextLoader("document.txt", encoding="utf-8")
documents = loader.load()
print(f"加载了 {len(documents)} 个文档")
print(f"内容预览: {documents[0].page_content[:200]}...")
加载 PDF 文件
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader("document.pdf")
documents = loader.load()
# 每页是一个文档
for doc in documents:
print(f"页码: {doc.metadata['page']}")
print(f"内容: {doc.page_content[:100]}...")
加载网页
from langchain_community.document_loaders import WebBaseLoader
loader = WebBaseLoader("https://example.com/article")
documents = loader.load()
加载目录
from langchain_community.document_loaders import DirectoryLoader
# 加载目录下所有 txt 文件
loader = DirectoryLoader(
"./docs/",
glob="**/*.txt",
loader_cls=TextLoader
)
documents = loader.load()
文本分割
大文档需要分割成小块才能有效检索:
RecursiveCharacterTextSplitter
最常用的分割器:
from langchain_text_splitters import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=500, # 每块最大字符数
chunk_overlap=50, # 块之间重叠的字符数
separators=["\n\n", "\n", "。", ",", " ", ""]
)
chunks = splitter.split_documents(documents)
print(f"分割成 {len(chunks)} 个文本块")
按 Token 分割
from langchain_text_splitters import TokenTextSplitter
splitter = TokenTextSplitter(
chunk_size=200, # 每块最大 token 数
chunk_overlap=20
)
chunks = splitter.split_documents(documents)
Markdown 分割
from langchain_text_splitters import MarkdownHeaderTextSplitter
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
splitter = MarkdownHeaderTextSplitter(headers_to_split_on)
chunks = splitter.split_text(markdown_text)
向量化存储
使用 OpenAI Embeddings
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
# 创建嵌入模型
embeddings = OpenAIEmbeddings()
# 创建向量存储
vectorstore = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
persist_directory="./chroma_db" # 持久化存储
)
# 简单检索测试
results = vectorstore.similarity_search("搜索词", k=3)
for doc in results:
print(doc.page_content[:100])
使用 FAISS
from langchain_community.vectorstores import FAISS
# 创建 FAISS 向量存储
vectorstore = FAISS.from_documents(
documents=chunks,
embedding=embeddings
)
# 保存到本地
vectorstore.save_local("faiss_index")
# 加载
loaded_vectorstore = FAISS.load_local(
"faiss_index",
embeddings,
allow_dangerous_deserialization=True
)
使用国内模型的 Embeddings
# 智谱 AI
from langchain_zhipuai import ZhipuAIEmbeddings
embeddings = ZhipuAIEmbeddings()
# 本地模型 (HuggingFace)
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
创建检索器
基础检索器
# 从向量存储创建检索器
retriever = vectorstore.as_retriever(
search_type="similarity", # 相似度搜索
search_kwargs={"k": 4} # 返回 4 个结果
)
# 检索文档
docs = retriever.invoke("我的问题是什么?")
MMR 检索器
最大边际相关性(MMR)检索,平衡相关性和多样性:
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={
"k": 4,
"fetch_k": 10, # 先获取 10 个,再选出 4 个最多样的
"lambda_mult": 0.5 # 多样性权重
}
)
带分数的检索
# 获取带相似度分数的结果
results_with_scores = vectorstore.similarity_search_with_score(
"搜索词",
k=4
)
for doc, score in results_with_scores:
print(f"分数: {score:.4f}")
print(f"内容: {doc.page_content[:100]}...")
构建 RAG 链
基础 RAG 链
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# 创建组件
llm = ChatOpenAI(model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_template("""
基于以下上下文回答问题。如果上下文中没有相关信息,请说"我无法从提供的文档中找到答案"。
上下文:
{context}
问题:{question}
回答:""")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# 构建 RAG 链
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# 使用
answer = rag_chain.invoke("公司的退款政策是什么?")
print(answer)
带来源引用的 RAG
from langchain_core.runnables import RunnableParallel
# 创建并行链,同时返回答案和来源
rag_chain_with_sources = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
).assign(
answer=lambda x: (
prompt.format(context=format_docs(x["context"]), question=x["question"])
| llm
| StrOutputParser()
).invoke({})
)
# 或者更清晰的写法
def create_rag_with_sources(retriever, llm):
def get_answer(inputs):
context = inputs["context"]
question = inputs["question"]
formatted_prompt = prompt.invoke({
"context": format_docs(context),
"question": question
})
answer = llm.invoke(formatted_prompt)
return {
"answer": answer.content,
"sources": [doc.metadata.get("source", "未知") for doc in context]
}
return (
{"context": retriever, "question": RunnablePassthrough()}
| get_answer
)
rag_chain = create_rag_with_sources(retriever, llm)
result = rag_chain.invoke("问题")
print(f"回答:{result['answer']}")
print(f"来源:{result['sources']}")
高级 RAG 技术
查询转换
优化用户查询:
from langchain_core.prompts import ChatPromptTemplate
# 查询重写
rewrite_prompt = ChatPromptTemplate.from_template("""
将以下用户问题重写为更适合搜索的形式,保持原意:
原问题:{question}
重写后的问题:""")
rewrite_chain = rewrite_prompt | llm | StrOutputParser()
# 在 RAG 链中使用
enhanced_rag = (
{"question": rewrite_chain}
| {"context": lambda x: retriever.invoke(x["question"]), "question": lambda x: x["question"]}
| prompt
| llm
| StrOutputParser()
)
多查询检索
生成多个查询变体提高召回率:
from langchain.retrievers.multi_query import MultiQueryRetriever
multi_retriever = MultiQueryRetriever.from_llm(
retriever=vectorstore.as_retriever(),
llm=llm
)
# 会自动生成多个查询变体
docs = multi_retriever.invoke("问题")
上下文压缩
压缩检索到的文档,只保留相关部分:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
compressor = LLMChainExtractor.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=retriever
)
# 返回压缩后的文档
compressed_docs = compression_retriever.invoke("问题")
完整示例:文档问答系统
"""
文档问答系统 - RAG 完整示例
"""
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from dotenv import load_dotenv
import os
load_dotenv()
class DocumentQA:
"""文档问答系统"""
def __init__(self, persist_directory: str = "./vectorstore"):
self.persist_directory = persist_directory
self.embeddings = OpenAIEmbeddings()
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50
)
self.prompt = ChatPromptTemplate.from_template("""
你是一个文档问答助手。请根据以下文档内容回答用户的问题。
要求:
1. 只使用提供的文档内容回答
2. 如果文档中没有相关信息,请明确说明
3. 回答要简洁准确
4. 适当引用原文
文档内容:
{context}
用户问题:{question}
回答:""")
self.vectorstore = None
self.retriever = None
self.chain = None
# 尝试加载已有的向量存储
self._load_vectorstore()
def _load_vectorstore(self):
"""加载已有的向量存储"""
if os.path.exists(self.persist_directory):
try:
self.vectorstore = Chroma(
persist_directory=self.persist_directory,
embedding_function=self.embeddings
)
self._setup_chain()
print(f"已加载向量存储,包含 {self.vectorstore._collection.count()} 个文档块")
except Exception as e:
print(f"加载向量存储失败: {e}")
def _setup_chain(self):
"""设置检索链"""
self.retriever = self.vectorstore.as_retriever(
search_type="mmr",
search_kwargs={"k": 4}
)
def format_docs(docs):
return "\n\n---\n\n".join(
f"[来源: {doc.metadata.get('source', '未知')}]\n{doc.page_content}"
for doc in docs
)
self.chain = (
{"context": self.retriever | format_docs, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
def add_documents(self, file_path: str):
"""添加文档"""
print(f"正在加载文档: {file_path}")
# 判断是文件还是目录
if os.path.isdir(file_path):
loader = DirectoryLoader(
file_path,
glob="**/*.txt",
loader_cls=TextLoader,
loader_kwargs={"encoding": "utf-8"}
)
else:
loader = TextLoader(file_path, encoding="utf-8")
documents = loader.load()
print(f"加载了 {len(documents)} 个文档")
# 分割文档
chunks = self.splitter.split_documents(documents)
print(f"分割成 {len(chunks)} 个文本块")
# 添加到向量存储
if self.vectorstore is None:
self.vectorstore = Chroma.from_documents(
documents=chunks,
embedding=self.embeddings,
persist_directory=self.persist_directory
)
else:
self.vectorstore.add_documents(chunks)
self._setup_chain()
print("文档已添加到知识库")
def query(self, question: str) -> dict:
"""查询问题"""
if self.chain is None:
return {"answer": "请先添加文档到知识库", "sources": []}
# 获取相关文档
docs = self.retriever.invoke(question)
# 获取回答
answer = self.chain.invoke(question)
return {
"answer": answer,
"sources": [
{
"content": doc.page_content[:200] + "...",
"source": doc.metadata.get("source", "未知")
}
for doc in docs
]
}
def clear(self):
"""清除向量存储"""
if os.path.exists(self.persist_directory):
import shutil
shutil.rmtree(self.persist_directory)
self.vectorstore = None
self.retriever = None
self.chain = None
print("知识库已清除")
def main():
qa = DocumentQA()
print("=" * 50)
print(" 文档问答系统")
print(" 命令: add <文件路径> - 添加文档")
print(" clear - 清除知识库")
print(" quit - 退出")
print("=" * 50)
while True:
user_input = input("\n你:").strip()
if not user_input:
continue
if user_input.lower() == "quit":
print("再见!")
break
if user_input.lower() == "clear":
qa.clear()
continue
if user_input.lower().startswith("add "):
file_path = user_input[4:].strip()
qa.add_documents(file_path)
continue
# 查询
result = qa.query(user_input)
print(f"\n回答:{result['answer']}")
if result['sources']:
print("\n参考来源:")
for i, source in enumerate(result['sources'], 1):
print(f" {i}. {source['source']}")
if __name__ == "__main__":
main()
小结
本章介绍了:
✅ RAG 的概念和优势
✅ 文档加载器的使用
✅ 文本分割策略
✅ 向量化存储和检索
✅ 构建 RAG 链
✅ 高级 RAG 技术
✅ 完整的文档问答系统
下一步
进阶篇到此结束!接下来进入高级篇,学习如何构建自主决策的 Agent 智能体。
练习
- 构建一个 PDF 文档问答系统
- 实现多文档对比分析
- 添加对话历史支持
- 尝试不同的分割策略和检索参数