环境问题修复中

This commit is contained in:
Ziwei.He 2025-06-29 23:59:06 -07:00
parent c8948e0a44
commit 00bbd902b8
33 changed files with 116 additions and 110 deletions

View File

@ -1,10 +1,10 @@
import asyncio import asyncio
import numpy as np import numpy as np
import re import re
from lightrag import LightRAG from CoreRAG.lightrag.lightrag import LightRAG
from lightrag.base import QueryParam from CoreRAG.lightrag.base import QueryParam
from custom_rag_processor import QwenEmbedding, DeepSeekCompletion from CoreRAG.custom_rag_processor import QwenEmbedding, DeepSeekCompletion
from typing import List,Tuple
class MassageAcupointRAG: class MassageAcupointRAG:
def __init__(self, working_dir: str): def __init__(self, working_dir: str):
self.working_dir = working_dir self.working_dir = working_dir
@ -16,7 +16,7 @@ class MassageAcupointRAG:
) )
@staticmethod @staticmethod
async def async_embed(texts: list[str]) -> np.ndarray: async def async_embed(texts: List[str]) -> np.ndarray:
async with QwenEmbedding() as embedder: async with QwenEmbedding() as embedder:
return await embedder.embed(texts) return await embedder.embed(texts)
@ -26,7 +26,7 @@ class MassageAcupointRAG:
return await completer.complete(prompt, **kwargs) return await completer.complete(prompt, **kwargs)
@staticmethod @staticmethod
def extract_acupoint_list(text: str) -> list[str]: def extract_acupoint_list(text: str) -> List[str]:
""" """
LLM 返回文本中提取穴位名称列表形如 ['肩井穴', ...] LLM 返回文本中提取穴位名称列表形如 ['肩井穴', ...]
""" """
@ -39,7 +39,7 @@ class MassageAcupointRAG:
async def initialize(self): async def initialize(self):
await self.rag.initialize_storages() await self.rag.initialize_storages()
async def query_acupoints(self, user_query: str) -> tuple[str, list[str]]: async def query_acupoints(self, user_query: str) -> Tuple[str, List[str]]:
""" """
执行 RAG 查询并提取穴位列表 执行 RAG 查询并提取穴位列表

Binary file not shown.

View File

