blob: 53937b1fe440498126a565009c4f747fb626dbb7 [file] [log] [blame] [raw]
"""
流式消息发送器
实现流式消息发送和编辑功能
"""
import asyncio
import logging
import time
from typing import Union, Optional, Callable, AsyncGenerator
from telegram import Message
from .interfaces import ITelegramClient, IStreamMessageSender
logger = logging.getLogger(__name__)
class StreamMessageSender(IStreamMessageSender):
"""流式消息发送器实现"""
def __init__(
self,
telegram_client: ITelegramClient,
update_interval: float = 1.0,
webhook_broadcast_callback: Optional[Callable] = None
):
"""
初始化流式消息发送器
Args:
telegram_client: Telegram客户端
update_interval: 消息更新间隔(秒)
webhook_broadcast_callback: Webhook广播回调函数
"""
self.telegram_client = telegram_client
self.update_interval = update_interval
self.webhook_broadcast_callback = webhook_broadcast_callback
async def _send_message_with_markdown(self, chat_id, text, reply_to_message_id=None):
"""发送支持Markdown格式的消息"""
return await self.telegram_client.send_message(
chat_id=chat_id,
text=text,
parse_mode='Markdown',
reply_to_message_id=reply_to_message_id
)
async def _edit_message_with_markdown(self, text, chat_id, message_id):
"""编辑支持Markdown格式的消息"""
return await self.telegram_client.edit_message_text(
text=text,
chat_id=chat_id,
message_id=message_id,
parse_mode='Markdown'
)
async def send_streaming_message(
self,
chat_id: Union[int, str],
message_generator: Callable,
initial_text: str = "⌨️ User is typing...",
reply_to_message_id: Optional[int] = None
) -> Message:
"""
发送流式更新的消息
Args:
chat_id: 聊天ID
message_generator: 消息生成器函数
initial_text: 初始消息文本
reply_to_message_id: 回复的消息ID
Returns:
最终的消息对象
"""
logger.info(f"🔍 [DEBUG] send_streaming_message 开始,聊天ID: {chat_id}")
# 发送初始消息
logger.info(f"🔍 [DEBUG] 发送初始消息: '{initial_text}'")
message = await self._send_message_with_markdown(
chat_id=chat_id,
text=initial_text,
reply_to_message_id=reply_to_message_id
)
logger.info(f"🔍 [DEBUG] 初始消息发送成功,消息ID: {message.message_id}")
try:
logger.info(f"🔍 [DEBUG] 检查message_generator类型: {type(message_generator)}")
# 启动流式更新
final_message = None
if hasattr(message_generator, '__call__'):
logger.info(f"🔍 [DEBUG] message_generator是可调用对象,调用获取生成器...")
# 调用生成器函数获取实际的异步生成器
generator = message_generator()
logger.info(f"🔍 [DEBUG] 调用结果类型: {type(generator)}")
if hasattr(generator, '__aiter__'):
logger.info(f"🔍 [DEBUG] 检测到异步生成器,调用_handle_async_generator_instance")
# 这是一个异步生成器
final_message = await self._handle_async_generator_instance(chat_id, message.message_id, generator)
else:
logger.info(f"🔍 [DEBUG] 非异步生成器,调用_handle_sync_result")
# 这是一个普通返回值或同步生成器
await self._handle_sync_result(chat_id, message.message_id, generator)
else:
logger.info(f"🔍 [DEBUG] message_generator不可调用,直接处理")
# 直接传入的值
await self._handle_sync_result(chat_id, message.message_id, message_generator)
logger.info(f"🔍 [DEBUG] 流式消息处理完成")
except Exception as e:
logger.error(f"💥 [DEBUG] 流式消息发送错误: {type(e).__name__}: {e}")
import traceback
logger.error(f"💥 [DEBUG] 异常堆栈:\n{traceback.format_exc()}")
error_text = f"❌ 消息处理出错: {str(e)}"
await self._edit_message_with_markdown(
text=error_text,
chat_id=chat_id,
message_id=message.message_id
)
# 返回最终的消息对象(如果有)或初始消息对象
result_message = final_message or message
logger.info(f"🔍 [DEBUG] send_streaming_message 返回消息ID: {result_message.message_id}")
return result_message
async def _handle_async_generator_instance(
self,
chat_id: Union[int, str],
message_id: int,
generator
):
"""处理异步生成器实例"""
logger.info(f"🔍 [DEBUG] _handle_async_generator_instance 开始")
accumulated_text = ""
last_update_time = 0
chunk_count = 0
final_message = None
logger.info(f"🔍 [DEBUG] 开始迭代异步生成器...")
async for chunk in generator:
chunk_count += 1
logger.info(f"🔍 [DEBUG] 收到异步chunk #{chunk_count}: {len(str(chunk))} 字符")
accumulated_text += str(chunk)
current_time = time.time()
# 按时间间隔更新消息
if current_time - last_update_time >= self.update_interval:
logger.info(f"🔍 [DEBUG] 更新消息 (间隔: {current_time - last_update_time:.2f}s)")
try:
formatted_text = self._format_message_text(accumulated_text)
logger.info(f"🔍 [DEBUG] 格式化中间消息,长度: {len(formatted_text)}")
await self._edit_message_with_markdown(
text=formatted_text,
chat_id=chat_id,
message_id=message_id
)
last_update_time = current_time
logger.info(f"🔍 [DEBUG] 消息更新成功")
except Exception as e:
logger.warning(f"⚠️ [DEBUG] 消息更新失败: {type(e).__name__}: {e}")
# 如果更新失败,记录详细错误但继续处理
import traceback
logger.warning(f"⚠️ [DEBUG] 消息更新失败详情:\n{traceback.format_exc()}")
logger.info(f"🔍 [DEBUG] 异步生成器迭代完成,总chunks: {chunk_count}")
# 发送最终消息
if accumulated_text:
logger.info(f"🔍 [DEBUG] 发送最终消息,总长度: {len(accumulated_text)}")
try:
final_text = self._format_message_text(accumulated_text, final=True)
logger.info(f"🔍 [DEBUG] 最终消息格式化完成,长度: {len(final_text)}")
final_message = await self._edit_message_with_markdown(
text=final_text,
chat_id=chat_id,
message_id=message_id
)
logger.info(f"🔍 [DEBUG] 最终消息发送成功")
except Exception as e:
logger.error(f"💥 [DEBUG] 最终消息发送失败: {type(e).__name__}: {e}")
# 尝试发送一个简化的最终消息
try:
simple_text = accumulated_text.strip()
if len(simple_text) > 4000:
simple_text = simple_text[:4000] + "..."
final_message = await self._edit_message_with_markdown(
text=simple_text,
chat_id=chat_id,
message_id=message_id
)
logger.info(f"🔍 [DEBUG] 简化最终消息发送成功")
except Exception as e2:
logger.error(f"💥 [DEBUG] 简化最终消息也发送失败: {e2}")
else:
logger.warning(f"⚠️ [DEBUG] 累积文本为空,跳过最终消息更新")
return final_message
async def _handle_sync_result(
self,
chat_id: Union[int, str],
message_id: int,
result
):
"""处理同步结果"""
try:
text = str(result)
await self._edit_message_with_markdown(
text=self._format_message_text(text, final=True),
chat_id=chat_id,
message_id=message_id
)
except Exception as e:
logger.error(f"同步结果处理错误: {e}")
await self._edit_message_with_markdown(
text=f"❌ 处理错误: {str(e)}",
chat_id=chat_id,
message_id=message_id
)
async def _handle_async_generator(
self,
chat_id: Union[int, str],
message_id: int,
generator_func: Callable
):
"""处理异步生成器"""
accumulated_text = ""
last_update_time = 0
async for chunk in generator_func():
accumulated_text += str(chunk)
current_time = time.time()
# 按时间间隔更新消息
if current_time - last_update_time >= self.update_interval:
try:
await self._edit_message_with_markdown(
text=self._format_message_text(accumulated_text),
chat_id=chat_id,
message_id=message_id
)
last_update_time = current_time
except Exception as e:
logger.warning(f"消息更新失败: {e}")
# 发送最终消息
if accumulated_text:
await self._edit_message_with_markdown(
text=self._format_message_text(accumulated_text, final=True),
chat_id=chat_id,
message_id=message_id
)
async def _handle_sync_generator(
self,
chat_id: Union[int, str],
message_id: int,
generator_func: Callable
):
"""处理同步生成器或函数"""
accumulated_text = ""
last_update_time = 0
try:
# 如果是生成器函数
if hasattr(generator_func, '__call__'):
result = generator_func()
if hasattr(result, '__iter__') and not isinstance(result, (str, bytes)):
# 迭代器/生成器
for chunk in result:
accumulated_text += str(chunk)
current_time = time.time()
if current_time - last_update_time >= self.update_interval:
try:
await self._edit_message_with_markdown(
text=self._format_message_text(accumulated_text),
chat_id=chat_id,
message_id=message_id
)
last_update_time = current_time
except Exception as e:
logger.warning(f"消息更新失败: {e}")
# 让出控制权以避免阻塞
await asyncio.sleep(0.01)
else:
# 普通返回值
accumulated_text = str(result)
else:
accumulated_text = str(generator_func)
except Exception as e:
logger.error(f"生成器处理错误: {e}")
accumulated_text = f"❌ 处理错误: {str(e)}"
# 发送最终消息
if accumulated_text:
await self._edit_message_with_markdown(
text=self._format_message_text(accumulated_text, final=True),
chat_id=chat_id,
message_id=message_id
)
def _format_message_text(self, text: str, final: bool = False) -> str:
"""
格式化消息文本
Args:
text: 原始文本
final: 是否为最终消息
Returns:
格式化后的文本
"""
if not text.strip():
return "🤔 正在处理..."
# 清理Markdown格式以适配Telegram
cleaned_text = self._clean_markdown(text.strip())
# 限制消息长度
max_length = 4000 # 留一些空间给状态指示器
if len(cleaned_text) > max_length:
cleaned_text = cleaned_text[:max_length] + "..."
# 添加状态指示器
if final:
return cleaned_text
else:
return f"{cleaned_text}\n\n⏳ 继续生成中..."
def _clean_markdown(self, text: str) -> str:
"""
清理Markdown格式以适配Telegram,使用更安全的方法
Args:
text: 包含Markdown的文本
Returns:
清理后的文本,确保Markdown实体正确闭合
"""
import re
# 对于复杂的Markdown解析错误,采用更保守的策略:
# 完全移除可能有问题的Markdown格式,只保留最基本的格式
# 保护代码块(这些通常是安全的)
code_blocks = []
def preserve_code_block(match):
code_blocks.append(match.group(0))
return f"__CODE_BLOCK_{len(code_blocks)-1}__"
text = re.sub(r'```[\s\S]*?```', preserve_code_block, text)
# 保护行内代码(这些通常是安全的)
inline_codes = []
def preserve_inline_code(match):
inline_codes.append(match.group(0))
return f"__INLINE_CODE_{len(inline_codes)-1}__"
text = re.sub(r'`[^`]*`', preserve_inline_code, text)
# 移除所有可能有问题的Markdown格式字符
# 这是最安全的方法,避免Telegram解析错误
# 但要保护我们的占位符
problematic_chars = ['*', '[', ']', '(', ')']
for char in problematic_chars:
# 移除这些字符,但保持文本内容
text = text.replace(char, '')
# 单独处理下划线,避免破坏占位符
# 移除不在占位符中的下划线
import re
text = re.sub(r'(?<!_)_(?!(_|[A-Z]|[0-9]))', '', text)
# 恢复代码块
for i, code_block in enumerate(code_blocks):
text = text.replace(f"__CODE_BLOCK_{i}__", code_block)
# 恢复行内代码
for i, inline_code in enumerate(inline_codes):
text = text.replace(f"__INLINE_CODE_{i}__", inline_code)
# 移除多余的空行
text = re.sub(r'\n{3,}', '\n\n', text)
return text
async def send_typing_action(self, chat_id: Union[int, str]):
"""
发送正在输入的状态
Args:
chat_id: 聊天ID
"""
try:
# 这里需要实际的Bot实例来发送typing action
# 在实际实现中,我们会通过依赖注入获取Bot实例
pass
except Exception as e:
logger.warning(f"发送typing状态失败: {e}")
async def send_chunked_message(
self,
chat_id: Union[int, str],
text: str,
chunk_size: int = 4000,
reply_to_message_id: Optional[int] = None
) -> list[Message]:
"""
发送长消息(分块发送)
Args:
chat_id: 聊天ID
text: 消息文本
chunk_size: 每块的大小
reply_to_message_id: 回复的消息ID
Returns:
发送的消息列表
"""
messages = []
# 将长文本分块
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
for i, chunk in enumerate(chunks):
# 只有第一条消息需要回复原消息
reply_id = reply_to_message_id if i == 0 else None
message = await self.telegram_client.send_message(
chat_id=chat_id,
text=chunk,
reply_to_message_id=reply_id
)
messages.append(message)
# 避免发送频率过快
if i < len(chunks) - 1:
await asyncio.sleep(0.1)
return messages