环境问题修复中

This commit is contained in:
Ziwei.He 2025-06-29 23:59:06 -07:00
parent c8948e0a44
commit 00bbd902b8
33 changed files with 116 additions and 110 deletions

View File

@ -1,10 +1,10 @@
import asyncio
import numpy as np
import re
from lightrag import LightRAG
from lightrag.base import QueryParam
from custom_rag_processor import QwenEmbedding, DeepSeekCompletion
from CoreRAG.lightrag.lightrag import LightRAG
from CoreRAG.lightrag.base import QueryParam
from CoreRAG.custom_rag_processor import QwenEmbedding, DeepSeekCompletion
from typing import List,Tuple
class MassageAcupointRAG:
def __init__(self, working_dir: str):
self.working_dir = working_dir
@ -16,7 +16,7 @@ class MassageAcupointRAG:
)
@staticmethod
async def async_embed(texts: list[str]) -> np.ndarray:
async def async_embed(texts: List[str]) -> np.ndarray:
async with QwenEmbedding() as embedder:
return await embedder.embed(texts)
@ -26,7 +26,7 @@ class MassageAcupointRAG:
return await completer.complete(prompt, **kwargs)
@staticmethod
def extract_acupoint_list(text: str) -> list[str]:
def extract_acupoint_list(text: str) -> List[str]:
"""
LLM 返回文本中提取穴位名称列表形如 ['肩井穴', ...]
"""
@ -39,7 +39,7 @@ class MassageAcupointRAG:
async def initialize(self):
await self.rag.initialize_storages()
async def query_acupoints(self, user_query: str) -> tuple[str, list[str]]:
async def query_acupoints(self, user_query: str) -> Tuple[str, List[str]]:
"""
执行 RAG 查询并提取穴位列表

Binary file not shown.

View File

@ -4,10 +4,11 @@ import logging
import logging.config
import aiohttp
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import logger, set_verbose_debug
from CoreRAG.lightrag import LightRAG, QueryParam
from CoreRAG.lightrag.kg.shared_storage import initialize_pipeline_status
from CoreRAG.lightrag.utils import logger, set_verbose_debug
import tiktoken
from typing import List
WORKING_DIR = "./dickens"
# API配置
@ -33,7 +34,7 @@ class QwenEmbedding:
if self.session:
await self.session.close()
async def embed(self, texts: list[str]) -> np.ndarray:
async def embed(self, texts: List[str]) -> np.ndarray:
"""自动分批请求嵌入"""
if not self.session:
raise RuntimeError("Session not initialized. Use async with.")
@ -114,7 +115,7 @@ class TextChunker:
self.max_tokens = max_tokens
self.overlap = overlap
def chunk_text(self, text: str) -> list[str]:
def chunk_text(self, text: str) -> List[str]:
"""将长文本按token分块"""
tokens = self.tokenizer.encode(text)
chunks = []
@ -193,7 +194,7 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def Qwen_embed(texts: list[str]) -> np.ndarray:
async def Qwen_embed(texts: List[str]) -> np.ndarray:
async with QwenEmbedding() as embedder:
return await embedder.embed(texts)
Qwen_embed.embedding_dim = 1024
@ -303,7 +304,7 @@ class CustomRAGProcessor:
await initialize_pipeline_status()
return rag
async def qwen_embed_wrapper(self, texts: list[str]) -> np.ndarray:
async def qwen_embed_wrapper(self, texts: List[str]) -> np.ndarray:
"""Qwen嵌入函数包装器带重试机制"""
async with QwenEmbedding(self.qwen_api_key) as embedder:
return await embedder.embed(texts)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import List, Union, Callable, Dict, Tuple
from abc import ABC, abstractmethod
from enum import Enum
import os
@ -70,13 +70,13 @@ class QueryParam:
max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000"))
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
hl_keywords: list[str] = field(default_factory=list)
hl_keywords: List[str] = field(default_factory=list)
"""List of high-level keywords to prioritize in retrieval."""
ll_keywords: list[str] = field(default_factory=list)
ll_keywords: List[str] = field(default_factory=list)
"""List of low-level keywords to refine retrieval focus."""
conversation_history: list[dict[str, str]] = field(default_factory=list)
conversation_history: List[Dict[str, str]] = field(default_factory=list)
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
@ -84,16 +84,16 @@ class QueryParam:
history_turns: int = 3
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
ids: list[str] | None = None
ids: Union[List[str], None] = None
"""List of ids to filter the results."""
model_func: Callable[..., object] | None = None
model_func: Union[Callable[..., object], None] = None
"""Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function.
This allows using different models for different query modes.
"""
user_prompt: str | None = None
user_prompt: Union[str,None] = None
"""User-provided prompt for the query.
If proivded, this will be use instead of the default vaulue from prompt template.
"""
@ -102,7 +102,7 @@ class QueryParam:
@dataclass
class StorageNameSpace(ABC):
namespace: str
global_config: dict[str, Any]
global_config: Dict[str, Any]
async def initialize(self):
"""Initialize the storage"""
@ -117,7 +117,7 @@ class StorageNameSpace(ABC):
"""Commit the storage operations after indexing"""
@abstractmethod
async def drop(self) -> dict[str, str]:
async def drop(self) -> Dict[str, str]:
"""Drop all data from storage and clean up resources
This abstract method defines the contract for dropping all data from a storage implementation.
@ -151,12 +151,12 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: Union[List[str],None] = None
) -> List[Dict[str, Any]]:
"""Query the vector storage and retrieve top_k results."""
@abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
async def upsert(self, data: Dict[str, Dict[str, Any]]) -> None:
"""Insert or update vectors in the storage.
Importance notes for in-memory storage:
@ -186,7 +186,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
"""
@abstractmethod
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async def get_by_id(self, id: str) -> Union[Dict[str, Any],None]:
"""Get vector data by its ID
Args:
@ -198,7 +198,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
pass
@abstractmethod
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
@ -210,7 +210,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
pass
@abstractmethod
async def delete(self, ids: list[str]):
async def delete(self, ids: List[str]):
"""Delete vectors with specified IDs
Importance notes for in-memory storage:
@ -228,11 +228,11 @@ class BaseKVStorage(StorageNameSpace, ABC):
embedding_func: EmbeddingFunc
@abstractmethod
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async def get_by_id(self, id: str) -> Union[Dict[str, Any],None]:
"""Get value by id"""
@abstractmethod
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
"""Get values by ids"""
@abstractmethod
@ -249,7 +249,7 @@ class BaseKVStorage(StorageNameSpace, ABC):
"""
@abstractmethod
async def delete(self, ids: list[str]) -> None:
async def delete(self, ids: List[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
@ -263,7 +263,7 @@ class BaseKVStorage(StorageNameSpace, ABC):
None
"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
async def drop_cache_by_modes(self, modes: Union[List[str],None] = None) -> bool:
"""Delete specific records from storage by cache mode
Importance notes for in-memory storage:
@ -330,7 +330,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
"""
@abstractmethod
async def get_node(self, node_id: str) -> dict[str, str] | None:
async def get_node(self, node_id: str) -> Union[Dict[str, str],None]:
"""Get node by its ID, returning only node properties.
Args:
@ -343,7 +343,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
) -> Union[Dict[str, str],None]:
"""Get edge properties between two nodes.
Args:
@ -355,7 +355,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
"""
@abstractmethod
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
async def get_node_edges(self, source_node_id: str) -> Union[List[Tuple[str, str]],None]:
"""Get all edges connected to a node.
Args:
@ -366,7 +366,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist
"""
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
async def get_nodes_batch(self, node_ids: List[str]) -> Dict[str, Dict]:
"""Get nodes as a batch using UNWIND
Default implementation fetches nodes one by one.
@ -380,7 +380,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = node
return result
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
async def node_degrees_batch(self, node_ids: List[str]) -> Dict[str, int]:
"""Node degrees as a batch using UNWIND
Default implementation fetches node degrees one by one.
@ -394,8 +394,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
self, edge_pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
Default implementation calculates edge degrees one by one.
@ -409,8 +409,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
self, pairs: List[Dict[str, str]]
) -> Dict[Tuple[str, str], Dict]:
"""Get edges as a batch using UNWIND
Default implementation fetches edges one by one.
@ -427,8 +427,8 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
self, node_ids: List[str]
) -> Dict[str, List[Tuple[str, str]]]:
"""Get nodes edges as a batch using UNWIND
Default implementation fetches node edges one by one.
@ -442,7 +442,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
return result
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
async def upsert_node(self, node_id: str, node_data: Dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.
Importance notes for in-memory storage:
@ -457,7 +457,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, str]
) -> None:
"""Insert a new edge or update an existing edge in the graph.
@ -486,7 +486,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
"""
@abstractmethod
async def remove_nodes(self, nodes: list[str]):
async def remove_nodes(self, nodes: List[str]):
"""Delete multiple nodes
Importance notes:
@ -499,7 +499,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
"""
@abstractmethod
async def remove_edges(self, edges: list[tuple[str, str]]):
async def remove_edges(self, edges: List[Tuple[str, str]]):
"""Delete multiple edges
Importance notes:
@ -512,7 +512,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
"""
@abstractmethod
async def get_all_labels(self) -> list[str]:
async def get_all_labels(self) -> List[str]:
"""Get all labels in the graph.
Returns:
@ -564,11 +564,11 @@ class DocProcessingStatus:
"""ISO format timestamp when document was created"""
updated_at: str
"""ISO format timestamp when document was last updated"""
chunks_count: int | None = None
chunks_count: Union[int,None] = None
"""Number of chunks after splitting, used for processing"""
error: str | None = None
error: Union[str,None] = None
"""Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=Dict)
"""Additional metadata"""
@ -577,16 +577,16 @@ class DocStatusStorage(BaseKVStorage, ABC):
"""Base class for document status storage"""
@abstractmethod
async def get_status_counts(self) -> dict[str, int]:
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
@abstractmethod
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
) -> Dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
async def drop_cache_by_modes(self, modes: Union[List[str],None] = None) -> bool:
"""Drop cache is not supported for Doc Status storage"""
return False

