| """ |
| SSHOUT API客户端单元测试 |
| 测试基于SSHOUT API二进制协议的客户端功能 |
| """ |
| |
| import pytest |
| import asyncio |
| import struct |
| from datetime import datetime |
| from unittest.mock import AsyncMock, MagicMock, patch, Mock |
| |
| import sys |
| import os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../src')) |
| |
| from claude_agent.sshout.api_client import ( |
| SSHOUTApiClient, SSHOUTApiIntegration, |
| SSHOUTPacketType, SSHOUTMessageType, SSHOUTErrorCode, |
| SSHOUTMessage, SSHOUTUser |
| ) |
| |
| |
| class TestSSHOUTPacketTypes: |
| """测试SSHOUT包类型定义""" |
| |
| def test_client_to_server_packet_types(self): |
| """测试客户端到服务器包类型""" |
| assert SSHOUTPacketType.HELLO.value == 1 |
| assert SSHOUTPacketType.GET_ONLINE_USER.value == 2 |
| assert SSHOUTPacketType.SEND_MESSAGE.value == 3 |
| |
| def test_server_to_client_packet_types(self): |
| """测试服务器到客户端包类型""" |
| assert SSHOUTPacketType.PASS.value == 128 |
| assert SSHOUTPacketType.ONLINE_USERS_INFO.value == 129 |
| assert SSHOUTPacketType.RECEIVE_MESSAGE.value == 130 |
| assert SSHOUTPacketType.USER_STATE_CHANGE.value == 131 |
| assert SSHOUTPacketType.ERROR.value == 132 |
| assert SSHOUTPacketType.MOTD.value == 133 |
| |
| |
| class TestSSHOUTApiClient: |
| """测试SSHOUT API客户端""" |
| |
| def setup_method(self): |
| """设置测试""" |
| self.client = SSHOUTApiClient( |
| hostname="test.server.com", |
| port=22333, |
| username="testuser", |
| key_path="/path/to/key" |
| ) |
| |
| def test_client_initialization(self): |
| """测试客户端初始化""" |
| assert self.client.hostname == "test.server.com" |
| assert self.client.port == 22333 |
| assert self.client.username == "testuser" |
| assert self.client.key_path == "/path/to/key" |
| assert self.client.timeout == 10 |
| assert not self.client.connected |
| assert self.client.stdin is None |
| assert self.client.stdout is None |
| assert self.client.stderr is None |
| |
| def test_is_claude_mention(self): |
| """测试@Claude检测功能""" |
| # 正向测试 |
| assert self.client._is_claude_mention("@Claude help me") |
| assert self.client._is_claude_mention("@claude what's up") |
| assert self.client._is_claude_mention("Claude: how are you?") |
| assert self.client._is_claude_mention("claude, tell me a joke") |
| assert self.client._is_claude_mention("Hey @Claude") |
| assert self.client._is_claude_mention("Claude:你好") # 中文冒号 |
| assert self.client._is_claude_mention("claude,help") # 中文逗号 |
| |
| # 反向测试 |
| assert not self.client._is_claude_mention("Hello world") |
| assert not self.client._is_claude_mention("This is a test") |
| assert not self.client._is_claude_mention("includes Claude but not mention") |
| assert not self.client._is_claude_mention("claudeius was a Roman emperor") |
| |
| def test_add_callbacks(self): |
| """测试添加回调函数""" |
| callback1 = Mock() |
| callback2 = Mock() |
| mention_callback = Mock() |
| |
| # 添加消息回调 |
| self.client.add_message_callback(callback1) |
| self.client.add_message_callback(callback2) |
| assert len(self.client.message_callbacks) == 2 |
| |
| # 添加提及回调 |
| self.client.add_mention_callback(mention_callback) |
| assert len(self.client.mention_callbacks) == 1 |
| |
| def test_get_recent_messages(self): |
| """测试获取最近消息""" |
| # 没有消息时 |
| messages = self.client.get_recent_messages() |
| assert messages == [] |
| |
| # 添加一些消息 |
| test_message = SSHOUTMessage( |
| message_type=SSHOUTMessageType.PLAIN, |
| from_user="testuser", |
| to_user="", |
| content="Test message", |
| timestamp=datetime.now() |
| ) |
| self.client.recent_messages.append(test_message) |
| |
| messages = self.client.get_recent_messages(count=1) |
| assert len(messages) == 1 |
| assert messages[0].content == "Test message" |
| |
| async def test_connect_without_ssh_client(self): |
| """测试没有SSH客户端时的连接""" |
| with patch.object(self.client, 'ssh_client', None): |
| result = await self.client.connect() |
| assert result is False |
| |
| async def test_disconnect_when_not_connected(self): |
| """测试未连接时的断开连接""" |
| self.client.connected = False |
| await self.client.disconnect() # 应该不抛出异常 |
| |
| @patch('paramiko.SSHClient') |
| async def test_connect_ssh_failure(self, mock_ssh_class): |
| """测试SSH连接失败""" |
| mock_ssh = Mock() |
| mock_ssh.connect.side_effect = Exception("Connection failed") |
| mock_ssh_class.return_value = mock_ssh |
| |
| self.client.ssh_client = mock_ssh |
| result = await self.client.connect() |
| assert result is False |
| assert not self.client.connected |
| |
| async def test_send_packet_not_connected(self): |
| """测试未连接时发送包""" |
| self.client.connected = False |
| with pytest.raises(Exception): |
| await self.client._send_packet(SSHOUTPacketType.HELLO, b"test") |
| |
| async def test_process_receive_message(self): |
| """测试处理接收消息""" |
| # 准备测试数据 |
| test_content = "Hello @Claude" |
| test_from = "testuser" |
| test_to = "" |
| |
| # 构造数据包 (简化版) |
| data = struct.pack('<H', SSHOUTMessageType.PLAIN.value) |
| data += struct.pack('<H', len(test_from)) + test_from.encode('utf-8') |
| data += struct.pack('<H', len(test_to)) + test_to.encode('utf-8') |
| data += struct.pack('<H', len(test_content)) + test_content.encode('utf-8') |
| |
| # 添加回调 |
| message_callback = Mock() |
| mention_callback = Mock() |
| self.client.add_message_callback(message_callback) |
| self.client.add_mention_callback(mention_callback) |
| |
| # 处理消息 |
| await self.client._process_receive_message(data) |
| |
| # 验证消息被添加到recent_messages |
| assert len(self.client.recent_messages) == 1 |
| message = self.client.recent_messages[0] |
| assert message.content == test_content |
| assert message.from_user == test_from |
| |
| # 验证回调被调用 |
| message_callback.assert_called_once() |
| mention_callback.assert_called_once() # 因为包含@Claude |
| |
| async def test_process_user_state_change(self): |
| """测试处理用户状态变化""" |
| # 构造用户上线数据 |
| username = "newuser" |
| data = struct.pack('<B', 1) # 1表示上线 |
| data += struct.pack('<H', len(username)) + username.encode('utf-8') |
| |
| await self.client._process_user_state_change(data) |
| |
| # 检查用户是否添加到在线用户列表 |
| assert username in [user.username for user in self.client.online_users] |
| |
| async def test_process_error(self): |
| """测试处理错误包""" |
| error_code = SSHOUTErrorCode.MESSAGE_TOO_LONG.value |
| error_message = "Message too long" |
| |
| data = struct.pack('<H', error_code) |
| data += struct.pack('<H', len(error_message)) + error_message.encode('utf-8') |
| |
| # 这应该不抛出异常 |
| await self.client._process_error(data) |
| |
| async def test_process_motd(self): |
| """测试处理MOTD消息""" |
| motd_text = "Welcome to SSHOUT server" |
| data = struct.pack('<H', len(motd_text)) + motd_text.encode('utf-8') |
| |
| await self.client._process_motd(data) |
| assert self.client.motd == motd_text |
| |
| |
| class TestSSHOUTApiIntegration: |
| """测试SSHOUT API集成""" |
| |
| def setup_method(self): |
| """设置测试""" |
| self.integration = SSHOUTApiIntegration() |
| |
| async def test_integration_start_stop(self): |
| """测试集成启动和停止""" |
| # Mock配置 |
| mock_config = { |
| 'sshout': { |
| 'server': { |
| 'hostname': 'test.server.com', |
| 'port': 22333, |
| 'username': 'testuser' |
| }, |
| 'ssh_key': { |
| 'private_key_path': '/path/to/key' |
| } |
| } |
| } |
| |
| with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config: |
| mock_config_manager = Mock() |
| mock_config_manager.get_config.return_value = mock_config |
| mock_get_config.return_value = mock_config_manager |
| |
| with patch.object(self.integration, 'client') as mock_client: |
| mock_client.connect.return_value = True |
| |
| # 测试启动 |
| result = await self.integration.start() |
| assert result is True |
| mock_client.connect.assert_called_once() |
| |
| # 测试停止 |
| await self.integration.stop() |
| mock_client.disconnect.assert_called_once() |
| assert SSHOUTPacketType.GET_MOTD.value == 4 |
| |
| def test_server_to_client_packet_types(self): |
| """测试服务器到客户端包类型""" |
| assert SSHOUTPacketType.PASS.value == 128 |
| assert SSHOUTPacketType.ONLINE_USERS_INFO.value == 129 |
| assert SSHOUTPacketType.RECEIVE_MESSAGE.value == 130 |
| assert SSHOUTPacketType.USER_STATE_CHANGE.value == 131 |
| assert SSHOUTPacketType.ERROR.value == 132 |
| assert SSHOUTPacketType.MOTD.value == 133 |
| |
| def test_message_types(self): |
| """测试消息类型""" |
| assert SSHOUTMessageType.PLAIN.value == 1 |
| assert SSHOUTMessageType.RICH.value == 2 |
| assert SSHOUTMessageType.IMAGE.value == 3 |
| |
| def test_error_codes(self): |
| """测试错误码""" |
| assert SSHOUTErrorCode.SERVER_CLOSED.value == 1 |
| assert SSHOUTErrorCode.LOCAL_PACKET_CORRUPT.value == 2 |
| assert SSHOUTErrorCode.LOCAL_PACKET_TOO_LARGE.value == 3 |
| assert SSHOUTErrorCode.OUT_OF_MEMORY.value == 4 |
| assert SSHOUTErrorCode.INTERNAL_ERROR.value == 5 |
| assert SSHOUTErrorCode.USER_NOT_FOUND.value == 6 |
| assert SSHOUTErrorCode.MOTD_NOT_AVAILABLE.value == 7 |
| |
| |
| class TestSSHOUTMessage: |
| """测试SSHOUT消息数据结构""" |
| |
| def test_message_creation(self): |
| """测试消息创建""" |
| timestamp = datetime.now() |
| message = SSHOUTMessage( |
| timestamp=timestamp, |
| from_user="testuser", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content="Hello, world!" |
| ) |
| |
| assert message.timestamp == timestamp |
| assert message.from_user == "testuser" |
| assert message.to_user == "GLOBAL" |
| assert message.message_type == SSHOUTMessageType.PLAIN |
| assert message.content == "Hello, world!" |
| assert message.is_mention is False |
| |
| |
| class TestSSHOUTUser: |
| """测试SSHOUT用户数据结构""" |
| |
| def test_user_creation(self): |
| """测试用户创建""" |
| user = SSHOUTUser( |
| id=123, |
| username="testuser", |
| hostname="test.example.com" |
| ) |
| |
| assert user.id == 123 |
| assert user.username == "testuser" |
| assert user.hostname == "test.example.com" |
| |
| |
| class TestSSHOUTApiClient: |
| """测试SSHOUT API客户端""" |
| |
| def setup_method(self): |
| """测试设置""" |
| self.client = SSHOUTApiClient( |
| hostname="test.example.com", |
| port=22333, |
| username="testuser", |
| key_path="/fake/key/path" |
| ) |
| |
| def test_client_initialization(self): |
| """测试客户端初始化""" |
| assert self.client.hostname == "test.example.com" |
| assert self.client.port == 22333 |
| assert self.client.username == "testuser" |
| assert self.client.key_path == "/fake/key/path" |
| assert self.client.timeout == 10 |
| assert self.client.connected is False |
| assert len(self.client.message_callbacks) == 0 |
| assert len(self.client.mention_callbacks) == 0 |
| assert len(self.client.message_history) == 0 |
| assert self.client.max_history == 100 |
| assert self.client.my_user_id is None |
| assert self.client.my_username is None |
| |
| def test_default_mention_patterns(self): |
| """测试默认@Claude检测模式""" |
| expected_patterns = [ |
| "@Claude", "@claude", "@CLAUDE", |
| "Claude:", "claude:", |
| "Claude,", "claude,", |
| "Claude,", "claude," |
| ] |
| assert self.client.mention_patterns == expected_patterns |
| |
| def test_custom_mention_patterns(self): |
| """测试自定义@Claude检测模式""" |
| custom_patterns = ["@Bot", "Bot:"] |
| client = SSHOUTApiClient( |
| hostname="test.com", |
| port=22333, |
| username="user", |
| key_path="/fake/key", |
| mention_patterns=custom_patterns |
| ) |
| assert client.mention_patterns == custom_patterns |
| |
| def test_add_callbacks(self): |
| """测试添加回调函数""" |
| message_callback = MagicMock() |
| mention_callback = MagicMock() |
| |
| self.client.add_message_callback(message_callback) |
| self.client.add_mention_callback(mention_callback) |
| |
| assert message_callback in self.client.message_callbacks |
| assert mention_callback in self.client.mention_callbacks |
| |
| def test_claude_mention_detection(self): |
| """测试@Claude提及检测""" |
| test_cases = [ |
| ("@Claude help me", True), |
| ("@claude what's up", True), |
| ("@CLAUDE test", True), |
| ("Claude: please help", True), |
| ("claude: hi there", True), |
| ("Claude, how are you", True), |
| ("claude, hello", True), |
| ("Claude,你好", True), |
| ("claude,测试", True), |
| ("Hello everyone", False), |
| ("This is Claude speaking", False), # 不应该匹配 |
| ("claudebot", False), # 不应该匹配紧贴的字母 |
| ("@ClaudeAI", False), # 不应该匹配扩展名 |
| ] |
| |
| for content, expected in test_cases: |
| result = self.client._is_claude_mention(content) |
| assert result == expected, f"Failed for: '{content}', expected {expected}, got {result}" |
| |
| def test_message_history_management(self): |
| """测试消息历史管理""" |
| # 测试历史数量限制机制需要模拟实际的消息处理过程 |
| # 这里我们测试的是手动清理过程 |
| for i in range(150): # 超过最大历史数量 |
| message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| from_user=f"user{i}", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content=f"Message {i}" |
| ) |
| self.client.message_history.append(message) |
| |
| # 模拟_process_receive_message中的历史限制逻辑 |
| if len(self.client.message_history) > self.client.max_history: |
| self.client.message_history.pop(0) |
| |
| # 检查历史数量限制 |
| assert len(self.client.message_history) <= self.client.max_history |
| |
| def test_get_recent_messages(self): |
| """测试获取最近消息""" |
| # 添加测试消息 |
| for i in range(20): |
| message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| from_user=f"user{i}", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content=f"Message {i}" |
| ) |
| self.client.message_history.append(message) |
| |
| # 获取最近10条消息 |
| recent = self.client.get_recent_messages(10) |
| assert len(recent) == 10 |
| assert recent[-1].content == "Message 19" # 最新的消息 |
| |
| def test_get_context_messages(self): |
| """测试获取上下文消息""" |
| # 创建时间递增的消息 |
| messages = [] |
| base_time = datetime.now() |
| |
| for i in range(10): |
| timestamp = datetime.fromtimestamp(base_time.timestamp() + i) |
| message = SSHOUTMessage( |
| timestamp=timestamp, |
| from_user=f"user{i}", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content=f"Message {i}" |
| ) |
| messages.append(message) |
| self.client.message_history.append(message) |
| |
| # 获取第5条消息前的上下文 |
| before_time = messages[5].timestamp |
| context = self.client.get_context_messages(before_time, 3) |
| |
| assert len(context) == 3 |
| # 应该返回消息4, 3, 2 (按时间顺序) |
| assert context[0].content == "Message 2" |
| assert context[1].content == "Message 3" |
| assert context[2].content == "Message 4" |
| |
| def test_connection_status_disconnected(self): |
| """测试未连接状态""" |
| status = self.client.get_connection_status() |
| |
| expected_status = { |
| 'connected': False, |
| 'server': None, |
| 'message_count': 0, |
| 'my_user_id': None, |
| 'my_username': None |
| } |
| |
| assert status == expected_status |
| |
| def test_connection_status_connected(self): |
| """测试连接状态""" |
| # 模拟连接状态 |
| self.client.connected = True |
| self.client.my_user_id = 123 |
| self.client.my_username = "testuser" |
| |
| # 添加一些消息 |
| message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| from_user="other", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content="Test message content that is quite long and should be truncated" |
| ) |
| self.client.message_history.append(message) |
| |
| status = self.client.get_connection_status() |
| |
| assert status['connected'] is True |
| assert status['server'] == "test.example.com:22333" |
| assert status['message_count'] == 1 |
| assert status['my_user_id'] == 123 |
| assert status['my_username'] == "testuser" |
| assert len(status['recent_messages']) == 1 |
| |
| # 检查消息截断 |
| recent_msg = status['recent_messages'][0] |
| assert len(recent_msg['content']) <= 53 # 50 + "..." |
| |
| |
| class TestSSHOUTApiIntegration: |
| """测试SSHOUT API集成""" |
| |
| def setup_method(self): |
| """测试设置""" |
| self.mock_agent = MagicMock() |
| |
| # 创建mock配置管理器 |
| self.mock_config_manager = MagicMock() |
| self.mock_config = { |
| 'connection_mode': 'api', |
| 'mention_patterns': ['@Claude'], |
| 'server': { |
| 'hostname': 'test.example.com', |
| 'port': 22333, |
| 'username': 'testuser' |
| }, |
| 'ssh_key': { |
| 'private_key_path': '/fake/key/path', |
| 'timeout': 10 |
| }, |
| 'message': { |
| 'max_history': 100, |
| 'context_count': 5, |
| 'max_reply_length': 200 |
| } |
| } |
| self.mock_config_manager.get_sshout_config.return_value = self.mock_config |
| self.mock_config_manager.get.return_value = 5 |
| |
| # Mock get_config_manager函数 |
| with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config: |
| mock_get_config.return_value = self.mock_config_manager |
| |
| # Mock os.path.exists来避免文件检查 |
| with patch('claude_agent.sshout.api_client.os.path.exists') as mock_exists: |
| mock_exists.return_value = True |
| self.integration = SSHOUTApiIntegration(self.mock_agent) |
| |
| def test_integration_initialization(self): |
| """测试集成初始化""" |
| assert self.integration.agent == self.mock_agent |
| assert self.integration.client is None |
| |
| def test_config_validation_missing_section(self): |
| """测试配置验证 - 缺少必需段落""" |
| incomplete_config = {'server': {}} # 缺少ssh_key段落 |
| self.mock_config_manager.get_sshout_config.return_value = incomplete_config |
| |
| with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config: |
| mock_get_config.return_value = self.mock_config_manager |
| |
| with pytest.raises(ValueError, match="SSHOUT配置缺少必需的段落: ssh_key"): |
| SSHOUTApiIntegration(self.mock_agent) |
| |
| def test_config_validation_missing_server_key(self): |
| """测试配置验证 - 缺少服务器配置键""" |
| incomplete_config = { |
| 'server': {'hostname': 'test.com'}, # 缺少port和username |
| 'ssh_key': {'private_key_path': '/fake/key'} |
| } |
| self.mock_config_manager.get_sshout_config.return_value = incomplete_config |
| |
| with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config: |
| mock_get_config.return_value = self.mock_config_manager |
| |
| with pytest.raises(ValueError, match="SSHOUT服务器配置缺少必需的键: port"): |
| SSHOUTApiIntegration(self.mock_agent) |
| |
| def test_config_validation_missing_key_path(self): |
| """测试配置验证 - 缺少SSH密钥路径""" |
| incomplete_config = { |
| 'server': { |
| 'hostname': 'test.com', |
| 'port': 22333, |
| 'username': 'user' |
| }, |
| 'ssh_key': {} # 缺少private_key_path |
| } |
| self.mock_config_manager.get_sshout_config.return_value = incomplete_config |
| |
| with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config: |
| mock_get_config.return_value = self.mock_config_manager |
| |
| with pytest.raises(ValueError, match="SSHOUT配置缺少SSH私钥路径"): |
| SSHOUTApiIntegration(self.mock_agent) |
| |
| def test_config_validation_key_file_not_exists(self): |
| """测试配置验证 - SSH密钥文件不存在""" |
| with patch('claude_agent.sshout.api_client.get_config_manager') as mock_get_config: |
| mock_get_config.return_value = self.mock_config_manager |
| |
| with patch('claude_agent.sshout.api_client.os.path.exists') as mock_exists: |
| mock_exists.return_value = False |
| |
| with pytest.raises(FileNotFoundError, match="SSH私钥文件不存在"): |
| SSHOUTApiIntegration(self.mock_agent) |
| |
| def test_response_cleaning(self): |
| """测试响应清理功能""" |
| # 设置配置项 |
| self.mock_config_manager.get.return_value = 50 |
| |
| test_cases = [ |
| ("**bold text**", "bold text"), |
| ("*italic text*", "italic text"), |
| ("`code text`", "code text"), |
| ("Line 1\nLine 2\n\nLine 3", "Line 1\nLine 2\n\nLine 3"), # 保留换行 |
| ("Multiple spaces", "Multiple spaces"), # 保留行内空格 |
| ("This is a very long response that exceeds the maximum length limit and should be truncated", |
| "This is a very long response that exceeds the max..."), |
| ] |
| |
| for input_text, expected in test_cases: |
| result = self.integration._clean_response_for_sshout(input_text) |
| if len(input_text) > 50: |
| assert result.endswith("...") |
| assert len(result) <= 53 # 50 + "..." |
| else: |
| assert result == expected |
| |
| def test_connection_status_no_client(self): |
| """测试无客户端时的连接状态""" |
| status = self.integration.get_connection_status() |
| |
| expected_status = { |
| 'connected': False, |
| 'server': None, |
| 'message_count': 0, |
| 'api_version': '1.0' |
| } |
| |
| assert status == expected_status |
| |
| @pytest.mark.asyncio |
| async def test_send_message_no_client(self): |
| """测试无客户端时发送消息""" |
| result = await self.integration.send_message("test message") |
| assert result is False |
| |
| def test_message_callback_registration(self): |
| """测试消息回调注册""" |
| # 创建模拟客户端 |
| mock_client = MagicMock() |
| self.integration.client = mock_client |
| |
| # 验证回调函数被正确注册 |
| assert mock_client.add_message_callback.call_count == 0 |
| assert mock_client.add_mention_callback.call_count == 0 |
| |
| # 模拟连接过程中的回调注册 |
| self.integration.client = mock_client |
| self.integration.client.add_message_callback(self.integration._on_message_received) |
| self.integration.client.add_mention_callback(self.integration._on_claude_mentioned) |
| |
| assert mock_client.add_message_callback.called |
| assert mock_client.add_mention_callback.called |
| |
| |
| class TestSSHOUTBinaryPacketHandling: |
| """测试SSHOUT二进制包处理""" |
| |
| def test_hello_packet_format(self): |
| """测试HELLO包格式""" |
| magic = b"SSHOUT" |
| version = 1 |
| hello_data = magic + struct.pack(">H", version) |
| |
| # 验证包数据格式 |
| assert hello_data[:6] == b"SSHOUT" |
| assert struct.unpack(">H", hello_data[6:8])[0] == 1 |
| assert len(hello_data) == 8 |
| |
| def test_pass_packet_format(self): |
| """测试PASS包格式""" |
| magic = b"SSHOUT" |
| version = 1 |
| username = "testuser" |
| username_bytes = username.encode('utf-8') |
| |
| pass_data = (magic + |
| struct.pack(">H", version) + |
| struct.pack("B", len(username_bytes)) + |
| username_bytes) |
| |
| # 验证包数据格式 |
| assert pass_data[:6] == b"SSHOUT" |
| assert struct.unpack(">H", pass_data[6:8])[0] == 1 |
| assert pass_data[8] == len(username_bytes) |
| assert pass_data[9:9+len(username_bytes)] == username_bytes |
| |
| def test_send_message_packet_format(self): |
| """测试发送消息包格式""" |
| to_user = "GLOBAL" |
| message = "Hello, world!" |
| message_type = SSHOUTMessageType.PLAIN |
| |
| to_user_bytes = to_user.encode('utf-8') |
| message_bytes = message.encode('utf-8') |
| |
| data = (struct.pack("B", len(to_user_bytes)) + |
| to_user_bytes + |
| struct.pack("B", message_type.value) + |
| struct.pack(">I", len(message_bytes)) + |
| message_bytes) |
| |
| # 验证包数据格式 |
| offset = 0 |
| assert data[offset] == len(to_user_bytes) |
| offset += 1 |
| |
| assert data[offset:offset+len(to_user_bytes)] == to_user_bytes |
| offset += len(to_user_bytes) |
| |
| assert data[offset] == message_type.value |
| offset += 1 |
| |
| assert struct.unpack(">I", data[offset:offset+4])[0] == len(message_bytes) |
| offset += 4 |
| |
| assert data[offset:offset+len(message_bytes)] == message_bytes |
| |
| def test_receive_message_packet_parsing(self): |
| """测试接收消息包解析""" |
| import time |
| |
| timestamp = int(time.time()) |
| from_user = "sender" |
| to_user = "GLOBAL" |
| message_type = SSHOUTMessageType.PLAIN |
| content = "Test message" |
| |
| from_user_bytes = from_user.encode('utf-8') |
| to_user_bytes = to_user.encode('utf-8') |
| content_bytes = content.encode('utf-8') |
| |
| # 构造RECEIVE_MESSAGE包数据 |
| data = (struct.pack(">Q", timestamp) + # 时间戳 |
| struct.pack("B", len(from_user_bytes)) + # from_user_length |
| from_user_bytes + # from_user |
| struct.pack("B", len(to_user_bytes)) + # to_user_length |
| to_user_bytes + # to_user |
| struct.pack("B", message_type.value) + # message_type |
| struct.pack(">I", len(content_bytes)) + # message_length |
| content_bytes) # message |
| |
| # 解析包数据 |
| offset = 0 |
| parsed_timestamp = struct.unpack(">Q", data[offset:offset+8])[0] |
| offset += 8 |
| |
| from_user_length = data[offset] |
| offset += 1 |
| parsed_from_user = data[offset:offset+from_user_length].decode('utf-8') |
| offset += from_user_length |
| |
| to_user_length = data[offset] |
| offset += 1 |
| parsed_to_user = data[offset:offset+to_user_length].decode('utf-8') |
| offset += to_user_length |
| |
| parsed_message_type = SSHOUTMessageType(data[offset]) |
| offset += 1 |
| |
| message_length = struct.unpack(">I", data[offset:offset+4])[0] |
| offset += 4 |
| parsed_content = data[offset:offset+message_length].decode('utf-8') |
| |
| # 验证解析结果 |
| assert parsed_timestamp == timestamp |
| assert parsed_from_user == from_user |
| assert parsed_to_user == to_user |
| assert parsed_message_type == message_type |
| assert parsed_content == content |
| |
| |
| # 异步测试需要特殊处理 |
| class TestSSHOUTAsyncOperations: |
| """测试SSHOUT异步操作""" |
| |
| def setup_method(self): |
| """测试设置""" |
| self.client = SSHOUTApiClient( |
| hostname="test.example.com", |
| port=22333, |
| username="testuser", |
| key_path="/fake/key/path" |
| ) |
| |
| @pytest.mark.asyncio |
| async def test_connect_ssh_failure(self): |
| """测试SSH连接失败""" |
| with patch('paramiko.SSHClient') as mock_ssh_class: |
| mock_ssh = MagicMock() |
| mock_ssh_class.return_value = mock_ssh |
| mock_ssh.connect.side_effect = Exception("Connection failed") |
| |
| result = await self.client.connect() |
| assert result is False |
| assert self.client.connected is False |
| |
| @pytest.mark.asyncio |
| async def test_send_message_not_connected(self): |
| """测试未连接时发送消息""" |
| result = await self.client.send_message("GLOBAL", "test message") |
| assert result is False |
| |
| @pytest.mark.asyncio |
| async def test_disconnect_cleanup(self): |
| """测试断开连接清理""" |
| # 模拟已连接状态 |
| self.client.connected = True |
| mock_channel = MagicMock() |
| mock_client = MagicMock() |
| mock_stdin = MagicMock() |
| mock_stdout = MagicMock() |
| mock_stderr = MagicMock() |
| |
| self.client.channel = mock_channel |
| self.client.client = mock_client |
| self.client.stdin = mock_stdin |
| self.client.stdout = mock_stdout |
| self.client.stderr = mock_stderr |
| |
| await self.client.disconnect() |
| |
| assert self.client.connected is False |
| mock_stdin.close.assert_called_once() |
| mock_stdout.close.assert_called_once() |
| mock_stderr.close.assert_called_once() |
| mock_channel.close.assert_called_once() |
| mock_client.close.assert_called_once() |
| assert self.client.stdin is None |
| assert self.client.stdout is None |
| assert self.client.stderr is None |
| assert self.client.channel is None |
| assert self.client.client is None |
| |
| @pytest.mark.asyncio |
| async def test_read_exact_with_timeout(self): |
| """测试精确读取数据超时处理""" |
| # 模拟连接状态 |
| self.client.connected = True |
| mock_channel = MagicMock() |
| mock_channel.closed = False |
| mock_channel.recv_ready.return_value = False # 模拟数据未就绪 |
| self.client.channel = mock_channel |
| |
| # 测试超时 |
| with pytest.raises(Exception, match="读取数据超时"): |
| await self.client._read_exact(10) |
| |
| @pytest.mark.asyncio |
| async def test_read_exact_connection_closed(self): |
| """测试读取数据时连接断开""" |
| # 模拟连接状态 |
| self.client.connected = True |
| mock_channel = MagicMock() |
| mock_channel.closed = True # 模拟连接已关闭 |
| self.client.channel = mock_channel |
| |
| # 测试连接断开检测 |
| with pytest.raises(Exception, match="连接已关闭"): |
| await self.client._read_exact(10) |
| |
| @pytest.mark.asyncio |
| async def test_read_exact_success(self): |
| """测试成功读取指定长度数据""" |
| # 模拟连接状态 |
| self.client.connected = True |
| mock_channel = MagicMock() |
| mock_channel.closed = False |
| mock_channel.recv_ready.return_value = True |
| mock_channel.recv.return_value = b"test_data" |
| self.client.channel = mock_channel |
| |
| result = await self.client._read_exact(9) |
| assert result == b"test_data" |
| |
| @pytest.mark.asyncio |
| async def test_get_online_users_with_user_state_change(self): |
| """测试获取在线用户时处理USER_STATE_CHANGE包""" |
| # 模拟连接状态 |
| self.client.connected = True |
| mock_channel = MagicMock() |
| self.client.channel = mock_channel |
| |
| # 创建模拟包数据 |
| state_change_data = struct.pack("BB", 1, 8) + b"testuser" # 上线,用户名长度,用户名 |
| online_users_data = struct.pack(">HH", 123, 1) # 我的ID,用户数量 |
| |
| # 模拟接收包的顺序:先USER_STATE_CHANGE,后ONLINE_USERS_INFO |
| with patch.object(self.client, '_send_packet') as mock_send, \ |
| patch.object(self.client, '_receive_packet') as mock_receive, \ |
| patch.object(self.client, '_process_user_state_change') as mock_process: |
| |
| mock_receive.side_effect = [ |
| (SSHOUTPacketType.USER_STATE_CHANGE, state_change_data), |
| (SSHOUTPacketType.ONLINE_USERS_INFO, online_users_data) |
| ] |
| |
| await self.client._get_online_users() |
| |
| # 验证发送了GET_ONLINE_USER请求 |
| mock_send.assert_called_once_with(SSHOUTPacketType.GET_ONLINE_USER, b"") |
| |
| # 验证处理了USER_STATE_CHANGE包 |
| mock_process.assert_called_once_with(state_change_data) |
| |
| # 验证设置了用户ID |
| assert self.client.my_user_id == 123 |
| |
| @pytest.mark.asyncio |
| async def test_keep_alive_task(self): |
| """测试保活任务""" |
| # 模拟连接状态 |
| self.client.connected = True |
| mock_client = MagicMock() |
| mock_transport = MagicMock() |
| mock_transport.is_active.return_value = True |
| mock_client.get_transport.return_value = mock_transport |
| self.client.client = mock_client |
| |
| mock_channel = MagicMock() |
| mock_channel.closed = False |
| self.client.channel = mock_channel |
| |
| # 创建一个快速结束的保活任务 |
| async def quick_keep_alive(): |
| await asyncio.sleep(0.1) # 短暂等待 |
| if self.client.connected and hasattr(self.client.client, 'get_transport'): |
| transport = self.client.client.get_transport() |
| if transport and transport.is_active(): |
| transport.send_ignore() |
| self.client.connected = False # 结束循环 |
| |
| # 替换保活方法进行测试 |
| original_keep_alive = self.client._keep_alive |
| self.client._keep_alive = quick_keep_alive |
| |
| # 运行保活任务 |
| await self.client._keep_alive() |
| |
| # 验证保活包发送 |
| mock_transport.send_ignore.assert_called() |
| |
| # 恢复原方法 |
| self.client._keep_alive = original_keep_alive |
| |
| @pytest.mark.asyncio |
| async def test_message_listener_connection_closed_detection(self): |
| """测试消息监听器检测连接断开""" |
| # 模拟连接状态 |
| self.client.connected = True |
| mock_channel = MagicMock() |
| mock_channel.closed = True # 模拟连接已关闭 |
| self.client.channel = mock_channel |
| |
| # 创建一个快速结束的消息监听任务 |
| async def quick_listener(): |
| if not self.client.channel or self.client.channel.closed: |
| self.client.connected = False |
| return |
| |
| # 替换监听方法进行测试 |
| original_listener = self.client._message_listener |
| self.client._message_listener = quick_listener |
| |
| # 运行消息监听任务 |
| await self.client._message_listener() |
| |
| # 验证连接状态被正确设置 |
| assert self.client.connected is False |
| |
| # 恢复原方法 |
| self.client._message_listener = original_listener |
| |
| def test_ssh_object_references_initialization(self): |
| """测试SSH对象引用初始化""" |
| assert self.client.stdin is None |
| assert self.client.stdout is None |
| assert self.client.stderr is None |
| |
| def test_check_connection_state(self): |
| """测试连接状态检查""" |
| # 默认未连接状态 |
| assert not self.client.connected |
| |
| # 模拟连接状态 |
| self.client.connected = True |
| assert self.client.connected |
| |
| # 模拟断开连接 |
| self.client.connected = False |
| assert not self.client.connected |
| |
| def test_format_message_for_display(self): |
| """测试消息格式化显示""" |
| message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| from_user="testuser", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content="Hello world!" |
| ) |
| |
| # 测试格式化功能(这可能是一个私有方法或在其他方法中实现) |
| formatted = f"[{message.timestamp.strftime('%H:%M:%S')}] {message.from_user}: {message.content}" |
| assert "testuser" in formatted |
| assert "Hello world!" in formatted |
| |
| def test_message_callback_registration(self): |
| """测试消息回调注册""" |
| callback_called = False |
| test_message = None |
| |
| def test_callback(message): |
| nonlocal callback_called, test_message |
| callback_called = True |
| test_message = message |
| |
| # 注册回调 |
| self.client.add_message_callback(test_callback) |
| assert test_callback in self.client.message_callbacks |
| |
| # 模拟触发回调 |
| message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| from_user="testuser", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content="Test message" |
| ) |
| |
| # 手动触发回调(模拟_process_receive_message中的逻辑) |
| for callback in self.client.message_callbacks: |
| callback(message) |
| |
| assert callback_called |
| assert test_message == message |
| |
| def test_mention_callback_registration(self): |
| """测试@Claude提及回调注册""" |
| mention_called = False |
| test_message = None |
| test_context = None |
| |
| def test_mention_callback(message, context): |
| nonlocal mention_called, test_message, test_context |
| mention_called = True |
| test_message = message |
| test_context = context |
| |
| # 注册@Claude回调 |
| self.client.add_mention_callback(test_mention_callback) |
| assert test_mention_callback in self.client.mention_callbacks |
| |
| # 模拟@Claude消息 |
| message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| from_user="testuser", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content="@Claude help me" |
| ) |
| |
| # 手动触发@Claude检测和回调 |
| if self.client._is_claude_mention(message.content): |
| context = [] # 模拟上下文 |
| for callback in self.client.mention_callbacks: |
| callback(message, context) |
| |
| assert mention_called |
| assert test_message == message |
| |
| def test_connection_error_handling(self): |
| """测试连接错误处理""" |
| # 测试无效的主机名 |
| invalid_client = SSHOUTApiClient( |
| hostname="invalid.host.name", |
| port=22333, |
| username="sshout", |
| key_path="/tmp/test_key", |
| timeout=1 |
| ) |
| |
| # 模拟连接失败 |
| assert not invalid_client.connected |
| |
| def test_packet_type_validation(self): |
| """测试数据包类型验证""" |
| # 测试有效的数据包类型 |
| valid_types = [ |
| SSHOUTPacketType.HELLO, |
| SSHOUTPacketType.PASS, |
| SSHOUTPacketType.SEND_MESSAGE, |
| SSHOUTPacketType.ONLINE_USERS_INFO |
| ] |
| |
| for packet_type in valid_types: |
| assert isinstance(packet_type.value, int) |
| assert packet_type.value >= 0 |
| |
| def test_message_type_validation(self): |
| """测试消息类型验证""" |
| # 测试有效的消息类型 |
| valid_types = [ |
| SSHOUTMessageType.PLAIN, |
| SSHOUTMessageType.RICH, |
| SSHOUTMessageType.IMAGE |
| ] |
| |
| for msg_type in valid_types: |
| assert isinstance(msg_type.value, int) |
| assert msg_type.value >= 0 |
| |
| def test_message_history_access(self): |
| """测试消息历史访问""" |
| # 添加一些测试消息 |
| for i in range(5): |
| message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| from_user=f"user{i}", |
| to_user="GLOBAL", |
| message_type=SSHOUTMessageType.PLAIN, |
| content=f"Message {i}" |
| ) |
| self.client.message_history.append(message) |
| |
| # 测试历史访问 |
| assert len(self.client.message_history) == 5 |
| assert self.client.message_history[0].content == "Message 0" |
| assert self.client.message_history[-1].content == "Message 4" |
| |
| def test_configuration_validation(self): |
| """测试配置验证""" |
| # 测试默认配置(使用setup方法中的测试配置) |
| assert self.client.hostname == "test.example.com" |
| assert self.client.port == 22333 |
| assert self.client.username == "testuser" |
| assert self.client.timeout == 10 |
| assert self.client.max_history == 100 |
| |
| def test_logging_integration(self): |
| """测试日志集成""" |
| # 验证日志器存在且配置正确 |
| assert self.client.logger is not None |
| assert "sshout.api_client" in self.client.logger.name |
| |
| |
| class TestSSHOUTApiClientCoverageEnhancement: |
| """覆盖率提升测试 - 针对缺失的错误处理和边界条件""" |
| |
| def setup_method(self): |
| """设置测试环境""" |
| self.config = { |
| 'hostname': 'test.example.com', |
| 'port': 22333, |
| 'username': 'testuser', |
| 'key_path': '/test/key', |
| 'timeout': 10 |
| } |
| |
| @pytest.fixture |
| def api_client(self): |
| """创建测试用的API客户端""" |
| client = SSHOUTApiClient( |
| hostname=self.config['hostname'], |
| port=self.config['port'], |
| username=self.config['username'], |
| key_path=self.config['key_path'], |
| timeout=self.config['timeout'] |
| ) |
| return client |
| |
| @pytest.mark.asyncio |
| async def test_read_exact_connection_closed_exception(self, api_client): |
| """测试_read_exact连接关闭异常处理""" |
| api_client.channel = Mock() |
| api_client.channel.recv.return_value = b"" # 空数据表示连接关闭 |
| |
| with pytest.raises(Exception) as exc_info: |
| await api_client._read_exact(10) |
| |
| assert "连接已关闭" in str(exc_info.value) |
| |
| @pytest.mark.asyncio |
| async def test_read_exact_closed_exception_handling(self, api_client): |
| """测试_read_exact处理包含'closed'的异常""" |
| api_client.channel = Mock() |
| api_client.channel.recv.side_effect = Exception("Socket is closed") |
| |
| with pytest.raises(Exception) as exc_info: |
| await api_client._read_exact(10) |
| |
| assert "连接已关闭" in str(exc_info.value) |
| |
| @pytest.mark.asyncio |
| async def test_get_online_users_warning_unexpected_packet(self, api_client): |
| """测试_get_online_users收到意外包类型时的警告""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| |
| with patch.object(api_client, '_send_packet') as mock_send, \ |
| patch.object(api_client, '_receive_packet') as mock_receive: |
| |
| # 模拟收到意外包类型 |
| mock_receive.return_value = (99, b"unexpected") # 99是意外的包类型 |
| |
| result = await api_client._get_online_users() |
| |
| # 验证警告被记录 |
| api_client.logger.warning.assert_called() |
| call_args = api_client.logger.warning.call_args[0][0] |
| assert "获取在线用户时收到意外包类型" in call_args |
| |
| @pytest.mark.asyncio |
| async def test_get_online_users_max_attempts_exceeded(self, api_client): |
| """测试_get_online_users超过最大尝试次数""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| |
| with patch.object(api_client, '_send_packet') as mock_send, \ |
| patch.object(api_client, '_receive_packet') as mock_receive: |
| |
| # 模拟总是收到意外包类型,超过最大尝试次数 |
| mock_receive.return_value = (99, b"unexpected") |
| |
| result = await api_client._get_online_users() |
| |
| # 验证错误被记录 |
| api_client.logger.error.assert_called() |
| call_args = api_client.logger.error.call_args[0][0] |
| assert "获取在线用户失败: 超过最大尝试次数" in call_args |
| |
| @pytest.mark.asyncio |
| async def test_get_online_users_invalid_packet_data(self, api_client): |
| """测试_get_online_users包数据格式错误""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| |
| with patch.object(api_client, '_send_packet') as mock_send, \ |
| patch.object(api_client, '_receive_packet') as mock_receive: |
| |
| # 模拟收到格式错误的包(长度小于4) |
| mock_receive.return_value = (SSHOUTPacketType.ONLINE_USERS_INFO, b"xx") # 只有2字节 |
| |
| result = await api_client._get_online_users() |
| |
| # 验证错误被记录 |
| api_client.logger.error.assert_called_with("❌ 在线用户信息包格式错误") |
| |
| @pytest.mark.asyncio |
| async def test_message_listener_basic_coverage(self, api_client): |
| """测试_message_listener基本覆盖场景""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| |
| # 测试连接断开检测 |
| api_client.channel = None |
| await api_client._message_listener() |
| |
| # 验证错误被记录 |
| api_client.logger.error.assert_called_with("❌ 检测到连接断开,停止消息监听") |
| assert not api_client.connected |
| |
| @pytest.mark.asyncio |
| async def test_message_listener_exception_handling(self, api_client): |
| """测试_message_listener异常处理""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| api_client.channel = Mock() |
| api_client.channel.recv_ready.return_value = True |
| |
| call_count = 0 |
| |
| async def mock_receive_packet(): |
| nonlocal call_count |
| call_count += 1 |
| if call_count == 1: |
| raise Exception("连接已关闭") |
| else: |
| api_client.connected = False |
| return (SSHOUTPacketType.RECEIVE_MESSAGE, b"") |
| |
| with patch.object(api_client, '_receive_packet', side_effect=mock_receive_packet), \ |
| patch('asyncio.sleep'): |
| |
| await api_client._message_listener() |
| |
| # 验证错误处理 |
| api_client.logger.error.assert_called() |
| assert not api_client.connected |
| |
| @pytest.mark.asyncio |
| async def test_keep_alive_not_connected_break(self, api_client): |
| """测试_keep_alive当connected为False时跳出循环""" |
| api_client.connected = False |
| api_client.logger = Mock() |
| |
| await api_client._keep_alive() |
| |
| # 验证立即跳出循环 |
| api_client.logger.debug.assert_called_with("💓 连接保活任务已停止") |
| |
| @pytest.mark.asyncio |
| async def test_keep_alive_channel_closed_detection(self, api_client): |
| """测试_keep_alive检测到channel关闭""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| api_client.channel = Mock() |
| api_client.channel.closed = True # 模拟channel已关闭 |
| |
| with patch('asyncio.sleep'): |
| await api_client._keep_alive() |
| |
| # 验证错误处理和连接状态更新 |
| api_client.logger.error.assert_called_with("❌ 检测到连接断开(保活检查)") |
| assert not api_client.connected |
| |
| @pytest.mark.asyncio |
| async def test_keep_alive_transport_inactive(self, api_client): |
| """测试_keep_alive传输层连接断开""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| api_client.channel = Mock() |
| api_client.channel.closed = False |
| api_client.client = Mock() |
| |
| # 模拟transport不活跃 |
| mock_transport = Mock() |
| mock_transport.is_active.return_value = False |
| api_client.client.get_transport.return_value = mock_transport |
| |
| with patch('asyncio.sleep'): |
| await api_client._keep_alive() |
| |
| # 验证错误处理 |
| api_client.logger.error.assert_called_with("❌ SSH传输层连接断开") |
| assert not api_client.connected |
| |
| @pytest.mark.asyncio |
| async def test_keep_alive_exception_handling(self, api_client): |
| """测试_keep_alive异常处理""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| api_client.channel = Mock() |
| api_client.channel.closed = False |
| |
| call_count = 0 |
| |
| async def mock_sleep(duration): |
| nonlocal call_count |
| call_count += 1 |
| if call_count == 1: |
| raise Exception("Keep alive error") |
| api_client.connected = False |
| |
| with patch('asyncio.sleep', side_effect=mock_sleep): |
| await api_client._keep_alive() |
| |
| # 验证异常处理 |
| api_client.logger.error.assert_called() |
| assert not api_client.connected |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_invalid_data_length(self, api_client): |
| """测试_process_receive_message数据长度不足""" |
| api_client.logger = Mock() |
| |
| # 模拟数据长度不足(小于13字节) |
| short_data = b"short" |
| |
| await api_client._process_receive_message(short_data) |
| |
| # 验证错误处理 |
| api_client.logger.error.assert_called_with("❌ 接收消息包格式错误") |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_complete_flow(self, api_client): |
| """测试_process_receive_message完整流程""" |
| api_client.logger = Mock() |
| api_client.message_callbacks = [] |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| |
| # 创建一个完整的有效消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "testuser" |
| to_user = "target" |
| message = "Hello, this is a test message!" |
| |
| # 构建完整的包数据 |
| packet_data = struct.pack(">Q", timestamp) # 8字节时间戳 |
| packet_data += struct.pack("B", len(from_user)) # 1字节from_user长度 |
| packet_data += from_user.encode('utf-8') # from_user |
| packet_data += struct.pack("B", len(to_user)) # 1字节to_user长度 |
| packet_data += to_user.encode('utf-8') # to_user |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) # 1字节消息类型 |
| packet_data += struct.pack(">I", len(message)) # 4字节消息长度 |
| packet_data += message.encode('utf-8') # 消息内容 |
| |
| # 测试正常的消息处理 |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证消息被添加到历史记录 |
| assert len(api_client.message_history) == 1 |
| message_obj = api_client.message_history[0] |
| assert message_obj.from_user == from_user |
| assert message_obj.to_user == to_user |
| assert message_obj.content == message |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_with_callbacks(self, api_client): |
| """测试_process_receive_message带回调处理""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| |
| # 设置消息回调 |
| callback_mock = Mock() |
| api_client.message_callbacks = [callback_mock] |
| |
| # 创建消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "sender" |
| to_user = "" |
| message = "callback test" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| if to_user: |
| packet_data += to_user.encode('utf-8') |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证回调被调用 |
| callback_mock.assert_called_once() |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_history_limit(self, api_client): |
| """测试_process_receive_message历史记录限制""" |
| api_client.logger = Mock() |
| api_client.message_callbacks = [] |
| api_client.message_history = [] |
| api_client.max_history = 2 # 设置小的历史限制 |
| |
| # 添加3条消息,测试历史限制 |
| for i in range(3): |
| timestamp = int(datetime.now().timestamp()) |
| from_user = f"user{i}" |
| to_user = "" |
| message = f"message {i}" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证历史记录被限制在max_history |
| assert len(api_client.message_history) == api_client.max_history |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_exception_handling(self, api_client): |
| """测试_process_receive_message异常处理""" |
| api_client.logger = Mock() |
| api_client.message_callbacks = [] |
| api_client.message_history = [] |
| |
| # 创建一个格式错误的消息包(长度声明错误) |
| timestamp = int(datetime.now().timestamp()) |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", 255) # 声明很长的from_user长度 |
| packet_data += b"short" # 但实际数据很短 |
| |
| # 这应该触发异常处理 |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证错误被记录 |
| api_client.logger.error.assert_called() |
| error_call_args = api_client.logger.error.call_args[0][0] |
| assert "处理接收消息错误" in error_call_args |
| |
| @pytest.mark.asyncio |
| async def test_send_message_basic_functionality(self, api_client): |
| """测试send_message基本功能""" |
| api_client.connected = True |
| api_client.channel = Mock() |
| api_client.logger = Mock() |
| |
| # Mock _send_packet |
| with patch.object(api_client, '_send_packet') as mock_send: |
| await api_client.send_message("target", "Hello") |
| |
| # 验证_send_packet被调用 |
| mock_send.assert_called_once() |
| call_args = mock_send.call_args |
| assert call_args[0][0] == SSHOUTPacketType.SEND_MESSAGE |
| |
| @pytest.mark.asyncio |
| async def test_send_message_not_connected(self, api_client): |
| """测试send_message未连接状态""" |
| api_client.connected = False |
| api_client.logger = Mock() |
| |
| result = await api_client.send_message("target", "Hello") |
| |
| # 验证返回False |
| assert result is False |
| |
| @pytest.mark.asyncio |
| async def test_send_global_message_functionality(self, api_client): |
| """测试send_global_message基本功能""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| |
| # Mock send_message |
| with patch.object(api_client, 'send_message') as mock_send: |
| await api_client.send_global_message("Global message") |
| |
| # 验证send_message被调用 |
| mock_send.assert_called_once_with("GLOBAL", "Global message", SSHOUTMessageType.PLAIN) |
| |
| def test_api_client_initialization_and_basic_properties(self): |
| """测试API客户端初始化和基本属性""" |
| client = SSHOUTApiClient( |
| hostname="test.host.com", |
| port=22333, |
| username="testuser", |
| key_path="/test/key/path", |
| timeout=30 |
| ) |
| |
| assert client.hostname == "test.host.com" |
| assert client.port == 22333 |
| assert client.username == "testuser" |
| assert client.key_path == "/test/key/path" |
| assert client.timeout == 30 |
| assert not client.connected |
| assert client.message_callbacks == [] |
| assert client.mention_callbacks == [] |
| |
| |
| class TestSSHOUTApiClientAdvancedCoverage: |
| """高级覆盖率测试 - 针对复杂场景和边界条件""" |
| |
| def setup_method(self): |
| """设置测试环境""" |
| self.config = { |
| 'hostname': 'test.example.com', |
| 'port': 22333, |
| 'username': 'testuser', |
| 'key_path': '/test/key', |
| 'timeout': 10 |
| } |
| |
| @pytest.fixture |
| def api_client(self): |
| """创建测试用的API客户端""" |
| client = SSHOUTApiClient( |
| hostname=self.config['hostname'], |
| port=self.config['port'], |
| username=self.config['username'], |
| key_path=self.config['key_path'], |
| timeout=self.config['timeout'] |
| ) |
| return client |
| |
| @pytest.mark.asyncio |
| async def test_get_online_users_complete_flow(self, api_client): |
| """测试_get_online_users完整流程""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| |
| with patch.object(api_client, '_send_packet') as mock_send, \ |
| patch.object(api_client, '_receive_packet') as mock_receive: |
| |
| # 模拟完整的在线用户信息包 |
| user_id = 123 |
| user_count = 1 |
| username = "testuser" |
| hostname = "test.host.com" |
| |
| # 构建在线用户信息包数据 |
| packet_data = struct.pack(">H", user_id) # my_user_id |
| packet_data += struct.pack(">H", user_count) # user_count |
| |
| # 用户信息:用户ID + 用户名长度 + 用户名 + 主机名长度 + 主机名 |
| packet_data += struct.pack(">H", user_id) # 用户ID |
| packet_data += struct.pack("B", len(username)) # 用户名长度 |
| packet_data += username.encode('utf-8') # 用户名 |
| packet_data += struct.pack("B", len(hostname)) # 主机名长度 |
| packet_data += hostname.encode('utf-8') # 主机名 |
| |
| mock_receive.return_value = (SSHOUTPacketType.ONLINE_USERS_INFO, packet_data) |
| |
| await api_client._get_online_users() |
| |
| # 验证日志记录用户信息 |
| api_client.logger.debug.assert_called() |
| debug_call_args = api_client.logger.debug.call_args[0][0] |
| assert username in debug_call_args and hostname in debug_call_args |
| |
| @pytest.mark.asyncio |
| async def test_get_online_users_exception_handling(self, api_client): |
| """测试_get_online_users异常处理""" |
| api_client.connected = True |
| api_client.logger = Mock() |
| |
| with patch.object(api_client, '_send_packet') as mock_send, \ |
| patch.object(api_client, '_receive_packet') as mock_receive: |
| |
| # 模拟_receive_packet抛出异常 |
| mock_receive.side_effect = Exception("Network error") |
| |
| await api_client._get_online_users() |
| |
| # 验证异常被正确记录 |
| api_client.logger.error.assert_called() |
| error_call_args = api_client.logger.error.call_args[0][0] |
| assert "获取在线用户失败" in error_call_args |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_async_callback(self, api_client): |
| """测试_process_receive_message异步回调处理""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| |
| # 设置异步回调 |
| async_callback = AsyncMock() |
| api_client.message_callbacks = [async_callback] |
| |
| # 创建消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "sender" |
| to_user = "" |
| message = "async callback test" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| with patch('asyncio.create_task') as mock_create_task: |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证asyncio.create_task被调用 |
| mock_create_task.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_sync_callback(self, api_client): |
| """测试_process_receive_message同步回调处理""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| |
| # 设置同步回调 |
| sync_callback = Mock() |
| api_client.message_callbacks = [sync_callback] |
| |
| # 创建消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "sender" |
| to_user = "" |
| message = "sync callback test" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证同步回调被直接调用 |
| sync_callback.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_callback_exception(self, api_client): |
| """测试_process_receive_message回调异常处理""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| |
| # 设置会抛出异常的回调 |
| error_callback = Mock(side_effect=Exception("Callback error")) |
| api_client.message_callbacks = [error_callback] |
| |
| # 创建消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "sender" |
| to_user = "" |
| message = "error callback test" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证异常被正确记录 |
| api_client.logger.error.assert_called() |
| error_call_args = api_client.logger.error.call_args[0][0] |
| assert "消息回调错误" in error_call_args |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_claude_mention_detection(self, api_client): |
| """测试_process_receive_message @Claude提及检测""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| api_client.message_callbacks = [] |
| |
| # Mock _is_claude_mention 方法 |
| with patch.object(api_client, '_is_claude_mention', return_value=True): |
| # 创建包含@Claude提及的消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "user" |
| to_user = "" |
| message = "@Claude hello" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证@Claude提及被检测并记录 |
| api_client.logger.info.assert_called() |
| info_call_args = api_client.logger.info.call_args[0][0] |
| assert "检测到@Claude提及" in info_call_args |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_mention_async_callback(self, api_client): |
| """测试_process_receive_message @Claude提及异步回调""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| api_client.message_callbacks = [] |
| |
| # 设置异步mention回调 |
| async_mention_callback = AsyncMock() |
| api_client.mention_callbacks = [async_mention_callback] |
| |
| with patch.object(api_client, '_is_claude_mention', return_value=True): |
| # 创建消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "user" |
| to_user = "" |
| message = "@Claude test" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| with patch('asyncio.create_task') as mock_create_task: |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证异步mention回调被创建为task |
| mock_create_task.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_mention_sync_callback(self, api_client): |
| """测试_process_receive_message @Claude提及同步回调""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| api_client.message_callbacks = [] |
| |
| # 设置同步mention回调 |
| sync_mention_callback = Mock() |
| api_client.mention_callbacks = [sync_mention_callback] |
| |
| with patch.object(api_client, '_is_claude_mention', return_value=True): |
| # 创建消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "user" |
| to_user = "" |
| message = "@Claude sync test" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证同步mention回调被直接调用 |
| sync_mention_callback.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_process_receive_message_mention_callback_exception(self, api_client): |
| """测试_process_receive_message @Claude提及回调异常处理""" |
| api_client.logger = Mock() |
| api_client.message_history = [] |
| api_client.max_history = 100 |
| api_client.message_callbacks = [] |
| |
| # 设置会抛出异常的mention回调 |
| error_mention_callback = Mock(side_effect=Exception("Mention callback error")) |
| api_client.mention_callbacks = [error_mention_callback] |
| |
| with patch.object(api_client, '_is_claude_mention', return_value=True): |
| # 创建消息包 |
| timestamp = int(datetime.now().timestamp()) |
| from_user = "user" |
| to_user = "" |
| message = "@Claude error test" |
| |
| packet_data = struct.pack(">Q", timestamp) |
| packet_data += struct.pack("B", len(from_user)) |
| packet_data += from_user.encode('utf-8') |
| packet_data += struct.pack("B", len(to_user)) |
| packet_data += struct.pack("B", SSHOUTMessageType.PLAIN) |
| packet_data += struct.pack(">I", len(message)) |
| packet_data += message.encode('utf-8') |
| |
| await api_client._process_receive_message(packet_data) |
| |
| # 验证mention回调异常被正确记录 |
| api_client.logger.error.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_send_message_exception_handling(self, api_client): |
| """测试send_message异常处理""" |
| api_client.connected = True |
| api_client.channel = Mock() |
| api_client.logger = Mock() |
| |
| # Mock _send_packet 抛出异常 |
| with patch.object(api_client, '_send_packet', side_effect=Exception("Send error")): |
| result = await api_client.send_message("target", "Hello") |
| |
| # 验证异常被处理并返回False |
| assert result is False |
| api_client.logger.error.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_process_user_state_change_basic(self, api_client): |
| """测试_process_user_state_change基本功能""" |
| api_client.logger = Mock() |
| |
| # 创建用户状态变化包数据 |
| # 格式:用户ID(2字节) + 状态(1字节) + 用户名长度(1字节) + 用户名 |
| user_id = 123 |
| status = 1 # 上线 |
| username = "testuser" |
| |
| packet_data = struct.pack(">H", user_id) # 用户ID |
| packet_data += struct.pack("B", status) # 状态 |
| packet_data += struct.pack("B", len(username)) # 用户名长度 |
| packet_data += username.encode('utf-8') # 用户名 |
| |
| await api_client._process_user_state_change(packet_data) |
| |
| # 验证日志记录 |
| api_client.logger.info.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_process_user_state_change_invalid_data(self, api_client): |
| """测试_process_user_state_change无效数据处理""" |
| api_client.logger = Mock() |
| |
| # 测试空数据 |
| await api_client._process_user_state_change(b"") |
| |
| # 测试长度不足的数据 |
| await api_client._process_user_state_change(b"short") |
| |
| # 验证没有崩溃(可能有日志记录) |
| assert True |
| |
| @pytest.mark.asyncio |
| async def test_process_error_basic(self, api_client): |
| """测试_process_error基本功能""" |
| api_client.logger = Mock() |
| |
| # 创建错误包数据 |
| error_code = 1 |
| error_message = "Server connection lost" |
| |
| packet_data = struct.pack("B", error_code) # 错误码 |
| packet_data += struct.pack("B", len(error_message)) # 错误消息长度 |
| packet_data += error_message.encode('utf-8') # 错误消息 |
| |
| await api_client._process_error(packet_data) |
| |
| # 验证错误被记录 |
| api_client.logger.error.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_process_motd_basic(self, api_client): |
| """测试_process_motd基本功能""" |
| api_client.logger = Mock() |
| |
| # 创建MOTD包数据 |
| motd_message = "Welcome to SSHOUT server!" |
| |
| packet_data = struct.pack("B", len(motd_message)) # MOTD长度 |
| packet_data += motd_message.encode('utf-8') # MOTD内容 |
| |
| await api_client._process_motd(packet_data) |
| |
| # 验证MOTD被记录 |
| api_client.logger.info.assert_called() |
| |
| def test_channel_property_checks(self, api_client): |
| """测试channel相关属性检查""" |
| # 测试channel为None时的各种方法 |
| api_client.channel = None |
| |
| # 这些应该能正常运行而不崩溃 |
| assert api_client.channel is None |
| |
| # 测试channel存在的情况 |
| api_client.channel = Mock() |
| assert api_client.channel is not None |
| |
| @pytest.mark.asyncio |
| async def test_additional_edge_cases(self, api_client): |
| """测试其他边界情况和错误处理""" |
| api_client.logger = Mock() |
| |
| # 测试各种处理函数的空数据处理 |
| await api_client._process_user_state_change(b"") |
| await api_client._process_error(b"") |
| await api_client._process_motd(b"") |
| |
| # 验证没有崩溃 |
| assert True |
| |
| |
| if __name__ == '__main__': |
| pytest.main([__file__]) |