摘要
本文介绍一个面向数据质量保障的RAG(检索增强生成)系统的设计与实现,核心在于通过明确的"系统契约"来定义组件边界与数据流规格,并支持契约的动态演进。项目提供一个可运行的最小化实现,涵盖文档加载、向量检索、契约验证、质量监控与演进管理等多个模块。通过代码与架构图,阐述了如何将数据质量指标(如完整性、新鲜度、相关性)内嵌于系统交互中,利用契约版本化与可观测性工具保障检索结果可靠性,从而构建一个健壮、可维护的RAG应用。
1. 项目概述与设计思路
在构建生产级RAG(Retrieval-Augmented Generation)系统时,面临的核心挑战之一是数据质量问题:来源文档是否完整、准确、及时?检索到的片段是否真正相关?生成答案时如何确保信息不过时?传统的做法往往是在各个组件内部零散地处理这些问题,导致系统耦合度高,问题难以追踪和修复。
本项目提出一种基于系统契约(System Contract) 的设计范式。我们将RAG系统中的关键交互(如文档加载、查询检索、答案生成)抽象为明确的契约接口。每个契约不仅定义了输入输出格式,更关键的是,它内嵌了数据质量规格(Data Quality Specifications),例如:
- 完整性契约:文档解析后必须包含标题、正文、原始URL等字段。
- 新鲜度契约:文档元数据中必须包含
last_updated时间戳,且检索结果应优先考虑近期文档。 - 相关性契约:向量检索返回的片段必须与查询问题有足够高的相似度得分。
这些契约构成了系统各组件之间的"边界"。组件只需遵循契约,而无需了解对方内部实现。更重要的是,契约本身是可版本化和可演进的。当业务需求或数据源变化时,我们可以定义新的契约版本,系统可以平滑迁移,旧版本契约在一段时间内仍被支持,确保了向后兼容性。
项目的核心目标是通过一个轻量级但完整的代码框架,演示如何:
- 定义契约:使用Pydantic模型清晰地定义数据结构和质量规则。
- 实现契约化的组件:构建遵守特定契约的文档加载器、向量索引和检索器。
- 执行契约检查:在关键数据流节点自动验证契约合规性。
- 监控与演进:记录契约验证结果(可观测性),并管理契约版本的生命周期。
以下Mermaid图展示了本项目的核心系统架构与数据流:
图1:系统架构与数据流图。粉色节点(C,I)代表核心的契约检查点,蓝色区域(J-N)代表负责质量保障与契约演进的治理层。
2. 项目结构树
rag_contract_system/
├── pyproject.toml # 项目依赖与配置
├── .env.example # 环境变量示例
├── main.py # 应用主入口
├── core/
│ ├── __init__.py
│ ├── contracts/ # 契约定义目录
│ │ ├── __init__.py
│ │ ├── base.py # 基础契约类
│ │ ├── document_v1.py # 文档契约 v1
│ │ ├── document_v2.py # 文档契约 v2 (演进示例)
│ │ └── retrieval_v1.py # 检索契约 v1
│ ├── processors/ # 数据处理组件
│ │ ├── __init__.py
│ │ ├── document_loader.py # 文档加载器
│ │ └── text_splitter.py # 文本分块器
│ ├── retrieval/ # 检索组件
│ │ ├── __init__.py
│ │ ├── vector_store.py # 简易向量存储
│ │ └── retriever.py # 检索器
│ ├── llm/ # LLM集成
│ │ ├── __init__.py
│ │ └── generator.py # 答案生成器
│ └── governance/ # 治理组件
│ ├── __init__.py
│ ├── validator.py # 契约验证器
│ ├── monitor.py # 监控器(模拟)
│ └── version_manager.py # 契约版本管理器
├── data/ # 示例数据目录
│ └── sample_docs/
│ ├── doc1.pdf
│ └── doc2.html
└── tests/ # 单元测试
├── __init__.py
├── test_contracts.py
└── test_retriever.py
3. 核心代码实现
文件路径:pyproject.toml
[project]
name = "rag-contract-system"
version = "0.1.0"
description = "A RAG system with data quality contracts"
authors = [{name = "Developer"}]
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"pydantic>=2.0",
"python-dotenv>=1.0",
"openai>=1.0", # 用于嵌入和生成
"tiktoken", # 用于token计数
"PyPDF2>=3.0", # 用于PDF解析
"beautifulsoup4>=4.12", # 用于HTML解析
"numpy>=1.24",
"scikit-learn>=1.3", # 用于简易向量相似度计算
"loguru>=0.7", # 结构化日志
]
[project.optional-dependencies]
dev = ["pytest>=7.0", "black", "mypy"]
文件路径:core/contracts/base.py
"""
基础契约类与元数据定义。
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, ConfigDict
class ContractCompliance(str, Enum):
"""契约合规状态枚举。"""
COMPLIANT = "COMPLIANT"
VIOLATED = "VIOLATED"
WARNING = "WARNING" # 部分违反或质量略低
class QualityMetric(BaseModel):
"""单一质量指标的模型。"""
name: str
value: float
threshold: Optional[float] = None
unit: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ContractValidationResult(BaseModel):
"""契约验证结果。"""
contract_id: str # 如 "document.v1"
contract_version: str
data_entity_id: Optional[str] = None # 被验证的数据实体标识(如文档ID)
compliance: ContractCompliance
metrics: Dict[str, QualityMetric] = Field(default_factory=dict)
violations: list[str] = Field(default_factory=list) # 违反的规则描述
timestamp: datetime = Field(default_factory=datetime.utcnow)
def is_compliant(self) -> bool:
return self.compliance == ContractCompliance.COMPLIANT
class BaseContract(BaseModel):
"""
所有系统契约的基类。
定义了契约的标识、版本和通用的验证入口。
"""
contract_id: str # 唯一标识符,如"document_ingestion"
version: str = "1.0.0"
def validate(self, data: Any) -> ContractValidationResult:
"""
验证输入数据是否符合本契约。
子类必须实现具体的验证逻辑。
"""
raise NotImplementedError("Subclasses must implement this method.")
文件路径:core/contracts/document_v1.py
"""
文档摄取契约版本1。
定义了从原始文档解析后,数据块必须满足的质量规格。
"""
import re
from datetime import datetime
from typing import List
from pydantic import field_validator
from .base import BaseContract, ContractCompliance, ContractValidationResult, QualityMetric
class DocumentChunkV1(BaseContract):
"""
契约 v1: 文档块结构。
包含内容、元数据及必须满足的质量规则。
"""
contract_id: str = "document_chunk"
version: str = "1.0.0"
# ----- 数据结构定义 -----
id: str # 块唯一ID
text: str # 块文本内容
source: str # 来源标识,如文件路径、URL
chunk_index: int # 在原文中的顺序
metadata: dict = {} # 扩展元数据
# ----- 内嵌的质量规则(通过Pydantic验证器实现)-----
@field_validator('text')
@classmethod
def validate_text_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError('Chunk text cannot be empty.')
# 简单检查是否有过多无意义字符
if len(re.findall(r'[�]', v)) > 5:
raise ValueError('Chunk contains too many unrecognized characters.')
return v.strip()
@field_validator('metadata')
@classmethod
def validate_metadata_has_source_url(cls, v: dict) -> dict:
if 'source_url' not in v and 'file_path' not in v:
raise ValueError('Metadata must contain either source_url or file_path.')
return v
# ----- 显式契约验证方法 -----
def validate_self(self) -> ContractValidationResult:
"""
验证当前文档块实例是否符合契约v1的所有规则。
此方法演示了更复杂的、超出Pydantic内置验证器的检查。
"""
result = ContractValidationResult(
contract_id=self.contract_id,
contract_version=self.version,
data_entity_id=self.id,
compliance=ContractCompliance.COMPLIANT,
violations=[]
)
# 计算并记录质量指标
word_count = len(self.text.split())
result.metrics['word_count'] = QualityMetric(name='word_count', value=word_count, unit='words')
# 规则1: 文本长度应在合理范围内
if word_count < 10:
result.compliance = ContractCompliance.VIOLATED
result.violations.append("Chunk text is too short (less than 10 words).")
elif word_count > 500:
result.metrics['word_count'].threshold = 500
# 超过500词不是致命错误,但记录警告
if result.compliance == ContractCompliance.COMPLIANT:
result.compliance = ContractCompliance.WARNING
result.violations.append("Chunk text is longer than recommended 500 words.")
# 规则2: 检查元数据中是否有更新时间(新鲜度)
last_updated = self.metadata.get('last_updated')
if last_updated:
try:
# 假设last_updated是iso格式字符串或datetime
if isinstance(last_updated, str):
update_time = datetime.fromisoformat(last_updated.replace('Z', '+00:00'))
else:
update_time = last_updated
days_old = (datetime.utcnow() - update_time).days
result.metrics['data_freshness_days'] = QualityMetric(
name='data_freshness', value=days_old, unit='days'
)
if days_old > 365: # 超过1年视为陈旧
if result.compliance == ContractCompliance.COMPLIANT:
result.compliance = ContractCompliance.WARNING
result.violations.append(f"Source data is old ({days_old} days).")
except (ValueError, TypeError):
result.violations.append("Metadata 'last_updated' format is invalid.")
else:
# 没有更新时间信息,记录警告
if result.compliance == ContractCompliance.COMPLIANT:
result.compliance = ContractCompliance.WARNING
result.violations.append("Metadata lacks 'last_updated' field (freshness unknown).")
return result
文件路径:core/contracts/retrieval_v1.py
"""
检索结果契约版本1。
定义了从向量存储返回的检索结果必须满足的质量规格。
"""
from typing import List
from pydantic import Field
from .base import BaseContract, ContractCompliance, ContractValidationResult, QualityMetric
class RetrievedChunkV1(BaseContract):
"""
契约 v1: 检索到的单个文档块,包含相似度得分。
"""
contract_id: str = "retrieved_chunk"
version: str = "1.0.0"
chunk_id: str
text: str
source: str
similarity_score: float = Field(ge=0.0, le=1.0) # 强制范围在0-1
metadata: dict = {}
class RetrievalResultV1(BaseContract):
"""
契约 v1: 一次检索操作的整体结果。
包含查询和返回的块列表,并定义了结果集级别的质量规则。
"""
contract_id: str = "retrieval_result"
version: str = "1.0.0"
query: str
query_id: Optional[str] = None
retrieved_chunks: List[RetrievedChunkV1] = Field(default_factory=list)
top_k: int = 5
def validate_self(self) -> ContractValidationResult:
result = ContractValidationResult(
contract_id=self.contract_id,
contract_version=self.version,
data_entity_id=self.query_id or f"query_{hash(self.query)}",
compliance=ContractCompliance.COMPLIANT,
violations=[]
)
# 指标1: 返回结果数量
retrieved_count = len(self.retrieved_chunks)
result.metrics['retrieved_count'] = QualityMetric(
name='retrieved_count', value=retrieved_count
)
if retrieved_count == 0:
result.compliance = ContractCompliance.VIOLATED
result.violations.append("Retrieval returned zero chunks.")
return result # 无结果,无需进行后续检查
# 规则1: 应返回请求的top_k个结果(除非库存不足)
if retrieved_count < min(3, self.top_k): # 至少返回3个或top_k中较小的那个
result.compliance = ContractCompliance.WARNING
result.violations.append(f"Retrieved only {retrieved_count} chunks, less than expected.")
# 指标2 & 规则2: 平均相似度得分
avg_score = sum(c.similarity_score for c in self.retrieved_chunks) / retrieved_count
result.metrics['avg_similarity'] = QualityMetric(
name='avg_similarity', value=avg_score, threshold=0.6
)
if avg_score < 0.6: # 阈值示例
result.compliance = ContractCompliance.VIOLATED
result.violations.append(f"Average similarity score ({avg_score:.3f}) is below threshold 0.6.")
# 指标3 & 规则3: 得分分布(通过最高与最低得分差衡量)
scores = [c.similarity_score for c in self.retrieved_chunks]
score_range = max(scores) - min(scores)
result.metrics['score_range'] = QualityMetric(name='score_range', value=score_range)
if score_range > 0.5: # 得分差异过大,可能检索不够精确
if result.compliance == ContractCompliance.COMPLIANT:
result.compliance = ContractCompliance.WARNING
result.violations.append(f"Similarity scores vary widely (range: {score_range:.3f}).")
# 规则4: 去重 - 检查是否有来源完全相同的块(简单示例)
source_set = {(c.chunk_id, c.source) for c in self.retrieved_chunks}
if len(source_set) < retrieved_count:
result.violations.append("Potential duplicate chunks in retrieval results.")
return result
文件路径:core/processors/document_loader.py
"""
契约化的文档加载器。
负责加载原始文档,解析成文本和元数据,并包装成符合DocumentChunkV1契约的对象。
"""
import hashlib
from pathlib import Path
from typing import List, Union
import PyPDF2
from bs4 import BeautifulSoup
from loguru import logger
from ..contracts.document_v1 import DocumentChunkV1
class ContractualDocumentLoader:
"""遵守文档契约的加载器。"""
def __init__(self, default_metadata: dict = None):
self.default_metadata = default_metadata or {}
def load_from_file(self, file_path: Union[str, Path]) -> List[DocumentChunkV1]:
"""从文件加载并生成初始文档块(未分块)。"""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"Document file not found: {file_path}")
logger.info(f"Loading document: {path.name}")
raw_text = ""
metadata = {
'file_path': str(path.absolute()),
'file_name': path.name,
'file_size': path.stat().st_size,
'loader': self.__class__.__name__,
**self.default_metadata
}
# 解析不同格式
suffix = path.suffix.lower()
if suffix == '.pdf':
raw_text, metadata = self._parse_pdf(path, metadata)
elif suffix in ['.html', '.htm']:
raw_text, metadata = self._parse_html(path, metadata)
elif suffix == '.txt':
raw_text = path.read_text(encoding='utf-8')
else:
raise ValueError(f"Unsupported file format: {suffix}")
# 创建初始的文档块(整个文档作为一个块,后续会分块)
doc_id = hashlib.md5(f"{path.name}_{path.stat().st_size}".encode()).hexdigest()[:16]
initial_chunk = DocumentChunkV1(
id=f"doc_{doc_id}",
text=raw_text,
source=str(path),
chunk_index=0,
metadata=metadata
)
# **关键:在返回前进行自我验证**
validation_result = initial_chunk.validate_self()
if not validation_result.is_compliant():
logger.warning(
f"Document {path.name} initial validation issues: {validation_result.violations}. "
f"Proceeding with compliance: {validation_result.compliance}"
)
# 将验证结果也存入元数据,供后续追踪
initial_chunk.metadata['initial_validation'] = validation_result.model_dump()
return [initial_chunk]
def _parse_pdf(self, path: Path, metadata: dict) -> tuple[str, dict]:
"""解析PDF文件。"""
text_parts = []
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
metadata['page_count'] = len(reader.pages)
metadata['pdf_info'] = reader.metadata or {}
for i, page in enumerate(reader.pages):
page_text = page.extract_text()
if page_text.strip():
text_parts.append(f"--- Page {i+1} ---\n{page_text}")
return "\n\n".join(text_parts), metadata
def _parse_html(self, path: Path, metadata: dict) -> tuple[str, dict]:
"""解析HTML文件,提取主体文本。"""
html_content = path.read_text(encoding='utf-8')
soup = BeautifulSoup(html_content, 'html.parser')
# 移除脚本、样式等标签
for element in soup(["script", "style", "header", "footer", "nav"]):
element.decompose()
main_content = soup.find('main') or soup.find('article') or soup.body
if main_content:
raw_text = main_content.get_text(separator='\n', strip=True)
# 提取标题
title_tag = soup.find('title') or soup.find('h1')
if title_tag:
metadata['html_title'] = title_tag.get_text(strip=True)
else:
raw_text = soup.get_text(separator='\n', strip=True)
return raw_text, metadata
文件路径:core/retrieval/vector_store.py
"""
简易的向量存储实现,用于演示。
生产环境应替换为Chroma、Weaviate、Pinecone等专业向量数据库。
"""
import json
from typing import List, Dict, Any, Optional
import numpy as np
from loguru import logger
from sklearn.metrics.pairwise import cosine_similarity
class SimpleVectorStore:
"""内存中的简易向量存储,支持增删查。"""
def __init__(self, dimension: int = 1536): # 默认OpenAI embedding维度
self.dimension = dimension
self.chunks: List[Dict[str, Any]] = [] # 存储块数据
self.embeddings: Optional[np.ndarray] = None # 存储所有向量的矩阵
self.next_id = 0
def add_chunks(self, chunk_data: List[Dict[str, Any]], embeddings: np.ndarray):
"""
添加块及其向量。
chunk_data: 每个元素是包含`id`, `text`, `source`, `metadata`等的字典。
embeddings: 形状为 (n_chunks, dimension) 的numpy数组。
"""
if len(chunk_data) != embeddings.shape[0]:
raise ValueError("Number of chunks must match number of embeddings.")
if embeddings.shape[1] != self.dimension:
raise ValueError(f"Embedding dimension mismatch. Expected {self.dimension}, got {embeddings.shape[1]}.")
for i, chunk in enumerate(chunk_data):
chunk['internal_id'] = self.next_id
self.next_id += 1
self.chunks.append(chunk)
if self.embeddings is None:
self.embeddings = embeddings
else:
self.embeddings = np.vstack([self.embeddings, embeddings])
logger.info(f"Added {len(chunk_data)} chunks. Total chunks: {len(self.chunks)}")
def search(self, query_embedding: np.ndarray, top_k: int = 5, metadata_filter: Optional[dict] = None) -> List[Dict[str, Any]]:
"""
语义搜索。
返回包含`chunk_id`, `text`, `source`, `similarity_score`, `metadata`的字典列表。
"""
if self.embeddings is None or len(self.chunks) == 0:
return []
# 1. 计算余弦相似度
# query_embedding 形状: (1, dimension) -> 调整为 (dimension,)
query_vec = query_embedding.flatten().reshape(1, -1)
similarities = cosine_similarity(query_vec, self.embeddings).flatten() # 形状: (n_chunks,)
# 2. 构建结果列表(包含索引和分数)
scored_results = []
for idx, score in enumerate(similarities):
chunk = self.chunks[idx]
# 应用元数据过滤器(简易实现,仅匹配键值对)
if metadata_filter:
if not all(chunk.get('metadata', {}).get(k) == v for k, v in metadata_filter.items()):
continue
scored_results.append((idx, score, chunk))
# 3. 按分数降序排序,取前top_k个
scored_results.sort(key=lambda x: x[1], reverse=True)
top_results = scored_results[:top_k]
# 4. 格式化为输出
retrieved = []
for idx, score, chunk in top_results:
retrieved.append({
'chunk_id': chunk['id'],
'text': chunk['text'],
'source': chunk['source'],
'similarity_score': float(score), # 转换为Python float
'metadata': chunk.get('metadata', {})
})
return retrieved
def save(self, filepath: str):
"""保存存储状态到文件(演示用,非生产级)。"""
data = {
'dimension': self.dimension,
'chunks': self.chunks,
'embeddings': self.embeddings.tolist() if self.embeddings is not None else None,
'next_id': self.next_id
}
with open(filepath, 'w') as f:
json.dump(data, f, default=str)
def load(self, filepath: str):
"""从文件加载存储状态。"""
with open(filepath, 'r') as f:
data = json.load(f)
self.dimension = data['dimension']
self.chunks = data['chunks']
self.embeddings = np.array(data['embeddings']) if data['embeddings'] else None
self.next_id = data['next_id']
文件路径:core/governance/validator.py
"""
契约验证器。
作为独立的服务组件,在数据流的关键节点验证契约合规性。
"""
from typing import Type, Dict, Any
from loguru import logger
from ..contracts.base import BaseContract, ContractValidationResult
from ..contracts.document_v1 import DocumentChunkV1
from ..contracts.retrieval_v1 import RetrievalResultV1
class ContractValidator:
"""管理各类契约的验证。"""
# 契约类型注册表
_contract_registry: Dict[str, Type[BaseContract]] = {
DocumentChunkV1.__fields__['contract_id'].default: DocumentChunkV1,
RetrievalResultV1.__fields__['contract_id'].default: RetrievalResultV1,
}
def __init__(self, monitor=None): # monitor 用于发送验证结果
self.monitor = monitor
def validate(self, contract_id: str, data: Any) -> ContractValidationResult:
"""
通用验证入口。
contract_id: 要验证的契约ID。
data: 可以是字典(用于创建契约实例)或已经是契约实例。
"""
contract_cls = self._contract_registry.get(contract_id)
if not contract_cls:
return ContractValidationResult(
contract_id=contract_id,
contract_version="unknown",
compliance="VIOLATED",
violations=[f"Unknown contract ID: {contract_id}"]
)
try:
# 如果data不是契约实例,则尝试实例化
if not isinstance(data, BaseContract):
# 假设data是适合该契约模型的字典
contract_instance = contract_cls(**data)
else:
contract_instance = data
# 调用契约实例的验证方法
if hasattr(contract_instance, 'validate_self'):
validation_result = contract_instance.validate_self()
else:
# 对于没有自定义validate_self的契约,仅做基础验证(Pydantic已做)
validation_result = ContractValidationResult(
contract_id=contract_id,
contract_version=contract_instance.version,
compliance="COMPLIANT",
violations=[]
)
except Exception as e:
# 实例化或验证过程中发生异常,视为严重违反
logger.exception(f"Contract validation failed for {contract_id}")
validation_result = ContractValidationResult(
contract_id=contract_id,
contract_version="unknown",
compliance="VIOLATED",
violations=[f"Validation process error: {str(e)}"]
)
# 将验证结果发送到监控系统(如果存在)
if self.monitor:
self.monitor.record_validation(validation_result)
logger.log(
"WARNING" if validation_result.compliance != "COMPLIANT" else "INFO",
f"Contract {contract_id} validation: {validation_result.compliance}. "
f"Violations: {len(validation_result.violations)}"
)
return validation_result
文件路径:core/governance/version_manager.py
"""
契约版本管理器。
处理契约的演进、版本兼容性和迁移。
"""
from typing import Dict, Any, Optional
from datetime import datetime
from loguru import logger
from ..contracts.base import BaseContract
class ContractVersionManager:
"""
管理契约的生命周期。
支持:
1. 注册新版本契约。
2. 查询当前活跃版本。
3. 将旧版本数据升级到新版本(如果支持)。
4. 弃用旧版本。
"""
def __init__(self):
self.contract_versions: Dict[str, Dict[str, Any]] = {} # contract_id -> {version -> details}
self.active_versions: Dict[str, str] = {} # contract_id -> active_version
self.deprecation_schedule: Dict[str, Dict[str, datetime]] = {} # contract_id.version -> deprecation_date
def register_version(
self,
contract_cls: Type[BaseContract],
is_active: bool = False,
migration_func: Optional[callable] = None
):
"""注册一个契约版本。"""
contract_id = contract_cls.__fields__['contract_id'].default
version = contract_cls.version
if contract_id not in self.contract_versions:
self.contract_versions[contract_id] = {}
self.contract_versions[contract_id][version] = {
'class': contract_cls,
'registered_at': datetime.utcnow(),
'migration_func': migration_func # 从上一版本升级的函数
}
if is_active:
self.active_versions[contract_id] = version
logger.info(f"Registered active version {version} for contract '{contract_id}'.")
else:
logger.info(f"Registered version {version} for contract '{contract_id}' (not active).")
def get_active_version(self, contract_id: str) -> Optional[str]:
"""获取指定契约的当前活跃版本号。"""
return self.active_versions.get(contract_id)
def migrate_data(
self, contract_id: str, data: Dict[str, Any], from_version: str, to_version: str
) -> Optional[Dict[str, Any]]:
"""
尝试将数据从一个契约版本迁移到另一个版本。
返回迁移后的数据字典,如果无法迁移则返回None。
"""
if contract_id not in self.contract_versions:
logger.error(f"Unknown contract ID: {contract_id}")
return None
if to_version not in self.contract_versions[contract_id]:
logger.error(f"Target version {to_version} not registered for {contract_id}")
return None
# 简化逻辑:假设我们只支持从 v1 到 v2 的迁移
# 实际项目中这里可能是一系列复杂的转换规则
if from_version == "1.0.0" and to_version == "1.1.0":
logger.info(f"Migrating {contract_id} from {from_version} to {to_version}")
# 示例:为v1.1.0添加一个默认字段
if contract_id == "document_chunk":
data.setdefault('metadata', {}).setdefault('migrated_to_v1.1.0', True)
# 可以调用注册的 migration_func
target_info = self.contract_versions[contract_id][to_version]
if target_info['migration_func']:
return target_info['migration_func'](data)
return data
else:
logger.warning(f"No migration path from {from_version} to {to_version} for {contract_id}")
return None
文件路径:main.py
"""
应用主入口,演示完整的工作流程。
"""
import os
from dotenv import load_dotenv
import numpy as np
from loguru import logger
from core.processors.document_loader import ContractualDocumentLoader
from core.processors.text_splitter import RecursiveCharacterTextSplitter
from core.retrieval.vector_store import SimpleVectorStore
from core.retrieval.retriever import SemanticRetriever
from core.llm.generator import OpenAIGenerator
from core.governance.validator import ContractValidator
from core.governance.monitor import LoggingMonitor
from core.governance.version_manager import ContractVersionManager
# 导入契约以注册
from core.contracts.document_v1 import DocumentChunkV1
from core.contracts.retrieval_v1 import RetrievalResultV1
# 加载环境变量(例如OPENAI_API_KEY)
load_dotenv()
def main():
logger.add("logs/rag_system_{time}.log", rotation="1 day", level="INFO")
# 0. 初始化治理组件
monitor = LoggingMonitor()
validator = ContractValidator(monitor=monitor)
version_manager = ContractVersionManager()
version_manager.register_version(DocumentChunkV1, is_active=True)
version_manager.register_version(RetrievalResultV1, is_active=True)
# 1. 加载文档
logger.info("Step 1: Loading documents...")
loader = ContractualDocumentLoader(default_metadata={'domain': 'tech_docs'})
# 假设我们有一个示例文档目录 `data/sample_docs/`
sample_doc_path = "data/sample_docs/sample_rag_article.txt"
# 为了演示,我们创建一个简单的文本文件
os.makedirs("data/sample_docs", exist_ok=True)
with open(sample_doc_path, 'w', encoding='utf-8') as f:
f.write("""
# Retrieval-Augmented Generation (RAG)
This is a sample article about RAG systems. RAG combines the power of dense vector retrieval with large language models.
It helps to ground LLM responses in factual, up-to-date information from a knowledge base.
## Key Components
1. Document Ingestion: Load and chunk your documents.
2. Vector Embedding: Convert text into numerical vectors.
3. Retrieval: Find relevant chunks for a user query.
4. Generation: LLM synthesizes an answer using retrieved context.
## Data Quality is crucial. Contracts can define rules for completeness, freshness, and relevance.
Last updated: 2024-05-15
""")
initial_chunks = loader.load_from_file(sample_doc_path)
# 2. 文本分块
logger.info("Step 2: Splitting text into chunks...")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
all_chunks = []
for initial_chunk in initial_chunks:
chunks = text_splitter.split_document(initial_chunk)
# 验证每个分块后的块
for chunk in chunks:
validation_res = validator.validate(DocumentChunkV1.__fields__['contract_id'].default, chunk)
if validation_res.is_compliant() or validation_res.compliance == "WARNING":
all_chunks.append(chunk)
else:
logger.error(f"Chunk {chunk.id} failed validation: {validation_res.violations}")
logger.info(f"Created {len(all_chunks)} validated chunks.")
# 3. 生成嵌入并存入向量存储(这里模拟嵌入)
logger.info("Step 3: Generating embeddings and populating vector store...")
vector_store = SimpleVectorStore(dimension=1536)
# 为简单起见,我们使用随机嵌入来模拟。实际应使用OpenAI/BGE等模型。
chunk_data_for_store = []
simulated_embeddings = []
for i, chunk in enumerate(all_chunks):
chunk_data_for_store.append({
'id': chunk.id,
'text': chunk.text,
'source': chunk.source,
'metadata': chunk.metadata
})
# 模拟一个1536维的随机向量(单位向量以进行有意义的余弦相似度)
random_vec = np.random.randn(1536)
unit_vec = random_vec / np.linalg.norm(random_vec)
simulated_embeddings.append(unit_vec)
vector_store.add_chunks(chunk_data_for_store, np.array(simulated_embeddings))
# 4. 初始化检索器和生成器
retriever = SemanticRetriever(vector_store=vector_store, validator=validator)
# 初始化生成器(需要OPENAI_API_KEY)
generator = OpenAIGenerator(model="gpt-3.5-turbo")
# 5. 执行示例查询
logger.info("Step 4: Performing a sample query...")
sample_queries = [
"What is RAG and why is data quality important?",
"Tell me about the key components of a RAG system.",
]
for query in sample_queries:
logger.info(f"\n--- Query: '{query}' ---")
# 5a. 检索
retrieval_result = retriever.retrieve(query, top_k=3)
# retrieval_result 已由 retriever 内部验证
if retrieval_result and len(retrieval_result.retrieved_chunks) > 0:
# 5b. 生成答案
answer = generator.generate(query, retrieval_result)
logger.info(f"Answer: {answer}\n")
else:
logger.warning("No valid retrieval results for generation.")
logger.success("Demo completed.")
# 可选:保存向量存储状态
# vector_store.save("vector_store_backup.json")
if __name__ == "__main__":
main()
4. 安装依赖与运行步骤
步骤1:环境准备与依赖安装
确保你的Python版本在3.9以上。建议使用虚拟环境。
# 1. 克隆或创建项目目录
mkdir rag_contract_system && cd rag_contract_system
# 2. 创建虚拟环境(可选但推荐)
python -m venv venv
# 激活虚拟环境
# Windows: venv\Scripts\activate
# Linux/Mac: source venv/bin/activate
# 3. 将前面 `pyproject.toml` 文件的内容复制到项目根目录的 `pyproject.toml` 文件中。
# 4. 安装项目依赖
pip install .
# 5. 安装开发依赖(可选,用于运行测试)
pip install .[dev]
# 6. 设置环境变量
# 复制示例环境变量文件并填入你的OpenAI API密钥
cp .env.example .env
# 编辑 .env 文件,设置 `OPENAI_API_KEY=你的密钥`
步骤2:准备示例数据与运行
项目中的main.py会自动在data/sample_docs/目录下创建一个示例文本文件。你也可以将自己的PDF或HTML文档放入该目录,并修改main.py中的文件路径。
运行主程序:
python main.py
程序将执行以下流水线,并在控制台和logs/目录下的日志文件中输出详细过程:
- 加载并验证示例文档。
- 将文档分割成块,并对每块进行契约验证。
- 为文档块生成模拟的向量嵌入,并存入内存向量库。
- 对两个示例查询进行语义检索,并验证检索结果的质量契约。
- (如果配置了有效的
OPENAI_API_KEY)调用OpenAI API生成最终答案。
5. 测试与验证步骤
我们为契约验证和检索器等核心功能编写了简单的单元测试。
文件路径:tests/test_contracts.py
import pytest
from datetime import datetime, timedelta
from core.contracts.document_v1 import DocumentChunkV1
from core.contracts.retrieval_v1 import RetrievedChunkV1, RetrievalResultV1
class TestDocumentChunkV1:
def test_valid_chunk(self):
"""测试一个完全合规的文档块。"""
chunk = DocumentChunkV1(
id="test_123",
text="This is a valid chunk with more than ten words to satisfy the contract rule.",
source="/path/to/doc.pdf",
chunk_index=0,
metadata={"source_url": "http://example.com", "last_updated": "2024-01-01"}
)
result = chunk.validate_self()
assert result.is_compliant() is True
assert "word_count" in result.metrics
def test_chunk_too_short(self):
"""测试文本过短的违规情况。"""
chunk = DocumentChunkV1(
id="test_short",
text="Short.", # 少于10词
source="/path/to/doc.pdf",
chunk_index=0,
metadata={"source_url": "http://example.com"}
)
result = chunk.validate_self()
assert result.is_compliant() is False
assert "too short" in result.violations[0].lower()
def test_chunk_missing_last_updated(self):
"""测试缺少最新更新时间(应产生警告)。"""
chunk = DocumentChunkV1(
id="test_no_date",
text="Valid text here with sufficient length for the contract.",
source="/path/to/doc.pdf",
chunk_index=0,
metadata={"source_url": "http://example.com"} # 无 last_updated
)
result = chunk.validate_self()
assert result.compliance == "WARNING"
assert "lacks 'last_updated'" in result.violations[0]
class TestRetrievalResultV1:
def test_valid_retrieval(self):
"""测试一个高质量的检索结果。"""
chunks = [
RetrievedChunkV1(
chunk_id="c1",
text="Relevant text about RAG.",
source="doc1",
similarity_score=0.85
),
RetrievedChunkV1(
chunk_id="c2",
text="More details on contracts.",
source="doc2",
similarity_score=0.78
),
]
result = RetrievalResultV1(
query="What is RAG?",
retrieved_chunks=chunks,
top_k=5
)
validation = result.validate_self()
assert validation.is_compliant() is True
assert validation.metrics['avg_similarity'].value > 0.6
def test_low_similarity_violation(self):
"""测试平均相似度过低的违规情况。"""
chunks = [
RetrievedChunkV1(chunk_id="c1", text="text", source="doc1", similarity_score=0.3),
RetrievedChunkV1(chunk_id="c2", text="text", source="doc2", similarity_score=0.4),
]
result = RetrievalResultV1(query="test", retrieved_chunks=chunks)
validation = result.validate_self()
assert validation.is_compliant() is False
assert "below threshold" in validation.violations[0]
运行测试:
# 在项目根目录执行
pytest tests/ -v
这将运行所有测试,并验证契约定义的核心逻辑是否按预期工作。
6. 核心流程与契约检查序列
以下序列图详细展示了用户查询触发后,系统内部如何通过契约检查点来保障数据质量。
图2:查询-检索-生成流程中的契约检查序列。验证器在检索后介入,其结果决定流程是否继续至生成阶段。
7. 扩展说明与最佳实践
-
生产环境向量数据库:将
SimpleVectorStore替换为如ChromaDB、Weaviate或Pinecone。这些数据库提供持久化、高性能检索和内置的元数据过滤,与本项目的契约理念(如新鲜度过滤)完美契合。 -
契约演进策略:
- 版本化:每次契约变更(如添加新字段、修改规则)都应提升版本号(遵循语义化版本)。
- 双版本支持:在过渡期,系统应能同时处理新旧契约版本的数据。
ContractVersionManager和迁移函数是此功能的核心。 - 监控与告警:当某旧版本契约的使用频率降至阈值以下,或达到预定的弃用日期后,可发出正式弃用告警,最终停止支持。
-
可观测性深化:
- 将
LoggingMonitor升级为集成Prometheus、Grafana或分布式追踪系统(如OpenTelemetry)。 - 不仅记录验证结果,还可将质量指标(如平均相似度得分、数据新鲜度分布)作为时间序列数据上报,用于绘制仪表盘和设置告警规则(如"过去5分钟平均相关性得分低于0.6")。
- 将
-
安全与合规:可在契约中添加数据安全规格,例如检查文档块是否包含个人身份信息(PII),或确保检索结果来自经过授权的数据源。
本项目提供了一个坚实的起点,通过将"契约"作为一等公民,使RAG系统的数据质量保障从一种事后补救的思维,转变为一种可设计、可验证、可演进的核心架构属性。