From c8948e0a440d829bf9e9254e49643ec8a5d6687d Mon Sep 17 00:00:00 2001 From: "ziwei.he" Date: Mon, 30 Jun 2025 10:59:38 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85=E7=BC=BA=E5=A4=B1=E7=9A=84?= =?UTF-8?q?=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CoreRAG/custom_rag_processor.py | 373 ++++++++++++++++++++++++++++++++ 1 file changed, 373 insertions(+) create mode 100644 CoreRAG/custom_rag_processor.py diff --git a/CoreRAG/custom_rag_processor.py b/CoreRAG/custom_rag_processor.py new file mode 100644 index 0000000..563f970 --- /dev/null +++ b/CoreRAG/custom_rag_processor.py @@ -0,0 +1,373 @@ +import os +import asyncio +import logging +import logging.config +import aiohttp +import numpy as np +from lightrag import LightRAG, QueryParam +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import logger, set_verbose_debug +import tiktoken +WORKING_DIR = "./dickens" + +# API配置 +# DEEPSEEK_API_BASE = "https://api.deepseek.com/chat/completions" +DEEPSEEK_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" +QWEN_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings" +QWEN_EMBEDDING_MODEL = "text-embedding-v4" +# DEEPSEEK_CHAT_MODEL = "deepseek-reasoner" +DEEPSEEK_CHAT_MODEL = "deepseek-r1" +EMBEDDING_BATCH_SIZE = 10 +MAX_RETRIES = 4 + +class QwenEmbedding: + def __init__(self): + self.api_key ='sk-36930e681f094274964ffe6c51d62078' + self.session = None + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc, tb): + if self.session: + await self.session.close() + + async def embed(self, texts: list[str]) -> np.ndarray: + """自动分批请求嵌入""" + if not self.session: + raise RuntimeError("Session not initialized. Use async with.") + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + all_embeddings = [] + + BATCH_SIZE = 10 + for i in range(0, len(texts), BATCH_SIZE): + batch = texts[i:i+BATCH_SIZE] + payload = { + "model": QWEN_EMBEDDING_MODEL, + "input": batch + } + + async with self.session.post(QWEN_API_BASE, headers=headers, json=payload) as response: + if response.status != 200: + error = await response.text() + raise RuntimeError(f"DeepSeek API error: {error}") + data = await response.json() + batch_embeddings = [item["embedding"] for item in data["data"]] + all_embeddings.extend(batch_embeddings) + + return np.array(all_embeddings) + +class DeepSeekCompletion: + """DeepSeek大模型接口""" + def __init__(self): + # self.api_key ='sk-5a2809dc6cb545618f5c565ca546597e' + self.api_key ='sk-36930e681f094274964ffe6c51d62078' + self.session = None + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc, tb): + if self.session: + await self.session.close() + + async def complete(self, prompt: str, **kwargs) -> str: + """获取模型补全结果""" + if not self.session: + raise RuntimeError("Session not initialized. Use async with.") + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": DEEPSEEK_CHAT_MODEL, + "messages": [{"role": "user", "content": prompt}], + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 2000) + } + + async with self.session.post( + f"{DEEPSEEK_API_BASE}", + headers=headers, + json=payload + ) as response: + if response.status != 200: + error = await response.text() + raise RuntimeError(f"DeepSeek API error: {error}") + + data = await response.json() + return data["choices"][0]["message"]["content"] + +class TextChunker: + """文本分块处理器""" + def __init__(self, max_tokens: int = 200, overlap: int = 50): + self.tokenizer = tiktoken.get_encoding("cl100k_base") + self.max_tokens = max_tokens + self.overlap = overlap + + def chunk_text(self, text: str) -> list[str]: + """将长文本按token分块""" + tokens = self.tokenizer.encode(text) + chunks = [] + start = 0 + while start < len(tokens): + end = min(start + self.max_tokens, len(tokens)) + chunk = self.tokenizer.decode(tokens[start:end]) + chunks.append(chunk) + start += self.max_tokens - self.overlap + return chunks + +def configure_logging(): + """Configure logging for the application""" + + # Reset any existing handlers to ensure clean configuration + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + + # Get log directory path from environment variable or use current directory + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_demo.log")) + + print(f"\nLightRAG demo log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + + # Set the logger level to INFO + logger.setLevel(logging.INFO) + # Enable verbose debug if needed + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true") + + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def Qwen_embed(texts: list[str]) -> np.ndarray: + async with QwenEmbedding() as embedder: + return await embedder.embed(texts) +Qwen_embed.embedding_dim = 1024 + +async def deepseek_complete(prompt: str, **kwargs) -> str: + for _ in range(3): + try: + async with DeepSeekCompletion() as completer: + return await completer.complete(prompt, **kwargs) + except Exception as e: + print(f"[Retry] DeepSeek Error: {e}") + await asyncio.sleep(1) + raise RuntimeError("DeepSeek failed after 3 retries.") + +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=Qwen_embed, + llm_model_func=deepseek_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + +async def process_document(): + # 清理旧数据 + files_to_delete = [ + "graph_chunk_entity_relation.graphml", + "kv_store_doc_status.json", + "kv_store_full_docs.json", + "kv_store_text_chunks.json", + "vdb_chunks.json", + "vdb_entities.json", + "vdb_relationships.json", + ] + for file in files_to_delete: + file_path = os.path.join(WORKING_DIR, file) + if os.path.exists(file_path): + os.remove(file_path) + print(f"Deleting old file:: {file_path}") + + # 初始化 RAG + rag = await initialize_rag() + + # 嵌入测试 + test_text = ["This is a test string for embedding."] + embedding = await rag.embedding_func(test_text) + embedding_dim = embedding.shape[1] + print("\n=======================") + print("Test embedding function") + print("========================") + print(f"Test dict: {test_text}") + print(f"Detected embedding dimension: {embedding_dim}\n\n") + + # 读取书籍文本 + with open("./book.txt", "r", encoding="utf-8") as f: + content = f.read() + + # 跳过 Project Gutenberg 的版权头 + start_marker = "A Christmas Carol" + start_index = content.find(start_marker) + if start_index != -1: + content = content[start_index:] + + # 分块处理:chunk size = 500 tokens, overlap = 50 + chunks = split_text_into_chunks(content, max_tokens=500, overlap=50) + print(f"Total chunks to insert: {len(chunks)}") + + # 每批最多嵌入 10 条,分批调用 + BATCH_SIZE = 10 + for i in range(0, len(chunks), BATCH_SIZE): + batch = chunks[i:i+BATCH_SIZE] + await rag.ainsert(batch) + print(f">> Inserted chunk batch {i // BATCH_SIZE + 1}") + return rag + +class CustomRAGProcessor: + """自定义RAG文档处理器""" + def __init__( + self, + working_dir: str = WORKING_DIR, + qwen_api_key: str = "sk-c1f2de78c13a455b806cf32648e36e25", + deepseek_api_key: str = "sk-c1f2de78c13a455b806cf32648e36e25" + ): + self.working_dir = working_dir + self.qwen_api_key = qwen_api_key or os.getenv("QWEN_API_KEY") + self.deepseek_api_key = deepseek_api_key or os.getenv("DEEPSEEK_API_KEY") + self.chunker = TextChunker() + + # 确保工作目录存在 + os.makedirs(self.working_dir, exist_ok=True) + + async def initialize_rag(self) -> LightRAG: + """初始化RAG实例""" + rag = LightRAG( + working_dir=self.working_dir, + embedding_func=self.qwen_embed_wrapper, + llm_model_func=self.deepseek_complete_wrapper, + llm_model_name=DEEPSEEK_CHAT_MODEL, + chunk_token_size=200, + chunk_overlap_token_size=50 + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + async def qwen_embed_wrapper(self, texts: list[str]) -> np.ndarray: + """Qwen嵌入函数包装器,带重试机制""" + async with QwenEmbedding(self.qwen_api_key) as embedder: + return await embedder.embed(texts) + + async def deepseek_complete_wrapper(self, prompt: str, **kwargs) -> str: + """DeepSeek完成函数包装器,带重试机制""" + for attempt in range(MAX_RETRIES): + try: + async with DeepSeekCompletion(self.deepseek_api_key) as completer: + return await completer.complete(prompt, **kwargs) + except Exception as e: + if attempt == MAX_RETRIES-1: + raise + logging.warning(f"Attempt {attempt + 1} failed: {str(e)}") + await asyncio.sleep(1) + + async def process_text(self, text: str, rag: LightRAG) -> None: + """处理文本内容并存入RAG""" + chunks = self.chunker.chunk_text(text) + + for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE): + batch = chunks[i:i+EMBEDDING_BATCH_SIZE] + await rag.ainsert(batch) + logging.info(f"Inserted batch {i//EMBEDDING_BATCH_SIZE + 1}/{len(chunks)//EMBEDDING_BATCH_SIZE + 1}") + + async def process_file(self, file_path: str, cleanup: bool = True) -> LightRAG: + """ + 处理单个文件 + :param file_path: 文件路径 + :param cleanup: 是否清理旧数据 + :return: 初始化好的RAG实例 + """ + if cleanup: + await self.cleanup_previous_data() + + rag = await self.initialize_rag() + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # 跳过Project Gutenberg版权信息(如需要) + if "Project Gutenberg" in content: + start_marker = "A Christmas Carol" # 或其他起始标记 + start_index = content.find(start_marker) + if start_index != -1: + content = content[start_index:] + + await self.process_text(content, rag) + return rag + + async def cleanup_previous_data(self) -> None: + """清理之前的处理数据""" + files_to_delete = [ + "graph_chunk_entity_relation.graphml", + "kv_store_doc_status.json", + "kv_store_full_docs.json", + "kv_store_text_chunks.json", + "vdb_chunks.json", + "vdb_entities.json", + "vdb_relationships.json", + ] + + for filename in files_to_delete: + file_path = os.path.join(self.working_dir, filename) + if os.path.exists(file_path): + os.remove(file_path) + logging.info(f"Deleted previous data file: {file_path}") \ No newline at end of file