补充缺失的库
This commit is contained in:
parent
c18f4306be
commit
c8948e0a44
373
CoreRAG/custom_rag_processor.py
Normal file
373
CoreRAG/custom_rag_processor.py
Normal file
@ -0,0 +1,373 @@
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.utils import logger, set_verbose_debug
|
||||
import tiktoken
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
# API配置
|
||||
# DEEPSEEK_API_BASE = "https://api.deepseek.com/chat/completions"
|
||||
DEEPSEEK_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
|
||||
QWEN_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings"
|
||||
QWEN_EMBEDDING_MODEL = "text-embedding-v4"
|
||||
# DEEPSEEK_CHAT_MODEL = "deepseek-reasoner"
|
||||
DEEPSEEK_CHAT_MODEL = "deepseek-r1"
|
||||
EMBEDDING_BATCH_SIZE = 10
|
||||
MAX_RETRIES = 4
|
||||
|
||||
class QwenEmbedding:
|
||||
def __init__(self):
|
||||
self.api_key ='sk-36930e681f094274964ffe6c51d62078'
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def embed(self, texts: list[str]) -> np.ndarray:
|
||||
"""自动分批请求嵌入"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Session not initialized. Use async with.")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
BATCH_SIZE = 10
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch = texts[i:i+BATCH_SIZE]
|
||||
payload = {
|
||||
"model": QWEN_EMBEDDING_MODEL,
|
||||
"input": batch
|
||||
}
|
||||
|
||||
async with self.session.post(QWEN_API_BASE, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error = await response.text()
|
||||
raise RuntimeError(f"DeepSeek API error: {error}")
|
||||
data = await response.json()
|
||||
batch_embeddings = [item["embedding"] for item in data["data"]]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return np.array(all_embeddings)
|
||||
|
||||
class DeepSeekCompletion:
|
||||
"""DeepSeek大模型接口"""
|
||||
def __init__(self):
|
||||
# self.api_key ='sk-5a2809dc6cb545618f5c565ca546597e'
|
||||
self.api_key ='sk-36930e681f094274964ffe6c51d62078'
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def complete(self, prompt: str, **kwargs) -> str:
|
||||
"""获取模型补全结果"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Session not initialized. Use async with.")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": DEEPSEEK_CHAT_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 2000)
|
||||
}
|
||||
|
||||
async with self.session.post(
|
||||
f"{DEEPSEEK_API_BASE}",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error = await response.text()
|
||||
raise RuntimeError(f"DeepSeek API error: {error}")
|
||||
|
||||
data = await response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
class TextChunker:
|
||||
"""文本分块处理器"""
|
||||
def __init__(self, max_tokens: int = 200, overlap: int = 50):
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
self.max_tokens = max_tokens
|
||||
self.overlap = overlap
|
||||
|
||||
def chunk_text(self, text: str) -> list[str]:
|
||||
"""将长文本按token分块"""
|
||||
tokens = self.tokenizer.encode(text)
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(tokens):
|
||||
end = min(start + self.max_tokens, len(tokens))
|
||||
chunk = self.tokenizer.decode(tokens[start:end])
|
||||
chunks.append(chunk)
|
||||
start += self.max_tokens - self.overlap
|
||||
return chunks
|
||||
|
||||
def configure_logging():
|
||||
"""Configure logging for the application"""
|
||||
|
||||
# Reset any existing handlers to ensure clean configuration
|
||||
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
|
||||
logger_instance = logging.getLogger(logger_name)
|
||||
logger_instance.handlers = []
|
||||
logger_instance.filters = []
|
||||
|
||||
# Get log directory path from environment variable or use current directory
|
||||
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
||||
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_demo.log"))
|
||||
|
||||
print(f"\nLightRAG demo log file: {log_file_path}\n")
|
||||
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
|
||||
|
||||
# Get log file max size and backup count from environment variables
|
||||
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
||||
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
||||
|
||||
logging.config.dictConfig(
|
||||
{
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(levelname)s: %(message)s",
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"formatter": "default",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
"file": {
|
||||
"formatter": "detailed",
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"filename": log_file_path,
|
||||
"maxBytes": log_max_bytes,
|
||||
"backupCount": log_backup_count,
|
||||
"encoding": "utf-8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"lightrag": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Set the logger level to INFO
|
||||
logger.setLevel(logging.INFO)
|
||||
# Enable verbose debug if needed
|
||||
set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
|
||||
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def Qwen_embed(texts: list[str]) -> np.ndarray:
|
||||
async with QwenEmbedding() as embedder:
|
||||
return await embedder.embed(texts)
|
||||
Qwen_embed.embedding_dim = 1024
|
||||
|
||||
async def deepseek_complete(prompt: str, **kwargs) -> str:
|
||||
for _ in range(3):
|
||||
try:
|
||||
async with DeepSeekCompletion() as completer:
|
||||
return await completer.complete(prompt, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"[Retry] DeepSeek Error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
raise RuntimeError("DeepSeek failed after 3 retries.")
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
embedding_func=Qwen_embed,
|
||||
llm_model_func=deepseek_complete,
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
async def process_document():
|
||||
# 清理旧数据
|
||||
files_to_delete = [
|
||||
"graph_chunk_entity_relation.graphml",
|
||||
"kv_store_doc_status.json",
|
||||
"kv_store_full_docs.json",
|
||||
"kv_store_text_chunks.json",
|
||||
"vdb_chunks.json",
|
||||
"vdb_entities.json",
|
||||
"vdb_relationships.json",
|
||||
]
|
||||
for file in files_to_delete:
|
||||
file_path = os.path.join(WORKING_DIR, file)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
print(f"Deleting old file:: {file_path}")
|
||||
|
||||
# 初始化 RAG
|
||||
rag = await initialize_rag()
|
||||
|
||||
# 嵌入测试
|
||||
test_text = ["This is a test string for embedding."]
|
||||
embedding = await rag.embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
print("\n=======================")
|
||||
print("Test embedding function")
|
||||
print("========================")
|
||||
print(f"Test dict: {test_text}")
|
||||
print(f"Detected embedding dimension: {embedding_dim}\n\n")
|
||||
|
||||
# 读取书籍文本
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 跳过 Project Gutenberg 的版权头
|
||||
start_marker = "A Christmas Carol"
|
||||
start_index = content.find(start_marker)
|
||||
if start_index != -1:
|
||||
content = content[start_index:]
|
||||
|
||||
# 分块处理:chunk size = 500 tokens, overlap = 50
|
||||
chunks = split_text_into_chunks(content, max_tokens=500, overlap=50)
|
||||
print(f"Total chunks to insert: {len(chunks)}")
|
||||
|
||||
# 每批最多嵌入 10 条,分批调用
|
||||
BATCH_SIZE = 10
|
||||
for i in range(0, len(chunks), BATCH_SIZE):
|
||||
batch = chunks[i:i+BATCH_SIZE]
|
||||
await rag.ainsert(batch)
|
||||
print(f">> Inserted chunk batch {i // BATCH_SIZE + 1}")
|
||||
return rag
|
||||
|
||||
class CustomRAGProcessor:
|
||||
"""自定义RAG文档处理器"""
|
||||
def __init__(
|
||||
self,
|
||||
working_dir: str = WORKING_DIR,
|
||||
qwen_api_key: str = "sk-c1f2de78c13a455b806cf32648e36e25",
|
||||
deepseek_api_key: str = "sk-c1f2de78c13a455b806cf32648e36e25"
|
||||
):
|
||||
self.working_dir = working_dir
|
||||
self.qwen_api_key = qwen_api_key or os.getenv("QWEN_API_KEY")
|
||||
self.deepseek_api_key = deepseek_api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||
self.chunker = TextChunker()
|
||||
|
||||
# 确保工作目录存在
|
||||
os.makedirs(self.working_dir, exist_ok=True)
|
||||
|
||||
async def initialize_rag(self) -> LightRAG:
|
||||
"""初始化RAG实例"""
|
||||
rag = LightRAG(
|
||||
working_dir=self.working_dir,
|
||||
embedding_func=self.qwen_embed_wrapper,
|
||||
llm_model_func=self.deepseek_complete_wrapper,
|
||||
llm_model_name=DEEPSEEK_CHAT_MODEL,
|
||||
chunk_token_size=200,
|
||||
chunk_overlap_token_size=50
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
return rag
|
||||
|
||||
async def qwen_embed_wrapper(self, texts: list[str]) -> np.ndarray:
|
||||
"""Qwen嵌入函数包装器,带重试机制"""
|
||||
async with QwenEmbedding(self.qwen_api_key) as embedder:
|
||||
return await embedder.embed(texts)
|
||||
|
||||
async def deepseek_complete_wrapper(self, prompt: str, **kwargs) -> str:
|
||||
"""DeepSeek完成函数包装器,带重试机制"""
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
async with DeepSeekCompletion(self.deepseek_api_key) as completer:
|
||||
return await completer.complete(prompt, **kwargs)
|
||||
except Exception as e:
|
||||
if attempt == MAX_RETRIES-1:
|
||||
raise
|
||||
logging.warning(f"Attempt {attempt + 1} failed: {str(e)}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_text(self, text: str, rag: LightRAG) -> None:
|
||||
"""处理文本内容并存入RAG"""
|
||||
chunks = self.chunker.chunk_text(text)
|
||||
|
||||
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
|
||||
batch = chunks[i:i+EMBEDDING_BATCH_SIZE]
|
||||
await rag.ainsert(batch)
|
||||
logging.info(f"Inserted batch {i//EMBEDDING_BATCH_SIZE + 1}/{len(chunks)//EMBEDDING_BATCH_SIZE + 1}")
|
||||
|
||||
async def process_file(self, file_path: str, cleanup: bool = True) -> LightRAG:
|
||||
"""
|
||||
处理单个文件
|
||||
:param file_path: 文件路径
|
||||
:param cleanup: 是否清理旧数据
|
||||
:return: 初始化好的RAG实例
|
||||
"""
|
||||
if cleanup:
|
||||
await self.cleanup_previous_data()
|
||||
|
||||
rag = await self.initialize_rag()
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 跳过Project Gutenberg版权信息(如需要)
|
||||
if "Project Gutenberg" in content:
|
||||
start_marker = "A Christmas Carol" # 或其他起始标记
|
||||
start_index = content.find(start_marker)
|
||||
if start_index != -1:
|
||||
content = content[start_index:]
|
||||
|
||||
await self.process_text(content, rag)
|
||||
return rag
|
||||
|
||||
async def cleanup_previous_data(self) -> None:
|
||||
"""清理之前的处理数据"""
|
||||
files_to_delete = [
|
||||
"graph_chunk_entity_relation.graphml",
|
||||
"kv_store_doc_status.json",
|
||||
"kv_store_full_docs.json",
|
||||
"kv_store_text_chunks.json",
|
||||
"vdb_chunks.json",
|
||||
"vdb_entities.json",
|
||||
"vdb_relationships.json",
|
||||
]
|
||||
|
||||
for filename in files_to_delete:
|
||||
file_path = os.path.join(self.working_dir, filename)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logging.info(f"Deleted previous data file: {file_path}")
|
Loading…
x
Reference in New Issue
Block a user