OpenAI API 接口对接完全教程 / 13 - RAG 系统实现
第 13 章 · RAG 系统实现
RAG (Retrieval-Augmented Generation) 让 LLM 基于外部知识回答问题,是企业级 AI 应用的核心架构。本章详解文档处理、分块策略、检索优化和完整实现。
13.1 RAG 架构概览
文档来源 → 文档解析 → 文本分块 → 向量嵌入 → 向量数据库
↓
用户查询 → 查询嵌入 → 相似度检索 → Top-K 文档 → LLM 生成答案
核心组件
| 组件 | 功能 | 推荐方案 |
|---|
| 文档解析 | 提取文本 | PyMuPDF, Unstructured |
| 文本分块 | 切分文档 | 递归分块、语义分块 |
| 向量嵌入 | 生成向量 | text-embedding-3-small |
| 向量存储 | 存储/检索 | FAISS, Chroma, Pinecone |
| LLM 生成 | 答案生成 | GPT-4o / GPT-4o mini |
13.2 文档解析
PDF 解析
import fitz # PyMuPDF
def parse_pdf(file_path: str) -> list[dict]:
"""解析 PDF 为页面列表"""
doc = fitz.open(file_path)
pages = []
for i, page in enumerate(doc):
text = page.get_text()
if text.strip():
pages.append({
"page": i + 1,
"text": text.strip(),
"source": file_path,
})
return pages
Markdown 解析
def parse_markdown(file_path: str) -> list[dict]:
"""解析 Markdown,按标题分段"""
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
sections = []
current_section = {"title": "", "text": "", "source": file_path}
for line in content.split("\n"):
if line.startswith("#"):
if current_section["text"].strip():
sections.append(current_section)
current_section = {
"title": line.lstrip("#").strip(),
"text": line + "\n",
"source": file_path,
}
else:
current_section["text"] += line + "\n"
if current_section["text"].strip():
sections.append(current_section)
return sections
13.3 分块策略 (Chunking)
13.3.1 递归字符分块(推荐)
def recursive_split(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
"""递归字符分块,保持语义完整性"""
separators = ["\n\n", "\n", "。", "!", "?", ". ", " ", ""]
def _split(text, separators):
if len(text) <= chunk_size:
return [text]
sep = separators[0] if separators else ""
parts = text.split(sep)
chunks = []
current = ""
for part in parts:
if len(current) + len(part) + len(sep) <= chunk_size:
current += (sep if current else "") + part
else:
if current:
chunks.append(current)
current = part
if current:
chunks.append(current)
return chunks
chunks = _split(text, separators)
# 添加重叠
if overlap > 0 and len(chunks) > 1:
overlapped = [chunks[0]]
for i in range(1, len(chunks)):
prev_tail = chunks[i-1][-overlap:] if len(chunks[i-1]) > overlap else chunks[i-1]
overlapped.append(prev_tail + chunks[i])
chunks = overlapped
return chunks
13.3.2 分块参数对比
| 策略 | Chunk Size | Overlap | 适用场景 |
|---|
| 小块 | 200-300 | 50 | 精准检索、FAQ |
| 中块 | 500-800 | 100 | 通用文档、知识库 |
| 大块 | 1000-1500 | 200 | 长文分析、报告 |
13.3.3 语义分块
from openai import OpenAI
import numpy as np
client = OpenAI()
def semantic_split(text: str, threshold: float = 0.75, max_chunk: int = 1000) -> list[str]:
"""基于语义相似度的分块"""
sentences = [s.strip() for s in text.replace("。", "。\n").split("\n") if s.strip()]
if len(sentences) <= 1:
return [text]
# 获取句子嵌入
response = client.embeddings.create(
model="text-embedding-3-small",
input=sentences,
)
embeddings = [e.embedding for e in response.data]
# 按语义相似度分组
chunks = []
current_chunk = [sentences[0]]
current_emb = [embeddings[0]]
for i in range(1, len(sentences)):
# 计算当前句子与当前块的平均相似度
avg_emb = np.mean(current_emb, axis=0)
sim = np.dot(embeddings[i], avg_emb) / (
np.linalg.norm(embeddings[i]) * np.linalg.norm(avg_emb)
)
current_text = " ".join(current_chunk) + sentences[i]
if sim >= threshold and len(current_text) <= max_chunk:
current_chunk.append(sentences[i])
current_emb.append(embeddings[i])
else:
chunks.append(" ".join(current_chunk))
current_chunk = [sentences[i]]
current_emb = [embeddings[i]]
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
13.4 向量存储与检索
ChromaDB 实现
# pip install chromadb
import chromadb
class RAGVectorStore:
"""基于 ChromaDB 的向量存储"""
def __init__(self, collection_name: str = "knowledge_base"):
self.client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"},
)
def add_documents(self, chunks: list[dict], embeddings: list[list[float]]):
"""添加文档到向量存储"""
self.collection.add(
ids=[f"doc_{i}" for i in range(len(chunks))],
embeddings=embeddings,
documents=[c["text"] for c in chunks],
metadatas=[{
"source": c.get("source", ""),
"page": c.get("page", 0),
"title": c.get("title", ""),
} for c in chunks],
)
def search(self, query_embedding: list[float], top_k: int = 5) -> list[dict]:
"""检索相似文档"""
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"],
)
docs = []
for i in range(len(results["documents"][0])):
docs.append({
"text": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
"score": 1 - results["distances"][0][i], # cosine 距离转相似度
})
return docs
13.5 完整 RAG Pipeline
from openai import OpenAI
class RAGPipeline:
"""完整 RAG 系统"""
def __init__(self):
self.client = OpenAI()
self.vector_store = RAGVectorStore()
self.embedding_model = "text-embedding-3-small"
self.llm_model = "gpt-4o-mini"
def ingest(self, file_path: str):
"""文档入库"""
# 1. 解析文档
if file_path.endswith(".pdf"):
pages = parse_pdf(file_path)
elif file_path.endswith(".md"):
pages = parse_markdown(file_path)
else:
with open(file_path, "r", encoding="utf-8") as f:
pages = [{"text": f.read(), "source": file_path, "page": 0}]
# 2. 分块
all_chunks = []
for page in pages:
chunks = recursive_split(page["text"], chunk_size=500, overlap=50)
for chunk in chunks:
all_chunks.append({
"text": chunk,
"source": page["source"],
"page": page.get("page", 0),
})
print(f"解析完成: {len(pages)} 页 → {len(all_chunks)} 个分块")
# 3. 生成嵌入(批量)
batch_size = 100
all_embeddings = []
for i in range(0, len(all_chunks), batch_size):
batch = all_chunks[i:i+batch_size]
response = self.client.embeddings.create(
model=self.embedding_model,
input=[c["text"] for c in batch],
)
all_embeddings.extend([e.embedding for e in response.data])
print(f"嵌入进度: {min(i+batch_size, len(all_chunks))}/{len(all_chunks)}")
# 4. 存入向量数据库
self.vector_store.add_documents(all_chunks, all_embeddings)
print(f"入库完成: {len(all_chunks)} 个分块")
def query(self, question: str, top_k: int = 5) -> str:
"""查询并生成答案"""
# 1. 查询嵌入
response = self.client.embeddings.create(
model=self.embedding_model,
input=question,
)
query_embedding = response.data[0].embedding
# 2. 检索相似文档
docs = self.vector_store.search(query_embedding, top_k=top_k)
# 3. 构建上下文
context = "\n\n---\n\n".join([
f"[来源: {d['metadata']['source']}, 第{d['metadata']['page']}页]\n{d['text']}"
for d in docs
])
# 4. LLM 生成答案
response = self.client.chat.completions.create(
model=self.llm_model,
messages=[
{
"role": "system",
"content": """基于以下参考资料回答用户问题。
规则:
1. 只基于提供的资料回答,不要编造信息
2. 如果资料中没有相关内容,明确说明
3. 引用具体来源
4. 使用中文回答"""
},
{
"role": "user",
"content": f"参考资料:\n{context}\n\n问题:{question}"
}
],
temperature=0.3,
max_tokens=1000,
)
answer = response.choices[0].message.content
return {
"answer": answer,
"sources": [
{"text": d["text"][:200], "score": d["score"], "metadata": d["metadata"]}
for d in docs
],
}
使用示例
rag = RAGPipeline()
# 文档入库
rag.ingest("company_manual.pdf")
rag.ingest("faq.md")
# 查询
result = rag.query("公司的退款政策是什么?")
print(f"答案: {result['answer']}")
print(f"\n参考来源:")
for src in result["sources"]:
print(f" [{src['score']:.3f}] {src['text'][:80]}...")
13.6 检索优化技巧
13.6.1 混合检索
def hybrid_retrieve(query: str, semantic_weight: float = 0.7) -> list[dict]:
"""语义 + 关键词混合检索"""
# 语义检索
semantic_results = vector_store.search(query_embedding, top_k=10)
# 关键词检索(BM25 简化版)
keyword_results = []
for doc in all_documents:
score = sum(1 for word in query if word in doc["text"])
if score > 0:
keyword_results.append({**doc, "kw_score": score / len(query)})
# 合并排序
combined = {}
for doc in semantic_results:
key = doc["text"][:100]
combined[key] = {"doc": doc, "score": semantic_weight * doc["score"]}
for doc in keyword_results:
key = doc["text"][:100]
if key in combined:
combined[key]["score"] += (1 - semantic_weight) * doc["kw_score"]
else:
combined[key] = {"doc": doc, "score": (1 - semantic_weight) * doc["kw_score"]}
return sorted(combined.values(), key=lambda x: x["score"], reverse=True)[:5]
13.6.2 查询改写
def rewrite_query(original_query: str) -> str:
"""用 LLM 改写查询,提高检索效果"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "将用户的问题改写为更适合搜索的形式。只输出改写后的查询。"},
{"role": "user", "content": original_query},
],
max_tokens=100,
temperature=0.3,
)
return response.choices[0].message.content
13.6.3 Re-ranking
def rerank_results(query: str, docs: list[dict], top_k: int = 3) -> list[dict]:
"""用 LLM 重排检索结果"""
docs_text = "\n\n".join([
f"[{i}] {d['text'][:300]}" for i, d in enumerate(docs)
])
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "根据与查询的相关性,对文档排序。输出最相关的文档编号(JSON数组)。"},
{"role": "user", "content": f"查询: {query}\n\n文档:\n{docs_text}"}
],
response_format={"type": "json_object"},
temperature=0.0,
)
import json
indices = json.loads(response.choices[0].message.content).get("ranking", [])
return [docs[i] for i in indices[:top_k] if i < len(docs)]
13.7 评估指标
| 指标 | 计算方式 | 说明 |
|---|
| Hit Rate | 命中文档数 / 总查询数 | 检索是否有命中 |
| MRR | 1/首个正确结果排名 | 排序质量 |
| Faithfulness | LLM 判断答案是否基于文档 | 幻觉检测 |
| Relevance | LLM 判断答案是否回答了问题 | 回答质量 |
13.8 注意事项
- 分块大小:太小丢失上下文,太大噪声多,500-800 tokens 是好的起点
- 重叠很重要:适当的 overlap 防止语义被切断
- 元数据保留:始终保留来源信息,便于引用和追踪
- 嵌入缓存:相同文本的 Embedding 结果不变,务必缓存
- 增量更新:大文档集考虑增量入库,避免全量重建
- 幻觉防护:在 System Prompt 中强调"只基于提供的资料回答"
13.9 扩展阅读
下一章:14 - AI Agent 架构 — 工具链设计、记忆系统、多 Agent 协作。