blob: 582050f076d46a93a7cece2de30f0c5e07d31b98 [file] [log] [blame] [raw]
"""
SSHOUT API客户端单元测试
测试基于SSHOUT API二进制协议的客户端功能
"""
import pytest
import asyncio
import struct
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch, Mock
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../src'))
from claude_agent.sshout.api_client import (
SSHOUTApiClient, SSHOUTApiIntegration,
SSHOUTPacketType, SSHOUTMessageType, SSHOUTErrorCode,
SSHOUTMessage, SSHOUTUser
)
class TestSSHOUTPacketTypes:
"""测试SSHOUT包类型定义"""
def test_client_to_server_packet_types(self):
"""测试客户端到服务器包类型"""
assert SSHOUTPacketType.HELLO.value == 1
assert SSHOUTPacketType.GET_ONLINE_USER.value == 2
assert SSHOUTPacketType.SEND_MESSAGE.value == 3
def test_server_to_client_packet_types(self):
"""测试服务器到客户端包类型"""
assert SSHOUTPacketType.PASS.value == 128
assert SSHOUTPacketType.ONLINE_USERS_INFO.value == 129
assert SSHOUTPacketType.RECEIVE_MESSAGE.value == 130
assert SSHOUTPacketType.USER_STATE_CHANGE.value == 131
assert SSHOUTPacketType.ERROR.value == 132
assert SSHOUTPacketType.MOTD.value == 133
class TestSSHOUTApiClient:
"""测试SSHOUT API客户端"""
def setup_method(self):
"""设置测试"""
self.client = SSHOUTApiClient(
hostname="test.server.com",
port=22333,
username="testuser",
key_path="/path/to/key"
)
def test_client_initialization(self):
"""测试客户端初始化"""
assert self.client.hostname == "test.server.com"
assert self.client.port == 22333
assert self.client.username == "testuser"
assert self.client.key_path == "/path/to/key"
assert self.client.timeout == 10
assert not self.client.connected
assert self.client.stdin is None
assert self.client.stdout is None
assert self.client.stderr is None
def test_is_claude_mention(self):
"""测试@Claude检测功能"""
# 正向测试
assert self.client._is_claude_mention("@Claude help me")
assert self.client._is_claude_mention("@claude what's up")
assert self.client._is_claude_mention("Claude: how are you?")
assert self.client._is_claude_mention("claude, tell me a joke")
assert self.client._is_claude_mention("Hey @Claude")
assert self.client._is_claude_mention("Claude:你好") # 中文冒号
assert self.client._is_claude_mention("claude,help") # 中文逗号
# 反向测试
assert not self.client._is_claude_mention("Hello world")
assert not self.client._is_claude_mention("This is a test")
assert not self.client._is_claude_mention("includes Claude but not mention")
assert not self.client._is_claude_mention("claudeius was a Roman emperor")
def test_add_callbacks(self):
"""测试添加回调函数"""
callback1 = Mock()
callback2 = Mock()
mention_callback = Mock()
# 添加消息回调
self.client.add_message_callback(callback1)
self.client.add_message_callback(callback2)
assert len(self.client.message_callbacks) == 2
# 添加提及回调
self.client.add_mention_callback(mention_callback)
assert len(self.client.mention_callbacks) == 1
def test_get_recent_messages(self):
"""测试获取最近消息"""
# 没有消息时
messages = self.client.get_recent_messages()
assert messages == []
# 添加一些消息
test_message = SSHOUTMessage(
message_type=SSHOUTMessageType.PLAIN,
from_user="testuser",
to_user="",
content="Test message",
timestamp=datetime.now()
)
self.client.recent_messages.append(test_message)
messages = self.client.get_recent_messages(count=1)
assert len(messages) == 1
assert messages[0].content == "Test message"
async def test_connect_without_ssh_client(self):
"""测试没有SSH客户端时的连接"""
with patch.object(self.client, 'ssh_client', None):
result = await self.client.connect()
assert result is False
async def test_disconnect_when_not_connected(self):
"""测试未连接时的断开连接"""
self.client.connected = False
await self.client.disconnect() # 应该不抛出异常
@patch('paramiko.SSHClient')
async def test_connect_ssh_failure(self, mock_ssh_class):
"""测试SSH连接失败"""
mock_ssh = Mock()
mock_ssh.connect.side_effect = Exception("Connection failed")
mock_ssh_class.return_value = mock_ssh
self.client.ssh_client = mock_ssh
result = await self.client.connect()
assert result is False
assert not self.client.connected
async def test_send_packet_not_connected(self):
"""测试未连接时发送包"""
self.client.connected = False
with pytest.raises(Exception):
await self.client._send_packet(SSHOUTPacketType.HELLO, b"test")
async def test_process_receive_message(self):
"""测试处理接收消息"""
# 准备测试数据
test_content = "Hello @Claude"
test_from = "testuser"
test_to = ""
# 构造数据包 (简化版)
data = struct.pack('<H', SSHOUTMessageType.PLAIN.value)
data += struct.pack('<H', len(test_from)) + test_from.encode('utf-8')
data += struct.pack('<H', len(test_to)) + test_to.encode('utf-8')
data += struct.pack('<H', len(test_content)) + test_content.encode('utf-8')
# 添加回调
message_callback = Mock()
mention_callback = Mock()
self.client.add_message_callback(message_callback)
self.client.add_mention_callback(mention_callback)
# 处理消息
await self.client._process_receive_message(data)
# 验证消息被添加到recent_messages
assert len(self.client.recent_messages) == 1
message = self.client.recent_messages[0]
assert message.content == test_content
assert message.from_user == test_from
# 验证回调被调用
message_callback.assert_called_once()
mention_callback.assert_called_once() # 因为包含@Claude
async def test_process_user_state_change(self):
"""测试处理用户状态变化"""
# 构造用户上线数据
username = "newuser"
data = struct.pack('<B', 1) # 1表示上线
data += struct.pack('<H', len(username)) + username.encode('utf-8')
await self.client._process_user_state_change(data)
# 检查用户是否添加到在线用户列表
assert username in [user.username for user in self.client.online_users]
async def test_process_error(self):
"""测试处理错误包"""
error_code = SSHOUTErrorCode.MESSAGE_TOO_LONG.value
error_message = "Message too long"
data = struct.pack('<H', error_code)
data += struct.pack('<H', len(error_message)) + error_message.encode('utf-8')
# 这应该不抛出异常
await self.client._process_error(data)
async def test_process_motd(self):
"""测试处理MOTD消息"""
motd_text = "Welcome to SSHOUT server"
data = struct.pack('<H', len(motd_text)) + motd_text.encode('utf-8')
await self.client._process_motd(data)
assert self.client.motd == motd_text
class TestSSHOUTApiIntegration:
"""测试SSHOUT API集成"""
def setup_method(self):
"""设置测试"""
self.integration = SSHOUTApiIntegration()
async def test_integration_start_stop(self):
"""测试集成启动和停止"""
# Mock配置
mock_config = {
'sshout': {
'server': {
'hostname': 'test.server.com',
'port': 22333,
'username': 'testuser'
},
'ssh_key': {
'private_key_path': '/path/to/key'
}
}
}
with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config:
mock_config_manager = Mock()
mock_config_manager.get_config.return_value = mock_config
mock_get_config.return_value = mock_config_manager
with patch.object(self.integration, 'client') as mock_client:
mock_client.connect.return_value = True
# 测试启动
result = await self.integration.start()
assert result is True
mock_client.connect.assert_called_once()
# 测试停止
await self.integration.stop()
mock_client.disconnect.assert_called_once()
assert SSHOUTPacketType.GET_MOTD.value == 4
def test_server_to_client_packet_types(self):
"""测试服务器到客户端包类型"""
assert SSHOUTPacketType.PASS.value == 128
assert SSHOUTPacketType.ONLINE_USERS_INFO.value == 129
assert SSHOUTPacketType.RECEIVE_MESSAGE.value == 130
assert SSHOUTPacketType.USER_STATE_CHANGE.value == 131
assert SSHOUTPacketType.ERROR.value == 132
assert SSHOUTPacketType.MOTD.value == 133
def test_message_types(self):
"""测试消息类型"""
assert SSHOUTMessageType.PLAIN.value == 1
assert SSHOUTMessageType.RICH.value == 2
assert SSHOUTMessageType.IMAGE.value == 3
def test_error_codes(self):
"""测试错误码"""
assert SSHOUTErrorCode.SERVER_CLOSED.value == 1
assert SSHOUTErrorCode.LOCAL_PACKET_CORRUPT.value == 2
assert SSHOUTErrorCode.LOCAL_PACKET_TOO_LARGE.value == 3
assert SSHOUTErrorCode.OUT_OF_MEMORY.value == 4
assert SSHOUTErrorCode.INTERNAL_ERROR.value == 5
assert SSHOUTErrorCode.USER_NOT_FOUND.value == 6
assert SSHOUTErrorCode.MOTD_NOT_AVAILABLE.value == 7
class TestSSHOUTMessage:
"""测试SSHOUT消息数据结构"""
def test_message_creation(self):
"""测试消息创建"""
timestamp = datetime.now()
message = SSHOUTMessage(
timestamp=timestamp,
from_user="testuser",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content="Hello, world!"
)
assert message.timestamp == timestamp
assert message.from_user == "testuser"
assert message.to_user == "GLOBAL"
assert message.message_type == SSHOUTMessageType.PLAIN
assert message.content == "Hello, world!"
assert message.is_mention is False
class TestSSHOUTUser:
"""测试SSHOUT用户数据结构"""
def test_user_creation(self):
"""测试用户创建"""
user = SSHOUTUser(
id=123,
username="testuser",
hostname="test.example.com"
)
assert user.id == 123
assert user.username == "testuser"
assert user.hostname == "test.example.com"
class TestSSHOUTApiClient:
"""测试SSHOUT API客户端"""
def setup_method(self):
"""测试设置"""
self.client = SSHOUTApiClient(
hostname="test.example.com",
port=22333,
username="testuser",
key_path="/fake/key/path"
)
def test_client_initialization(self):
"""测试客户端初始化"""
assert self.client.hostname == "test.example.com"
assert self.client.port == 22333
assert self.client.username == "testuser"
assert self.client.key_path == "/fake/key/path"
assert self.client.timeout == 10
assert self.client.connected is False
assert len(self.client.message_callbacks) == 0
assert len(self.client.mention_callbacks) == 0
assert len(self.client.message_history) == 0
assert self.client.max_history == 100
assert self.client.my_user_id is None
assert self.client.my_username is None
def test_default_mention_patterns(self):
"""测试默认@Claude检测模式"""
expected_patterns = [
"@Claude", "@claude", "@CLAUDE",
"Claude:", "claude:",
"Claude,", "claude,",
"Claude,", "claude,"
]
assert self.client.mention_patterns == expected_patterns
def test_custom_mention_patterns(self):
"""测试自定义@Claude检测模式"""
custom_patterns = ["@Bot", "Bot:"]
client = SSHOUTApiClient(
hostname="test.com",
port=22333,
username="user",
key_path="/fake/key",
mention_patterns=custom_patterns
)
assert client.mention_patterns == custom_patterns
def test_add_callbacks(self):
"""测试添加回调函数"""
message_callback = MagicMock()
mention_callback = MagicMock()
self.client.add_message_callback(message_callback)
self.client.add_mention_callback(mention_callback)
assert message_callback in self.client.message_callbacks
assert mention_callback in self.client.mention_callbacks
def test_claude_mention_detection(self):
"""测试@Claude提及检测"""
test_cases = [
("@Claude help me", True),
("@claude what's up", True),
("@CLAUDE test", True),
("Claude: please help", True),
("claude: hi there", True),
("Claude, how are you", True),
("claude, hello", True),
("Claude,你好", True),
("claude,测试", True),
("Hello everyone", False),
("This is Claude speaking", False), # 不应该匹配
("claudebot", False), # 不应该匹配紧贴的字母
("@ClaudeAI", False), # 不应该匹配扩展名
]
for content, expected in test_cases:
result = self.client._is_claude_mention(content)
assert result == expected, f"Failed for: '{content}', expected {expected}, got {result}"
def test_message_history_management(self):
"""测试消息历史管理"""
# 测试历史数量限制机制需要模拟实际的消息处理过程
# 这里我们测试的是手动清理过程
for i in range(150): # 超过最大历史数量
message = SSHOUTMessage(
timestamp=datetime.now(),
from_user=f"user{i}",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content=f"Message {i}"
)
self.client.message_history.append(message)
# 模拟_process_receive_message中的历史限制逻辑
if len(self.client.message_history) > self.client.max_history:
self.client.message_history.pop(0)
# 检查历史数量限制
assert len(self.client.message_history) <= self.client.max_history
def test_get_recent_messages(self):
"""测试获取最近消息"""
# 添加测试消息
for i in range(20):
message = SSHOUTMessage(
timestamp=datetime.now(),
from_user=f"user{i}",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content=f"Message {i}"
)
self.client.message_history.append(message)
# 获取最近10条消息
recent = self.client.get_recent_messages(10)
assert len(recent) == 10
assert recent[-1].content == "Message 19" # 最新的消息
def test_get_context_messages(self):
"""测试获取上下文消息"""
# 创建时间递增的消息
messages = []
base_time = datetime.now()
for i in range(10):
timestamp = datetime.fromtimestamp(base_time.timestamp() + i)
message = SSHOUTMessage(
timestamp=timestamp,
from_user=f"user{i}",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content=f"Message {i}"
)
messages.append(message)
self.client.message_history.append(message)
# 获取第5条消息前的上下文
before_time = messages[5].timestamp
context = self.client.get_context_messages(before_time, 3)
assert len(context) == 3
# 应该返回消息4, 3, 2 (按时间顺序)
assert context[0].content == "Message 2"
assert context[1].content == "Message 3"
assert context[2].content == "Message 4"
def test_connection_status_disconnected(self):
"""测试未连接状态"""
status = self.client.get_connection_status()
expected_status = {
'connected': False,
'server': None,
'message_count': 0,
'my_user_id': None,
'my_username': None
}
assert status == expected_status
def test_connection_status_connected(self):
"""测试连接状态"""
# 模拟连接状态
self.client.connected = True
self.client.my_user_id = 123
self.client.my_username = "testuser"
# 添加一些消息
message = SSHOUTMessage(
timestamp=datetime.now(),
from_user="other",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content="Test message content that is quite long and should be truncated"
)
self.client.message_history.append(message)
status = self.client.get_connection_status()
assert status['connected'] is True
assert status['server'] == "test.example.com:22333"
assert status['message_count'] == 1
assert status['my_user_id'] == 123
assert status['my_username'] == "testuser"
assert len(status['recent_messages']) == 1
# 检查消息截断
recent_msg = status['recent_messages'][0]
assert len(recent_msg['content']) <= 53 # 50 + "..."
class TestSSHOUTApiIntegration:
"""测试SSHOUT API集成"""
def setup_method(self):
"""测试设置"""
self.mock_agent = MagicMock()
# 创建mock配置管理器
self.mock_config_manager = MagicMock()
self.mock_config = {
'connection_mode': 'api',
'mention_patterns': ['@Claude'],
'server': {
'hostname': 'test.example.com',
'port': 22333,
'username': 'testuser'
},
'ssh_key': {
'private_key_path': '/fake/key/path',
'timeout': 10
},
'message': {
'max_history': 100,
'context_count': 5,
'max_reply_length': 200
}
}
self.mock_config_manager.get_sshout_config.return_value = self.mock_config
self.mock_config_manager.get.return_value = 5
# Mock get_config_manager函数
with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config:
mock_get_config.return_value = self.mock_config_manager
# Mock os.path.exists来避免文件检查
with patch('claude_agent.sshout.api_client.os.path.exists') as mock_exists:
mock_exists.return_value = True
self.integration = SSHOUTApiIntegration(self.mock_agent)
def test_integration_initialization(self):
"""测试集成初始化"""
assert self.integration.agent == self.mock_agent
assert self.integration.client is None
def test_config_validation_missing_section(self):
"""测试配置验证 - 缺少必需段落"""
incomplete_config = {'server': {}} # 缺少ssh_key段落
self.mock_config_manager.get_sshout_config.return_value = incomplete_config
with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config:
mock_get_config.return_value = self.mock_config_manager
with pytest.raises(ValueError, match="SSHOUT配置缺少必需的段落: ssh_key"):
SSHOUTApiIntegration(self.mock_agent)
def test_config_validation_missing_server_key(self):
"""测试配置验证 - 缺少服务器配置键"""
incomplete_config = {
'server': {'hostname': 'test.com'}, # 缺少port和username
'ssh_key': {'private_key_path': '/fake/key'}
}
self.mock_config_manager.get_sshout_config.return_value = incomplete_config
with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config:
mock_get_config.return_value = self.mock_config_manager
with pytest.raises(ValueError, match="SSHOUT服务器配置缺少必需的键: port"):
SSHOUTApiIntegration(self.mock_agent)
def test_config_validation_missing_key_path(self):
"""测试配置验证 - 缺少SSH密钥路径"""
incomplete_config = {
'server': {
'hostname': 'test.com',
'port': 22333,
'username': 'user'
},
'ssh_key': {} # 缺少private_key_path
}
self.mock_config_manager.get_sshout_config.return_value = incomplete_config
with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config:
mock_get_config.return_value = self.mock_config_manager
with pytest.raises(ValueError, match="SSHOUT配置缺少SSH私钥路径"):
SSHOUTApiIntegration(self.mock_agent)
def test_config_validation_key_file_not_exists(self):
"""测试配置验证 - SSH密钥文件不存在"""
with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config:
mock_get_config.return_value = self.mock_config_manager
with patch('claude_agent.sshout.api_client.os.path.exists') as mock_exists:
mock_exists.return_value = False
with pytest.raises(FileNotFoundError, match="SSH私钥文件不存在"):
SSHOUTApiIntegration(self.mock_agent)
def test_response_cleaning(self):
"""测试响应清理功能"""
# 设置配置项
self.mock_config_manager.get.return_value = 50
test_cases = [
("**bold text**", "bold text"),
("*italic text*", "italic text"),
("`code text`", "code text"),
("Line 1\nLine 2\n\nLine 3", "Line 1\nLine 2\n\nLine 3"), # 保留换行
("Multiple spaces", "Multiple spaces"), # 保留行内空格
("This is a very long response that exceeds the maximum length limit and should be truncated",
"This is a very long response that exceeds the max..."),
]
for input_text, expected in test_cases:
result = self.integration._clean_response_for_sshout(input_text)
if len(input_text) > 50:
assert result.endswith("...")
assert len(result) <= 53 # 50 + "..."
else:
assert result == expected
def test_connection_status_no_client(self):
"""测试无客户端时的连接状态"""
status = self.integration.get_connection_status()
expected_status = {
'connected': False,
'server': None,
'message_count': 0,
'api_version': '1.0'
}
assert status == expected_status
@pytest.mark.asyncio
async def test_send_message_no_client(self):
"""测试无客户端时发送消息"""
result = await self.integration.send_message("test message")
assert result is False
def test_message_callback_registration(self):
"""测试消息回调注册"""
# 创建模拟客户端
mock_client = MagicMock()
self.integration.client = mock_client
# 验证回调函数被正确注册
assert mock_client.add_message_callback.call_count == 0
assert mock_client.add_mention_callback.call_count == 0
# 模拟连接过程中的回调注册
self.integration.client = mock_client
self.integration.client.add_message_callback(self.integration._on_message_received)
self.integration.client.add_mention_callback(self.integration._on_claude_mentioned)
assert mock_client.add_message_callback.called
assert mock_client.add_mention_callback.called
class TestSSHOUTBinaryPacketHandling:
"""测试SSHOUT二进制包处理"""
def test_hello_packet_format(self):
"""测试HELLO包格式"""
magic = b"SSHOUT"
version = 1
hello_data = magic + struct.pack(">H", version)
# 验证包数据格式
assert hello_data[:6] == b"SSHOUT"
assert struct.unpack(">H", hello_data[6:8])[0] == 1
assert len(hello_data) == 8
def test_pass_packet_format(self):
"""测试PASS包格式"""
magic = b"SSHOUT"
version = 1
username = "testuser"
username_bytes = username.encode('utf-8')
pass_data = (magic +
struct.pack(">H", version) +
struct.pack("B", len(username_bytes)) +
username_bytes)
# 验证包数据格式
assert pass_data[:6] == b"SSHOUT"
assert struct.unpack(">H", pass_data[6:8])[0] == 1
assert pass_data[8] == len(username_bytes)
assert pass_data[9:9+len(username_bytes)] == username_bytes
def test_send_message_packet_format(self):
"""测试发送消息包格式"""
to_user = "GLOBAL"
message = "Hello, world!"
message_type = SSHOUTMessageType.PLAIN
to_user_bytes = to_user.encode('utf-8')
message_bytes = message.encode('utf-8')
data = (struct.pack("B", len(to_user_bytes)) +
to_user_bytes +
struct.pack("B", message_type.value) +
struct.pack(">I", len(message_bytes)) +
message_bytes)
# 验证包数据格式
offset = 0
assert data[offset] == len(to_user_bytes)
offset += 1
assert data[offset:offset+len(to_user_bytes)] == to_user_bytes
offset += len(to_user_bytes)
assert data[offset] == message_type.value
offset += 1
assert struct.unpack(">I", data[offset:offset+4])[0] == len(message_bytes)
offset += 4
assert data[offset:offset+len(message_bytes)] == message_bytes
def test_receive_message_packet_parsing(self):
"""测试接收消息包解析"""
import time
timestamp = int(time.time())
from_user = "sender"
to_user = "GLOBAL"
message_type = SSHOUTMessageType.PLAIN
content = "Test message"
from_user_bytes = from_user.encode('utf-8')
to_user_bytes = to_user.encode('utf-8')
content_bytes = content.encode('utf-8')
# 构造RECEIVE_MESSAGE包数据
data = (struct.pack(">Q", timestamp) + # 时间戳
struct.pack("B", len(from_user_bytes)) + # from_user_length
from_user_bytes + # from_user
struct.pack("B", len(to_user_bytes)) + # to_user_length
to_user_bytes + # to_user
struct.pack("B", message_type.value) + # message_type
struct.pack(">I", len(content_bytes)) + # message_length
content_bytes) # message
# 解析包数据
offset = 0
parsed_timestamp = struct.unpack(">Q", data[offset:offset+8])[0]
offset += 8
from_user_length = data[offset]
offset += 1
parsed_from_user = data[offset:offset+from_user_length].decode('utf-8')
offset += from_user_length
to_user_length = data[offset]
offset += 1
parsed_to_user = data[offset:offset+to_user_length].decode('utf-8')
offset += to_user_length
parsed_message_type = SSHOUTMessageType(data[offset])
offset += 1
message_length = struct.unpack(">I", data[offset:offset+4])[0]
offset += 4
parsed_content = data[offset:offset+message_length].decode('utf-8')
# 验证解析结果
assert parsed_timestamp == timestamp
assert parsed_from_user == from_user
assert parsed_to_user == to_user
assert parsed_message_type == message_type
assert parsed_content == content
# 异步测试需要特殊处理
class TestSSHOUTAsyncOperations:
"""测试SSHOUT异步操作"""
def setup_method(self):
"""测试设置"""
self.client = SSHOUTApiClient(
hostname="test.example.com",
port=22333,
username="testuser",
key_path="/fake/key/path"
)
@pytest.mark.asyncio
async def test_connect_ssh_failure(self):
"""测试SSH连接失败"""
with patch('paramiko.SSHClient') as mock_ssh_class:
mock_ssh = MagicMock()
mock_ssh_class.return_value = mock_ssh
mock_ssh.connect.side_effect = Exception("Connection failed")
result = await self.client.connect()
assert result is False
assert self.client.connected is False
@pytest.mark.asyncio
async def test_send_message_not_connected(self):
"""测试未连接时发送消息"""
result = await self.client.send_message("GLOBAL", "test message")
assert result is False
@pytest.mark.asyncio
async def test_disconnect_cleanup(self):
"""测试断开连接清理"""
# 模拟已连接状态
self.client.connected = True
mock_channel = MagicMock()
mock_client = MagicMock()
mock_stdin = MagicMock()
mock_stdout = MagicMock()
mock_stderr = MagicMock()
self.client.channel = mock_channel
self.client.client = mock_client
self.client.stdin = mock_stdin
self.client.stdout = mock_stdout
self.client.stderr = mock_stderr
await self.client.disconnect()
assert self.client.connected is False
mock_stdin.close.assert_called_once()
mock_stdout.close.assert_called_once()
mock_stderr.close.assert_called_once()
mock_channel.close.assert_called_once()
mock_client.close.assert_called_once()
assert self.client.stdin is None
assert self.client.stdout is None
assert self.client.stderr is None
assert self.client.channel is None
assert self.client.client is None
@pytest.mark.asyncio
async def test_read_exact_with_timeout(self):
"""测试精确读取数据超时处理"""
# 模拟连接状态
self.client.connected = True
mock_channel = MagicMock()
mock_channel.closed = False
mock_channel.recv_ready.return_value = False # 模拟数据未就绪
self.client.channel = mock_channel
# 测试超时
with pytest.raises(Exception, match="读取数据超时"):
await self.client._read_exact(10)
@pytest.mark.asyncio
async def test_read_exact_connection_closed(self):
"""测试读取数据时连接断开"""
# 模拟连接状态
self.client.connected = True
mock_channel = MagicMock()
mock_channel.closed = True # 模拟连接已关闭
self.client.channel = mock_channel
# 测试连接断开检测
with pytest.raises(Exception, match="连接已关闭"):
await self.client._read_exact(10)
@pytest.mark.asyncio
async def test_read_exact_success(self):
"""测试成功读取指定长度数据"""
# 模拟连接状态
self.client.connected = True
mock_channel = MagicMock()
mock_channel.closed = False
mock_channel.recv_ready.return_value = True
mock_channel.recv.return_value = b"test_data"
self.client.channel = mock_channel
result = await self.client._read_exact(9)
assert result == b"test_data"
@pytest.mark.asyncio
async def test_get_online_users_with_user_state_change(self):
"""测试获取在线用户时处理USER_STATE_CHANGE包"""
# 模拟连接状态
self.client.connected = True
mock_channel = MagicMock()
self.client.channel = mock_channel
# 创建模拟包数据
state_change_data = struct.pack("BB", 1, 8) + b"testuser" # 上线,用户名长度,用户名
online_users_data = struct.pack(">HH", 123, 1) # 我的ID,用户数量
# 模拟接收包的顺序:先USER_STATE_CHANGE,后ONLINE_USERS_INFO
with patch.object(self.client, '_send_packet') as mock_send, \
patch.object(self.client, '_receive_packet') as mock_receive, \
patch.object(self.client, '_process_user_state_change') as mock_process:
mock_receive.side_effect = [
(SSHOUTPacketType.USER_STATE_CHANGE, state_change_data),
(SSHOUTPacketType.ONLINE_USERS_INFO, online_users_data)
]
await self.client._get_online_users()
# 验证发送了GET_ONLINE_USER请求
mock_send.assert_called_once_with(SSHOUTPacketType.GET_ONLINE_USER, b"")
# 验证处理了USER_STATE_CHANGE包
mock_process.assert_called_once_with(state_change_data)
# 验证设置了用户ID
assert self.client.my_user_id == 123
@pytest.mark.asyncio
async def test_keep_alive_task(self):
"""测试保活任务"""
# 模拟连接状态
self.client.connected = True
mock_client = MagicMock()
mock_transport = MagicMock()
mock_transport.is_active.return_value = True
mock_client.get_transport.return_value = mock_transport
self.client.client = mock_client
mock_channel = MagicMock()
mock_channel.closed = False
self.client.channel = mock_channel
# 创建一个快速结束的保活任务
async def quick_keep_alive():
await asyncio.sleep(0.1) # 短暂等待
if self.client.connected and hasattr(self.client.client, 'get_transport'):
transport = self.client.client.get_transport()
if transport and transport.is_active():
transport.send_ignore()
self.client.connected = False # 结束循环
# 替换保活方法进行测试
original_keep_alive = self.client._keep_alive
self.client._keep_alive = quick_keep_alive
# 运行保活任务
await self.client._keep_alive()
# 验证保活包发送
mock_transport.send_ignore.assert_called()
# 恢复原方法
self.client._keep_alive = original_keep_alive
@pytest.mark.asyncio
async def test_message_listener_connection_closed_detection(self):
"""测试消息监听器检测连接断开"""
# 模拟连接状态
self.client.connected = True
mock_channel = MagicMock()
mock_channel.closed = True # 模拟连接已关闭
self.client.channel = mock_channel
# 创建一个快速结束的消息监听任务
async def quick_listener():
if not self.client.channel or self.client.channel.closed:
self.client.connected = False
return
# 替换监听方法进行测试
original_listener = self.client._message_listener
self.client._message_listener = quick_listener
# 运行消息监听任务
await self.client._message_listener()
# 验证连接状态被正确设置
assert self.client.connected is False
# 恢复原方法
self.client._message_listener = original_listener
def test_ssh_object_references_initialization(self):
"""测试SSH对象引用初始化"""
assert self.client.stdin is None
assert self.client.stdout is None
assert self.client.stderr is None
def test_check_connection_state(self):
"""测试连接状态检查"""
# 默认未连接状态
assert not self.client.connected
# 模拟连接状态
self.client.connected = True
assert self.client.connected
# 模拟断开连接
self.client.connected = False
assert not self.client.connected
def test_format_message_for_display(self):
"""测试消息格式化显示"""
message = SSHOUTMessage(
timestamp=datetime.now(),
from_user="testuser",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content="Hello world!"
)
# 测试格式化功能(这可能是一个私有方法或在其他方法中实现)
formatted = f"[{message.timestamp.strftime('%H:%M:%S')}] {message.from_user}: {message.content}"
assert "testuser" in formatted
assert "Hello world!" in formatted
def test_message_callback_registration(self):
"""测试消息回调注册"""
callback_called = False
test_message = None
def test_callback(message):
nonlocal callback_called, test_message
callback_called = True
test_message = message
# 注册回调
self.client.add_message_callback(test_callback)
assert test_callback in self.client.message_callbacks
# 模拟触发回调
message = SSHOUTMessage(
timestamp=datetime.now(),
from_user="testuser",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content="Test message"
)
# 手动触发回调(模拟_process_receive_message中的逻辑)
for callback in self.client.message_callbacks:
callback(message)
assert callback_called
assert test_message == message
def test_mention_callback_registration(self):
"""测试@Claude提及回调注册"""
mention_called = False
test_message = None
test_context = None
def test_mention_callback(message, context):
nonlocal mention_called, test_message, test_context
mention_called = True
test_message = message
test_context = context
# 注册@Claude回调
self.client.add_mention_callback(test_mention_callback)
assert test_mention_callback in self.client.mention_callbacks
# 模拟@Claude消息
message = SSHOUTMessage(
timestamp=datetime.now(),
from_user="testuser",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content="@Claude help me"
)
# 手动触发@Claude检测和回调
if self.client._is_claude_mention(message.content):
context = [] # 模拟上下文
for callback in self.client.mention_callbacks:
callback(message, context)
assert mention_called
assert test_message == message
def test_connection_error_handling(self):
"""测试连接错误处理"""
# 测试无效的主机名
invalid_client = SSHOUTApiClient(
hostname="invalid.host.name",
port=22333,
username="sshout",
key_path="/tmp/test_key",
timeout=1
)
# 模拟连接失败
assert not invalid_client.connected
def test_packet_type_validation(self):
"""测试数据包类型验证"""
# 测试有效的数据包类型
valid_types = [
SSHOUTPacketType.HELLO,
SSHOUTPacketType.PASS,
SSHOUTPacketType.SEND_MESSAGE,
SSHOUTPacketType.ONLINE_USERS_INFO
]
for packet_type in valid_types:
assert isinstance(packet_type.value, int)
assert packet_type.value >= 0
def test_message_type_validation(self):
"""测试消息类型验证"""
# 测试有效的消息类型
valid_types = [
SSHOUTMessageType.PLAIN,
SSHOUTMessageType.RICH,
SSHOUTMessageType.IMAGE
]
for msg_type in valid_types:
assert isinstance(msg_type.value, int)
assert msg_type.value >= 0
def test_message_history_access(self):
"""测试消息历史访问"""
# 添加一些测试消息
for i in range(5):
message = SSHOUTMessage(
timestamp=datetime.now(),
from_user=f"user{i}",
to_user="GLOBAL",
message_type=SSHOUTMessageType.PLAIN,
content=f"Message {i}"
)
self.client.message_history.append(message)
# 测试历史访问
assert len(self.client.message_history) == 5
assert self.client.message_history[0].content == "Message 0"
assert self.client.message_history[-1].content == "Message 4"
def test_configuration_validation(self):
"""测试配置验证"""
# 测试默认配置(使用setup方法中的测试配置)
assert self.client.hostname == "test.example.com"
assert self.client.port == 22333
assert self.client.username == "testuser"
assert self.client.timeout == 10
assert self.client.max_history == 100
def test_logging_integration(self):
"""测试日志集成"""
# 验证日志器存在且配置正确
assert self.client.logger is not None
assert "sshout.api_client" in self.client.logger.name
class TestSSHOUTApiClientCoverageEnhancement:
"""覆盖率提升测试 - 针对缺失的错误处理和边界条件"""
def setup_method(self):
"""设置测试环境"""
self.config = {
'hostname': 'test.example.com',
'port': 22333,
'username': 'testuser',
'key_path': '/test/key',
'timeout': 10
}
@pytest.fixture
def api_client(self):
"""创建测试用的API客户端"""
client = SSHOUTApiClient(
hostname=self.config['hostname'],
port=self.config['port'],
username=self.config['username'],
key_path=self.config['key_path'],
timeout=self.config['timeout']
)
return client
@pytest.mark.asyncio
async def test_read_exact_connection_closed_exception(self, api_client):
"""测试_read_exact连接关闭异常处理"""
api_client.channel = Mock()
api_client.channel.recv.return_value = b"" # 空数据表示连接关闭
with pytest.raises(Exception) as exc_info:
await api_client._read_exact(10)
assert "连接已关闭" in str(exc_info.value)
@pytest.mark.asyncio
async def test_read_exact_closed_exception_handling(self, api_client):
"""测试_read_exact处理包含'closed'的异常"""
api_client.channel = Mock()
api_client.channel.recv.side_effect = Exception("Socket is closed")
with pytest.raises(Exception) as exc_info:
await api_client._read_exact(10)
assert "连接已关闭" in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_online_users_warning_unexpected_packet(self, api_client):
"""测试_get_online_users收到意外包类型时的警告"""
api_client.connected = True
api_client.logger = Mock()
with patch.object(api_client, '_send_packet') as mock_send, \
patch.object(api_client, '_receive_packet') as mock_receive:
# 模拟收到意外包类型
mock_receive.return_value = (99, b"unexpected") # 99是意外的包类型
result = await api_client._get_online_users()
# 验证警告被记录
api_client.logger.warning.assert_called()
call_args = api_client.logger.warning.call_args[0][0]
assert "获取在线用户时收到意外包类型" in call_args
@pytest.mark.asyncio
async def test_get_online_users_max_attempts_exceeded(self, api_client):
"""测试_get_online_users超过最大尝试次数"""
api_client.connected = True
api_client.logger = Mock()
with patch.object(api_client, '_send_packet') as mock_send, \
patch.object(api_client, '_receive_packet') as mock_receive:
# 模拟总是收到意外包类型,超过最大尝试次数
mock_receive.return_value = (99, b"unexpected")
result = await api_client._get_online_users()
# 验证错误被记录
api_client.logger.error.assert_called()
call_args = api_client.logger.error.call_args[0][0]
assert "获取在线用户失败: 超过最大尝试次数" in call_args
@pytest.mark.asyncio
async def test_get_online_users_invalid_packet_data(self, api_client):
"""测试_get_online_users包数据格式错误"""
api_client.connected = True
api_client.logger = Mock()
with patch.object(api_client, '_send_packet') as mock_send, \
patch.object(api_client, '_receive_packet') as mock_receive:
# 模拟收到格式错误的包(长度小于4)
mock_receive.return_value = (SSHOUTPacketType.ONLINE_USERS_INFO, b"xx") # 只有2字节
result = await api_client._get_online_users()
# 验证错误被记录
api_client.logger.error.assert_called_with("❌ 在线用户信息包格式错误")
@pytest.mark.asyncio
async def test_message_listener_basic_coverage(self, api_client):
"""测试_message_listener基本覆盖场景"""
api_client.connected = True
api_client.logger = Mock()
# 测试连接断开检测
api_client.channel = None
await api_client._message_listener()
# 验证错误被记录
api_client.logger.error.assert_called_with("❌ 检测到连接断开,停止消息监听")
assert not api_client.connected
@pytest.mark.asyncio
async def test_message_listener_exception_handling(self, api_client):
"""测试_message_listener异常处理"""
api_client.connected = True
api_client.logger = Mock()
api_client.channel = Mock()
api_client.channel.recv_ready.return_value = True
call_count = 0
async def mock_receive_packet():
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("连接已关闭")
else:
api_client.connected = False
return (SSHOUTPacketType.RECEIVE_MESSAGE, b"")
with patch.object(api_client, '_receive_packet', side_effect=mock_receive_packet), \
patch('asyncio.sleep'):
await api_client._message_listener()
# 验证错误处理
api_client.logger.error.assert_called()
assert not api_client.connected
@pytest.mark.asyncio
async def test_keep_alive_not_connected_break(self, api_client):
"""测试_keep_alive当connected为False时跳出循环"""
api_client.connected = False
api_client.logger = Mock()
await api_client._keep_alive()
# 验证立即跳出循环
api_client.logger.debug.assert_called_with("💓 连接保活任务已停止")
@pytest.mark.asyncio
async def test_keep_alive_channel_closed_detection(self, api_client):
"""测试_keep_alive检测到channel关闭"""
api_client.connected = True
api_client.logger = Mock()
api_client.channel = Mock()
api_client.channel.closed = True # 模拟channel已关闭
with patch('asyncio.sleep'):
await api_client._keep_alive()
# 验证错误处理和连接状态更新
api_client.logger.error.assert_called_with("❌ 检测到连接断开(保活检查)")
assert not api_client.connected
@pytest.mark.asyncio
async def test_keep_alive_transport_inactive(self, api_client):
"""测试_keep_alive传输层连接断开"""
api_client.connected = True
api_client.logger = Mock()
api_client.channel = Mock()
api_client.channel.closed = False
api_client.client = Mock()
# 模拟transport不活跃
mock_transport = Mock()
mock_transport.is_active.return_value = False
api_client.client.get_transport.return_value = mock_transport
with patch('asyncio.sleep'):
await api_client._keep_alive()
# 验证错误处理
api_client.logger.error.assert_called_with("❌ SSH传输层连接断开")
assert not api_client.connected
@pytest.mark.asyncio
async def test_keep_alive_exception_handling(self, api_client):
"""测试_keep_alive异常处理"""
api_client.connected = True
api_client.logger = Mock()
api_client.channel = Mock()
api_client.channel.closed = False
call_count = 0
async def mock_sleep(duration):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("Keep alive error")
api_client.connected = False
with patch('asyncio.sleep', side_effect=mock_sleep):
await api_client._keep_alive()
# 验证异常处理
api_client.logger.error.assert_called()
assert not api_client.connected
@pytest.mark.asyncio
async def test_process_receive_message_invalid_data_length(self, api_client):
"""测试_process_receive_message数据长度不足"""
api_client.logger = Mock()
# 模拟数据长度不足(小于13字节)
short_data = b"short"
await api_client._process_receive_message(short_data)
# 验证错误处理
api_client.logger.error.assert_called_with("❌ 接收消息包格式错误")
@pytest.mark.asyncio
async def test_process_receive_message_complete_flow(self, api_client):
"""测试_process_receive_message完整流程"""
api_client.logger = Mock()
api_client.message_callbacks = []
api_client.message_history = []
api_client.max_history = 100
# 创建一个完整的有效消息包
timestamp = int(datetime.now().timestamp())
from_user = "testuser"
to_user = "target"
message = "Hello, this is a test message!"
# 构建完整的包数据
packet_data = struct.pack(">Q", timestamp) # 8字节时间戳
packet_data += struct.pack("B", len(from_user)) # 1字节from_user长度
packet_data += from_user.encode('utf-8') # from_user
packet_data += struct.pack("B", len(to_user)) # 1字节to_user长度
packet_data += to_user.encode('utf-8') # to_user
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) # 1字节消息类型
packet_data += struct.pack(">I", len(message)) # 4字节消息长度
packet_data += message.encode('utf-8') # 消息内容
# 测试正常的消息处理
await api_client._process_receive_message(packet_data)
# 验证消息被添加到历史记录
assert len(api_client.message_history) == 1
message_obj = api_client.message_history[0]
assert message_obj.from_user == from_user
assert message_obj.to_user == to_user
assert message_obj.content == message
@pytest.mark.asyncio
async def test_process_receive_message_with_callbacks(self, api_client):
"""测试_process_receive_message带回调处理"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
# 设置消息回调
callback_mock = Mock()
api_client.message_callbacks = [callback_mock]
# 创建消息包
timestamp = int(datetime.now().timestamp())
from_user = "sender"
to_user = ""
message = "callback test"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
if to_user:
packet_data += to_user.encode('utf-8')
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
await api_client._process_receive_message(packet_data)
# 验证回调被调用
callback_mock.assert_called_once()
@pytest.mark.asyncio
async def test_process_receive_message_history_limit(self, api_client):
"""测试_process_receive_message历史记录限制"""
api_client.logger = Mock()
api_client.message_callbacks = []
api_client.message_history = []
api_client.max_history = 2 # 设置小的历史限制
# 添加3条消息,测试历史限制
for i in range(3):
timestamp = int(datetime.now().timestamp())
from_user = f"user{i}"
to_user = ""
message = f"message {i}"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
await api_client._process_receive_message(packet_data)
# 验证历史记录被限制在max_history
assert len(api_client.message_history) == api_client.max_history
@pytest.mark.asyncio
async def test_process_receive_message_exception_handling(self, api_client):
"""测试_process_receive_message异常处理"""
api_client.logger = Mock()
api_client.message_callbacks = []
api_client.message_history = []
# 创建一个格式错误的消息包(长度声明错误)
timestamp = int(datetime.now().timestamp())
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", 255) # 声明很长的from_user长度
packet_data += b"short" # 但实际数据很短
# 这应该触发异常处理
await api_client._process_receive_message(packet_data)
# 验证错误被记录
api_client.logger.error.assert_called()
error_call_args = api_client.logger.error.call_args[0][0]
assert "处理接收消息错误" in error_call_args
@pytest.mark.asyncio
async def test_send_message_basic_functionality(self, api_client):
"""测试send_message基本功能"""
api_client.connected = True
api_client.channel = Mock()
api_client.logger = Mock()
# Mock _send_packet
with patch.object(api_client, '_send_packet') as mock_send:
await api_client.send_message("target", "Hello")
# 验证_send_packet被调用
mock_send.assert_called_once()
call_args = mock_send.call_args
assert call_args[0][0] == SSHOUTPacketType.SEND_MESSAGE
@pytest.mark.asyncio
async def test_send_message_not_connected(self, api_client):
"""测试send_message未连接状态"""
api_client.connected = False
api_client.logger = Mock()
result = await api_client.send_message("target", "Hello")
# 验证返回False
assert result is False
@pytest.mark.asyncio
async def test_send_global_message_functionality(self, api_client):
"""测试send_global_message基本功能"""
api_client.connected = True
api_client.logger = Mock()
# Mock send_message
with patch.object(api_client, 'send_message') as mock_send:
await api_client.send_global_message("Global message")
# 验证send_message被调用
mock_send.assert_called_once_with("GLOBAL", "Global message", SSHOUTMessageType.PLAIN)
def test_api_client_initialization_and_basic_properties(self):
"""测试API客户端初始化和基本属性"""
client = SSHOUTApiClient(
hostname="test.host.com",
port=22333,
username="testuser",
key_path="/test/key/path",
timeout=30
)
assert client.hostname == "test.host.com"
assert client.port == 22333
assert client.username == "testuser"
assert client.key_path == "/test/key/path"
assert client.timeout == 30
assert not client.connected
assert client.message_callbacks == []
assert client.mention_callbacks == []
class TestSSHOUTApiClientAdvancedCoverage:
"""高级覆盖率测试 - 针对复杂场景和边界条件"""
def setup_method(self):
"""设置测试环境"""
self.config = {
'hostname': 'test.example.com',
'port': 22333,
'username': 'testuser',
'key_path': '/test/key',
'timeout': 10
}
@pytest.fixture
def api_client(self):
"""创建测试用的API客户端"""
client = SSHOUTApiClient(
hostname=self.config['hostname'],
port=self.config['port'],
username=self.config['username'],
key_path=self.config['key_path'],
timeout=self.config['timeout']
)
return client
@pytest.mark.asyncio
async def test_get_online_users_complete_flow(self, api_client):
"""测试_get_online_users完整流程"""
api_client.connected = True
api_client.logger = Mock()
with patch.object(api_client, '_send_packet') as mock_send, \
patch.object(api_client, '_receive_packet') as mock_receive:
# 模拟完整的在线用户信息包
user_id = 123
user_count = 1
username = "testuser"
hostname = "test.host.com"
# 构建在线用户信息包数据
packet_data = struct.pack(">H", user_id) # my_user_id
packet_data += struct.pack(">H", user_count) # user_count
# 用户信息:用户ID + 用户名长度 + 用户名 + 主机名长度 + 主机名
packet_data += struct.pack(">H", user_id) # 用户ID
packet_data += struct.pack("B", len(username)) # 用户名长度
packet_data += username.encode('utf-8') # 用户名
packet_data += struct.pack("B", len(hostname)) # 主机名长度
packet_data += hostname.encode('utf-8') # 主机名
mock_receive.return_value = (SSHOUTPacketType.ONLINE_USERS_INFO, packet_data)
await api_client._get_online_users()
# 验证日志记录用户信息
api_client.logger.debug.assert_called()
debug_call_args = api_client.logger.debug.call_args[0][0]
assert username in debug_call_args and hostname in debug_call_args
@pytest.mark.asyncio
async def test_get_online_users_exception_handling(self, api_client):
"""测试_get_online_users异常处理"""
api_client.connected = True
api_client.logger = Mock()
with patch.object(api_client, '_send_packet') as mock_send, \
patch.object(api_client, '_receive_packet') as mock_receive:
# 模拟_receive_packet抛出异常
mock_receive.side_effect = Exception("Network error")
await api_client._get_online_users()
# 验证异常被正确记录
api_client.logger.error.assert_called()
error_call_args = api_client.logger.error.call_args[0][0]
assert "获取在线用户失败" in error_call_args
@pytest.mark.asyncio
async def test_process_receive_message_async_callback(self, api_client):
"""测试_process_receive_message异步回调处理"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
# 设置异步回调
async_callback = AsyncMock()
api_client.message_callbacks = [async_callback]
# 创建消息包
timestamp = int(datetime.now().timestamp())
from_user = "sender"
to_user = ""
message = "async callback test"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
with patch('asyncio.create_task') as mock_create_task:
await api_client._process_receive_message(packet_data)
# 验证asyncio.create_task被调用
mock_create_task.assert_called()
@pytest.mark.asyncio
async def test_process_receive_message_sync_callback(self, api_client):
"""测试_process_receive_message同步回调处理"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
# 设置同步回调
sync_callback = Mock()
api_client.message_callbacks = [sync_callback]
# 创建消息包
timestamp = int(datetime.now().timestamp())
from_user = "sender"
to_user = ""
message = "sync callback test"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
await api_client._process_receive_message(packet_data)
# 验证同步回调被直接调用
sync_callback.assert_called()
@pytest.mark.asyncio
async def test_process_receive_message_callback_exception(self, api_client):
"""测试_process_receive_message回调异常处理"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
# 设置会抛出异常的回调
error_callback = Mock(side_effect=Exception("Callback error"))
api_client.message_callbacks = [error_callback]
# 创建消息包
timestamp = int(datetime.now().timestamp())
from_user = "sender"
to_user = ""
message = "error callback test"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
await api_client._process_receive_message(packet_data)
# 验证异常被正确记录
api_client.logger.error.assert_called()
error_call_args = api_client.logger.error.call_args[0][0]
assert "消息回调错误" in error_call_args
@pytest.mark.asyncio
async def test_process_receive_message_claude_mention_detection(self, api_client):
"""测试_process_receive_message @Claude提及检测"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
api_client.message_callbacks = []
# Mock _is_claude_mention 方法
with patch.object(api_client, '_is_claude_mention', return_value=True):
# 创建包含@Claude提及的消息包
timestamp = int(datetime.now().timestamp())
from_user = "user"
to_user = ""
message = "@Claude hello"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
await api_client._process_receive_message(packet_data)
# 验证@Claude提及被检测并记录
api_client.logger.info.assert_called()
info_call_args = api_client.logger.info.call_args[0][0]
assert "检测到@Claude提及" in info_call_args
@pytest.mark.asyncio
async def test_process_receive_message_mention_async_callback(self, api_client):
"""测试_process_receive_message @Claude提及异步回调"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
api_client.message_callbacks = []
# 设置异步mention回调
async_mention_callback = AsyncMock()
api_client.mention_callbacks = [async_mention_callback]
with patch.object(api_client, '_is_claude_mention', return_value=True):
# 创建消息包
timestamp = int(datetime.now().timestamp())
from_user = "user"
to_user = ""
message = "@Claude test"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
with patch('asyncio.create_task') as mock_create_task:
await api_client._process_receive_message(packet_data)
# 验证异步mention回调被创建为task
mock_create_task.assert_called()
@pytest.mark.asyncio
async def test_process_receive_message_mention_sync_callback(self, api_client):
"""测试_process_receive_message @Claude提及同步回调"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
api_client.message_callbacks = []
# 设置同步mention回调
sync_mention_callback = Mock()
api_client.mention_callbacks = [sync_mention_callback]
with patch.object(api_client, '_is_claude_mention', return_value=True):
# 创建消息包
timestamp = int(datetime.now().timestamp())
from_user = "user"
to_user = ""
message = "@Claude sync test"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
await api_client._process_receive_message(packet_data)
# 验证同步mention回调被直接调用
sync_mention_callback.assert_called()
@pytest.mark.asyncio
async def test_process_receive_message_mention_callback_exception(self, api_client):
"""测试_process_receive_message @Claude提及回调异常处理"""
api_client.logger = Mock()
api_client.message_history = []
api_client.max_history = 100
api_client.message_callbacks = []
# 设置会抛出异常的mention回调
error_mention_callback = Mock(side_effect=Exception("Mention callback error"))
api_client.mention_callbacks = [error_mention_callback]
with patch.object(api_client, '_is_claude_mention', return_value=True):
# 创建消息包
timestamp = int(datetime.now().timestamp())
from_user = "user"
to_user = ""
message = "@Claude error test"
packet_data = struct.pack(">Q", timestamp)
packet_data += struct.pack("B", len(from_user))
packet_data += from_user.encode('utf-8')
packet_data += struct.pack("B", len(to_user))
packet_data += struct.pack("B", SSHOUTMessageType.PLAIN)
packet_data += struct.pack(">I", len(message))
packet_data += message.encode('utf-8')
await api_client._process_receive_message(packet_data)
# 验证mention回调异常被正确记录
api_client.logger.error.assert_called()
@pytest.mark.asyncio
async def test_send_message_exception_handling(self, api_client):
"""测试send_message异常处理"""
api_client.connected = True
api_client.channel = Mock()
api_client.logger = Mock()
# Mock _send_packet 抛出异常
with patch.object(api_client, '_send_packet', side_effect=Exception("Send error")):
result = await api_client.send_message("target", "Hello")
# 验证异常被处理并返回False
assert result is False
api_client.logger.error.assert_called()
@pytest.mark.asyncio
async def test_process_user_state_change_basic(self, api_client):
"""测试_process_user_state_change基本功能"""
api_client.logger = Mock()
# 创建用户状态变化包数据
# 格式:用户ID(2字节) + 状态(1字节) + 用户名长度(1字节) + 用户名
user_id = 123
status = 1 # 上线
username = "testuser"
packet_data = struct.pack(">H", user_id) # 用户ID
packet_data += struct.pack("B", status) # 状态
packet_data += struct.pack("B", len(username)) # 用户名长度
packet_data += username.encode('utf-8') # 用户名
await api_client._process_user_state_change(packet_data)
# 验证日志记录
api_client.logger.info.assert_called()
@pytest.mark.asyncio
async def test_process_user_state_change_invalid_data(self, api_client):
"""测试_process_user_state_change无效数据处理"""
api_client.logger = Mock()
# 测试空数据
await api_client._process_user_state_change(b"")
# 测试长度不足的数据
await api_client._process_user_state_change(b"short")
# 验证没有崩溃(可能有日志记录)
assert True
@pytest.mark.asyncio
async def test_process_error_basic(self, api_client):
"""测试_process_error基本功能"""
api_client.logger = Mock()
# 创建错误包数据
error_code = 1
error_message = "Server connection lost"
packet_data = struct.pack("B", error_code) # 错误码
packet_data += struct.pack("B", len(error_message)) # 错误消息长度
packet_data += error_message.encode('utf-8') # 错误消息
await api_client._process_error(packet_data)
# 验证错误被记录
api_client.logger.error.assert_called()
@pytest.mark.asyncio
async def test_process_motd_basic(self, api_client):
"""测试_process_motd基本功能"""
api_client.logger = Mock()
# 创建MOTD包数据
motd_message = "Welcome to SSHOUT server!"
packet_data = struct.pack("B", len(motd_message)) # MOTD长度
packet_data += motd_message.encode('utf-8') # MOTD内容
await api_client._process_motd(packet_data)
# 验证MOTD被记录
api_client.logger.info.assert_called()
def test_channel_property_checks(self, api_client):
"""测试channel相关属性检查"""
# 测试channel为None时的各种方法
api_client.channel = None
# 这些应该能正常运行而不崩溃
assert api_client.channel is None
# 测试channel存在的情况
api_client.channel = Mock()
assert api_client.channel is not None
@pytest.mark.asyncio
async def test_additional_edge_cases(self, api_client):
"""测试其他边界情况和错误处理"""
api_client.logger = Mock()
# 测试各种处理函数的空数据处理
await api_client._process_user_state_change(b"")
await api_client._process_error(b"")
await api_client._process_motd(b"")
# 验证没有崩溃
assert True
if __name__ == '__main__':
pytest.main([__file__])