从零实现 RAG

本章用最少组件实现一个 RAG。目标不是替代 LlamaIndex 或 LangChain,而是理解 RAG 的关键数据流。

最小架构

load documents
  -> split into chunks
  -> embed chunks
  -> retrieve top-k chunks
  -> build prompt with citations
  -> call LLM
  -> return answer and sources

Step 1: 准备文档

documents = [
    {
        "id": "doc-1",
        "title": "过拟合",
        "text": "过拟合是指模型在训练集表现很好,但在测试集或真实数据上表现较差。常见缓解方法包括正则化、增加数据、早停和降低模型复杂度。"
    },
    {
        "id": "doc-2",
        "title": "梯度下降",
        "text": "梯度下降是一种优化方法。它沿着损失函数梯度的反方向更新参数,从而逐步降低损失。"
    }
]

真实项目中要保留 sourceurlsectionupdated_at 等 metadata。

Step 2: 切分 chunk

def split_text(text: str, chunk_size: int = 120, overlap: int = 20) -> list[str]:
    chunks = []
    start = 0
    while start < len(text):
        end = start + chunk_size
        chunks.append(text[start:end])
        start = end - overlap
    return chunks


chunks = []
for doc in documents:
    for i, text in enumerate(split_text(doc["text"])):
        chunks.append({
            "chunk_id": f"{doc['id']}-{i}",
            "doc_id": doc["id"],
            "title": doc["title"],
            "text": text
        })

Step 3: 生成向量

from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

texts = [chunk["text"] for chunk in chunks]
vectors = embedder.encode(texts, normalize_embeddings=True)

Step 4: 检索

import numpy as np


def retrieve(query: str, top_k: int = 3) -> list[dict]:
    query_vector = embedder.encode([query], normalize_embeddings=True)[0]
    scores = vectors @ query_vector
    top_indices = np.argsort(scores)[::-1][:top_k]

    results = []
    for idx in top_indices:
        item = dict(chunks[idx])
        item["score"] = float(scores[idx])
        results.append(item)
    return results

Step 5: 构造上下文

def build_context(results: list[dict]) -> str:
    lines = []
    for i, item in enumerate(results, start=1):
        lines.append(
            f"[{i}] title={item['title']} source={item['chunk_id']}\n{item['text']}"
        )
    return "\n\n".join(lines)

Step 6: 调用模型

from openai import OpenAI

client = OpenAI()


def answer(query: str) -> str:
    results = retrieve(query)
    context = build_context(results)

    prompt = f"""
你是严谨的知识库问答助手。

规则:
1. 仅依据 context 回答。
2. 如果 context 没有答案,回答“资料不足”。
3. 每个关键结论都要标注引用编号,例如 [1]。

context:
{context}

question:
{query}
"""

    response = client.responses.create(
        model="your-model",
        input=prompt
    )
    return response.output_text

完整流程检查

用这个问题测试:

print(answer("什么是过拟合?怎么缓解?"))

理想答案应该:

什么时候引入框架

从零实现能帮你理解核心流程,但真实项目通常需要框架:

需求 可以引入
多格式文档解析 LlamaIndex、LangChain document loaders
复杂索引 LlamaIndex、向量数据库
多步骤工作流 LangGraph、LlamaIndex Workflows
评估和观测 Ragas、DeepEval、LangSmith、Arize Phoenix、Langfuse

常见改进点

Query rewrite

用户问题可能太短或上下文不足,可以先让模型改写查询:

把用户问题改写成适合检索知识库的查询,保留原意,不添加新事实。

向量检索适合语义相似,关键词检索适合精确术语、代码符号、ID、日期。实际系统经常组合 BM25 + embedding。

Rerank

先召回更多 chunk,再用 reranker 排序,可以提升引用质量。

Citation check

生成答案后再检查每个引用是否真的支持对应句子。

下一步