| """ |
| SSHOUT API客户端模块 |
| 基于SSHOUT API v1二进制包协议的客户端实现 |
| 参考: references/sshout/API.htm |
| """ |
| |
| import asyncio |
| import paramiko |
| import struct |
| import logging |
| import time |
| import os |
| from typing import Optional, List, Dict, Callable, Tuple, Any |
| from dataclasses import dataclass |
| from datetime import datetime |
| from enum import IntEnum |
| |
| from ..utils.config import get_config_manager |
| |
| |
| class SSHOUTPacketType(IntEnum): |
| """SSHOUT API包类型定义""" |
| # 客户端到服务器包 |
| HELLO = 1 |
| GET_ONLINE_USER = 2 |
| SEND_MESSAGE = 3 |
| GET_MOTD = 4 |
| |
| # 服务器到客户端包 |
| PASS = 128 |
| ONLINE_USERS_INFO = 129 |
| RECEIVE_MESSAGE = 130 |
| USER_STATE_CHANGE = 131 |
| ERROR = 132 |
| MOTD = 133 |
| |
| |
| class SSHOUTMessageType(IntEnum): |
| """SSHOUT消息类型定义""" |
| PLAIN = 1 |
| RICH = 2 |
| IMAGE = 3 |
| |
| |
| class SSHOUTErrorCode(IntEnum): |
| """SSHOUT错误码定义""" |
| SERVER_CLOSED = 1 |
| LOCAL_PACKET_CORRUPT = 2 |
| LOCAL_PACKET_TOO_LARGE = 3 |
| OUT_OF_MEMORY = 4 |
| INTERNAL_ERROR = 5 |
| USER_NOT_FOUND = 6 |
| MOTD_NOT_AVAILABLE = 7 |
| |
| |
| @dataclass |
| class SSHOUTMessage: |
| """SSHOUT消息数据结构""" |
| timestamp: datetime |
| from_user: str |
| to_user: str |
| message_type: SSHOUTMessageType |
| content: str |
| is_mention: bool = False |
| |
| |
| @dataclass |
| class SSHOUTUser: |
| """SSHOUT在线用户信息""" |
| id: int |
| username: str |
| hostname: str |
| |
| |
| class SSHOUTApiClient: |
| """SSHOUT API客户端""" |
| |
| def __init__(self, hostname: str, port: int, username: str, key_path: str, |
| mention_patterns: Optional[List[str]] = None, timeout: int = 10): |
| self.hostname = hostname |
| self.port = port |
| self.username = username |
| self.key_path = key_path |
| self.timeout = timeout |
| |
| self.client: Optional[paramiko.SSHClient] = None |
| self.channel: Optional[paramiko.Channel] = None |
| self.stdin = None |
| self.stdout = None |
| self.stderr = None |
| self.connected = False |
| |
| self.logger = logging.getLogger(__name__) |
| |
| # 消息处理回调 |
| self.message_callbacks: List[Callable[[SSHOUTMessage], None]] = [] |
| self.mention_callbacks: List[Callable[[SSHOUTMessage], None]] = [] |
| |
| # 消息历史 |
| self.message_history: List[SSHOUTMessage] = [] |
| self.max_history = 100 |
| |
| # 当前用户信息 |
| self.my_user_id: Optional[int] = None |
| self.my_username: Optional[str] = None |
| |
| # @Claude检测模式 |
| self.mention_patterns = mention_patterns or [ |
| "@Claude", "@claude", "@CLAUDE", |
| "Claude:", "claude:", |
| "Claude,", "claude,", |
| "Claude,", "claude," |
| ] |
| |
| async def connect(self) -> bool: |
| """连接到SSHOUT API服务器""" |
| try: |
| self.logger.info(f"🔌 连接到SSHOUT API服务器 {self.hostname}:{self.port}") |
| |
| # 创建SSH客户端 |
| self.client = paramiko.SSHClient() |
| self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) |
| |
| # 加载私钥 |
| try: |
| private_key = paramiko.ECDSAKey.from_private_key_file(self.key_path) |
| self.logger.info("🔑 成功加载ECDSA私钥") |
| except Exception as e: |
| self.logger.error(f"❌ 加载私钥失败: {e}") |
| return False |
| |
| # 连接SSH服务器 |
| self.client.connect( |
| hostname=self.hostname, |
| port=self.port, |
| username=self.username, |
| pkey=private_key, |
| timeout=self.timeout |
| ) |
| |
| # 请求API frontend通过SSH exec命令 |
| self.stdin, self.stdout, self.stderr = self.client.exec_command('api') |
| |
| # 使用stdout作为API数据通道 |
| self.channel = self.stdout.channel |
| if not self.channel: |
| self.logger.error("❌ 无法创建API数据通道") |
| return False |
| |
| # 设置通道为非阻塞模式 |
| self.channel.settimeout(0.1) |
| |
| self.logger.info("🔗 API frontend通道已建立") |
| |
| # 执行API握手 |
| if not await self._handshake(): |
| self.logger.error("❌ API握手失败") |
| return False |
| |
| self.connected = True |
| self.logger.info("✅ SSHOUT API连接成功建立") |
| |
| # 启动消息监听任务 |
| asyncio.create_task(self._message_listener()) |
| |
| # 启动保活任务 |
| asyncio.create_task(self._keep_alive()) |
| |
| # 获取在线用户列表 |
| await self._get_online_users() |
| |
| return True |
| |
| except Exception as e: |
| self.logger.error(f"❌ SSHOUT API连接失败: {e}") |
| await self.disconnect() |
| return False |
| |
| async def disconnect(self): |
| """断开SSHOUT API连接""" |
| try: |
| self.connected = False |
| |
| # 清理SSH对象引用 |
| if self.stdin: |
| self.stdin.close() |
| self.stdin = None |
| if self.stdout: |
| self.stdout.close() |
| self.stdout = None |
| if self.stderr: |
| self.stderr.close() |
| self.stderr = None |
| |
| if self.channel: |
| self.channel.close() |
| self.channel = None |
| |
| if self.client: |
| self.client.close() |
| self.client = None |
| |
| self.logger.info("🔌 SSHOUT API连接已断开") |
| |
| except Exception as e: |
| self.logger.error(f"❌ 断开连接时出错: {e}") |
| |
| async def _handshake(self) -> bool: |
| """执行SSHOUT API握手""" |
| try: |
| # 发送HELLO包 |
| hello_data = b"SSHOUT" + struct.pack(">H", 1) # magic + version |
| await self._send_packet(SSHOUTPacketType.HELLO, hello_data) |
| |
| # 等待PASS包响应 |
| packet_type, packet_data = await self._receive_packet() |
| if packet_type != SSHOUTPacketType.PASS: |
| self.logger.error(f"❌ 握手失败: 期望PASS包,收到{packet_type}") |
| return False |
| |
| # 解析PASS包 |
| if len(packet_data) < 8: |
| self.logger.error("❌ PASS包格式错误") |
| return False |
| |
| magic = packet_data[:6] |
| if magic != b"SSHOUT": |
| self.logger.error("❌ PASS包magic不匹配") |
| return False |
| |
| version = struct.unpack(">H", packet_data[6:8])[0] |
| if version != 1: |
| self.logger.error(f"❌ API版本不匹配: {version}") |
| return False |
| |
| # 提取用户名 |
| if len(packet_data) > 8: |
| username_length = packet_data[8] |
| if len(packet_data) >= 9 + username_length: |
| self.my_username = packet_data[9:9+username_length].decode('utf-8') |
| self.logger.info(f"✅ 握手成功,用户名: {self.my_username}") |
| |
| return True |
| |
| except Exception as e: |
| self.logger.error(f"❌ 握手过程出错: {e}") |
| return False |
| |
| async def _send_packet(self, packet_type: SSHOUTPacketType, data: bytes): |
| """发送SSHOUT API包""" |
| if not self.channel: |
| raise Exception("Channel未连接") |
| |
| # 构建包头 (length + type + data) |
| packet_length = len(data) + 1 # type字段的长度 |
| packet = struct.pack(">IB", packet_length, packet_type.value) + data |
| |
| self.channel.send(packet) |
| self.logger.debug(f"📤 发送包: type={packet_type.name}, length={packet_length}") |
| |
| async def _receive_packet(self) -> Tuple[SSHOUTPacketType, bytes]: |
| """接收SSHOUT API包""" |
| if not self.channel: |
| raise Exception("Channel未连接") |
| |
| # 读取包头 (4字节长度) |
| length_data = await self._read_exact(4) |
| packet_length = struct.unpack(">I", length_data)[0] |
| |
| # 读取包类型 (1字节) |
| type_data = await self._read_exact(1) |
| packet_type = SSHOUTPacketType(type_data[0]) |
| |
| # 读取包数据 |
| data_length = packet_length - 1 |
| data = b"" |
| if data_length > 0: |
| data = await self._read_exact(data_length) |
| |
| self.logger.debug(f"📥 接收包: type={packet_type.name}, length={packet_length}") |
| return packet_type, data |
| |
| async def _read_exact(self, length: int) -> bytes: |
| """精确读取指定长度的数据""" |
| data = b"" |
| timeout_count = 0 |
| max_timeout = 1000 # 10秒超时 (100 * 0.01) |
| |
| while len(data) < length: |
| if not self.channel or self.channel.closed: |
| raise Exception("连接已关闭") |
| |
| if not self.channel.recv_ready(): |
| await asyncio.sleep(0.01) |
| timeout_count += 1 |
| if timeout_count > max_timeout: |
| raise Exception("读取数据超时") |
| continue |
| |
| timeout_count = 0 # 重置超时计数 |
| try: |
| chunk = self.channel.recv(length - len(data)) |
| if not chunk: |
| raise Exception("连接已关闭") |
| data += chunk |
| except Exception as e: |
| if "closed" in str(e).lower(): |
| raise Exception("连接已关闭") |
| raise |
| |
| return data |
| |
| async def _get_online_users(self): |
| """获取在线用户列表""" |
| try: |
| await self._send_packet(SSHOUTPacketType.GET_ONLINE_USER, b"") |
| |
| # 可能会收到多个包,需要循环处理直到收到ONLINE_USERS_INFO |
| max_attempts = 10 |
| for attempt in range(max_attempts): |
| packet_type, packet_data = await self._receive_packet() |
| |
| if packet_type == SSHOUTPacketType.ONLINE_USERS_INFO: |
| break |
| elif packet_type == SSHOUTPacketType.USER_STATE_CHANGE: |
| # 处理用户状态变化包并继续等待 |
| await self._process_user_state_change(packet_data) |
| continue |
| else: |
| self.logger.warning(f"⚠️ 获取在线用户时收到意外包类型: {packet_type}") |
| continue |
| else: |
| self.logger.error(f"❌ 获取在线用户失败: 超过最大尝试次数") |
| return |
| |
| # 解析在线用户信息包 |
| if len(packet_data) < 4: |
| self.logger.error("❌ 在线用户信息包格式错误") |
| return |
| |
| self.my_user_id = struct.unpack(">H", packet_data[0:2])[0] |
| user_count = struct.unpack(">H", packet_data[2:4])[0] |
| |
| self.logger.info(f"👥 在线用户数量: {user_count}, 我的ID: {self.my_user_id}") |
| |
| # 解析用户列表 |
| offset = 4 |
| for i in range(user_count): |
| if offset + 4 > len(packet_data): |
| break |
| |
| user_id = struct.unpack(">H", packet_data[offset:offset+2])[0] |
| username_length = packet_data[offset+2] |
| username = packet_data[offset+3:offset+3+username_length].decode('utf-8') |
| |
| hostname_offset = offset + 3 + username_length |
| if hostname_offset + 1 > len(packet_data): |
| break |
| |
| hostname_length = packet_data[hostname_offset] |
| hostname = packet_data[hostname_offset+1:hostname_offset+1+hostname_length].decode('utf-8') |
| |
| self.logger.debug(f"👤 用户: {user_id} - {username}@{hostname}") |
| offset = hostname_offset + 1 + hostname_length |
| |
| except Exception as e: |
| self.logger.error(f"❌ 获取在线用户失败: {e}") |
| |
| async def send_message(self, to_user: str, message: str, |
| message_type: SSHOUTMessageType = SSHOUTMessageType.PLAIN) -> bool: |
| """发送消息到SSHOUT""" |
| if not self.connected or not self.channel: |
| self.logger.error("❌ 未连接到SSHOUT API服务器") |
| return False |
| |
| try: |
| # 构建SEND_MESSAGE包数据 |
| to_user_bytes = to_user.encode('utf-8') |
| message_bytes = message.encode('utf-8') |
| |
| data = struct.pack("B", len(to_user_bytes)) # to_user_length |
| data += to_user_bytes # to_user |
| data += struct.pack("B", message_type.value) # message_type |
| data += struct.pack(">I", len(message_bytes)) # message_length |
| data += message_bytes # message |
| |
| await self._send_packet(SSHOUTPacketType.SEND_MESSAGE, data) |
| self.logger.info(f"📤 发送消息到{to_user}: {message}") |
| return True |
| |
| except Exception as e: |
| self.logger.error(f"❌ 发送消息失败: {e}") |
| return False |
| |
| async def send_global_message(self, message: str, |
| message_type: SSHOUTMessageType = SSHOUTMessageType.PLAIN) -> bool: |
| """发送全局广播消息""" |
| return await self.send_message("GLOBAL", message, message_type) |
| |
| async def _message_listener(self): |
| """消息监听任务""" |
| self.logger.info("👂 启动SSHOUT API消息监听...") |
| |
| while self.connected: |
| try: |
| # 检查连接状态 |
| if not self.channel or self.channel.closed: |
| self.logger.error("❌ 检测到连接断开,停止消息监听") |
| self.connected = False |
| break |
| |
| if not self.channel.recv_ready(): |
| await asyncio.sleep(0.1) |
| continue |
| |
| # 接收包 |
| packet_type, packet_data = await self._receive_packet() |
| |
| if packet_type == SSHOUTPacketType.RECEIVE_MESSAGE: |
| await self._process_receive_message(packet_data) |
| elif packet_type == SSHOUTPacketType.USER_STATE_CHANGE: |
| await self._process_user_state_change(packet_data) |
| elif packet_type == SSHOUTPacketType.ERROR: |
| await self._process_error(packet_data) |
| elif packet_type == SSHOUTPacketType.MOTD: |
| await self._process_motd(packet_data) |
| else: |
| self.logger.debug(f"🔍 收到未处理的包类型: {packet_type}") |
| |
| except Exception as e: |
| if self.connected: |
| self.logger.error(f"❌ 消息监听错误: {e}") |
| # 检查是否是连接断开导致的错误 |
| if "连接已关闭" in str(e) or "closed" in str(e).lower(): |
| self.logger.error("❌ 连接断开,停止消息监听") |
| self.connected = False |
| break |
| await asyncio.sleep(1) |
| |
| self.logger.info("👂 SSHOUT API消息监听已停止") |
| |
| async def _keep_alive(self): |
| """保活任务 - 定期检查连接状态""" |
| self.logger.debug("💓 启动连接保活任务...") |
| |
| while self.connected: |
| try: |
| await asyncio.sleep(30) # 每30秒检查一次 |
| |
| if not self.connected: |
| break |
| |
| if not self.channel or self.channel.closed: |
| self.logger.error("❌ 检测到连接断开(保活检查)") |
| self.connected = False |
| break |
| |
| # SSH连接保活 |
| if self.client and hasattr(self.client, 'get_transport'): |
| transport = self.client.get_transport() |
| if transport and transport.is_active(): |
| # 发送SSH保活包 |
| transport.send_ignore() |
| self.logger.debug("💓 发送SSH保活包") |
| else: |
| self.logger.error("❌ SSH传输层连接断开") |
| self.connected = False |
| break |
| |
| except Exception as e: |
| if self.connected: |
| self.logger.error(f"❌ 保活任务错误: {e}") |
| self.connected = False |
| break |
| |
| self.logger.debug("💓 连接保活任务已停止") |
| |
| async def _process_receive_message(self, data: bytes): |
| """处理接收消息包""" |
| try: |
| if len(data) < 13: # 最小包大小 |
| self.logger.error("❌ 接收消息包格式错误") |
| return |
| |
| offset = 0 |
| |
| # 解析时间戳 |
| timestamp_int = struct.unpack(">Q", data[offset:offset+8])[0] |
| timestamp = datetime.fromtimestamp(timestamp_int) |
| offset += 8 |
| |
| # 解析from_user |
| from_user_length = data[offset] |
| offset += 1 |
| from_user = data[offset:offset+from_user_length].decode('utf-8') |
| offset += from_user_length |
| |
| # 解析to_user |
| to_user_length = data[offset] |
| offset += 1 |
| to_user = data[offset:offset+to_user_length].decode('utf-8') |
| offset += to_user_length |
| |
| # 解析message_type |
| message_type = SSHOUTMessageType(data[offset]) |
| offset += 1 |
| |
| # 解析message |
| message_length = struct.unpack(">I", data[offset:offset+4])[0] |
| offset += 4 |
| content = data[offset:offset+message_length].decode('utf-8') |
| |
| # 创建消息对象 |
| message = SSHOUTMessage( |
| timestamp=timestamp, |
| from_user=from_user, |
| to_user=to_user, |
| message_type=message_type, |
| content=content |
| ) |
| |
| # 添加到历史记录 |
| self.message_history.append(message) |
| if len(self.message_history) > self.max_history: |
| self.message_history.pop(0) |
| |
| # 调用消息回调 |
| for callback in self.message_callbacks: |
| try: |
| if asyncio.iscoroutinefunction(callback): |
| asyncio.create_task(callback(message)) |
| else: |
| callback(message) |
| except Exception as e: |
| self.logger.error(f"❌ 消息回调错误: {e}") |
| |
| # 检查是否是@Claude提及 |
| if self._is_claude_mention(message.content): |
| message.is_mention = True |
| self.logger.info(f"📢 检测到@Claude提及: {message.from_user}: {message.content}") |
| |
| for callback in self.mention_callbacks: |
| try: |
| if asyncio.iscoroutinefunction(callback): |
| asyncio.create_task(callback(message)) |
| else: |
| callback(message) |
| except Exception as e: |
| self.logger.error(f"❌ 提及回调错误: {e}") |
| |
| except Exception as e: |
| self.logger.error(f"❌ 处理接收消息错误: {e}") |
| |
| async def _process_user_state_change(self, data: bytes): |
| """处理用户状态变化包""" |
| try: |
| if len(data) < 2: |
| return |
| |
| state = data[0] |
| username_length = data[1] |
| username = data[2:2+username_length].decode('utf-8') |
| |
| status = "上线" if state == 1 else "下线" |
| self.logger.info(f"👤 用户状态变化: {username} {status}") |
| |
| except Exception as e: |
| self.logger.error(f"❌ 处理用户状态变化错误: {e}") |
| |
| async def _process_error(self, data: bytes): |
| """处理错误包""" |
| try: |
| if len(data) < 8: |
| return |
| |
| error_code = struct.unpack(">I", data[0:4])[0] |
| message_length = struct.unpack(">I", data[4:8])[0] |
| error_message = data[8:8+message_length].decode('utf-8') |
| |
| self.logger.error(f"❌ 服务器错误 [{error_code}]: {error_message}") |
| |
| except Exception as e: |
| self.logger.error(f"❌ 处理错误包错误: {e}") |
| |
| async def _process_motd(self, data: bytes): |
| """处理MOTD包""" |
| try: |
| if len(data) < 4: |
| return |
| |
| message_length = struct.unpack(">I", data[0:4])[0] |
| motd_message = data[4:4+message_length].decode('utf-8') |
| |
| self.logger.info(f"📢 每日消息: {motd_message}") |
| |
| except Exception as e: |
| self.logger.error(f"❌ 处理MOTD包错误: {e}") |
| |
| def _is_claude_mention(self, content: str) -> bool: |
| """检查消息是否包含@Claude提及""" |
| import re |
| |
| for pattern in self.mention_patterns: |
| if pattern.startswith('@'): |
| # @Claude 格式,支持紧贴中文或用空格分隔 |
| base_pattern = re.escape(pattern) |
| regex_pattern = base_pattern + r'(?![a-zA-Z0-9])' |
| elif pattern.endswith((':', ',', ',')): |
| # Claude: 或 Claude, 格式 |
| regex_pattern = re.escape(pattern) |
| else: |
| # 其他格式,默认添加单词边界 |
| regex_pattern = re.escape(pattern) + r'\b' |
| |
| if re.search(regex_pattern, content, re.IGNORECASE): |
| return True |
| |
| return False |
| |
| def add_message_callback(self, callback: Callable[[SSHOUTMessage], None]): |
| """添加消息回调""" |
| self.message_callbacks.append(callback) |
| |
| def add_mention_callback(self, callback: Callable[[SSHOUTMessage], None]): |
| """添加@Claude提及回调""" |
| self.mention_callbacks.append(callback) |
| |
| def get_recent_messages(self, count: int = 10) -> List[SSHOUTMessage]: |
| """获取最近的消息""" |
| return self.message_history[-count:] |
| |
| def get_context_messages(self, before_time: datetime, count: int = 5) -> List[SSHOUTMessage]: |
| """获取指定时间前的上下文消息""" |
| context_messages = [] |
| for msg in reversed(self.message_history): |
| if msg.timestamp < before_time: |
| context_messages.append(msg) |
| if len(context_messages) >= count: |
| break |
| return list(reversed(context_messages)) |
| |
| def get_connection_status(self) -> Dict[str, any]: |
| """获取连接状态""" |
| if not self.connected: |
| return { |
| 'connected': False, |
| 'server': None, |
| 'message_count': 0, |
| 'my_user_id': None, |
| 'my_username': None |
| } |
| |
| return { |
| 'connected': self.connected, |
| 'server': f"{self.hostname}:{self.port}", |
| 'message_count': len(self.message_history), |
| 'my_user_id': self.my_user_id, |
| 'my_username': self.my_username, |
| 'recent_messages': [ |
| { |
| 'timestamp': msg.timestamp.strftime('%H:%M:%S'), |
| 'from_user': msg.from_user, |
| 'content': msg.content[:50] + '...' if len(msg.content) > 50 else msg.content |
| } |
| for msg in self.get_recent_messages(3) |
| ] |
| } |
| |
| |
| class SSHOUTApiIntegration: |
| """SSHOUT API集成到Agent的桥接类""" |
| |
| def __init__(self, agent_core, config_name: str = None): |
| self.agent = agent_core |
| self.client: Optional[SSHOUTApiClient] = None |
| self.logger = logging.getLogger(__name__) |
| |
| # 加载配置 |
| self.config_manager = get_config_manager(config_name) |
| self.sshout_config = self.config_manager.get_sshout_config() |
| |
| # 验证配置完整性 |
| self._validate_config() |
| |
| def _validate_config(self): |
| """验证SSHOUT配置的完整性""" |
| required_sections = ['server', 'ssh_key'] |
| |
| for section in required_sections: |
| if section not in self.sshout_config: |
| raise ValueError(f"SSHOUT配置缺少必需的段落: {section}") |
| |
| # 验证服务器配置 |
| server_config = self.sshout_config['server'] |
| required_server_keys = ['hostname', 'port', 'username'] |
| |
| for key in required_server_keys: |
| if key not in server_config: |
| raise ValueError(f"SSHOUT服务器配置缺少必需的键: {key}") |
| |
| # 验证SSH密钥配置 |
| ssh_config = self.sshout_config['ssh_key'] |
| if 'private_key_path' not in ssh_config: |
| raise ValueError("SSHOUT配置缺少SSH私钥路径") |
| |
| # 检查密钥文件是否存在 |
| key_path = ssh_config['private_key_path'] |
| if not os.path.exists(key_path): |
| raise FileNotFoundError(f"SSH私钥文件不存在: {key_path}") |
| |
| async def connect_to_sshout_api(self) -> bool: |
| """连接到SSHOUT API服务器""" |
| try: |
| server_config = self.sshout_config['server'] |
| ssh_config = self.sshout_config['ssh_key'] |
| mention_patterns = self.sshout_config.get('mention_patterns', []) |
| timeout = ssh_config.get('timeout', 10) |
| |
| self.client = SSHOUTApiClient( |
| hostname=server_config['hostname'], |
| port=server_config['port'], |
| username=server_config['username'], |
| key_path=ssh_config['private_key_path'], |
| mention_patterns=mention_patterns, |
| timeout=timeout |
| ) |
| |
| # 设置消息和提及回调 |
| self.client.add_message_callback(self._on_message_received) |
| self.client.add_mention_callback(self._on_claude_mentioned) |
| |
| success = await self.client.connect() |
| if success: |
| self.logger.info("🎉 SSHOUT API集成已启用") |
| |
| return success |
| |
| except Exception as e: |
| self.logger.error(f"❌ SSHOUT API连接失败: {e}") |
| return False |
| |
| async def disconnect_from_sshout_api(self): |
| """断开SSHOUT API连接""" |
| if self.client: |
| await self.client.disconnect() |
| self.client = None |
| |
| def _on_message_received(self, message: SSHOUTMessage): |
| """处理接收到的普通消息""" |
| self.logger.info(f"💬 [{message.timestamp.strftime('%H:%M:%S')}] " |
| f"{message.from_user} -> {message.to_user}: {message.content}") |
| |
| async def _on_claude_mentioned(self, message: SSHOUTMessage): |
| """处理@Claude提及""" |
| try: |
| self.logger.info(f"🎯 处理@Claude提及: {message.from_user}: {message.content}") |
| |
| # 从配置获取上下文数量 |
| context_count = self.config_manager.get('sshout.message.context_count', 5) |
| |
| # 收集上下文消息 |
| context_messages = self.client.get_context_messages( |
| before_time=message.timestamp, |
| count=context_count |
| ) |
| |
| # 构建上下文提示 |
| context_text = "" |
| if context_messages: |
| context_text = "聊天室上下文消息:\n" |
| for ctx_msg in context_messages: |
| context_text += f"[{ctx_msg.timestamp.strftime('%H:%M:%S')}] " |
| context_text += f"{ctx_msg.from_user}: {ctx_msg.content}\n" |
| context_text += "\n" |
| |
| # 构建完整的提示 |
| full_prompt = f"""{context_text}当前消息: |
| [{message.timestamp.strftime('%H:%M:%S')}] {message.from_user}: {message.content} |
| |
| 请基于以上聊天室上下文,简洁地回复{message.from_user}的问题或评论。回复要自然、友好,适合聊天室环境。 |
| 重要:请直接回复内容,不要在回复开头添加@{message.from_user}或任何@前缀。""" |
| |
| # 使用Agent处理 |
| response = await self.agent.process_user_input(full_prompt) |
| |
| # 清理响应(移除可能的格式标记) |
| clean_response = self._clean_response_for_sshout(response) |
| |
| # 发送回复 |
| if self.client: |
| success = await self.client.send_global_message(clean_response) |
| if success: |
| self.logger.info(f"✅ 已回复 {message.from_user}") |
| else: |
| self.logger.error(f"❌ 回复失败") |
| |
| except Exception as e: |
| self.logger.error(f"❌ 处理@Claude提及时出错: {e}") |
| |
| def _clean_response_for_sshout(self, response: str) -> str: |
| """清理响应文本,使其适合SSHOUT环境""" |
| import re |
| |
| # 移除markdown格式 |
| response = re.sub(r'\*\*(.*?)\*\*', r'\1', response) # 粗体 |
| response = re.sub(r'\*(.*?)\*', r'\1', response) # 斜体 |
| response = re.sub(r'`(.*?)`', r'\1', response) # 代码 |
| |
| # 移除多余的换行和空格 |
| response = re.sub(r'\n+', ' ', response) |
| response = re.sub(r'\s+', ' ', response) |
| |
| # 从配置获取回复长度限制 (0表示无限制) |
| max_length = self.config_manager.get('sshout.message.max_reply_length', 0) |
| |
| # 仅在设置了长度限制时才截断 |
| if max_length > 0 and len(response) > max_length: |
| response = response[:max_length] + "..." |
| |
| return response.strip() |
| |
| def get_connection_status(self) -> Dict[str, any]: |
| """获取连接状态""" |
| if not self.client: |
| return { |
| 'connected': False, |
| 'server': None, |
| 'message_count': 0, |
| 'api_version': '1.0' |
| } |
| |
| return { |
| **self.client.get_connection_status(), |
| 'api_version': '1.0' |
| } |
| |
| async def send_message(self, message: str) -> bool: |
| """手动发送消息到SSHOUT""" |
| if not self.client or not self.client.connected: |
| return False |
| |
| return await self.client.send_global_message(message) |