469 lines
20 KiB
Python
469 lines
20 KiB
Python
import requests
|
||
import json
|
||
import sys
|
||
import time
|
||
import re
|
||
from pathlib import Path
|
||
import os
|
||
import traceback
|
||
sys.path.append("./")
|
||
from tools.yaml_operator import read_yaml
|
||
from tools.log import CustomLogger
|
||
from types import SimpleNamespace
|
||
|
||
class DifyClient:
|
||
"""
|
||
A client for interacting with the Dify API, with support for streaming responses.
|
||
"""
|
||
def __init__(self, base_url, api_key):
|
||
"""
|
||
Initialize the Dify client with API keys.
|
||
|
||
Args:
|
||
api_keys (dict): Dictionary containing API keys
|
||
"""
|
||
# 初始化日志记录器
|
||
self.logger = CustomLogger()
|
||
|
||
# # 如果没有提供API密钥,尝试从配置文件加载
|
||
# if api_keys is None:
|
||
# try:
|
||
# config_path = Path(__file__).resolve().parent.parent / 'config/api_key.yaml'
|
||
# api_keys = read_yaml(config_path)
|
||
# self.logger.log_info("从配置文件加载API密钥")
|
||
# except Exception as e:
|
||
# self.logger.log_error(f"无法加载API密钥: {e}")
|
||
# raise
|
||
|
||
# 设置API密钥和基础URL
|
||
# self.api_keys = api_keys
|
||
# self.dify_api_key = self.api_keys.get("dify_api_key", "")
|
||
# self.dify_api_base_url = self.api_keys.get("dify_api_base_url", "https://api.dify.ai/v1")
|
||
# self.dify_api_key = 'app-KFatoFKPMwjIt2za2paXvVA7'
|
||
# self.dify_api_key = 'app-9vrC0QkaVFbj1g2rKiwkyzv3'
|
||
# self.dify_api_base_url = 'http://124.71.62.243/v1'
|
||
self.dify_api_key = api_key
|
||
self.dify_api_base_url = base_url
|
||
|
||
|
||
# 验证API密钥是否存在
|
||
if not self.dify_api_key:
|
||
self.logger.log_error("Dify API密钥未提供")
|
||
raise ValueError("Dify API密钥未提供")
|
||
|
||
# 设置API请求头
|
||
self.headers = {
|
||
'Authorization': f'Bearer {self.dify_api_key}',
|
||
'Content-Type': 'application/json',
|
||
'Accept': 'text/event-stream' # 明确指定接受SSE格式
|
||
}
|
||
|
||
# 初始化会话ID
|
||
self.conversation_id = None
|
||
|
||
# 初始化消息历史
|
||
self.chat_message = []
|
||
|
||
# 设置请求超时时间(秒)
|
||
self.timeout = 30
|
||
|
||
# 用于解析推理内容的正则表达式
|
||
self.details_start_pattern = r'<details.*?>\s*<summary>\s*Thinking\.\.\.\s*</summary>'
|
||
self.details_end_pattern = r'</details>'
|
||
|
||
def chat_completion(self, query, user_id="user123", inputs=None, conversation_id=None, stream=True, callback=None):
|
||
"""
|
||
发送聊天请求到Dify API并获取回复
|
||
|
||
Args:
|
||
query (str): 用户的查询文本
|
||
user_id (str): 用户ID
|
||
inputs (dict): 输入参数
|
||
conversation_id (str): 会话ID,如果为None则创建新会话
|
||
stream (bool): 是否使用流式响应
|
||
|
||
Returns:
|
||
如果stream=False,返回完整响应
|
||
如果stream=True,返回响应生成器
|
||
"""
|
||
if inputs is None:
|
||
inputs = {}
|
||
|
||
# 准备请求数据
|
||
data = {
|
||
"inputs": inputs,
|
||
"query": query,
|
||
"response_mode": "streaming" if stream else "blocking",
|
||
"user": user_id
|
||
}
|
||
|
||
# 如果有会话ID,添加到请求中
|
||
if conversation_id:
|
||
data["conversation_id"] = conversation_id
|
||
elif self.conversation_id:
|
||
data["conversation_id"] = self.conversation_id
|
||
|
||
url = f"{self.dify_api_base_url}/chat-messages"
|
||
|
||
self.logger.log_info(f"发送请求到: {url}")
|
||
self.logger.log_info(f"请求数据: {data}")
|
||
|
||
try:
|
||
# 发送请求
|
||
if stream:
|
||
# 流式响应处理
|
||
response = requests.post(
|
||
url,
|
||
headers=self.headers,
|
||
json=data, # 使用json参数而不是data+json.dumps
|
||
stream=True,
|
||
timeout=self.timeout
|
||
)
|
||
|
||
# 检查响应状态
|
||
if response.status_code != 200:
|
||
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {response.text}"
|
||
self.logger.log_error(error_msg)
|
||
return None
|
||
|
||
self.logger.log_info(f"成功获取流式响应,状态码: {response.status_code}")
|
||
|
||
# 如果是新会话,从响应头中获取会话ID
|
||
if not conversation_id and not self.conversation_id:
|
||
try:
|
||
# 尝试从响应头中获取会话ID
|
||
if 'X-Conversation-Id' in response.headers:
|
||
self.conversation_id = response.headers['X-Conversation-Id']
|
||
self.logger.log_info(f"从响应头获取新会话ID: {self.conversation_id}")
|
||
except Exception as e:
|
||
self.logger.log_error(f"从响应头获取会话ID失败: {str(e)}")
|
||
|
||
# 返回响应生成器
|
||
return self._stream_response(response, callback)
|
||
else:
|
||
# 非流式响应处理
|
||
response = requests.post(
|
||
url,
|
||
headers=self.headers,
|
||
json=data, # 使用json参数而不是data+json.dumps
|
||
timeout=self.timeout
|
||
)
|
||
|
||
# 检查响应状态
|
||
if response.status_code != 200:
|
||
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {response.text}"
|
||
self.logger.log_error(error_msg)
|
||
return None
|
||
|
||
# 解析响应
|
||
try:
|
||
response_data = response.json()
|
||
|
||
# 如果是新会话,保存会话ID
|
||
if not conversation_id and not self.conversation_id and 'conversation_id' in response_data:
|
||
self.conversation_id = response_data['conversation_id']
|
||
self.logger.log_info(f"新会话ID: {self.conversation_id}")
|
||
|
||
return response_data
|
||
except json.JSONDecodeError as e:
|
||
self.logger.log_error(f"解析响应JSON失败: {str(e)}, 响应内容: {response.text[:200]}")
|
||
return None
|
||
|
||
except requests.exceptions.Timeout:
|
||
self.logger.log_error(f"请求超时,超时设置: {self.timeout}秒")
|
||
return None
|
||
except requests.exceptions.ConnectionError as e:
|
||
self.logger.log_error(f"连接错误: {str(e)}")
|
||
return None
|
||
except Exception as e:
|
||
self.logger.log_error(f"请求处理过程中出错: {str(e)}")
|
||
self.logger.log_error(traceback.format_exc())
|
||
return None
|
||
|
||
def _stream_response(self, response, callback):
|
||
"""
|
||
处理SSE格式的流式响应,并转换为与OpenAI兼容的格式
|
||
|
||
Args:
|
||
response: 请求响应对象
|
||
|
||
Yields:
|
||
转换后的OpenAI格式响应
|
||
"""
|
||
try:
|
||
self.logger.log_info("开始处理SSE流式响应并转换为OpenAI格式")
|
||
|
||
# 用于存储完整内容
|
||
full_text = ""
|
||
|
||
# 标记是否在推理部分内
|
||
in_reasoning = False
|
||
|
||
# 逐行读取响应
|
||
for line in response.iter_lines(decode_unicode=True):
|
||
if not line:
|
||
continue
|
||
|
||
# 处理SSE格式数据
|
||
if line.startswith('data: '):
|
||
# 提取data部分
|
||
data = line[6:] # 去掉 'data: ' 前缀
|
||
|
||
# 跳过ping事件
|
||
if data == 'event: ping':
|
||
continue
|
||
|
||
try:
|
||
# 解析JSON数据
|
||
event_data = json.loads(data)
|
||
self.logger.log_error(event_data)
|
||
|
||
# 记录事件类型
|
||
event_type = event_data.get('event')
|
||
# event_title = event_data.get('title')
|
||
|
||
|
||
if event_type != 'message':
|
||
self.logger.log_info(f"收到事件: {event_type}")
|
||
# self.logger.log_info(f"title: {event_title}")
|
||
try:
|
||
event_title = event_data.get("data", {}).get("title", None)
|
||
self.logger.log_blue(f"事件title: {event_title}")
|
||
callback(event_type,event_title)
|
||
except Exception as e:
|
||
self.logger.log_error(f"dify_callback有问题{e}")
|
||
|
||
# 如果没有会话ID且响应中包含会话ID,则保存
|
||
if not self.conversation_id and 'conversation_id' in event_data:
|
||
self.conversation_id = event_data['conversation_id']
|
||
self.logger.log_info(f"从流式响应获取新会话ID: {self.conversation_id}")
|
||
|
||
# 处理消息事件
|
||
if event_type == 'message':
|
||
# 提取回答内容
|
||
answer = event_data.get('answer', '')
|
||
|
||
if answer:
|
||
# 更新完整文本
|
||
full_text += answer
|
||
|
||
# 检查是否包含<details>标签的开始
|
||
if re.search(self.details_start_pattern, answer):
|
||
in_reasoning = True
|
||
# 提取<details>标签内的内容作为推理内容
|
||
reasoning_content = re.sub(self.details_start_pattern, '', answer)
|
||
|
||
# 创建推理内容响应
|
||
delta = SimpleNamespace()
|
||
delta.reasoning_content = reasoning_content
|
||
delta.content = None
|
||
|
||
choice = SimpleNamespace()
|
||
choice.delta = delta
|
||
|
||
openai_response = SimpleNamespace()
|
||
openai_response.choices = [choice]
|
||
|
||
yield openai_response
|
||
continue
|
||
|
||
# 检查是否包含</details>标签
|
||
if in_reasoning and '</details>' in answer:
|
||
in_reasoning = False
|
||
# 提取</details>之前的内容作为推理内容的最后部分
|
||
parts = answer.split('</details>', 1)
|
||
|
||
if parts[0]:
|
||
# 创建推理内容响应
|
||
delta = SimpleNamespace()
|
||
delta.reasoning_content = parts[0]
|
||
delta.content = None
|
||
|
||
choice = SimpleNamespace()
|
||
choice.delta = delta
|
||
|
||
openai_response = SimpleNamespace()
|
||
openai_response.choices = [choice]
|
||
|
||
yield openai_response
|
||
|
||
if len(parts) > 1 and parts[1]:
|
||
# 创建正常内容响应
|
||
delta = SimpleNamespace()
|
||
delta.reasoning_content = None
|
||
delta.content = parts[1]
|
||
|
||
choice = SimpleNamespace()
|
||
choice.delta = delta
|
||
|
||
openai_response = SimpleNamespace()
|
||
openai_response.choices = [choice]
|
||
|
||
yield openai_response
|
||
|
||
continue
|
||
|
||
# 如果在推理部分内
|
||
if in_reasoning:
|
||
# 创建推理内容响应
|
||
delta = SimpleNamespace()
|
||
delta.reasoning_content = answer
|
||
delta.content = None
|
||
|
||
choice = SimpleNamespace()
|
||
choice.delta = delta
|
||
|
||
openai_response = SimpleNamespace()
|
||
openai_response.choices = [choice]
|
||
|
||
yield openai_response
|
||
else:
|
||
# 创建正常内容响应
|
||
delta = SimpleNamespace()
|
||
delta.reasoning_content = None
|
||
delta.content = answer
|
||
|
||
choice = SimpleNamespace()
|
||
choice.delta = delta
|
||
|
||
openai_response = SimpleNamespace()
|
||
openai_response.choices = [choice]
|
||
# print(openai_response)
|
||
yield openai_response
|
||
except json.JSONDecodeError as e:
|
||
self.logger.log_error(f"解析JSON失败: {str(e)}, 数据: {data[:200]}")
|
||
# 继续处理下一行,不中断流
|
||
except requests.exceptions.ChunkedEncodingError as e:
|
||
self.logger.log_error(f"分块编码错误: {str(e)}")
|
||
except requests.exceptions.RequestException as e:
|
||
self.logger.log_error(f"请求异常: {str(e)}")
|
||
except Exception as e:
|
||
self.logger.log_error(f"处理流式响应时出错: {str(e)}")
|
||
self.logger.log_error(traceback.format_exc())
|
||
|
||
def reset_conversation(self):
|
||
"""
|
||
重置当前会话
|
||
"""
|
||
self.conversation_id = None
|
||
self.logger.log_info("会话已重置")
|
||
|
||
def get_conversation_id(self):
|
||
"""
|
||
获取当前会话ID
|
||
|
||
Returns:
|
||
str: 当前会话ID
|
||
"""
|
||
return self.conversation_id
|
||
|
||
def chat(self, human_input):
|
||
"""
|
||
处理用户输入并获取模型回复
|
||
|
||
Args:
|
||
human_input (str): 用户输入文本
|
||
|
||
Returns:
|
||
dict: 包含响应信息的字典
|
||
"""
|
||
try:
|
||
return_dict = {}
|
||
user_id = os.uname()[1]
|
||
self.logger.log_info(f"用户输入: {human_input}")
|
||
|
||
# 发送聊天请求
|
||
response_generator = self.chat_completion(user_id=user_id, query=human_input)
|
||
|
||
if response_generator:
|
||
return_dict.update({
|
||
'chat_message': human_input,
|
||
'response_generator': response_generator
|
||
})
|
||
else:
|
||
return '我没有理解您的意思,请重新提问。'
|
||
|
||
return return_dict
|
||
except Exception as e:
|
||
self.logger.log_error(f"聊天处理过程中出错: {str(e)}")
|
||
self.logger.log_error(traceback.format_exc())
|
||
return '处理您的请求时出现错误,请稍后再试。'
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 从配置文件加载API密钥
|
||
# config_path = Path(__file__).resolve().parent.parent / 'config/api_key.yaml'
|
||
# api_keys = read_yaml(config_path)
|
||
|
||
# 创建Dify客户端
|
||
dify_client = DifyClient()
|
||
|
||
# 交互式聊天循环
|
||
print("Dify聊天机器人已启动,输入'exit'、'quit'或'q'退出")
|
||
while True:
|
||
# 从终端输入获取用户输入
|
||
human_input = input("请输入问题: ")
|
||
|
||
if human_input.lower() in ['exit', 'quit', 'q']:
|
||
print("程序结束")
|
||
break
|
||
|
||
# 记录开始时间
|
||
time1 = time.time()
|
||
|
||
# 发送请求并获取响应
|
||
return_dict = dify_client.chat(human_input)
|
||
|
||
# 处理响应
|
||
if isinstance(return_dict, str):
|
||
print(return_dict)
|
||
else:
|
||
response_generator = return_dict.get('response_generator', '')
|
||
|
||
# 用于分段输出的正则表达式
|
||
punctuation_marks_regex = r'[。,,;!?]'
|
||
full_content = ''
|
||
full_reasoning = ''
|
||
last_index = 0
|
||
|
||
# 处理流式响应
|
||
try:
|
||
for response in response_generator:
|
||
# 获取推理内容和正常内容
|
||
if hasattr(response, 'choices') and response.choices:
|
||
reasoning_result = getattr(response.choices[0].delta, "reasoning_content", None)
|
||
content_result = getattr(response.choices[0].delta, "content", None)
|
||
|
||
# 处理推理内容
|
||
if reasoning_result:
|
||
full_reasoning += reasoning_result
|
||
print(f"[推理] {reasoning_result}", end="", flush=True)
|
||
|
||
# 处理正常内容
|
||
if content_result:
|
||
full_content += content_result
|
||
# 按标点符号分段输出
|
||
punctuation_indices = [m.start() for m in re.finditer(punctuation_marks_regex, full_content)]
|
||
for index in punctuation_indices:
|
||
if index > last_index:
|
||
accumulated_text = full_content[last_index:index+1]
|
||
last_index = index + 1
|
||
if accumulated_text.strip():
|
||
print(accumulated_text.strip())
|
||
|
||
# 输出剩余内容
|
||
if full_content and last_index < len(full_content):
|
||
remaining_text = full_content[last_index:]
|
||
if remaining_text.strip():
|
||
print(remaining_text.strip())
|
||
|
||
# 如果没有任何内容输出
|
||
if not full_content and not full_reasoning:
|
||
print("未收到有效回复,请检查API连接或重试。")
|
||
except Exception as e:
|
||
print(f"处理响应时出错: {str(e)}")
|
||
print(traceback.format_exc())
|
||
|
||
# 记录结束时间并输出总耗时
|
||
time2 = time.time()
|
||
print(f"总耗时: {time2-time1:.2f}秒") |