blob: 5aabc0c3a6f20d7446bd3ed25c81e39dd27ab5e7 [file] [log] [blame] [raw]
"""
SSHOUT API客户端模块
基于SSHOUT API v1二进制包协议的客户端实现
参考: references/sshout/API.htm
"""
import asyncio
import paramiko
import struct
import logging
import time
import os
from typing import Optional, List, Dict, Callable, Tuple, Any
from dataclasses import dataclass
from datetime import datetime
from enum import IntEnum
from ..utils.config import get_config_manager
class SSHOUTPacketType(IntEnum):
"""SSHOUT API包类型定义"""
# 客户端到服务器包
HELLO = 1
GET_ONLINE_USER = 2
SEND_MESSAGE = 3
GET_MOTD = 4
# 服务器到客户端包
PASS = 128
ONLINE_USERS_INFO = 129
RECEIVE_MESSAGE = 130
USER_STATE_CHANGE = 131
ERROR = 132
MOTD = 133
class SSHOUTMessageType(IntEnum):
"""SSHOUT消息类型定义"""
PLAIN = 1
RICH = 2
IMAGE = 3
class SSHOUTErrorCode(IntEnum):
"""SSHOUT错误码定义"""
SERVER_CLOSED = 1
LOCAL_PACKET_CORRUPT = 2
LOCAL_PACKET_TOO_LARGE = 3
OUT_OF_MEMORY = 4
INTERNAL_ERROR = 5
USER_NOT_FOUND = 6
MOTD_NOT_AVAILABLE = 7
@dataclass
class SSHOUTMessage:
"""SSHOUT消息数据结构"""
timestamp: datetime
from_user: str
to_user: str
message_type: SSHOUTMessageType
content: str
is_mention: bool = False
@dataclass
class SSHOUTUser:
"""SSHOUT在线用户信息"""
id: int
username: str
hostname: str
class SSHOUTApiClient:
"""SSHOUT API客户端"""
def __init__(self, hostname: str, port: int, username: str, key_path: str,
mention_patterns: Optional[List[str]] = None, timeout: int = 10):
self.hostname = hostname
self.port = port
self.username = username
self.key_path = key_path
self.timeout = timeout
self.client: Optional[paramiko.SSHClient] = None
self.channel: Optional[paramiko.Channel] = None
self.stdin = None
self.stdout = None
self.stderr = None
self.connected = False
self.logger = logging.getLogger(__name__)
# 消息处理回调
self.message_callbacks: List[Callable[[SSHOUTMessage], None]] = []
self.mention_callbacks: List[Callable[[SSHOUTMessage], None]] = []
# 消息历史
self.message_history: List[SSHOUTMessage] = []
self.max_history = 100
# 当前用户信息
self.my_user_id: Optional[int] = None
self.my_username: Optional[str] = None
# @Claude检测模式
self.mention_patterns = mention_patterns or [
"@Claude", "@claude", "@CLAUDE",
"Claude:", "claude:",
"Claude,", "claude,",
"Claude,", "claude,"
]
async def connect(self) -> bool:
"""连接到SSHOUT API服务器"""
try:
self.logger.info(f"🔌 连接到SSHOUT API服务器 {self.hostname}:{self.port}")
# 创建SSH客户端
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# 加载私钥
try:
private_key = paramiko.ECDSAKey.from_private_key_file(self.key_path)
self.logger.info("🔑 成功加载ECDSA私钥")
except Exception as e:
self.logger.error(f"❌ 加载私钥失败: {e}")
return False
# 连接SSH服务器
self.client.connect(
hostname=self.hostname,
port=self.port,
username=self.username,
pkey=private_key,
timeout=self.timeout
)
# 请求API frontend通过SSH exec命令
self.stdin, self.stdout, self.stderr = self.client.exec_command('api')
# 使用stdout作为API数据通道
self.channel = self.stdout.channel
if not self.channel:
self.logger.error("❌ 无法创建API数据通道")
return False
# 设置通道为非阻塞模式
self.channel.settimeout(0.1)
self.logger.info("🔗 API frontend通道已建立")
# 执行API握手
if not await self._handshake():
self.logger.error("❌ API握手失败")
return False
self.connected = True
self.logger.info("✅ SSHOUT API连接成功建立")
# 启动消息监听任务
asyncio.create_task(self._message_listener())
# 启动保活任务
asyncio.create_task(self._keep_alive())
# 获取在线用户列表
await self._get_online_users()
return True
except Exception as e:
self.logger.error(f"❌ SSHOUT API连接失败: {e}")
await self.disconnect()
return False
async def disconnect(self):
"""断开SSHOUT API连接"""
try:
self.connected = False
# 清理SSH对象引用
if self.stdin:
self.stdin.close()
self.stdin = None
if self.stdout:
self.stdout.close()
self.stdout = None
if self.stderr:
self.stderr.close()
self.stderr = None
if self.channel:
self.channel.close()
self.channel = None
if self.client:
self.client.close()
self.client = None
self.logger.info("🔌 SSHOUT API连接已断开")
except Exception as e:
self.logger.error(f"❌ 断开连接时出错: {e}")
async def _handshake(self) -> bool:
"""执行SSHOUT API握手"""
try:
# 发送HELLO包
hello_data = b"SSHOUT" + struct.pack(">H", 1) # magic + version
await self._send_packet(SSHOUTPacketType.HELLO, hello_data)
# 等待PASS包响应
packet_type, packet_data = await self._receive_packet()
if packet_type != SSHOUTPacketType.PASS:
self.logger.error(f"❌ 握手失败: 期望PASS包,收到{packet_type}")
return False
# 解析PASS包
if len(packet_data) < 8:
self.logger.error("❌ PASS包格式错误")
return False
magic = packet_data[:6]
if magic != b"SSHOUT":
self.logger.error("❌ PASS包magic不匹配")
return False
version = struct.unpack(">H", packet_data[6:8])[0]
if version != 1:
self.logger.error(f"❌ API版本不匹配: {version}")
return False
# 提取用户名
if len(packet_data) > 8:
username_length = packet_data[8]
if len(packet_data) >= 9 + username_length:
self.my_username = packet_data[9:9+username_length].decode('utf-8')
self.logger.info(f"✅ 握手成功,用户名: {self.my_username}")
return True
except Exception as e:
self.logger.error(f"❌ 握手过程出错: {e}")
return False
async def _send_packet(self, packet_type: SSHOUTPacketType, data: bytes):
"""发送SSHOUT API包"""
if not self.channel:
raise Exception("Channel未连接")
# 构建包头 (length + type + data)
packet_length = len(data) + 1 # type字段的长度
packet = struct.pack(">IB", packet_length, packet_type.value) + data
self.channel.send(packet)
self.logger.debug(f"📤 发送包: type={packet_type.name}, length={packet_length}")
async def _receive_packet(self) -> Tuple[SSHOUTPacketType, bytes]:
"""接收SSHOUT API包"""
if not self.channel:
raise Exception("Channel未连接")
# 读取包头 (4字节长度)
length_data = await self._read_exact(4)
packet_length = struct.unpack(">I", length_data)[0]
# 读取包类型 (1字节)
type_data = await self._read_exact(1)
packet_type = SSHOUTPacketType(type_data[0])
# 读取包数据
data_length = packet_length - 1
data = b""
if data_length > 0:
data = await self._read_exact(data_length)
self.logger.debug(f"📥 接收包: type={packet_type.name}, length={packet_length}")
return packet_type, data
async def _read_exact(self, length: int) -> bytes:
"""精确读取指定长度的数据"""
data = b""
timeout_count = 0
max_timeout = 1000 # 10秒超时 (100 * 0.01)
while len(data) < length:
if not self.channel or self.channel.closed:
raise Exception("连接已关闭")
if not self.channel.recv_ready():
await asyncio.sleep(0.01)
timeout_count += 1
if timeout_count > max_timeout:
raise Exception("读取数据超时")
continue
timeout_count = 0 # 重置超时计数
try:
chunk = self.channel.recv(length - len(data))
if not chunk:
raise Exception("连接已关闭")
data += chunk
except Exception as e:
if "closed" in str(e).lower():
raise Exception("连接已关闭")
raise
return data
async def _get_online_users(self):
"""获取在线用户列表"""
try:
await self._send_packet(SSHOUTPacketType.GET_ONLINE_USER, b"")
# 可能会收到多个包,需要循环处理直到收到ONLINE_USERS_INFO
max_attempts = 10
for attempt in range(max_attempts):
packet_type, packet_data = await self._receive_packet()
if packet_type == SSHOUTPacketType.ONLINE_USERS_INFO:
break
elif packet_type == SSHOUTPacketType.USER_STATE_CHANGE:
# 处理用户状态变化包并继续等待
await self._process_user_state_change(packet_data)
continue
else:
self.logger.warning(f"⚠️ 获取在线用户时收到意外包类型: {packet_type}")
continue
else:
self.logger.error(f"❌ 获取在线用户失败: 超过最大尝试次数")
return
# 解析在线用户信息包
if len(packet_data) < 4:
self.logger.error("❌ 在线用户信息包格式错误")
return
self.my_user_id = struct.unpack(">H", packet_data[0:2])[0]
user_count = struct.unpack(">H", packet_data[2:4])[0]
self.logger.info(f"👥 在线用户数量: {user_count}, 我的ID: {self.my_user_id}")
# 解析用户列表
offset = 4
for i in range(user_count):
if offset + 4 > len(packet_data):
break
user_id = struct.unpack(">H", packet_data[offset:offset+2])[0]
username_length = packet_data[offset+2]
username = packet_data[offset+3:offset+3+username_length].decode('utf-8')
hostname_offset = offset + 3 + username_length
if hostname_offset + 1 > len(packet_data):
break
hostname_length = packet_data[hostname_offset]
hostname = packet_data[hostname_offset+1:hostname_offset+1+hostname_length].decode('utf-8')
self.logger.debug(f"👤 用户: {user_id} - {username}@{hostname}")
offset = hostname_offset + 1 + hostname_length
except Exception as e:
self.logger.error(f"❌ 获取在线用户失败: {e}")
async def send_message(self, to_user: str, message: str,
message_type: SSHOUTMessageType = SSHOUTMessageType.PLAIN) -> bool:
"""发送消息到SSHOUT"""
if not self.connected or not self.channel:
self.logger.error("❌ 未连接到SSHOUT API服务器")
return False
try:
# 构建SEND_MESSAGE包数据
to_user_bytes = to_user.encode('utf-8')
message_bytes = message.encode('utf-8')
data = struct.pack("B", len(to_user_bytes)) # to_user_length
data += to_user_bytes # to_user
data += struct.pack("B", message_type.value) # message_type
data += struct.pack(">I", len(message_bytes)) # message_length
data += message_bytes # message
await self._send_packet(SSHOUTPacketType.SEND_MESSAGE, data)
self.logger.info(f"📤 发送消息到{to_user}: {message}")
return True
except Exception as e:
self.logger.error(f"❌ 发送消息失败: {e}")
return False
async def send_global_message(self, message: str,
message_type: SSHOUTMessageType = SSHOUTMessageType.PLAIN) -> bool:
"""发送全局广播消息"""
return await self.send_message("GLOBAL", message, message_type)
async def _message_listener(self):
"""消息监听任务"""
self.logger.info("👂 启动SSHOUT API消息监听...")
while self.connected:
try:
# 检查连接状态
if not self.channel or self.channel.closed:
self.logger.error("❌ 检测到连接断开,停止消息监听")
self.connected = False
break
if not self.channel.recv_ready():
await asyncio.sleep(0.1)
continue
# 接收包
packet_type, packet_data = await self._receive_packet()
if packet_type == SSHOUTPacketType.RECEIVE_MESSAGE:
await self._process_receive_message(packet_data)
elif packet_type == SSHOUTPacketType.USER_STATE_CHANGE:
await self._process_user_state_change(packet_data)
elif packet_type == SSHOUTPacketType.ERROR:
await self._process_error(packet_data)
elif packet_type == SSHOUTPacketType.MOTD:
await self._process_motd(packet_data)
else:
self.logger.debug(f"🔍 收到未处理的包类型: {packet_type}")
except Exception as e:
if self.connected:
self.logger.error(f"❌ 消息监听错误: {e}")
# 检查是否是连接断开导致的错误
if "连接已关闭" in str(e) or "closed" in str(e).lower():
self.logger.error("❌ 连接断开,停止消息监听")
self.connected = False
break
await asyncio.sleep(1)
self.logger.info("👂 SSHOUT API消息监听已停止")
async def _keep_alive(self):
"""保活任务 - 定期检查连接状态"""
self.logger.debug("💓 启动连接保活任务...")
while self.connected:
try:
await asyncio.sleep(30) # 每30秒检查一次
if not self.connected:
break
if not self.channel or self.channel.closed:
self.logger.error("❌ 检测到连接断开(保活检查)")
self.connected = False
break
# SSH连接保活
if self.client and hasattr(self.client, 'get_transport'):
transport = self.client.get_transport()
if transport and transport.is_active():
# 发送SSH保活包
transport.send_ignore()
self.logger.debug("💓 发送SSH保活包")
else:
self.logger.error("❌ SSH传输层连接断开")
self.connected = False
break
except Exception as e:
if self.connected:
self.logger.error(f"❌ 保活任务错误: {e}")
self.connected = False
break
self.logger.debug("💓 连接保活任务已停止")
async def _process_receive_message(self, data: bytes):
"""处理接收消息包"""
try:
if len(data) < 13: # 最小包大小
self.logger.error("❌ 接收消息包格式错误")
return
offset = 0
# 解析时间戳
timestamp_int = struct.unpack(">Q", data[offset:offset+8])[0]
timestamp = datetime.fromtimestamp(timestamp_int)
offset += 8
# 解析from_user
from_user_length = data[offset]
offset += 1
from_user = data[offset:offset+from_user_length].decode('utf-8')
offset += from_user_length
# 解析to_user
to_user_length = data[offset]
offset += 1
to_user = data[offset:offset+to_user_length].decode('utf-8')
offset += to_user_length
# 解析message_type
message_type = SSHOUTMessageType(data[offset])
offset += 1
# 解析message
message_length = struct.unpack(">I", data[offset:offset+4])[0]
offset += 4
content = data[offset:offset+message_length].decode('utf-8')
# 创建消息对象
message = SSHOUTMessage(
timestamp=timestamp,
from_user=from_user,
to_user=to_user,
message_type=message_type,
content=content
)
# 添加到历史记录
self.message_history.append(message)
if len(self.message_history) > self.max_history:
self.message_history.pop(0)
# 调用消息回调
for callback in self.message_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(message))
else:
callback(message)
except Exception as e:
self.logger.error(f"❌ 消息回调错误: {e}")
# 检查是否是@Claude提及
if self._is_claude_mention(message.content):
message.is_mention = True
self.logger.info(f"📢 检测到@Claude提及: {message.from_user}: {message.content}")
for callback in self.mention_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(message))
else:
callback(message)
except Exception as e:
self.logger.error(f"❌ 提及回调错误: {e}")
except Exception as e:
self.logger.error(f"❌ 处理接收消息错误: {e}")
async def _process_user_state_change(self, data: bytes):
"""处理用户状态变化包"""
try:
if len(data) < 2:
return
state = data[0]
username_length = data[1]
username = data[2:2+username_length].decode('utf-8')
status = "上线" if state == 1 else "下线"
self.logger.info(f"👤 用户状态变化: {username} {status}")
except Exception as e:
self.logger.error(f"❌ 处理用户状态变化错误: {e}")
async def _process_error(self, data: bytes):
"""处理错误包"""
try:
if len(data) < 8:
return
error_code = struct.unpack(">I", data[0:4])[0]
message_length = struct.unpack(">I", data[4:8])[0]
error_message = data[8:8+message_length].decode('utf-8')
self.logger.error(f"❌ 服务器错误 [{error_code}]: {error_message}")
except Exception as e:
self.logger.error(f"❌ 处理错误包错误: {e}")
async def _process_motd(self, data: bytes):
"""处理MOTD包"""
try:
if len(data) < 4:
return
message_length = struct.unpack(">I", data[0:4])[0]
motd_message = data[4:4+message_length].decode('utf-8')
self.logger.info(f"📢 每日消息: {motd_message}")
except Exception as e:
self.logger.error(f"❌ 处理MOTD包错误: {e}")
def _is_claude_mention(self, content: str) -> bool:
"""检查消息是否包含@Claude提及"""
import re
for pattern in self.mention_patterns:
if pattern.startswith('@'):
# @Claude 格式,支持紧贴中文或用空格分隔
base_pattern = re.escape(pattern)
regex_pattern = base_pattern + r'(?![a-zA-Z0-9])'
elif pattern.endswith((':', ',', ',')):
# Claude: 或 Claude, 格式
regex_pattern = re.escape(pattern)
else:
# 其他格式,默认添加单词边界
regex_pattern = re.escape(pattern) + r'\b'
if re.search(regex_pattern, content, re.IGNORECASE):
return True
return False
def add_message_callback(self, callback: Callable[[SSHOUTMessage], None]):
"""添加消息回调"""
self.message_callbacks.append(callback)
def add_mention_callback(self, callback: Callable[[SSHOUTMessage], None]):
"""添加@Claude提及回调"""
self.mention_callbacks.append(callback)
def get_recent_messages(self, count: int = 10) -> List[SSHOUTMessage]:
"""获取最近的消息"""
return self.message_history[-count:]
def get_context_messages(self, before_time: datetime, count: int = 5) -> List[SSHOUTMessage]:
"""获取指定时间前的上下文消息"""
context_messages = []
for msg in reversed(self.message_history):
if msg.timestamp < before_time:
context_messages.append(msg)
if len(context_messages) >= count:
break
return list(reversed(context_messages))
def get_connection_status(self) -> Dict[str, any]:
"""获取连接状态"""
if not self.connected:
return {
'connected': False,
'server': None,
'message_count': 0,
'my_user_id': None,
'my_username': None
}
return {
'connected': self.connected,
'server': f"{self.hostname}:{self.port}",
'message_count': len(self.message_history),
'my_user_id': self.my_user_id,
'my_username': self.my_username,
'recent_messages': [
{
'timestamp': msg.timestamp.strftime('%H:%M:%S'),
'from_user': msg.from_user,
'content': msg.content[:50] + '...' if len(msg.content) > 50 else msg.content
}
for msg in self.get_recent_messages(3)
]
}
class SSHOUTApiIntegration:
"""SSHOUT API集成到Agent的桥接类"""
def __init__(self, agent_core, config_name: str = None):
self.agent = agent_core
self.client: Optional[SSHOUTApiClient] = None
self.logger = logging.getLogger(__name__)
# 加载配置
self.config_manager = get_config_manager(config_name)
self.sshout_config = self.config_manager.get_sshout_config()
# 验证配置完整性
self._validate_config()
def _validate_config(self):
"""验证SSHOUT配置的完整性"""
required_sections = ['server', 'ssh_key']
for section in required_sections:
if section not in self.sshout_config:
raise ValueError(f"SSHOUT配置缺少必需的段落: {section}")
# 验证服务器配置
server_config = self.sshout_config['server']
required_server_keys = ['hostname', 'port', 'username']
for key in required_server_keys:
if key not in server_config:
raise ValueError(f"SSHOUT服务器配置缺少必需的键: {key}")
# 验证SSH密钥配置
ssh_config = self.sshout_config['ssh_key']
if 'private_key_path' not in ssh_config:
raise ValueError("SSHOUT配置缺少SSH私钥路径")
# 检查密钥文件是否存在
key_path = ssh_config['private_key_path']
if not os.path.exists(key_path):
raise FileNotFoundError(f"SSH私钥文件不存在: {key_path}")
async def connect_to_sshout_api(self) -> bool:
"""连接到SSHOUT API服务器"""
try:
server_config = self.sshout_config['server']
ssh_config = self.sshout_config['ssh_key']
mention_patterns = self.sshout_config.get('mention_patterns', [])
timeout = ssh_config.get('timeout', 10)
self.client = SSHOUTApiClient(
hostname=server_config['hostname'],
port=server_config['port'],
username=server_config['username'],
key_path=ssh_config['private_key_path'],
mention_patterns=mention_patterns,
timeout=timeout
)
# 设置消息和提及回调
self.client.add_message_callback(self._on_message_received)
self.client.add_mention_callback(self._on_claude_mentioned)
success = await self.client.connect()
if success:
self.logger.info("🎉 SSHOUT API集成已启用")
return success
except Exception as e:
self.logger.error(f"❌ SSHOUT API连接失败: {e}")
return False
async def disconnect_from_sshout_api(self):
"""断开SSHOUT API连接"""
if self.client:
await self.client.disconnect()
self.client = None
def _on_message_received(self, message: SSHOUTMessage):
"""处理接收到的普通消息"""
self.logger.info(f"💬 [{message.timestamp.strftime('%H:%M:%S')}] "
f"{message.from_user} -> {message.to_user}: {message.content}")
async def _on_claude_mentioned(self, message: SSHOUTMessage):
"""处理@Claude提及"""
try:
self.logger.info(f"🎯 处理@Claude提及: {message.from_user}: {message.content}")
# 从配置获取上下文数量
context_count = self.config_manager.get('sshout.message.context_count', 5)
# 收集上下文消息
context_messages = self.client.get_context_messages(
before_time=message.timestamp,
count=context_count
)
# 构建上下文提示
context_text = ""
if context_messages:
context_text = "聊天室上下文消息:\n"
for ctx_msg in context_messages:
context_text += f"[{ctx_msg.timestamp.strftime('%H:%M:%S')}] "
context_text += f"{ctx_msg.from_user}: {ctx_msg.content}\n"
context_text += "\n"
# 构建完整的提示
full_prompt = f"""{context_text}当前消息:
[{message.timestamp.strftime('%H:%M:%S')}] {message.from_user}: {message.content}
请基于以上聊天室上下文,简洁地回复{message.from_user}的问题或评论。回复要自然、友好,适合聊天室环境。
重要:请直接回复内容,不要在回复开头添加@{message.from_user}或任何@前缀。"""
# 使用Agent处理
response = await self.agent.process_user_input(full_prompt)
# 清理响应(移除可能的格式标记)
clean_response = self._clean_response_for_sshout(response)
# 发送回复
if self.client:
success = await self.client.send_global_message(clean_response)
if success:
self.logger.info(f"✅ 已回复 {message.from_user}")
else:
self.logger.error(f"❌ 回复失败")
except Exception as e:
self.logger.error(f"❌ 处理@Claude提及时出错: {e}")
def _clean_response_for_sshout(self, response: str) -> str:
"""清理响应文本,使其适合SSHOUT环境"""
import re
# 移除markdown格式
response = re.sub(r'\*\*(.*?)\*\*', r'\1', response) # 粗体
response = re.sub(r'\*(.*?)\*', r'\1', response) # 斜体
response = re.sub(r'`(.*?)`', r'\1', response) # 代码
# 移除多余的换行和空格
response = re.sub(r'\n+', ' ', response)
response = re.sub(r'\s+', ' ', response)
# 从配置获取回复长度限制 (0表示无限制)
max_length = self.config_manager.get('sshout.message.max_reply_length', 0)
# 仅在设置了长度限制时才截断
if max_length > 0 and len(response) > max_length:
response = response[:max_length] + "..."
return response.strip()
def get_connection_status(self) -> Dict[str, any]:
"""获取连接状态"""
if not self.client:
return {
'connected': False,
'server': None,
'message_count': 0,
'api_version': '1.0'
}
return {
**self.client.get_connection_status(),
'api_version': '1.0'
}
async def send_message(self, message: str) -> bool:
"""手动发送消息到SSHOUT"""
if not self.client or not self.client.connected:
return False
return await self.client.send_global_message(message)