大型语言模型虽然强大,但存在明显的局限性:知识截止、训练成本高、容易产生幻觉。检索增强生成(Retrieval-Augmented Generation,RAG)通过将外部知识检索与生成相结合,有效解决了这些问题。RAG让模型能够在回答问题时参考权威、最新的信息,大大提高了回答的准确性和可靠性。从企业知识库问答到个性化助手,从文档分析到学术研究,RAG正在成为LLM应用的主流架构。本文将深入解析RAG的原理、实现和优化技巧。

RAG是一种结合信息检索和文本生成的AI架构,工作流程如下:
用户查询 → 检索相关文档 → 生成增强回答 → 返回结果
↓
外部知识库核心思想:不是让模型"记住"所有知识,而是给它提供"参考书",让它在回答时查阅。
方法 | 知识来源 | 准确性 | 更新成本 | 幻觉率 | 适用场景 |
|---|---|---|---|---|---|
纯生成式 | 训练数据 | 中 | 高(需重训) | 高 | 通用对话 |
微调模型 | 训练数据 | 中高 | 中 | 中高 | 领域适应 |
RAG | 外部知识库 | 高 | 低 | 低 | 知识密集型任务 |
搜索引擎 | 外部知识库 | 高 | 低 | 极低 | 信息检索 |
知识更新便捷:更新文档即可,无需重新训练模型
回答可解释:可以展示参考来源,建立信任
减少幻觉:基于检索到的事实生成回答
私有数据利用:可以安全使用企业内部数据
成本效益:相比微调,部署和维护成本更低
┌─────────────────────────────────────────────────────────────┐
│ RAG系统架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 文档 │ -> │ 向量化 │ -> │ 向量数据库 │ │
│ │ 预处理 │ │ (Embed) │ │ (DB) │ │
│ └──────────┘ └──────────┘ └──────────┘ │
│ │ │
│ ↓ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 用户查询 │ -> │ 查询编码 │ -> │ 相似度 │ │
│ └──────────┘ └──────────┘ │ 检索 │ │
│ └──────────┘ │
│ │ │
│ ↓ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ LLM生成模块 │ │
│ │ 查询 + 检索文档 → 增强提示 → 生成回答 │ │
│ └──────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘文档处理器:将原始文档切分为适合检索的块
嵌入模型:将文本转换为向量表示
向量数据库:存储和检索向量
检索器:找到与查询最相关的文档
生成器:LLM根据检索结果生成回答
文本嵌入将文本映射为固定维度的向量,语义相似的文本在向量空间中距离更近。
嵌入模型 | 维度 | 特点 | 适用场景 |
|---|---|---|---|
text-embedding-ada-002 | 1536 | OpenAI官方 | 通用场景 |
bge-large-zh | 1024 | 中文优化 | 中文应用 |
e5-large-v2 | 1024 | 多语言 | 跨语言检索 |
BAAI/bge-m3 | 1024 | 多功能 | 混合检索 |
余弦相似度:最常用,衡量向量方向相似性
similarity = cos(θ) = (A·B) / (||A|| × ||B||)欧氏距离:衡量向量空间距离
distance = ||A - B|| = √Σ(Ai - Bi)²点积:简化计算,适合归一化向量
similarity = A·B = ΣAi × Bi策略 | 方法 | 优点 | 缺点 |
|---|---|---|---|
固定长度 | 按字符数切分 | 简单 | 可能破坏语义 |
语义切分 | 按段落/章节 | 保持完整性 | 块大小不均 |
滑动窗口 | 重叠切分 | 避免信息丢失 | 存储冗余 |
递归切分 | 多级切分 | 灵活 | 实现复杂 |
下面实现一个生产级RAG系统。
# RAG检索增强生成实践代码
import torch
import numpy as np
from typing import List, Dict, Optional, Tuple
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import openai
from dataclasses import dataclass
from abc import ABC, abstractmethod
import re
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
# ==================== 配置 ====================
OPENAI_API_KEY = "your-api-key-here" # 替换为实际API密钥
openai.api_key = OPENAI_API_KEY
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# ==================== 文档处理器 ====================
class DocumentSplitter(ABC):
"""文档切分器基类"""
@abstractmethod
def split(self, text: str) -> List[str]:
pass
class CharacterSplitter(DocumentSplitter):
"""固定字符数切分器"""
def __init__(self, chunk_size: int = 500, overlap: int = 50):
self.chunk_size = chunk_size
self.overlap = overlap
def split(self, text: str) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + self.chunk_size
chunk = text[start:end]
chunks.append(chunk)
start = end - self.overlap
return [c for c in chunks if len(c.strip()) > 0]
class SemanticSplitter(DocumentSplitter):
"""语义切分器(基于段落和句子)"""
def __init__(self, max_chunk_size: int = 1000, min_chunk_size: int = 200):
self.max_chunk_size = max_chunk_size
self.min_chunk_size = min_chunk_size
def split(self, text: str) -> List[str]:
# 按段落分割
paragraphs = re.split(r'\n\s*\n', text)
chunks = []
current_chunk = ""
for para in paragraphs:
if len(current_chunk) + len(para) <= self.max_chunk_size:
current_chunk += para + "\n\n"
else:
if len(current_chunk) >= self.min_chunk_size:
chunks.append(current_chunk.strip())
current_chunk = para + "\n\n"
if len(current_chunk) >= self.min_chunk_size:
chunks.append(current_chunk.strip())
return chunks
class RecursiveSplitter(DocumentSplitter):
"""递归切分器(多级切分)"""
def __init__(self, separators: List[str] = None, chunk_size: int = 1000):
self.separators = separators or ["\n\n", "\n", "。", ".", " ", ""]
self.chunk_size = chunk_size
def split(self, text: str) -> List[str]:
return self._recursive_split(text, self.separators)
def _recursive_split(self, text: str, separators: List[str]) -> List[str]:
if len(text) <= self.chunk_size:
return [text]
if not separators:
return [text[i:i+self.chunk_size]
for i in range(0, len(text), self.chunk_size)]
separator = separators[0]
parts = text.split(separator)
chunks = []
current_chunk = ""
for part in parts:
if len(current_chunk) + len(part) <= self.chunk_size:
current_chunk += part + separator
else:
if current_chunk:
chunks.append(current_chunk.strip())
if len(part) > self.chunk_size:
# 递归使用下一个分隔符
chunks.extend(self._recursive_split(part, separators[1:]))
else:
current_chunk = part + separator
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# ==================== 嵌入模型 ====================
class EmbeddingModel(ABC):
"""嵌入模型基类"""
@abstractmethod
def encode(self, texts: List[str]) -> np.ndarray:
pass
class SentenceTransformerEmbedding(EmbeddingModel):
"""SentenceTransformer嵌入模型"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name, device=device)
self.dimension = self.model.get_sentence_embedding_dimension()
def encode(self, texts: List[str]) -> np.ndarray:
return self.model.encode(texts, show_progress_bar=False)
class OpenAIEmbedding(EmbeddingModel):
"""OpenAI嵌入模型"""
def __init__(self, model: str = "text-embedding-ada-002"):
self.model = model
self.dimension = 1536
def encode(self, texts: List[str]) -> np.ndarray:
embeddings = []
for text in texts:
response = openai.Embedding.create(
input=text,
model=self.model
)
embeddings.append(response['data'][0]['embedding'])
return np.array(embeddings)
# ==================== 向量数据库 ====================
class VectorStore:
"""向量存储接口"""
def __init__(self, embedding_model: EmbeddingModel):
self.embedding_model = embedding_model
self.chroma_client = chromadb.PersistentClient(
path="./chroma_db"
)
self.collection = None
def create_collection(self, name: str):
"""创建集合"""
self.collection = self.chroma_client.get_or_create_collection(
name=name,
metadata={"hnsw:space": "cosine"}
)
def add_documents(self, documents: List[Dict[str, str]]):
"""添加文档到向量库"""
texts = [doc['text'] for doc in documents]
embeddings = self.embedding_model.encode(texts)
ids = [f"doc_{i}" for i in range(len(documents))]
metadatas = [{k: v for k, v in doc.items() if k != 'text'}
for doc in documents]
self.collection.add(
embeddings=embeddings.tolist(),
documents=texts,
metadatas=metadatas,
ids=ids
)
def search(
self,
query: str,
n_results: int = 5,
filters: Optional[Dict] = None
) -> List[Dict]:
"""搜索相关文档"""
query_embedding = self.embedding_model.encode([query])[0]
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=filters
)
return [
{
'text': doc,
'metadata': meta,
'distance': dist
}
for doc, meta, dist in zip(
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
)
]
def delete_collection(self):
"""删除集合"""
if self.collection:
self.chroma_client.delete_collection(self.collection.name)
# ==================== RAG系统 ====================
class RAGSystem:
"""完整的RAG系统"""
def __init__(
self,
embedding_model: EmbeddingModel,
collection_name: str = "rag_knowledge_base"
):
self.embedding_model = embedding_model
self.vector_store = VectorStore(embedding_model)
self.vector_store.create_collection(collection_name)
self.splitter = SemanticSplitter()
def ingest_documents(
self,
documents: List[str],
metadatas: Optional[List[Dict]] = None
):
"""摄取文档到知识库"""
all_chunks = []
chunk_id = 0
for i, doc in enumerate(tqdm(documents, desc="处理文档")):
chunks = self.splitter.split(doc)
for chunk in chunks:
doc_data = {'text': chunk}
if metadatas and i < len(metadatas):
doc_data.update(metadatas[i])
doc_data['chunk_id'] = chunk_id
doc_data['source_doc_id'] = i
all_chunks.append(doc_data)
chunk_id += 1
print(f"切分完成:共{len(documents)}个文档 -> {len(all_chunks)}个块")
# 添加到向量库
self.vector_store.add_documents(all_chunks)
print(f"文档已添加到知识库")
def query(
self,
question: str,
top_k: int = 3,
filters: Optional[Dict] = None
) -> Dict:
"""查询RAG系统"""
# 检索相关文档
retrieved_docs = self.vector_store.search(
query=question,
n_results=top_k,
filters=filters
)
# 构建增强提示
context = "\n\n".join([
f"[文档片段 {i+1}]\n{doc['text']}"
for i, doc in enumerate(retrieved_docs)
])
prompt = f"""基于以下参考文档回答问题。如果文档中没有相关信息,请明确说明。
参考文档:
{context}
问题:{question}
回答:"""
return {
'prompt': prompt,
'retrieved_docs': retrieved_docs,
'context': context
}
def answer(
self,
question: str,
top_k: int = 3,
model: str = "gpt-3.5-turbo"
) -> Dict:
"""生成完整回答"""
# 获取增强提示
rag_result = self.query(question, top_k)
# 调用LLM生成回答
try:
response = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": "你是一个有帮助的AI助手。"},
{"role": "user", "content": rag_result['prompt']}
],
temperature=0.7,
max_tokens=1000
)
answer = response.choices[0].message.content
return {
'question': question,
'answer': answer,
'sources': rag_result['retrieved_docs']
}
except Exception as e:
return {
'question': question,
'answer': f"生成回答时出错: {str(e)}",
'sources': rag_result['retrieved_docs']
}
# ==================== 高级RAG技术 ====================
class HybridRAG(RAGSystem):
"""混合检索RAG(向量+关键词)"""
def __init__(
self,
embedding_model: EmbeddingModel,
collection_name: str = "hybrid_rag"
):
super().__init__(embedding_model, collection_name)
self.keyword_index = {} # 简化的关键词索引
def build_keyword_index(self, documents: List[str]):
"""构建关键词索引"""
from collections import defaultdict
import jieba # 中文分词
for i, doc in enumerate(documents):
words = jieba.cut(doc)
for word in words:
if len(word) > 1: # 忽略单字
if word not in self.keyword_index:
self.keyword_index[word] = []
self.keyword_index[word].append(i)
def keyword_search(self, query: str, top_k: int = 5) -> List[int]:
"""关键词搜索"""
import jieba
words = jieba.cut(query)
doc_scores = defaultdict(int)
for word in words:
if word in self.keyword_index:
for doc_id in self.keyword_index[word]:
doc_scores[doc_id] += 1
sorted_docs = sorted(doc_scores.items(), key=lambda x: -x[1])
return [doc_id for doc_id, _ in sorted_docs[:top_k]]
def query(self, question: str, top_k: int = 3, alpha: float = 0.5):
"""混合检索(结合向量和关键词)"""
# 向量检索
vector_results = self.vector_store.search(question, n_results=top_k * 2)
# 关键词检索
keyword_doc_ids = self.keyword_search(question, top_k=top_k * 2)
# 融合结果
fused_scores = {}
for i, doc in enumerate(vector_results):
doc_id = doc['metadata'].get('source_doc_id', i)
fused_scores[doc_id] = fused_scores.get(doc_id, 0) + alpha * (1 - i / top_k)
for i, doc_id in enumerate(keyword_doc_ids):
fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 - alpha) * (1 - i / top_k)
# 排序
sorted_docs = sorted(fused_scores.items(), key=lambda x: -x[1])
top_doc_ids = [doc_id for doc_id, _ in sorted_docs[:top_k]]
# 获取文档内容
final_docs = [vector_results[i] for i in range(min(top_k, len(vector_results)))]
return {
'prompt': self._build_prompt(question, final_docs),
'retrieved_docs': final_docs,
'fusion_scores': fused_scores
}
class AdaptiveRAG(RAGSystem):
"""自适应RAG(根据查询复杂度调整检索策略)"""
def __init__(self, embedding_model: EmbeddingModel):
super().__init__(embedding_model, "adaptive_rag")
# 简单的查询复杂度分类器
from sentence_transformers import util
self.complexity_model = embedding_model
def estimate_complexity(self, query: str) -> str:
"""估计查询复杂度"""
# 基于规则的简单判断
if len(query) < 20:
return "simple"
elif any(word in query for word in ["为什么", "如何", "怎样", "原理", "比较"]):
return "complex"
else:
return "medium"
def query(self, question: str, top_k: int = 3):
"""根据复杂度自适应检索"""
complexity = self.estimate_complexity(question)
if complexity == "simple":
# 简单查询:使用较少的检索文档
actual_top_k = max(1, top_k // 2)
elif complexity == "complex":
# 复杂查询:检索更多文档
actual_top_k = top_k * 2
else:
actual_top_k = top_k
return super().query(question, actual_top_k)
# ==================== 评估工具 ====================
class RAGEvaluator:
"""RAG系统评估器"""
@staticmethod
def evaluate_retrieval(
rag_system: RAGSystem,
test_questions: List[Dict[str, str]]
) -> Dict:
"""评估检索质量"""
results = {
'mrr': [], # Mean Reciprocal Rank
'precision_at_k': [],
'recall_at_k': []
}
for item in test_questions:
question = item['question']
relevant_docs = set(item.get('relevant_docs', []))
retrieved = rag_system.vector_store.search(question, n_results=10)
# 计算MRR
for i, doc in enumerate(retrieved):
doc_id = doc['metadata'].get('source_doc_id', -1)
if doc_id in relevant_docs:
results['mrr'].append(1 / (i + 1))
break
else:
results['mrr'].append(0)
# 计算Precision@K和Recall@K
for k in [3, 5, 10]:
retrieved_ids = set([
doc['metadata'].get('source_doc_id', -1)
for doc in retrieved[:k]
])
precision = len(retrieved_ids & relevant_docs) / k
recall = len(retrieved_ids & relevant_docs) / len(relevant_docs) if relevant_docs else 0
results['precision_at_k'].append({'k': k, 'score': precision})
results['recall_at_k'].append({'k': k, 'score': recall})
# 汇总结果
summary = {
'mrr': np.mean(results['mrr']),
'precision_at_3': np.mean([r['score'] for r in results['precision_at_k'] if r['k'] == 3]),
'precision_at_5': np.mean([r['score'] for r in results['precision_at_k'] if r['k'] == 5]),
'recall_at_3': np.mean([r['score'] for r in results['recall_at_k'] if r['k'] == 3]),
'recall_at_5': np.mean([r['score'] for r in results['recall_at_k'] if r['k'] == 5]),
}
return summary
# ==================== 主程序 ====================
def main():
print("="*70)
print("RAG检索增强生成系统演示")
print("="*70)
# 示例文档
sample_documents = [
"""
人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。
这些任务包括语言理解、学习、推理、感知和问题解决。现代AI技术主要基于机器学习和深度学习。
""",
"""
机器学习是AI的子领域,使计算机能够从数据中学习并改进,而无需明确编程。
主要的机器学习类型包括监督学习、无监督学习和强化学习。
""",
"""
深度学习是机器学习的分支,使用多层神经网络从大量数据中学习。
它在图像识别、自然语言处理和语音识别等领域取得了突破性成果。
""",
"""
自然语言处理(NLP)是AI的重要应用领域,专注于计算机与人类语言之间的交互。
主要任务包括文本分类、情感分析、机器翻译和问答系统。
""",
"""
计算机视觉是使计算机能够理解和解释视觉信息的技术。
应用包括人脸识别、物体检测、图像分割和自动驾驶中的视觉感知。
"""
]
# 创建RAG系统
print("\n初始化RAG系统...")
embedding_model = SentenceTransformerEmbedding()
rag_system = RAGSystem(embedding_model)
# 摄取文档
print("\n摄取文档到知识库...")
rag_system.ingest_documents(sample_documents)
# 查询示例
questions = [
"什么是深度学习?",
"机器学习和人工智能有什么关系?",
"计算机视觉有哪些应用?",
"自然语言处理的主要任务是什么?"
]
print("\n" + "="*70)
print("RAG查询演示")
print("="*70)
for question in questions:
print(f"\n问题:{question}")
print("-"*70)
result = rag_system.query(question, top_k=2)
print("检索到的相关文档:")
for i, doc in enumerate(result['retrieved_docs'], 1):
print(f"\n[文档 {i}] (相似度: {1-doc['distance']:.3f})")
print(doc['text'][:200] + "...")
print(f"\n生成的增强提示:")
print(result['prompt'][:300] + "...")
# RAG技术对比表
print("\n" + "="*70)
print("RAG技术对比")
print("="*70)
comparison_table = """
| RAG变体 | 特点 | 优势 | 劣势 | 适用场景 |
|---------|------|------|------|----------|
| Naive RAG | 一次检索直接生成 | 简单快速 | 上下文可能不足 | 简单问答 |
| Recursive RAG | 迭代检索细化 | 提高检索质量 | 计算开销大 | 复杂查询 |
| Hybrid RAG | 向量+关键词检索 | 互补优势 | 实现复杂 | 多样化查询 |
| Adaptive RAG | 根据查询调整策略 | 动态优化 | 需要额外训练 | 生产环境 |
| Modular RAG | 模块化可配置 | 灵活性高 | 集成复杂 | 企业应用 |
"""
print(comparison_table)
# 保存配置
config = {
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"vector_db": "chromadb",
"chunk_size": 1000,
"overlap": 200,
"top_k": 3,
"temperature": 0.7
}
with open('C:/Users/PC/Desktop/MD/rag_config.json', 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
print("\n配置已保存到 rag_config.json")
# ==================== 资源表格 ====================
resources = """
## 相关资源
| 资源类型 | 名称 | 链接 |
|---------|------|------|
| 论文 | Retrieval-Augmented Generation for Knowledge-Intensive NLP | https://arxiv.org/abs/2005.11401 |
| 框架 | LangChain | https://python.langchain.com/ |
| 框架 | LlamaIndex | https://www.llamaindex.ai/ |
| 向量数据库 | ChromaDB | https://www.trychroma.com/ |
| 向量数据库 | Pinecone | https://www.pinecone.io/ |
| 向量数据库 | Milvus | https://milvus.io/ |
| 嵌入模型 | Sentence Transformers | https://www.sbert.net/ |
| 嵌入模型 | BGE Embeddings | https://github.com/FlagOpen/FlagEmbedding |
| 教程 | RAG Tutorial | https://github.com/langchain-ai/rag |
## 向量数据库对比
| 数据库 | 开源 | 云服务 | 性能 | 特点 |
|--------|------|--------|------|------|
| ChromaDB | ✓ | - | 中 | 轻量级,易用 |
| Pinecone | - | ✓ | 高 | 托管服务,可扩展 |
| Milvus | ✓ | ✓ | 高 | 分布式,高性能 |
| Weaviate | ✓ | ✓ | 中 | 模块化架构 |
| Qdrant | ✓ | ✓ | 高 | 滤波能力强 |
## 扩展阅读
- **Advanced RAG**: 查询重写、重排序、混合检索
- **GraphRAG**: 结合知识图谱的RAG
- **Self-RAG**: 自我反思的RAG系统
- **RAGAS**: RAG系统评估框架
"""
if __name__ == "__main__":
main()
print(resources)查询重写:将用户查询改写为更适合检索的形式
# 示例:将"如何减肥"改写为"减肥方法 饮食 运动"
def rewrite_query(query: str) -> str:
# 实现查询扩展逻辑
pass混合检索:结合向量和关键词检索
重排序(Reranking):对初步检索结果重新排序
# 使用交叉编码器重排序
from sentence_transformers import CrossEncoder
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')提示工程:设计更有效的提示模板
上下文窗口管理:智能选择最相关的上下文
引用来源:在回答中标注参考文档
缓存机制:缓存常见问题的答案
批处理:批量处理查询提高吞吐量
异步处理:异步检索和生成
场景 | 推荐方案 | 理由 |
|---|---|---|
需要最新信息 | RAG | 更新文档即可 |
私有数据使用 | RAG | 数据安全可控 |
改变说话风格 | 微调 | 需要模型调整 |
特殊格式输出 | 微调 | 需要训练学习 |
两者结合 | RAG + 微调 | 最佳效果 |
RAG通过将检索与生成结合,为大语言模型提供了访问外部知识的能力,有效解决了知识截止和幻觉问题。随着向量数据库、嵌入模型和LLM的不断发展,RAG将成为知识密集型AI应用的核心架构。
掌握RAG技术,开发者可以构建更准确、可靠、可解释的AI应用,满足企业级应用的需求。