补充缺失的库

This commit is contained in:
ziwei.he 2025-06-30 10:59:38 +08:00
parent c18f4306be
commit c8948e0a44

View File

@ -0,0 +1,373 @@
import os
import asyncio
import logging
import logging.config
import aiohttp
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import logger, set_verbose_debug
import tiktoken
WORKING_DIR = "./dickens"
# API配置
# DEEPSEEK_API_BASE = "https://api.deepseek.com/chat/completions"
DEEPSEEK_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
QWEN_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings"
QWEN_EMBEDDING_MODEL = "text-embedding-v4"
# DEEPSEEK_CHAT_MODEL = "deepseek-reasoner"
DEEPSEEK_CHAT_MODEL = "deepseek-r1"
EMBEDDING_BATCH_SIZE = 10
MAX_RETRIES = 4
class QwenEmbedding:
def __init__(self):
self.api_key ='sk-36930e681f094274964ffe6c51d62078'
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc, tb):
if self.session:
await self.session.close()
async def embed(self, texts: list[str]) -> np.ndarray:
"""自动分批请求嵌入"""
if not self.session:
raise RuntimeError("Session not initialized. Use async with.")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
all_embeddings = []
BATCH_SIZE = 10
for i in range(0, len(texts), BATCH_SIZE):
batch = texts[i:i+BATCH_SIZE]
payload = {
"model": QWEN_EMBEDDING_MODEL,
"input": batch
}
async with self.session.post(QWEN_API_BASE, headers=headers, json=payload) as response:
if response.status != 200:
error = await response.text()
raise RuntimeError(f"DeepSeek API error: {error}")
data = await response.json()
batch_embeddings = [item["embedding"] for item in data["data"]]
all_embeddings.extend(batch_embeddings)
return np.array(all_embeddings)
class DeepSeekCompletion:
"""DeepSeek大模型接口"""
def __init__(self):
# self.api_key ='sk-5a2809dc6cb545618f5c565ca546597e'
self.api_key ='sk-36930e681f094274964ffe6c51d62078'
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc, tb):
if self.session:
await self.session.close()
async def complete(self, prompt: str, **kwargs) -> str:
"""获取模型补全结果"""
if not self.session:
raise RuntimeError("Session not initialized. Use async with.")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": DEEPSEEK_CHAT_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 2000)
}
async with self.session.post(
f"{DEEPSEEK_API_BASE}",
headers=headers,
json=payload
) as response:
if response.status != 200:
error = await response.text()
raise RuntimeError(f"DeepSeek API error: {error}")
data = await response.json()
return data["choices"][0]["message"]["content"]
class TextChunker:
"""文本分块处理器"""
def __init__(self, max_tokens: int = 200, overlap: int = 50):
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.max_tokens = max_tokens
self.overlap = overlap
def chunk_text(self, text: str) -> list[str]:
"""将长文本按token分块"""
tokens = self.tokenizer.encode(text)
chunks = []
start = 0
while start < len(tokens):
end = min(start + self.max_tokens, len(tokens))
chunk = self.tokenizer.decode(tokens[start:end])
chunks.append(chunk)
start += self.max_tokens - self.overlap
return chunks
def configure_logging():
"""Configure logging for the application"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger_instance = logging.getLogger(logger_name)
logger_instance.handlers = []
logger_instance.filters = []
# Get log directory path from environment variable or use current directory
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_demo.log"))
print(f"\nLightRAG demo log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
},
}
)
# Set the logger level to INFO
logger.setLevel(logging.INFO)
# Enable verbose debug if needed
set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def Qwen_embed(texts: list[str]) -> np.ndarray:
async with QwenEmbedding() as embedder:
return await embedder.embed(texts)
Qwen_embed.embedding_dim = 1024
async def deepseek_complete(prompt: str, **kwargs) -> str:
for _ in range(3):
try:
async with DeepSeekCompletion() as completer:
return await completer.complete(prompt, **kwargs)
except Exception as e:
print(f"[Retry] DeepSeek Error: {e}")
await asyncio.sleep(1)
raise RuntimeError("DeepSeek failed after 3 retries.")
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
embedding_func=Qwen_embed,
llm_model_func=deepseek_complete,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def process_document():
# 清理旧数据
files_to_delete = [
"graph_chunk_entity_relation.graphml",
"kv_store_doc_status.json",
"kv_store_full_docs.json",
"kv_store_text_chunks.json",
"vdb_chunks.json",
"vdb_entities.json",
"vdb_relationships.json",
]
for file in files_to_delete:
file_path = os.path.join(WORKING_DIR, file)
if os.path.exists(file_path):
os.remove(file_path)
print(f"Deleting old file:: {file_path}")
# 初始化 RAG
rag = await initialize_rag()
# 嵌入测试
test_text = ["This is a test string for embedding."]
embedding = await rag.embedding_func(test_text)
embedding_dim = embedding.shape[1]
print("\n=======================")
print("Test embedding function")
print("========================")
print(f"Test dict: {test_text}")
print(f"Detected embedding dimension: {embedding_dim}\n\n")
# 读取书籍文本
with open("./book.txt", "r", encoding="utf-8") as f:
content = f.read()
# 跳过 Project Gutenberg 的版权头
start_marker = "A Christmas Carol"
start_index = content.find(start_marker)
if start_index != -1:
content = content[start_index:]
# 分块处理chunk size = 500 tokens, overlap = 50
chunks = split_text_into_chunks(content, max_tokens=500, overlap=50)
print(f"Total chunks to insert: {len(chunks)}")
# 每批最多嵌入 10 条,分批调用
BATCH_SIZE = 10
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i+BATCH_SIZE]
await rag.ainsert(batch)
print(f">> Inserted chunk batch {i // BATCH_SIZE + 1}")
return rag
class CustomRAGProcessor:
"""自定义RAG文档处理器"""
def __init__(
self,
working_dir: str = WORKING_DIR,
qwen_api_key: str = "sk-c1f2de78c13a455b806cf32648e36e25",
deepseek_api_key: str = "sk-c1f2de78c13a455b806cf32648e36e25"
):
self.working_dir = working_dir
self.qwen_api_key = qwen_api_key or os.getenv("QWEN_API_KEY")
self.deepseek_api_key = deepseek_api_key or os.getenv("DEEPSEEK_API_KEY")
self.chunker = TextChunker()
# 确保工作目录存在
os.makedirs(self.working_dir, exist_ok=True)
async def initialize_rag(self) -> LightRAG:
"""初始化RAG实例"""
rag = LightRAG(
working_dir=self.working_dir,
embedding_func=self.qwen_embed_wrapper,
llm_model_func=self.deepseek_complete_wrapper,
llm_model_name=DEEPSEEK_CHAT_MODEL,
chunk_token_size=200,
chunk_overlap_token_size=50
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def qwen_embed_wrapper(self, texts: list[str]) -> np.ndarray:
"""Qwen嵌入函数包装器带重试机制"""
async with QwenEmbedding(self.qwen_api_key) as embedder:
return await embedder.embed(texts)
async def deepseek_complete_wrapper(self, prompt: str, **kwargs) -> str:
"""DeepSeek完成函数包装器带重试机制"""
for attempt in range(MAX_RETRIES):
try:
async with DeepSeekCompletion(self.deepseek_api_key) as completer:
return await completer.complete(prompt, **kwargs)
except Exception as e:
if attempt == MAX_RETRIES-1:
raise
logging.warning(f"Attempt {attempt + 1} failed: {str(e)}")
await asyncio.sleep(1)
async def process_text(self, text: str, rag: LightRAG) -> None:
"""处理文本内容并存入RAG"""
chunks = self.chunker.chunk_text(text)
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
batch = chunks[i:i+EMBEDDING_BATCH_SIZE]
await rag.ainsert(batch)
logging.info(f"Inserted batch {i//EMBEDDING_BATCH_SIZE + 1}/{len(chunks)//EMBEDDING_BATCH_SIZE + 1}")
async def process_file(self, file_path: str, cleanup: bool = True) -> LightRAG:
"""
处理单个文件
:param file_path: 文件路径
:param cleanup: 是否清理旧数据
:return: 初始化好的RAG实例
"""
if cleanup:
await self.cleanup_previous_data()
rag = await self.initialize_rag()
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# 跳过Project Gutenberg版权信息如需要
if "Project Gutenberg" in content:
start_marker = "A Christmas Carol" # 或其他起始标记
start_index = content.find(start_marker)
if start_index != -1:
content = content[start_index:]
await self.process_text(content, rag)
return rag
async def cleanup_previous_data(self) -> None:
"""清理之前的处理数据"""
files_to_delete = [
"graph_chunk_entity_relation.graphml",
"kv_store_doc_status.json",
"kv_store_full_docs.json",
"kv_store_text_chunks.json",
"vdb_chunks.json",
"vdb_entities.json",
"vdb_relationships.json",
]
for filename in files_to_delete:
file_path = os.path.join(self.working_dir, filename)
if os.path.exists(file_path):
os.remove(file_path)
logging.info(f"Deleted previous data file: {file_path}")