View File

@ -1,3 +1,5 @@
from typing import List
from typing import Dict
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
@ -45,7 +47,7 @@ STORAGE_IMPLEMENTATIONS = {
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
STORAGE_ENV_REQUIREMENTS: Dict[str, List[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],

View File

@ -1,11 +1,11 @@
import os
from dataclasses import dataclass
from typing import Any, final
from typing import Any, final,Dict,List,Union,Set
from lightrag.base import (
from ..base import (
BaseKVStorage,
)
from lightrag.utils import (
from ..utils import (
load_json,
logger,
write_json,
@ -85,7 +85,7 @@ class JsonKVStorage(BaseKVStorage):
write_json(data_dict, self._file_name)
await clear_all_update_flags(self.namespace)
async def get_all(self) -> dict[str, Any]:
async def get_all(self) -> Dict[str, Any]:
"""Get all data from storage
Returns:
@ -94,11 +94,11 @@ class JsonKVStorage(BaseKVStorage):
async with self._storage_lock:
return dict(self._data)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async def get_by_id(self, id: str) -> Union[Dict[str, Any] , None]:
async with self._storage_lock:
return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async def get_by_ids(self, ids: List[str]) -> List[Dict[str, Any]]:
async with self._storage_lock:
return [
(
@ -109,11 +109,11 @@ class JsonKVStorage(BaseKVStorage):
for id in ids
]
async def filter_keys(self, keys: set[str]) -> set[str]:
async def filter_keys(self, keys: Set[str]) -> Set[str]:
async with self._storage_lock:
return set(keys) - set(self._data.keys())
return Set(keys) - Set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
async def upsert(self, data: Dict[str, Dict[str, Any]]) -> None:
"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
@ -126,7 +126,7 @@ class JsonKVStorage(BaseKVStorage):
self._data.update(data)
await set_all_update_flags(self.namespace)
async def delete(self, ids: list[str]) -> None:
async def delete(self, ids: List[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
@ -149,7 +149,7 @@ class JsonKVStorage(BaseKVStorage):
if any_deleted:
await set_all_update_flags(self.namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
async def drop_cache_by_modes(self, modes: Union[List[str] , None] = None) -> bool:
"""Delete specific records from storage by by cache mode
Importance notes for in-memory storage:
@ -172,7 +172,7 @@ class JsonKVStorage(BaseKVStorage):
except Exception:
return False
async def drop(self) -> dict[str, str]:
async def drop(self) -> Dict[str, str]:
"""Drop all data from storage and clean up resources
This action will persistent the data to disk immediately.

View File

@ -5,15 +5,15 @@ from dataclasses import dataclass
import numpy as np
import time
from lightrag.utils import (
from ..utils import (
logger,
compute_mdhash_id,
)
import pipmaster as pm
from lightrag.base import BaseVectorStorage
from ..base import BaseVectorStorage
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
# if not pm.is_installed("nano-vectordb"):
# pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
from .shared_storage import (

View File

@ -1,10 +1,10 @@
import os
from dataclasses import dataclass
from typing import final
from typing import final,Dict,List,Union,Tuple
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..utils import logger
from ..base import BaseGraphStorage
import pipmaster as pm
@ -98,7 +98,7 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None:
async def get_node(self, node_id: str) -> Union[Dict[str, str],None]:
graph = await self._get_graph()
return graph.nodes.get(node_id)
@ -112,17 +112,17 @@ class NetworkXStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
) -> Union[Dict[str, str],None]:
graph = await self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
async def get_node_edges(self, source_node_id: str) -> Union[List[Tuple[str, str]],None]:
graph = await self._get_graph()
if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
async def upsert_node(self, node_id: str, node_data: Dict[str, str]) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
@ -133,7 +133,7 @@ class NetworkXStorage(BaseGraphStorage):
graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, str]
) -> None:
"""
Importance notes:
@ -158,7 +158,7 @@ class NetworkXStorage(BaseGraphStorage):
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def remove_nodes(self, nodes: list[str]):
async def remove_nodes(self, nodes: List[str]):
"""Delete multiple nodes
Importance notes:
@ -174,7 +174,7 @@ class NetworkXStorage(BaseGraphStorage):
if graph.has_node(node):
graph.remove_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
async def remove_edges(self, edges: List[Tuple[str, str]]):
"""Delete multiple edges
Importance notes:
@ -190,7 +190,7 @@ class NetworkXStorage(BaseGraphStorage):
if graph.has_edge(source, target):
graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
async def get_all_labels(self) -> List[str]:
"""
Get all node labels in the graph
Returns:
@ -389,7 +389,7 @@ class NetworkXStorage(BaseGraphStorage):
return True
async def drop(self) -> dict[str, str]:
async def drop(self) -> Dict[str, str]:
"""Drop all graph data from storage and clean up resources
This method will:

View File

@ -1,5 +1,4 @@
from __future__ import annotations
import traceback
import asyncio
import configparser
@ -21,18 +20,18 @@ from typing import (
List,
Dict,
)
from lightrag.constants import (
from .constants import (
DEFAULT_MAX_TOKEN_SUMMARY,
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
)
from lightrag.utils import get_env_value
from .utils import get_env_value
from lightrag.kg import (
from .kg import (
STORAGES,
verify_storage_implementation,
)
from lightrag.kg.shared_storage import (
from .kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
)
@ -199,7 +198,7 @@ class LightRAG:
)
"""Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field(
embedding_cache_config: Dict[str, Any] = field(
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
@ -283,7 +282,7 @@ class LightRAG:
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self):
from lightrag.kg.shared_storage import (
from .kg.shared_storage import (
initialize_share_data,
)

View File

@ -1,18 +1,18 @@
from __future__ import annotations
from pydantic import BaseModel
from typing import Any, Optional
from typing import Any, Optional,List,Dict
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: list[str]
low_level_keywords: list[str]
high_level_keywords: List[str]
low_level_keywords: List[str]
class KnowledgeGraphNode(BaseModel):
id: str
labels: list[str]
properties: dict[str, Any] # anything else goes here
labels: List[str]
properties: Dict[str, Any] # anything else goes here
class KnowledgeGraphEdge(BaseModel):
@ -20,10 +20,10 @@ class KnowledgeGraphEdge(BaseModel):
type: Optional[str]
source: str # id of source node
target: str # id of target node
properties: dict[str, Any] # anything else goes here
properties: Dict[str, Any] # anything else goes here
class KnowledgeGraph(BaseModel):
nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = []
nodes: List[KnowledgeGraphNode] = []
edges: List[KnowledgeGraphEdge] = []
is_truncated: bool = False

View File

@ -14,9 +14,9 @@ from functools import wraps
from hashlib import md5
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import numpy as np
from lightrag.prompt import PROMPTS
from .prompt import PROMPTS
from dotenv import load_dotenv
from lightrag.constants import (
from .constants import (
DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME,
@ -1696,7 +1696,7 @@ def check_storage_env_vars(storage_name: str) -> None:
Raises:
ValueError: If required environment variables are missing
"""
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
from .kg import STORAGE_ENV_REQUIREMENTS
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]

View File

@ -20,7 +20,7 @@ class autoGenerator():
async def query(self):
self.rag_client = MassageAcupointRAG(
working_dir="C:/Users/ZIWEI/Documents/work/向量化/CoreRAG/Massage_10216"
working_dir="CoreRAG/Massage_10216"
)
await self.rag_client.initialize()

View File

@ -1,6 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
from typing import List
# ======================= 参数配置 ======================= #
GRID_SIZE = (40, 130) # 行x列
TOTAL_STEPS = 50
@ -166,14 +166,14 @@ class Agent:
# ======================= 可视化模块 ======================= #
class Visualizer:
def __init__(self, grid_size, full_path_ref: list[tuple]):
def __init__(self, grid_size, full_path_ref: List[tuple]):
self.grid_size = grid_size
self.full_path_ref = full_path_ref
plt.ion()
self.fig, self.ax = plt.subplots(figsize=(4, 10)) # 调整大小
self.im = self.ax.imshow(np.zeros(grid_size).T, origin='lower', cmap='viridis')
self.colorbar = plt.colorbar(self.im, ax=self.ax, label='Height')
self._artists: list = []
self._artists: List = []
# 保持横纵比一致,避免变形
self.ax.set_aspect('equal')

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,8 +1,11 @@
try:from scripts.APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
except:from APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
import numpy as np
try:from scripts.sorter import sorter
except:from sorter import sorter
import matplotlib.pyplot as plt
import json
import time

View File

@ -1,6 +1,7 @@
import re
import json
import numpy as np
from typing import List
class sorter:
''' 重点穴位按摩排列器 '''
@ -18,7 +19,7 @@ class sorter:
self.massage_side = massage_side # 默认为双边'both'
else:
raise ValueError("未指定按摩在左侧、右侧或两侧")
def _extract_acupoints(self,respnse_from_llm:str)->list[str]:
def _extract_acupoints(self,respnse_from_llm:str)->List[str]:
pattern = r"[0-9]+\.\s*([\u4e00-\u9fa5]{2,5}穴)"
matches = re.findall(pattern, respnse_from_llm)
# 去重 & 排除空值
@ -26,7 +27,7 @@ class sorter:
print(unique_names)
return unique_names
def sort_acupoints(self,respnse_from_llm:list[str])->list[str]:
def sort_acupoints(self,respnse_from_llm:List[str])->List[str]:
if self.body_part == 'back':
allowed_names = []
if self.body_part == 'shoulder':
@ -38,7 +39,7 @@ class sorter:
"志室右","肓门右","胃仓右","意舍右","阳纲右","胞肓右","气海右俞","大肠右俞","小肠右俞",
"中膂右俞","肾俞右","关元右俞","膀胱右俞","白环右俞","秩边右","京门右"]
def __filter_acupoints(acupoints:list[str],allowed_names: list[str])->list[str]:
def __filter_acupoints(acupoints:List[str],allowed_names: List[str])->List[str]:
acupoints_cleaned = []
matched_keys = []
for acupoint in acupoints: