环境问题修复中
This commit is contained in:
parent
c8948e0a44
commit
00bbd902b8
@ -1,10 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
from lightrag import LightRAG
|
from CoreRAG.lightrag.lightrag import LightRAG
|
||||||
from lightrag.base import QueryParam
|
from CoreRAG.lightrag.base import QueryParam
|
||||||
from custom_rag_processor import QwenEmbedding, DeepSeekCompletion
|
from CoreRAG.custom_rag_processor import QwenEmbedding, DeepSeekCompletion
|
||||||
|
from typing import List,Tuple
|
||||||
class MassageAcupointRAG:
|
class MassageAcupointRAG:
|
||||||
def __init__(self, working_dir: str):
|
def __init__(self, working_dir: str):
|
||||||
self.working_dir = working_dir
|
self.working_dir = working_dir
|
||||||
@ -16,7 +16,7 @@ class MassageAcupointRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_embed(texts: list[str]) -> np.ndarray:
|
async def async_embed(texts: List[str]) -> np.ndarray:
|
||||||
async with QwenEmbedding() as embedder:
|
async with QwenEmbedding() as embedder:
|
||||||
return await embedder.embed(texts)
|
return await embedder.embed(texts)
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ class MassageAcupointRAG:
|
|||||||
return await completer.complete(prompt, **kwargs)
|
return await completer.complete(prompt, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_acupoint_list(text: str) -> list[str]:
|
def extract_acupoint_list(text: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
从 LLM 返回文本中提取穴位名称列表(形如 ['肩井穴', ...])
|
从 LLM 返回文本中提取穴位名称列表(形如 ['肩井穴', ...])
|
||||||
"""
|
"""
|
||||||
@ -39,7 +39,7 @@ class MassageAcupointRAG:
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await self.rag.initialize_storages()
|
await self.rag.initialize_storages()
|
||||||
|
|
||||||
async def query_acupoints(self, user_query: str) -> tuple[str, list[str]]:
|
async def query_acupoints(self, user_query: str) -> Tuple[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
执行 RAG 查询并提取穴位列表
|
执行 RAG 查询并提取穴位列表
|
||||||
|
|
||||||
|
BIN
CoreRAG/__pycache__/MassageAcupointRAG.cpython-38.pyc
Normal file
BIN
CoreRAG/__pycache__/MassageAcupointRAG.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/__pycache__/custom_rag_processor.cpython-38.pyc
Normal file
BIN
CoreRAG/__pycache__/custom_rag_processor.cpython-38.pyc
Normal file
Binary file not shown.
@ -4,10 +4,11 @@ import logging
|
|||||||
import logging.config
|
import logging.config
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightrag import LightRAG, QueryParam
|
from CoreRAG.lightrag import LightRAG, QueryParam
|
||||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
from CoreRAG.lightrag.kg.shared_storage import initialize_pipeline_status
|
||||||
from lightrag.utils import logger, set_verbose_debug
|
from CoreRAG.lightrag.utils import logger, set_verbose_debug
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from typing import List
|
||||||
WORKING_DIR = "./dickens"
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
# API配置
|
# API配置
|
||||||
@ -33,7 +34,7 @@ class QwenEmbedding:
|
|||||||
if self.session:
|
if self.session:
|
||||||
await self.session.close()
|
await self.session.close()
|
||||||
|
|
||||||
async def embed(self, texts: list[str]) -> np.ndarray:
|
async def embed(self, texts: List[str]) -> np.ndarray:
|
||||||
"""自动分批请求嵌入"""
|
"""自动分批请求嵌入"""
|
||||||
if not self.session:
|
if not self.session:
|
||||||
raise RuntimeError("Session not initialized. Use async with.")
|
raise RuntimeError("Session not initialized. Use async with.")
|
||||||
@ -114,7 +115,7 @@ class TextChunker:
|
|||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.overlap = overlap
|
self.overlap = overlap
|
||||||
|
|
||||||
def chunk_text(self, text: str) -> list[str]:
|
def chunk_text(self, text: str) -> List[str]:
|
||||||
"""将长文本按token分块"""
|
"""将长文本按token分块"""
|
||||||
tokens = self.tokenizer.encode(text)
|
tokens = self.tokenizer.encode(text)
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -193,7 +194,7 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
async def Qwen_embed(texts: list[str]) -> np.ndarray:
|
async def Qwen_embed(texts: List[str]) -> np.ndarray:
|
||||||
async with QwenEmbedding() as embedder:
|
async with QwenEmbedding() as embedder:
|
||||||
return await embedder.embed(texts)
|
return await embedder.embed(texts)
|
||||||
Qwen_embed.embedding_dim = 1024
|
Qwen_embed.embedding_dim = 1024
|
||||||
@ -303,7 +304,7 @@ class CustomRAGProcessor:
|
|||||||
await initialize_pipeline_status()
|
await initialize_pipeline_status()
|
||||||
return rag
|
return rag
|
||||||
|
|
||||||
async def qwen_embed_wrapper(self, texts: list[str]) -> np.ndarray:
|
async def qwen_embed_wrapper(self, texts: List[str]) -> np.ndarray:
|
||||||
"""Qwen嵌入函数包装器,带重试机制"""
|
"""Qwen嵌入函数包装器,带重试机制"""
|
||||||
async with QwenEmbedding(self.qwen_api_key) as embedder:
|
async with QwenEmbedding(self.qwen_api_key) as embedder:
|
||||||
return await embedder.embed(texts)
|
return await embedder.embed(texts)
|
||||||
|
BIN
CoreRAG/lightrag/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/base.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/base.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/constants.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/constants.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/lightrag.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/lightrag.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/namespace.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/namespace.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/operate.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/operate.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/prompt.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/prompt.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/types.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/types.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/__pycache__/utils.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from typing import List, Union, Callable, Dict, Tuple
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import os
|
import os
|
||||||
@ -70,13 +70,13 @@ class QueryParam:
|
|||||||
max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000"))
|
max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000"))
|
||||||
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
|
||||||
|
|
||||||
hl_keywords: list[str] = field(default_factory=list)
|
hl_keywords: List[str] = field(default_factory=list)
|
||||||
"""List of high-level keywords to prioritize in retrieval."""
|
"""List of high-level keywords to prioritize in retrieval."""
|
||||||
|
|
||||||
ll_keywords: list[str] = field(default_factory=list)
|
ll_keywords: List[str] = field(default_factory=list)
|
||||||
"""List of low-level keywords to refine retrieval focus."""
|
"""List of low-level keywords to refine retrieval focus."""
|
||||||
|
|
||||||
conversation_history: list[dict[str, str]] = field(default_factory=list)
|
conversation_history: List[Dict[str, str]] = field(default_factory=list)
|
||||||
"""Stores past conversation history to maintain context.
|
"""Stores past conversation history to maintain context.
|
||||||
Format: [{"role": "user/assistant", "content": "message"}].
|
Format: [{"role": "user/assistant", "content": "message"}].
|
||||||
"""
|
"""
|
||||||
@ -84,16 +84,16 @@ class QueryParam:
|
|||||||
history_turns: int = 3
|
history_turns: int = 3
|
||||||
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
|
||||||
|
|
||||||
ids: list[str] | None = None
|
ids: Union[List[str], None] = None
|
||||||
"""List of ids to filter the results."""
|
"""List of ids to filter the results."""
|
||||||
|
|
||||||
model_func: Callable[..., object] | None = None
|
model_func: Union[Callable[..., object], None] = None
|
||||||
"""Optional override for the LLM model function to use for this specific query.
|
"""Optional override for the LLM model function to use for this specific query.
|
||||||
If provided, this will be used instead of the global model function.
|
If provided, this will be used instead of the global model function.
|
||||||
This allows using different models for different query modes.
|
This allows using different models for different query modes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user_prompt: str | None = None
|
user_prompt: Union[str,None] = None
|
||||||
"""User-provided prompt for the query.
|
"""User-provided prompt for the query.
|
||||||
If proivded, this will be use instead of the default vaulue from prompt template.
|
If proivded, this will be use instead of the default vaulue from prompt template.
|
||||||
"""
|
"""
|
||||||
@ -102,7 +102,7 @@ class QueryParam:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class StorageNameSpace(ABC):
|
class StorageNameSpace(ABC):
|
||||||
namespace: str
|
namespace: str
|
||||||
global_config: dict[str, Any]
|
global_config: Dict[str, Any]
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize the storage"""
|
"""Initialize the storage"""
|
||||||
@ -117,7 +117,7 @@ class StorageNameSpace(ABC):
|
|||||||
"""Commit the storage operations after indexing"""
|
"""Commit the storage operations after indexing"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> Dict[str, str]:
|
||||||
"""Drop all data from storage and clean up resources
|
"""Drop all data from storage and clean up resources
|
||||||
|
|
||||||
This abstract method defines the contract for dropping all data from a storage implementation.
|
This abstract method defines the contract for dropping all data from a storage implementation.
|
||||||
@ -151,12 +151,12 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(
|
async def query(
|
||||||
self, query: str, top_k: int, ids: list[str] | None = None
|
self, query: str, top_k: int, ids: Union[List[str],None] = None
|
||||||
) -> list[dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Query the vector storage and retrieve top_k results."""
|
"""Query the vector storage and retrieve top_k results."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: Dict[str, Dict[str, Any]]) -> None:
|
||||||
"""Insert or update vectors in the storage.
|
"""Insert or update vectors in the storage.
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
@ -186,7 +186,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> Union[Dict[str, Any],None]:
|
||||||
"""Get vector data by its ID
|
"""Get vector data by its ID
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -198,7 +198,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
|
||||||
"""Get multiple vector data by their IDs
|
"""Get multiple vector data by their IDs
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -210,7 +210,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete(self, ids: list[str]):
|
async def delete(self, ids: List[str]):
|
||||||
"""Delete vectors with specified IDs
|
"""Delete vectors with specified IDs
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
@ -228,11 +228,11 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
|||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> Union[Dict[str, Any],None]:
|
||||||
"""Get value by id"""
|
"""Get value by id"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
|
||||||
"""Get values by ids"""
|
"""Get values by ids"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -249,7 +249,7 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: List[str]) -> None:
|
||||||
"""Delete specific records from storage by their IDs
|
"""Delete specific records from storage by their IDs
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
@ -263,7 +263,7 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
async def drop_cache_by_modes(self, modes: Union[List[str],None] = None) -> bool:
|
||||||
"""Delete specific records from storage by cache mode
|
"""Delete specific records from storage by cache mode
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
@ -330,7 +330,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> Union[Dict[str, str],None]:
|
||||||
"""Get node by its ID, returning only node properties.
|
"""Get node by its ID, returning only node properties.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -343,7 +343,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> Union[Dict[str, str],None]:
|
||||||
"""Get edge properties between two nodes.
|
"""Get edge properties between two nodes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -355,7 +355,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> Union[List[Tuple[str, str]],None]:
|
||||||
"""Get all edges connected to a node.
|
"""Get all edges connected to a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -366,7 +366,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
or None if the node doesn't exist
|
or None if the node doesn't exist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
async def get_nodes_batch(self, node_ids: List[str]) -> Dict[str, Dict]:
|
||||||
"""Get nodes as a batch using UNWIND
|
"""Get nodes as a batch using UNWIND
|
||||||
|
|
||||||
Default implementation fetches nodes one by one.
|
Default implementation fetches nodes one by one.
|
||||||
@ -380,7 +380,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
result[node_id] = node
|
result[node_id] = node
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
async def node_degrees_batch(self, node_ids: List[str]) -> Dict[str, int]:
|
||||||
"""Node degrees as a batch using UNWIND
|
"""Node degrees as a batch using UNWIND
|
||||||
|
|
||||||
Default implementation fetches node degrees one by one.
|
Default implementation fetches node degrees one by one.
|
||||||
@ -394,8 +394,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def edge_degrees_batch(
|
async def edge_degrees_batch(
|
||||||
self, edge_pairs: list[tuple[str, str]]
|
self, edge_pairs: List[Tuple[str, str]]
|
||||||
) -> dict[tuple[str, str], int]:
|
) -> Dict[Tuple[str, str], int]:
|
||||||
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
||||||
|
|
||||||
Default implementation calculates edge degrees one by one.
|
Default implementation calculates edge degrees one by one.
|
||||||
@ -409,8 +409,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_edges_batch(
|
async def get_edges_batch(
|
||||||
self, pairs: list[dict[str, str]]
|
self, pairs: List[Dict[str, str]]
|
||||||
) -> dict[tuple[str, str], dict]:
|
) -> Dict[Tuple[str, str], Dict]:
|
||||||
"""Get edges as a batch using UNWIND
|
"""Get edges as a batch using UNWIND
|
||||||
|
|
||||||
Default implementation fetches edges one by one.
|
Default implementation fetches edges one by one.
|
||||||
@ -427,8 +427,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_nodes_edges_batch(
|
async def get_nodes_edges_batch(
|
||||||
self, node_ids: list[str]
|
self, node_ids: List[str]
|
||||||
) -> dict[str, list[tuple[str, str]]]:
|
) -> Dict[str, List[Tuple[str, str]]]:
|
||||||
"""Get nodes edges as a batch using UNWIND
|
"""Get nodes edges as a batch using UNWIND
|
||||||
|
|
||||||
Default implementation fetches node edges one by one.
|
Default implementation fetches node edges one by one.
|
||||||
@ -442,7 +442,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: Dict[str, str]) -> None:
|
||||||
"""Insert a new node or update an existing node in the graph.
|
"""Insert a new node or update an existing node in the graph.
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
@ -457,7 +457,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Insert a new edge or update an existing edge in the graph.
|
"""Insert a new edge or update an existing edge in the graph.
|
||||||
|
|
||||||
@ -486,7 +486,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def remove_nodes(self, nodes: list[str]):
|
async def remove_nodes(self, nodes: List[str]):
|
||||||
"""Delete multiple nodes
|
"""Delete multiple nodes
|
||||||
|
|
||||||
Importance notes:
|
Importance notes:
|
||||||
@ -499,7 +499,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
async def remove_edges(self, edges: List[Tuple[str, str]]):
|
||||||
"""Delete multiple edges
|
"""Delete multiple edges
|
||||||
|
|
||||||
Importance notes:
|
Importance notes:
|
||||||
@ -512,7 +512,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> List[str]:
|
||||||
"""Get all labels in the graph.
|
"""Get all labels in the graph.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -564,11 +564,11 @@ class DocProcessingStatus:
|
|||||||
"""ISO format timestamp when document was created"""
|
"""ISO format timestamp when document was created"""
|
||||||
updated_at: str
|
updated_at: str
|
||||||
"""ISO format timestamp when document was last updated"""
|
"""ISO format timestamp when document was last updated"""
|
||||||
chunks_count: int | None = None
|
chunks_count: Union[int,None] = None
|
||||||
"""Number of chunks after splitting, used for processing"""
|
"""Number of chunks after splitting, used for processing"""
|
||||||
error: str | None = None
|
error: Union[str,None] = None
|
||||||
"""Error message if failed"""
|
"""Error message if failed"""
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: Dict[str, Any] = field(default_factory=Dict)
|
||||||
"""Additional metadata"""
|
"""Additional metadata"""
|
||||||
|
|
||||||
|
|
||||||
@ -577,16 +577,16 @@ class DocStatusStorage(BaseKVStorage, ABC):
|
|||||||
"""Base class for document status storage"""
|
"""Base class for document status storage"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> Dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_docs_by_status(
|
async def get_docs_by_status(
|
||||||
self, status: DocStatus
|
self, status: DocStatus
|
||||||
) -> dict[str, DocProcessingStatus]:
|
) -> Dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents with a specific status"""
|
"""Get all documents with a specific status"""
|
||||||
|
|
||||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
async def drop_cache_by_modes(self, modes: Union[List[str],None] = None) -> bool:
|
||||||
"""Drop cache is not supported for Doc Status storage"""
|
"""Drop cache is not supported for Doc Status storage"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import List
|
||||||
|
from typing import Dict
|
||||||
STORAGE_IMPLEMENTATIONS = {
|
STORAGE_IMPLEMENTATIONS = {
|
||||||
"KV_STORAGE": {
|
"KV_STORAGE": {
|
||||||
"implementations": [
|
"implementations": [
|
||||||
@ -45,7 +47,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Storage implementation environment variable without default value
|
# Storage implementation environment variable without default value
|
||||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
STORAGE_ENV_REQUIREMENTS: Dict[str, List[str]] = {
|
||||||
# KV Storage Implementations
|
# KV Storage Implementations
|
||||||
"JsonKVStorage": [],
|
"JsonKVStorage": [],
|
||||||
"MongoKVStorage": [],
|
"MongoKVStorage": [],
|
||||||
|
BIN
CoreRAG/lightrag/kg/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/kg/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/kg/__pycache__/json_kv_impl.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/kg/__pycache__/json_kv_impl.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
CoreRAG/lightrag/kg/__pycache__/networkx_impl.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/kg/__pycache__/networkx_impl.cpython-38.pyc
Normal file
Binary file not shown.
BIN
CoreRAG/lightrag/kg/__pycache__/shared_storage.cpython-38.pyc
Normal file
BIN
CoreRAG/lightrag/kg/__pycache__/shared_storage.cpython-38.pyc
Normal file
Binary file not shown.
@ -1,11 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, final
|
from typing import Any, final,Dict,List,Union,Set
|
||||||
|
|
||||||
from lightrag.base import (
|
from ..base import (
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
)
|
)
|
||||||
from lightrag.utils import (
|
from ..utils import (
|
||||||
load_json,
|
load_json,
|
||||||
logger,
|
logger,
|
||||||
write_json,
|
write_json,
|
||||||
@ -85,7 +85,7 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
await clear_all_update_flags(self.namespace)
|
await clear_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def get_all(self) -> dict[str, Any]:
|
async def get_all(self) -> Dict[str, Any]:
|
||||||
"""Get all data from storage
|
"""Get all data from storage
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -94,11 +94,11 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return dict(self._data)
|
return dict(self._data)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> Union[Dict[str, Any] , None]:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return self._data.get(id)
|
return self._data.get(id)
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
@ -109,11 +109,11 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
for id in ids
|
for id in ids
|
||||||
]
|
]
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: Set[str]) -> Set[str]:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return set(keys) - set(self._data.keys())
|
return Set(keys) - Set(self._data.keys())
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: Dict[str, Dict[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
1. Changes will be persisted to disk during the next index_done_callback
|
1. Changes will be persisted to disk during the next index_done_callback
|
||||||
@ -126,7 +126,7 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await set_all_update_flags(self.namespace)
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: List[str]) -> None:
|
||||||
"""Delete specific records from storage by their IDs
|
"""Delete specific records from storage by their IDs
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
@ -149,7 +149,7 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
if any_deleted:
|
if any_deleted:
|
||||||
await set_all_update_flags(self.namespace)
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
async def drop_cache_by_modes(self, modes: Union[List[str] , None] = None) -> bool:
|
||||||
"""Delete specific records from storage by by cache mode
|
"""Delete specific records from storage by by cache mode
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
@ -172,7 +172,7 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> Dict[str, str]:
|
||||||
"""Drop all data from storage and clean up resources
|
"""Drop all data from storage and clean up resources
|
||||||
This action will persistent the data to disk immediately.
|
This action will persistent the data to disk immediately.
|
||||||
|
|
||||||
|
@ -5,15 +5,15 @@ from dataclasses import dataclass
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from lightrag.utils import (
|
from ..utils import (
|
||||||
logger,
|
logger,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
)
|
)
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
from lightrag.base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
|
|
||||||
if not pm.is_installed("nano-vectordb"):
|
# if not pm.is_installed("nano-vectordb"):
|
||||||
pm.install("nano-vectordb")
|
# pm.install("nano-vectordb")
|
||||||
|
|
||||||
from nano_vectordb import NanoVectorDB
|
from nano_vectordb import NanoVectorDB
|
||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import final
|
from typing import final,Dict,List,Union,Tuple
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from lightrag.utils import logger
|
from ..utils import logger
|
||||||
from lightrag.base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
graph = await self._get_graph()
|
graph = await self._get_graph()
|
||||||
return graph.has_edge(source_node_id, target_node_id)
|
return graph.has_edge(source_node_id, target_node_id)
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> Union[Dict[str, str],None]:
|
||||||
graph = await self._get_graph()
|
graph = await self._get_graph()
|
||||||
return graph.nodes.get(node_id)
|
return graph.nodes.get(node_id)
|
||||||
|
|
||||||
@ -112,17 +112,17 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> Union[Dict[str, str],None]:
|
||||||
graph = await self._get_graph()
|
graph = await self._get_graph()
|
||||||
return graph.edges.get((source_node_id, target_node_id))
|
return graph.edges.get((source_node_id, target_node_id))
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> Union[List[Tuple[str, str]],None]:
|
||||||
graph = await self._get_graph()
|
graph = await self._get_graph()
|
||||||
if graph.has_node(source_node_id):
|
if graph.has_node(source_node_id):
|
||||||
return list(graph.edges(source_node_id))
|
return list(graph.edges(source_node_id))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: Dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Importance notes:
|
Importance notes:
|
||||||
1. Changes will be persisted to disk during the next index_done_callback
|
1. Changes will be persisted to disk during the next index_done_callback
|
||||||
@ -133,7 +133,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
graph.add_node(node_id, **node_data)
|
graph.add_node(node_id, **node_data)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Importance notes:
|
Importance notes:
|
||||||
@ -158,7 +158,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||||
|
|
||||||
async def remove_nodes(self, nodes: list[str]):
|
async def remove_nodes(self, nodes: List[str]):
|
||||||
"""Delete multiple nodes
|
"""Delete multiple nodes
|
||||||
|
|
||||||
Importance notes:
|
Importance notes:
|
||||||
@ -174,7 +174,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
if graph.has_node(node):
|
if graph.has_node(node):
|
||||||
graph.remove_node(node)
|
graph.remove_node(node)
|
||||||
|
|
||||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
async def remove_edges(self, edges: List[Tuple[str, str]]):
|
||||||
"""Delete multiple edges
|
"""Delete multiple edges
|
||||||
|
|
||||||
Importance notes:
|
Importance notes:
|
||||||
@ -190,7 +190,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
if graph.has_edge(source, target):
|
if graph.has_edge(source, target):
|
||||||
graph.remove_edge(source, target)
|
graph.remove_edge(source, target)
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get all node labels in the graph
|
Get all node labels in the graph
|
||||||
Returns:
|
Returns:
|
||||||
@ -389,7 +389,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> Dict[str, str]:
|
||||||
"""Drop all graph data from storage and clean up resources
|
"""Drop all graph data from storage and clean up resources
|
||||||
|
|
||||||
This method will:
|
This method will:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import configparser
|
import configparser
|
||||||
@ -21,18 +20,18 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Dict,
|
Dict,
|
||||||
)
|
)
|
||||||
from lightrag.constants import (
|
from .constants import (
|
||||||
DEFAULT_MAX_TOKEN_SUMMARY,
|
DEFAULT_MAX_TOKEN_SUMMARY,
|
||||||
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
||||||
)
|
)
|
||||||
from lightrag.utils import get_env_value
|
from .utils import get_env_value
|
||||||
|
|
||||||
from lightrag.kg import (
|
from .kg import (
|
||||||
STORAGES,
|
STORAGES,
|
||||||
verify_storage_implementation,
|
verify_storage_implementation,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lightrag.kg.shared_storage import (
|
from .kg.shared_storage import (
|
||||||
get_namespace_data,
|
get_namespace_data,
|
||||||
get_pipeline_status_lock,
|
get_pipeline_status_lock,
|
||||||
)
|
)
|
||||||
@ -199,7 +198,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
"""Maximum number of concurrent embedding function calls."""
|
"""Maximum number of concurrent embedding function calls."""
|
||||||
|
|
||||||
embedding_cache_config: dict[str, Any] = field(
|
embedding_cache_config: Dict[str, Any] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"similarity_threshold": 0.95,
|
"similarity_threshold": 0.95,
|
||||||
@ -283,7 +282,7 @@ class LightRAG:
|
|||||||
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
from lightrag.kg.shared_storage import (
|
from .kg.shared_storage import (
|
||||||
initialize_share_data,
|
initialize_share_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional,List,Dict
|
||||||
|
|
||||||
|
|
||||||
class GPTKeywordExtractionFormat(BaseModel):
|
class GPTKeywordExtractionFormat(BaseModel):
|
||||||
high_level_keywords: list[str]
|
high_level_keywords: List[str]
|
||||||
low_level_keywords: list[str]
|
low_level_keywords: List[str]
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraphNode(BaseModel):
|
class KnowledgeGraphNode(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
labels: list[str]
|
labels: List[str]
|
||||||
properties: dict[str, Any] # anything else goes here
|
properties: Dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraphEdge(BaseModel):
|
class KnowledgeGraphEdge(BaseModel):
|
||||||
@ -20,10 +20,10 @@ class KnowledgeGraphEdge(BaseModel):
|
|||||||
type: Optional[str]
|
type: Optional[str]
|
||||||
source: str # id of source node
|
source: str # id of source node
|
||||||
target: str # id of target node
|
target: str # id of target node
|
||||||
properties: dict[str, Any] # anything else goes here
|
properties: Dict[str, Any] # anything else goes here
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraph(BaseModel):
|
class KnowledgeGraph(BaseModel):
|
||||||
nodes: list[KnowledgeGraphNode] = []
|
nodes: List[KnowledgeGraphNode] = []
|
||||||
edges: list[KnowledgeGraphEdge] = []
|
edges: List[KnowledgeGraphEdge] = []
|
||||||
is_truncated: bool = False
|
is_truncated: bool = False
|
||||||
|
@ -14,9 +14,9 @@ from functools import wraps
|
|||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightrag.prompt import PROMPTS
|
from .prompt import PROMPTS
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from lightrag.constants import (
|
from .constants import (
|
||||||
DEFAULT_LOG_MAX_BYTES,
|
DEFAULT_LOG_MAX_BYTES,
|
||||||
DEFAULT_LOG_BACKUP_COUNT,
|
DEFAULT_LOG_BACKUP_COUNT,
|
||||||
DEFAULT_LOG_FILENAME,
|
DEFAULT_LOG_FILENAME,
|
||||||
@ -1696,7 +1696,7 @@ def check_storage_env_vars(storage_name: str) -> None:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If required environment variables are missing
|
ValueError: If required environment variables are missing
|
||||||
"""
|
"""
|
||||||
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
|
from .kg import STORAGE_ENV_REQUIREMENTS
|
||||||
|
|
||||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||||
|
@ -20,7 +20,7 @@ class autoGenerator():
|
|||||||
|
|
||||||
async def query(self):
|
async def query(self):
|
||||||
self.rag_client = MassageAcupointRAG(
|
self.rag_client = MassageAcupointRAG(
|
||||||
working_dir="C:/Users/ZIWEI/Documents/work/向量化/CoreRAG/Massage_10216"
|
working_dir="CoreRAG/Massage_10216"
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.rag_client.initialize()
|
await self.rag_client.initialize()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from typing import List
|
||||||
# ======================= 参数配置 ======================= #
|
# ======================= 参数配置 ======================= #
|
||||||
GRID_SIZE = (40, 130) # 行x列
|
GRID_SIZE = (40, 130) # 行x列
|
||||||
TOTAL_STEPS = 50
|
TOTAL_STEPS = 50
|
||||||
@ -166,14 +166,14 @@ class Agent:
|
|||||||
|
|
||||||
# ======================= 可视化模块 ======================= #
|
# ======================= 可视化模块 ======================= #
|
||||||
class Visualizer:
|
class Visualizer:
|
||||||
def __init__(self, grid_size, full_path_ref: list[tuple]):
|
def __init__(self, grid_size, full_path_ref: List[tuple]):
|
||||||
self.grid_size = grid_size
|
self.grid_size = grid_size
|
||||||
self.full_path_ref = full_path_ref
|
self.full_path_ref = full_path_ref
|
||||||
plt.ion()
|
plt.ion()
|
||||||
self.fig, self.ax = plt.subplots(figsize=(4, 10)) # 调整大小
|
self.fig, self.ax = plt.subplots(figsize=(4, 10)) # 调整大小
|
||||||
self.im = self.ax.imshow(np.zeros(grid_size).T, origin='lower', cmap='viridis')
|
self.im = self.ax.imshow(np.zeros(grid_size).T, origin='lower', cmap='viridis')
|
||||||
self.colorbar = plt.colorbar(self.im, ax=self.ax, label='Height')
|
self.colorbar = plt.colorbar(self.im, ax=self.ax, label='Height')
|
||||||
self._artists: list = []
|
self._artists: List = []
|
||||||
|
|
||||||
# 保持横纵比一致,避免变形
|
# 保持横纵比一致,避免变形
|
||||||
self.ax.set_aspect('equal')
|
self.ax.set_aspect('equal')
|
||||||
|
BIN
scripts/__pycache__/APF_global_demo.cpython-38.pyc
Normal file
BIN
scripts/__pycache__/APF_global_demo.cpython-38.pyc
Normal file
Binary file not shown.
BIN
scripts/__pycache__/planner.cpython-38.pyc
Normal file
BIN
scripts/__pycache__/planner.cpython-38.pyc
Normal file
Binary file not shown.
BIN
scripts/__pycache__/sorter.cpython-38.pyc
Normal file
BIN
scripts/__pycache__/sorter.cpython-38.pyc
Normal file
Binary file not shown.
@ -1,8 +1,11 @@
|
|||||||
try:from scripts.APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
|
try:from scripts.APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
|
||||||
except:from APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
|
except:from APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
try:from scripts.sorter import sorter
|
try:from scripts.sorter import sorter
|
||||||
except:from sorter import sorter
|
except:from sorter import sorter
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
|
|
||||||
class sorter:
|
class sorter:
|
||||||
''' 重点穴位按摩排列器 '''
|
''' 重点穴位按摩排列器 '''
|
||||||
@ -18,7 +19,7 @@ class sorter:
|
|||||||
self.massage_side = massage_side # 默认为双边'both'
|
self.massage_side = massage_side # 默认为双边'both'
|
||||||
else:
|
else:
|
||||||
raise ValueError("未指定按摩在左侧、右侧或两侧")
|
raise ValueError("未指定按摩在左侧、右侧或两侧")
|
||||||
def _extract_acupoints(self,respnse_from_llm:str)->list[str]:
|
def _extract_acupoints(self,respnse_from_llm:str)->List[str]:
|
||||||
pattern = r"[0-9]+\.\s*([\u4e00-\u9fa5]{2,5}穴)"
|
pattern = r"[0-9]+\.\s*([\u4e00-\u9fa5]{2,5}穴)"
|
||||||
matches = re.findall(pattern, respnse_from_llm)
|
matches = re.findall(pattern, respnse_from_llm)
|
||||||
# 去重 & 排除空值
|
# 去重 & 排除空值
|
||||||
@ -26,7 +27,7 @@ class sorter:
|
|||||||
print(unique_names)
|
print(unique_names)
|
||||||
return unique_names
|
return unique_names
|
||||||
|
|
||||||
def sort_acupoints(self,respnse_from_llm:list[str])->list[str]:
|
def sort_acupoints(self,respnse_from_llm:List[str])->List[str]:
|
||||||
if self.body_part == 'back':
|
if self.body_part == 'back':
|
||||||
allowed_names = []
|
allowed_names = []
|
||||||
if self.body_part == 'shoulder':
|
if self.body_part == 'shoulder':
|
||||||
@ -38,7 +39,7 @@ class sorter:
|
|||||||
"志室右","肓门右","胃仓右","意舍右","阳纲右","胞肓右","气海右俞","大肠右俞","小肠右俞",
|
"志室右","肓门右","胃仓右","意舍右","阳纲右","胞肓右","气海右俞","大肠右俞","小肠右俞",
|
||||||
"中膂右俞","肾俞右","关元右俞","膀胱右俞","白环右俞","秩边右","京门右"]
|
"中膂右俞","肾俞右","关元右俞","膀胱右俞","白环右俞","秩边右","京门右"]
|
||||||
|
|
||||||
def __filter_acupoints(acupoints:list[str],allowed_names: list[str])->list[str]:
|
def __filter_acupoints(acupoints:List[str],allowed_names: List[str])->List[str]:
|
||||||
acupoints_cleaned = []
|
acupoints_cleaned = []
|
||||||
matched_keys = []
|
matched_keys = []
|
||||||
for acupoint in acupoints:
|
for acupoint in acupoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user