2025-05-27 15:46:31 +08:00

469 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import sys
import time
import re
from pathlib import Path
import os
import traceback
sys.path.append("./")
from tools.yaml_operator import read_yaml
from tools.log import CustomLogger
from types import SimpleNamespace
class DifyClient:
"""
A client for interacting with the Dify API, with support for streaming responses.
"""
def __init__(self, base_url, api_key):
"""
Initialize the Dify client with API keys.
Args:
api_keys (dict): Dictionary containing API keys
"""
# 初始化日志记录器
self.logger = CustomLogger()
# # 如果没有提供API密钥尝试从配置文件加载
# if api_keys is None:
# try:
# config_path = Path(__file__).resolve().parent.parent / 'config/api_key.yaml'
# api_keys = read_yaml(config_path)
# self.logger.log_info("从配置文件加载API密钥")
# except Exception as e:
# self.logger.log_error(f"无法加载API密钥: {e}")
# raise
# 设置API密钥和基础URL
# self.api_keys = api_keys
# self.dify_api_key = self.api_keys.get("dify_api_key", "")
# self.dify_api_base_url = self.api_keys.get("dify_api_base_url", "https://api.dify.ai/v1")
# self.dify_api_key = 'app-KFatoFKPMwjIt2za2paXvVA7'
# self.dify_api_key = 'app-9vrC0QkaVFbj1g2rKiwkyzv3'
# self.dify_api_base_url = 'http://124.71.62.243/v1'
self.dify_api_key = api_key
self.dify_api_base_url = base_url
# 验证API密钥是否存在
if not self.dify_api_key:
self.logger.log_error("Dify API密钥未提供")
raise ValueError("Dify API密钥未提供")
# 设置API请求头
self.headers = {
'Authorization': f'Bearer {self.dify_api_key}',
'Content-Type': 'application/json',
'Accept': 'text/event-stream' # 明确指定接受SSE格式
}
# 初始化会话ID
self.conversation_id = None
# 初始化消息历史
self.chat_message = []
# 设置请求超时时间(秒)
self.timeout = 30
# 用于解析推理内容的正则表达式
self.details_start_pattern = r'<details.*?>\s*<summary>\s*Thinking\.\.\.\s*</summary>'
self.details_end_pattern = r'</details>'
def chat_completion(self, query, user_id="user123", inputs=None, conversation_id=None, stream=True, callback=None):
"""
发送聊天请求到Dify API并获取回复
Args:
query (str): 用户的查询文本
user_id (str): 用户ID
inputs (dict): 输入参数
conversation_id (str): 会话ID如果为None则创建新会话
stream (bool): 是否使用流式响应
Returns:
如果stream=False返回完整响应
如果stream=True返回响应生成器
"""
if inputs is None:
inputs = {}
# 准备请求数据
data = {
"inputs": inputs,
"query": query,
"response_mode": "streaming" if stream else "blocking",
"user": user_id
}
# 如果有会话ID添加到请求中
if conversation_id:
data["conversation_id"] = conversation_id
elif self.conversation_id:
data["conversation_id"] = self.conversation_id
url = f"{self.dify_api_base_url}/chat-messages"
self.logger.log_info(f"发送请求到: {url}")
self.logger.log_info(f"请求数据: {data}")
try:
# 发送请求
if stream:
# 流式响应处理
response = requests.post(
url,
headers=self.headers,
json=data, # 使用json参数而不是data+json.dumps
stream=True,
timeout=self.timeout
)
# 检查响应状态
if response.status_code != 200:
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {response.text}"
self.logger.log_error(error_msg)
return None
self.logger.log_info(f"成功获取流式响应,状态码: {response.status_code}")
# 如果是新会话从响应头中获取会话ID
if not conversation_id and not self.conversation_id:
try:
# 尝试从响应头中获取会话ID
if 'X-Conversation-Id' in response.headers:
self.conversation_id = response.headers['X-Conversation-Id']
self.logger.log_info(f"从响应头获取新会话ID: {self.conversation_id}")
except Exception as e:
self.logger.log_error(f"从响应头获取会话ID失败: {str(e)}")
# 返回响应生成器
return self._stream_response(response, callback)
else:
# 非流式响应处理
response = requests.post(
url,
headers=self.headers,
json=data, # 使用json参数而不是data+json.dumps
timeout=self.timeout
)
# 检查响应状态
if response.status_code != 200:
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {response.text}"
self.logger.log_error(error_msg)
return None
# 解析响应
try:
response_data = response.json()
# 如果是新会话保存会话ID
if not conversation_id and not self.conversation_id and 'conversation_id' in response_data:
self.conversation_id = response_data['conversation_id']
self.logger.log_info(f"新会话ID: {self.conversation_id}")
return response_data
except json.JSONDecodeError as e:
self.logger.log_error(f"解析响应JSON失败: {str(e)}, 响应内容: {response.text[:200]}")
return None
except requests.exceptions.Timeout:
self.logger.log_error(f"请求超时,超时设置: {self.timeout}")
return None
except requests.exceptions.ConnectionError as e:
self.logger.log_error(f"连接错误: {str(e)}")
return None
except Exception as e:
self.logger.log_error(f"请求处理过程中出错: {str(e)}")
self.logger.log_error(traceback.format_exc())
return None
def _stream_response(self, response, callback):
"""
处理SSE格式的流式响应并转换为与OpenAI兼容的格式
Args:
response: 请求响应对象
Yields:
转换后的OpenAI格式响应
"""
try:
self.logger.log_info("开始处理SSE流式响应并转换为OpenAI格式")
# 用于存储完整内容
full_text = ""
# 标记是否在推理部分内
in_reasoning = False
# 逐行读取响应
for line in response.iter_lines(decode_unicode=True):
if not line:
continue
# 处理SSE格式数据
if line.startswith('data: '):
# 提取data部分
data = line[6:] # 去掉 'data: ' 前缀
# 跳过ping事件
if data == 'event: ping':
continue
try:
# 解析JSON数据
event_data = json.loads(data)
self.logger.log_error(event_data)
# 记录事件类型
event_type = event_data.get('event')
# event_title = event_data.get('title')
if event_type != 'message':
self.logger.log_info(f"收到事件: {event_type}")
# self.logger.log_info(f"title: {event_title}")
try:
event_title = event_data.get("data", {}).get("title", None)
self.logger.log_blue(f"事件title: {event_title}")
callback(event_type,event_title)
except Exception as e:
self.logger.log_error(f"dify_callback有问题{e}")
# 如果没有会话ID且响应中包含会话ID则保存
if not self.conversation_id and 'conversation_id' in event_data:
self.conversation_id = event_data['conversation_id']
self.logger.log_info(f"从流式响应获取新会话ID: {self.conversation_id}")
# 处理消息事件
if event_type == 'message':
# 提取回答内容
answer = event_data.get('answer', '')
if answer:
# 更新完整文本
full_text += answer
# 检查是否包含<details>标签的开始
if re.search(self.details_start_pattern, answer):
in_reasoning = True
# 提取<details>标签内的内容作为推理内容
reasoning_content = re.sub(self.details_start_pattern, '', answer)
# 创建推理内容响应
delta = SimpleNamespace()
delta.reasoning_content = reasoning_content
delta.content = None
choice = SimpleNamespace()
choice.delta = delta
openai_response = SimpleNamespace()
openai_response.choices = [choice]
yield openai_response
continue
# 检查是否包含</details>标签
if in_reasoning and '</details>' in answer:
in_reasoning = False
# 提取</details>之前的内容作为推理内容的最后部分
parts = answer.split('</details>', 1)
if parts[0]:
# 创建推理内容响应
delta = SimpleNamespace()
delta.reasoning_content = parts[0]
delta.content = None
choice = SimpleNamespace()
choice.delta = delta
openai_response = SimpleNamespace()
openai_response.choices = [choice]
yield openai_response
if len(parts) > 1 and parts[1]:
# 创建正常内容响应
delta = SimpleNamespace()
delta.reasoning_content = None
delta.content = parts[1]
choice = SimpleNamespace()
choice.delta = delta
openai_response = SimpleNamespace()
openai_response.choices = [choice]
yield openai_response
continue
# 如果在推理部分内
if in_reasoning:
# 创建推理内容响应
delta = SimpleNamespace()
delta.reasoning_content = answer
delta.content = None
choice = SimpleNamespace()
choice.delta = delta
openai_response = SimpleNamespace()
openai_response.choices = [choice]
yield openai_response
else:
# 创建正常内容响应
delta = SimpleNamespace()
delta.reasoning_content = None
delta.content = answer
choice = SimpleNamespace()
choice.delta = delta
openai_response = SimpleNamespace()
openai_response.choices = [choice]
# print(openai_response)
yield openai_response
except json.JSONDecodeError as e:
self.logger.log_error(f"解析JSON失败: {str(e)}, 数据: {data[:200]}")
# 继续处理下一行,不中断流
except requests.exceptions.ChunkedEncodingError as e:
self.logger.log_error(f"分块编码错误: {str(e)}")
except requests.exceptions.RequestException as e:
self.logger.log_error(f"请求异常: {str(e)}")
except Exception as e:
self.logger.log_error(f"处理流式响应时出错: {str(e)}")
self.logger.log_error(traceback.format_exc())
def reset_conversation(self):
"""
重置当前会话
"""
self.conversation_id = None
self.logger.log_info("会话已重置")
def get_conversation_id(self):
"""
获取当前会话ID
Returns:
str: 当前会话ID
"""
return self.conversation_id
def chat(self, human_input):
"""
处理用户输入并获取模型回复
Args:
human_input (str): 用户输入文本
Returns:
dict: 包含响应信息的字典
"""
try:
return_dict = {}
user_id = os.uname()[1]
self.logger.log_info(f"用户输入: {human_input}")
# 发送聊天请求
response_generator = self.chat_completion(user_id=user_id, query=human_input)
if response_generator:
return_dict.update({
'chat_message': human_input,
'response_generator': response_generator
})
else:
return '我没有理解您的意思,请重新提问。'
return return_dict
except Exception as e:
self.logger.log_error(f"聊天处理过程中出错: {str(e)}")
self.logger.log_error(traceback.format_exc())
return '处理您的请求时出现错误,请稍后再试。'
if __name__ == '__main__':
# 从配置文件加载API密钥
# config_path = Path(__file__).resolve().parent.parent / 'config/api_key.yaml'
# api_keys = read_yaml(config_path)
# 创建Dify客户端
dify_client = DifyClient()
# 交互式聊天循环
print("Dify聊天机器人已启动输入'exit''quit''q'退出")
while True:
# 从终端输入获取用户输入
human_input = input("请输入问题: ")
if human_input.lower() in ['exit', 'quit', 'q']:
print("程序结束")
break
# 记录开始时间
time1 = time.time()
# 发送请求并获取响应
return_dict = dify_client.chat(human_input)
# 处理响应
if isinstance(return_dict, str):
print(return_dict)
else:
response_generator = return_dict.get('response_generator', '')
# 用于分段输出的正则表达式
punctuation_marks_regex = r'[。,]'
full_content = ''
full_reasoning = ''
last_index = 0
# 处理流式响应
try:
for response in response_generator:
# 获取推理内容和正常内容
if hasattr(response, 'choices') and response.choices:
reasoning_result = getattr(response.choices[0].delta, "reasoning_content", None)
content_result = getattr(response.choices[0].delta, "content", None)
# 处理推理内容
if reasoning_result:
full_reasoning += reasoning_result
print(f"[推理] {reasoning_result}", end="", flush=True)
# 处理正常内容
if content_result:
full_content += content_result
# 按标点符号分段输出
punctuation_indices = [m.start() for m in re.finditer(punctuation_marks_regex, full_content)]
for index in punctuation_indices:
if index > last_index:
accumulated_text = full_content[last_index:index+1]
last_index = index + 1
if accumulated_text.strip():
print(accumulated_text.strip())
# 输出剩余内容
if full_content and last_index < len(full_content):
remaining_text = full_content[last_index:]
if remaining_text.strip():
print(remaining_text.strip())
# 如果没有任何内容输出
if not full_content and not full_reasoning:
print("未收到有效回复请检查API连接或重试。")
except Exception as e:
print(f"处理响应时出错: {str(e)}")
print(traceback.format_exc())
# 记录结束时间并输出总耗时
time2 = time.time()
print(f"总耗时: {time2-time1:.2f}")