@ -4,10 +4,11 @@ import logging
import logging.config import logging.config
import aiohttp import aiohttp
import numpy as np import numpy as np
from lightrag import LightRAG, QueryParam from CoreRAG.lightrag import LightRAG, QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status from CoreRAG.lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import logger, set_verbose_debug from CoreRAG.lightrag.utils import logger, set_verbose_debug
import tiktoken import tiktoken
from typing import List
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
# API配置 # API配置
@ -33,7 +34,7 @@ class QwenEmbedding:
if self.session: if self.session:
await self.session.close() await self.session.close()
async def embed(self, texts: list[str]) -> np.ndarray: async def embed(self, texts: List[str]) -> np.ndarray:
"""自动分批请求嵌入""" """自动分批请求嵌入"""
if not self.session: if not self.session:
raise RuntimeError("Session not initialized. Use async with.") raise RuntimeError("Session not initialized. Use async with.")
@ -114,7 +115,7 @@ class TextChunker:
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.overlap = overlap self.overlap = overlap
def chunk_text(self, text: str) -> list[str]: def chunk_text(self, text: str) -> List[str]:
"""将长文本按token分块""" """将长文本按token分块"""
tokens = self.tokenizer.encode(text) tokens = self.tokenizer.encode(text)
chunks = [] chunks = []
@ -193,7 +194,7 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
async def Qwen_embed(texts: list[str]) -> np.ndarray: async def Qwen_embed(texts: List[str]) -> np.ndarray:
async with QwenEmbedding() as embedder: async with QwenEmbedding() as embedder:
return await embedder.embed(texts) return await embedder.embed(texts)
Qwen_embed.embedding_dim = 1024 Qwen_embed.embedding_dim = 1024
@ -303,7 +304,7 @@ class CustomRAGProcessor:
await initialize_pipeline_status() await initialize_pipeline_status()
return rag return rag
async def qwen_embed_wrapper(self, texts: list[str]) -> np.ndarray: async def qwen_embed_wrapper(self, texts: List[str]) -> np.ndarray:
"""Qwen嵌入函数包装器带重试机制""" """Qwen嵌入函数包装器带重试机制"""
async with QwenEmbedding(self.qwen_api_key) as embedder: async with QwenEmbedding(self.qwen_api_key) as embedder:
return await embedder.embed(texts) return await embedder.embed(texts)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Union, Callable, Dict, Tuple
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
import os import os
@ -70,13 +70,13 @@ class QueryParam:
max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000")) max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000"))
"""Maximum number of tokens allocated for entity descriptions in local retrieval.""" """Maximum number of tokens allocated for entity descriptions in local retrieval."""
hl_keywords: list[str] = field(default_factory=list) hl_keywords: List[str] = field(default_factory=list)
"""List of high-level keywords to prioritize in retrieval.""" """List of high-level keywords to prioritize in retrieval."""
ll_keywords: list[str] = field(default_factory=list) ll_keywords: List[str] = field(default_factory=list)
"""List of low-level keywords to refine retrieval focus.""" """List of low-level keywords to refine retrieval focus."""
conversation_history: list[dict[str, str]] = field(default_factory=list) conversation_history: List[Dict[str, str]] = field(default_factory=list)
"""Stores past conversation history to maintain context. """Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}]. Format: [{"role": "user/assistant", "content": "message"}].
""" """
@ -84,16 +84,16 @@ class QueryParam:
history_turns: int = 3 history_turns: int = 3
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
ids: list[str] | None = None ids: Union[List[str], None] = None
"""List of ids to filter the results.""" """List of ids to filter the results."""
model_func: Callable[..., object] | None = None model_func: Union[Callable[..., object], None] = None
"""Optional override for the LLM model function to use for this specific query. """Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function. If provided, this will be used instead of the global model function.
This allows using different models for different query modes. This allows using different models for different query modes.
""" """
user_prompt: str | None = None user_prompt: Union[str,None] = None
"""User-provided prompt for the query. """User-provided prompt for the query.
If proivded, this will be use instead of the default vaulue from prompt template. If proivded, this will be use instead of the default vaulue from prompt template.
""" """
@ -102,7 +102,7 @@ class QueryParam:
@dataclass @dataclass
class StorageNameSpace(ABC): class StorageNameSpace(ABC):
namespace: str namespace: str
global_config: dict[str, Any] global_config: Dict[str, Any]
async def initialize(self): async def initialize(self):
"""Initialize the storage""" """Initialize the storage"""
@ -117,7 +117,7 @@ class StorageNameSpace(ABC):
"""Commit the storage operations after indexing""" """Commit the storage operations after indexing"""
@abstractmethod @abstractmethod
async def drop(self) -> dict[str, str]: async def drop(self) -> Dict[str, str]:
"""Drop all data from storage and clean up resources """Drop all data from storage and clean up resources
This abstract method defines the contract for dropping all data from a storage implementation. This abstract method defines the contract for dropping all data from a storage implementation.
@ -151,12 +151,12 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def query( async def query(
self, query: str, top_k: int, ids: list[str] | None = None self, query: str, top_k: int, ids: Union[List[str],None] = None
) -> list[dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.""" """Query the vector storage and retrieve top_k results."""
@abstractmethod @abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: Dict[str, Dict[str, Any]]) -> None:
"""Insert or update vectors in the storage. """Insert or update vectors in the storage.
Importance notes for in-memory storage: Importance notes for in-memory storage:
@ -186,7 +186,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
""" """
@abstractmethod @abstractmethod
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> Union[Dict[str, Any],None]:
"""Get vector data by its ID """Get vector data by its ID
Args: Args:
@ -198,7 +198,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
pass pass
@abstractmethod @abstractmethod
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
"""Get multiple vector data by their IDs """Get multiple vector data by their IDs
Args: Args:
@ -210,7 +210,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
pass pass
@abstractmethod @abstractmethod
async def delete(self, ids: list[str]): async def delete(self, ids: List[str]):
"""Delete vectors with specified IDs """Delete vectors with specified IDs
Importance notes for in-memory storage: Importance notes for in-memory storage:
@ -228,11 +228,11 @@ class BaseKVStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
@abstractmethod @abstractmethod
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> Union[Dict[str, Any],None]:
"""Get value by id""" """Get value by id"""
@abstractmethod @abstractmethod
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
"""Get values by ids""" """Get values by ids"""
@abstractmethod @abstractmethod
@ -249,7 +249,7 @@ class BaseKVStorage(StorageNameSpace, ABC):
""" """
@abstractmethod @abstractmethod
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: List[str]) -> None:
"""Delete specific records from storage by their IDs """Delete specific records from storage by their IDs
Importance notes for in-memory storage: Importance notes for in-memory storage:
@ -263,7 +263,7 @@ class BaseKVStorage(StorageNameSpace, ABC):
None None
""" """
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: async def drop_cache_by_modes(self, modes: Union[List[str],None] = None) -> bool:
"""Delete specific records from storage by cache mode """Delete specific records from storage by cache mode
Importance notes for in-memory storage: Importance notes for in-memory storage:
@ -330,7 +330,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
""" """
@abstractmethod @abstractmethod
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> Union[Dict[str, str],None]:
"""Get node by its ID, returning only node properties. """Get node by its ID, returning only node properties.
Args: Args:
@ -343,7 +343,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> Union[Dict[str, str],None]:
"""Get edge properties between two nodes. """Get edge properties between two nodes.
Args: Args:
@ -355,7 +355,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
""" """
@abstractmethod @abstractmethod
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> Union[List[Tuple[str, str]],None]:
"""Get all edges connected to a node. """Get all edges connected to a node.
Args: Args:
@ -366,7 +366,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist or None if the node doesn't exist
""" """
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: List[str]) -> Dict[str, Dict]:
"""Get nodes as a batch using UNWIND """Get nodes as a batch using UNWIND
Default implementation fetches nodes one by one. Default implementation fetches nodes one by one.
@ -380,7 +380,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = node result[node_id] = node
return result return result
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: List[str]) -> Dict[str, int]:
"""Node degrees as a batch using UNWIND """Node degrees as a batch using UNWIND
Default implementation fetches node degrees one by one. Default implementation fetches node degrees one by one.
@ -394,8 +394,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result return result
async def edge_degrees_batch( async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]] self, edge_pairs: List[Tuple[str, str]]
) -> dict[tuple[str, str], int]: ) -> Dict[Tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch """Edge degrees as a batch using UNWIND also uses node_degrees_batch
Default implementation calculates edge degrees one by one. Default implementation calculates edge degrees one by one.
@ -409,8 +409,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result return result
async def get_edges_batch( async def get_edges_batch(
self, pairs: list[dict[str, str]] self, pairs: List[Dict[str, str]]
) -> dict[tuple[str, str], dict]: ) -> Dict[Tuple[str, str], Dict]:
"""Get edges as a batch using UNWIND """Get edges as a batch using UNWIND
Default implementation fetches edges one by one. Default implementation fetches edges one by one.
@ -427,8 +427,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result return result
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
self, node_ids: list[str] self, node_ids: List[str]
) -> dict[str, list[tuple[str, str]]]: ) -> Dict[str, List[Tuple[str, str]]]:
"""Get nodes edges as a batch using UNWIND """Get nodes edges as a batch using UNWIND
Default implementation fetches node edges one by one. Default implementation fetches node edges one by one.
@ -442,7 +442,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result return result
@abstractmethod @abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: Dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph. """Insert a new node or update an existing node in the graph.
Importance notes for in-memory storage: Importance notes for in-memory storage:
@ -457,7 +457,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: Dict[str, str]
) -> None: ) -> None:
"""Insert a new edge or update an existing edge in the graph. """Insert a new edge or update an existing edge in the graph.
@ -486,7 +486,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
""" """
@abstractmethod @abstractmethod
async def remove_nodes(self, nodes: list[str]): async def remove_nodes(self, nodes: List[str]):
"""Delete multiple nodes """Delete multiple nodes
Importance notes: Importance notes:
@ -499,7 +499,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
""" """
@abstractmethod @abstractmethod
async def remove_edges(self, edges: list[tuple[str, str]]): async def remove_edges(self, edges: List[Tuple[str, str]]):
"""Delete multiple edges """Delete multiple edges
Importance notes: Importance notes:
@ -512,7 +512,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
""" """
@abstractmethod @abstractmethod
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> List[str]:
"""Get all labels in the graph. """Get all labels in the graph.
Returns: Returns:
@ -564,11 +564,11 @@ class DocProcessingStatus:
"""ISO format timestamp when document was created""" """ISO format timestamp when document was created"""
updated_at: str updated_at: str
"""ISO format timestamp when document was last updated""" """ISO format timestamp when document was last updated"""
chunks_count: int | None = None chunks_count: Union[int,None] = None
"""Number of chunks after splitting, used for processing""" """Number of chunks after splitting, used for processing"""
error: str | None = None error: Union[str,None] = None
"""Error message if failed""" """Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=Dict)
"""Additional metadata""" """Additional metadata"""
@ -577,16 +577,16 @@ class DocStatusStorage(BaseKVStorage, ABC):
"""Base class for document status storage""" """Base class for document status storage"""
@abstractmethod @abstractmethod
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
@abstractmethod @abstractmethod
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> Dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: async def drop_cache_by_modes(self, modes: Union[List[str],None] = None) -> bool:
"""Drop cache is not supported for Doc Status storage""" """Drop cache is not supported for Doc Status storage"""
return False return False

View File

@ -1,3 +1,5 @@
from typing import List
from typing import Dict
STORAGE_IMPLEMENTATIONS = { STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": { "KV_STORAGE": {
"implementations": [ "implementations": [
@ -45,7 +47,7 @@ STORAGE_IMPLEMENTATIONS = {
} }
# Storage implementation environment variable without default value # Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { STORAGE_ENV_REQUIREMENTS: Dict[str, List[str]] = {
# KV Storage Implementations # KV Storage Implementations
"JsonKVStorage": [], "JsonKVStorage": [],
"MongoKVStorage": [], "MongoKVStorage": [],

View File

@ -1,11 +1,11 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final from typing import Any, final,Dict,List,Union,Set
from lightrag.base import ( from ..base import (
BaseKVStorage, BaseKVStorage,
) )
from lightrag.utils import ( from ..utils import (
load_json, load_json,
logger, logger,
write_json, write_json,
@ -85,7 +85,7 @@ class JsonKVStorage(BaseKVStorage):
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
await clear_all_update_flags(self.namespace) await clear_all_update_flags(self.namespace)
async def get_all(self) -> dict[str, Any]: async def get_all(self) -> Dict[str, Any]:
"""Get all data from storage """Get all data from storage
Returns: Returns:
@ -94,11 +94,11 @@ class JsonKVStorage(BaseKVStorage):
async with self._storage_lock: async with self._storage_lock:
return dict(self._data) return dict(self._data)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> Union[Dict[str, Any] , None]:
async with self._storage_lock: async with self._storage_lock:
return self._data.get(id) return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
async with self._storage_lock: async with self._storage_lock:
return [ return [
( (
@ -109,11 +109,11 @@ class JsonKVStorage(BaseKVStorage):
for id in ids for id in ids
] ]
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: Set[str]) -> Set[str]:
async with self._storage_lock: async with self._storage_lock:
return set(keys) - set(self._data.keys()) return Set(keys) - Set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: Dict[str, Dict[str, Any]]) -> None:
""" """
Importance notes for in-memory storage: Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback 1. Changes will be persisted to disk during the next index_done_callback
@ -126,7 +126,7 @@ class JsonKVStorage(BaseKVStorage):
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: List[str]) -> None:
"""Delete specific records from storage by their IDs """Delete specific records from storage by their IDs
Importance notes for in-memory storage: Importance notes for in-memory storage:
@ -149,7 +149,7 @@ class JsonKVStorage(BaseKVStorage):
if any_deleted: if any_deleted:
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: async def drop_cache_by_modes(self, modes: Union[List[str] , None] = None) -> bool:
"""Delete specific records from storage by by cache mode """Delete specific records from storage by by cache mode
Importance notes for in-memory storage: Importance notes for in-memory storage:
@ -172,7 +172,7 @@ class JsonKVStorage(BaseKVStorage):
except Exception: except Exception:
return False return False
async def drop(self) -> dict[str, str]: async def drop(self) -> Dict[str, str]:
"""Drop all data from storage and clean up resources """Drop all data from storage and clean up resources
This action will persistent the data to disk immediately. This action will persistent the data to disk immediately.

View File

@ -5,15 +5,15 @@ from dataclasses import dataclass
import numpy as np import numpy as np
import time import time
from lightrag.utils import ( from ..utils import (
logger, logger,
compute_mdhash_id, compute_mdhash_id,
) )
import pipmaster as pm import pipmaster as pm
from lightrag.base import BaseVectorStorage from ..base import BaseVectorStorage
if not pm.is_installed("nano-vectordb"): # if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb") # pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
from .shared_storage import ( from .shared_storage import (

View File

@ -1,10 +1,10 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import final from typing import final,Dict,List,Union,Tuple
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from ..utils import logger
from lightrag.base import BaseGraphStorage from ..base import BaseGraphStorage
import pipmaster as pm import pipmaster as pm
@ -98,7 +98,7 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph() graph = await self._get_graph()
return graph.has_edge(source_node_id, target_node_id) return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> Union[Dict[str, str],None]:
graph = await self._get_graph() graph = await self._get_graph()
return graph.nodes.get(node_id) return graph.nodes.get(node_id)
@ -112,17 +112,17 @@ class NetworkXStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> Union[Dict[str, str],None]:
graph = await self._get_graph() graph = await self._get_graph()
return graph.edges.get((source_node_id, target_node_id)) return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> Union[List[Tuple[str, str]],None]:
graph = await self._get_graph() graph = await self._get_graph()
if graph.has_node(source_node_id): if graph.has_node(source_node_id):
return list(graph.edges(source_node_id)) return list(graph.edges(source_node_id))
return None return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: Dict[str, str]) -> None:
""" """
Importance notes: Importance notes:
1. Changes will be persisted to disk during the next index_done_callback 1. Changes will be persisted to disk during the next index_done_callback
@ -133,7 +133,7 @@ class NetworkXStorage(BaseGraphStorage):
graph.add_node(node_id, **node_data) graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: Dict[str, str]
) -> None: ) -> None:
""" """
Importance notes: Importance notes:
@ -158,7 +158,7 @@ class NetworkXStorage(BaseGraphStorage):
else: else:
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def remove_nodes(self, nodes: list[str]): async def remove_nodes(self, nodes: List[str]):
"""Delete multiple nodes """Delete multiple nodes
Importance notes: Importance notes:
@ -174,7 +174,7 @@ class NetworkXStorage(BaseGraphStorage):
if graph.has_node(node): if graph.has_node(node):
graph.remove_node(node) graph.remove_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]): async def remove_edges(self, edges: List[Tuple[str, str]]):
"""Delete multiple edges """Delete multiple edges
Importance notes: Importance notes:
@ -190,7 +190,7 @@ class NetworkXStorage(BaseGraphStorage):
if graph.has_edge(source, target): if graph.has_edge(source, target):
graph.remove_edge(source, target) graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> List[str]:
""" """
Get all node labels in the graph Get all node labels in the graph
Returns: Returns:
@ -389,7 +389,7 @@ class NetworkXStorage(BaseGraphStorage):
return True return True
async def drop(self) -> dict[str, str]: async def drop(self) -> Dict[str, str]:
"""Drop all graph data from storage and clean up resources """Drop all graph data from storage and clean up resources
This method will: This method will:

View File

@ -1,5 +1,4 @@
from __future__ import annotations from __future__ import annotations
import traceback import traceback
import asyncio import asyncio
import configparser import configparser
@ -21,18 +20,18 @@ from typing import (
List, List,
Dict, Dict,
) )
from lightrag.constants import ( from .constants import (
DEFAULT_MAX_TOKEN_SUMMARY, DEFAULT_MAX_TOKEN_SUMMARY,
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
) )
from lightrag.utils import get_env_value from .utils import get_env_value
from lightrag.kg import ( from .kg import (
STORAGES, STORAGES,
verify_storage_implementation, verify_storage_implementation,
) )
from lightrag.kg.shared_storage import ( from .kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_pipeline_status_lock, get_pipeline_status_lock,
) )
@ -199,7 +198,7 @@ class LightRAG:
) )
"""Maximum number of concurrent embedding function calls.""" """Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field( embedding_cache_config: Dict[str, Any] = field(
default_factory=lambda: { default_factory=lambda: {
"enabled": False, "enabled": False,
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
@ -283,7 +282,7 @@ class LightRAG:
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self): def __post_init__(self):
from lightrag.kg.shared_storage import ( from .kg.shared_storage import (
initialize_share_data, initialize_share_data,
) )

View File

@ -1,18 +1,18 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Optional from typing import Any, Optional,List,Dict
class GPTKeywordExtractionFormat(BaseModel): class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: list[str] high_level_keywords: List[str]
low_level_keywords: list[str] low_level_keywords: List[str]
class KnowledgeGraphNode(BaseModel): class KnowledgeGraphNode(BaseModel):
id: str id: str
labels: list[str] labels: List[str]
properties: dict[str, Any] # anything else goes here properties: Dict[str, Any] # anything else goes here
class KnowledgeGraphEdge(BaseModel): class KnowledgeGraphEdge(BaseModel):
@ -20,10 +20,10 @@ class KnowledgeGraphEdge(BaseModel):
type: Optional[str] type: Optional[str]
source: str # id of source node source: str # id of source node
target: str # id of target node target: str # id of target node
properties: dict[str, Any] # anything else goes here properties: Dict[str, Any] # anything else goes here
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
nodes: list[KnowledgeGraphNode] = [] nodes: List[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = [] edges: List[KnowledgeGraphEdge] = []
is_truncated: bool = False is_truncated: bool = False

View File

@ -14,9 +14,9 @@ from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Protocol, Callable, TYPE_CHECKING, List from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import numpy as np import numpy as np
from lightrag.prompt import PROMPTS from .prompt import PROMPTS
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag.constants import ( from .constants import (
DEFAULT_LOG_MAX_BYTES, DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT, DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME, DEFAULT_LOG_FILENAME,
@ -1696,7 +1696,7 @@ def check_storage_env_vars(storage_name: str) -> None:
Raises: Raises:
ValueError: If required environment variables are missing ValueError: If required environment variables are missing
""" """
from lightrag.kg import STORAGE_ENV_REQUIREMENTS from .kg import STORAGE_ENV_REQUIREMENTS
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ] missing_vars = [var for var in required_vars if var not in os.environ]

View File

@ -20,7 +20,7 @@ class autoGenerator():
async def query(self): async def query(self):
self.rag_client = MassageAcupointRAG( self.rag_client = MassageAcupointRAG(
working_dir="C:/Users/ZIWEI/Documents/work/向量化/CoreRAG/Massage_10216" working_dir="CoreRAG/Massage_10216"
) )
await self.rag_client.initialize() await self.rag_client.initialize()

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import List
# ======================= 参数配置 ======================= # # ======================= 参数配置 ======================= #
GRID_SIZE = (40, 130) # 行x列 GRID_SIZE = (40, 130) # 行x列
TOTAL_STEPS = 50 TOTAL_STEPS = 50
@ -166,14 +166,14 @@ class Agent:
# ======================= 可视化模块 ======================= # # ======================= 可视化模块 ======================= #
class Visualizer: class Visualizer:
def __init__(self, grid_size, full_path_ref: list[tuple]): def __init__(self, grid_size, full_path_ref: List[tuple]):
self.grid_size = grid_size self.grid_size = grid_size
self.full_path_ref = full_path_ref self.full_path_ref = full_path_ref
plt.ion() plt.ion()
self.fig, self.ax = plt.subplots(figsize=(4, 10)) # 调整大小 self.fig, self.ax = plt.subplots(figsize=(4, 10)) # 调整大小
self.im = self.ax.imshow(np.zeros(grid_size).T, origin='lower', cmap='viridis') self.im = self.ax.imshow(np.zeros(grid_size).T, origin='lower', cmap='viridis')
self.colorbar = plt.colorbar(self.im, ax=self.ax, label='Height') self.colorbar = plt.colorbar(self.im, ax=self.ax, label='Height')
self._artists: list = [] self._artists: List = []
# 保持横纵比一致,避免变形 # 保持横纵比一致,避免变形
self.ax.set_aspect('equal') self.ax.set_aspect('equal')

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,8 +1,11 @@
try:from scripts.APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer try:from scripts.APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
except:from APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer except:from APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
import numpy as np import numpy as np
try:from scripts.sorter import sorter try:from scripts.sorter import sorter
except:from sorter import sorter except:from sorter import sorter
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import json import json
import time import time

View File

@ -1,6 +1,7 @@
import re import re
import json import json
import numpy as np import numpy as np
from typing import List
class sorter: class sorter:
''' 重点穴位按摩排列器 ''' ''' 重点穴位按摩排列器 '''
@ -18,7 +19,7 @@ class sorter:
self.massage_side = massage_side # 默认为双边'both' self.massage_side = massage_side # 默认为双边'both'
else: else:
raise ValueError("未指定按摩在左侧、右侧或两侧") raise ValueError("未指定按摩在左侧、右侧或两侧")
def _extract_acupoints(self,respnse_from_llm:str)->list[str]: def _extract_acupoints(self,respnse_from_llm:str)->List[str]:
pattern = r"[0-9]+\.\s*([\u4e00-\u9fa5]{2,5}穴)" pattern = r"[0-9]+\.\s*([\u4e00-\u9fa5]{2,5}穴)"
matches = re.findall(pattern, respnse_from_llm) matches = re.findall(pattern, respnse_from_llm)
# 去重 & 排除空值 # 去重 & 排除空值
@ -26,7 +27,7 @@ class sorter:
print(unique_names) print(unique_names)
return unique_names return unique_names
def sort_acupoints(self,respnse_from_llm:list[str])->list[str]: def sort_acupoints(self,respnse_from_llm:List[str])->List[str]:
if self.body_part == 'back': if self.body_part == 'back':
allowed_names = [] allowed_names = []
if self.body_part == 'shoulder': if self.body_part == 'shoulder':
@ -38,7 +39,7 @@ class sorter:
"志室右","肓门右","胃仓右","意舍右","阳纲右","胞肓右","气海右俞","大肠右俞","小肠右俞", "志室右","肓门右","胃仓右","意舍右","阳纲右","胞肓右","气海右俞","大肠右俞","小肠右俞",
"中膂右俞","肾俞右","关元右俞","膀胱右俞","白环右俞","秩边右","京门右"] "中膂右俞","肾俞右","关元右俞","膀胱右俞","白环右俞","秩边右","京门右"]
def __filter_acupoints(acupoints:list[str],allowed_names: list[str])->list[str]: def __filter_acupoints(acupoints:List[str],allowed_names: List[str])->List[str]:
acupoints_cleaned = [] acupoints_cleaned = []
matched_keys = [] matched_keys = []
for acupoint in acupoints: for acupoint in acupoints: