| """ |
| SSHOUT API客户端连接和协议测试 |
| 专注于网络连接、握手协议和错误处理的测试覆盖 |
| """ |
| |
| import pytest |
| import asyncio |
| import struct |
| from unittest.mock import Mock, AsyncMock, patch, MagicMock |
| import paramiko |
| |
| from src.claude_agent.sshout.api_client import ( |
| SSHOUTApiClient, SSHOUTPacketType, SSHOUTMessageType, SSHOUTMessage |
| ) |
| from datetime import datetime |
| |
| # 设置所有异步测试为自动标记 |
| pytestmark = pytest.mark.asyncio |
| |
| |
| class TestSSHOUTConnectionAndProtocol: |
| """测试SSHOUT连接和协议功能""" |
| |
| def setup_method(self): |
| """测试前准备""" |
| self.mock_key_path = "/tmp/test_key" |
| self.client = SSHOUTApiClient( |
| hostname="test.example.com", |
| port=22333, |
| username="testuser", |
| key_path=self.mock_key_path, |
| timeout=10 |
| ) |
| |
| @patch('paramiko.SSHClient') |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| async def test_connect_success_full_flow(self, mock_key_loader, mock_ssh_client): |
| """测试完整连接成功流程""" |
| # 设置密钥加载Mock |
| mock_private_key = Mock() |
| mock_key_loader.return_value = mock_private_key |
| |
| # 设置SSH客户端Mock |
| mock_ssh_instance = Mock() |
| mock_ssh_client.return_value = mock_ssh_instance |
| |
| # 设置exec_command返回值 |
| mock_stdin = Mock() |
| mock_stdout = Mock() |
| mock_stderr = Mock() |
| mock_channel = Mock() |
| mock_stdout.channel = mock_channel |
| mock_ssh_instance.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr) |
| |
| # Mock握手成功 |
| with patch.object(self.client, '_handshake', return_value=True) as mock_handshake: |
| # Mock启动任务 |
| with patch('asyncio.create_task') as mock_create_task: |
| # Mock获取在线用户 |
| with patch.object(self.client, '_get_online_users') as mock_get_users: |
| result = await self.client.connect() |
| |
| # 验证调用 |
| mock_key_loader.assert_called_once_with(self.mock_key_path) |
| mock_ssh_instance.connect.assert_called_once_with( |
| hostname="test.example.com", |
| port=22333, |
| username="testuser", |
| pkey=mock_private_key, |
| timeout=10 |
| ) |
| mock_ssh_instance.exec_command.assert_called_once_with('api') |
| mock_handshake.assert_called_once() |
| assert result is True |
| assert self.client.connected is True |
| |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| async def test_connect_key_loading_failure(self, mock_key_loader): |
| """测试密钥加载失败""" |
| mock_key_loader.side_effect = Exception("密钥文件不存在") |
| |
| result = await self.client.connect() |
| |
| assert result is False |
| assert self.client.connected is False |
| |
| @patch('paramiko.SSHClient') |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| async def test_connect_ssh_connection_failure(self, mock_key_loader, mock_ssh_client): |
| """测试SSH连接失败""" |
| # 设置密钥加载成功 |
| mock_private_key = Mock() |
| mock_key_loader.return_value = mock_private_key |
| |
| # 设置SSH连接失败 |
| mock_ssh_instance = Mock() |
| mock_ssh_client.return_value = mock_ssh_instance |
| mock_ssh_instance.connect.side_effect = paramiko.AuthenticationException("认证失败") |
| |
| result = await self.client.connect() |
| |
| assert result is False |
| assert self.client.connected is False |
| |
| @patch('paramiko.SSHClient') |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| async def test_connect_channel_creation_failure(self, mock_key_loader, mock_ssh_client): |
| """测试通道创建失败""" |
| # 设置密钥和SSH连接成功 |
| mock_private_key = Mock() |
| mock_key_loader.return_value = mock_private_key |
| |
| mock_ssh_instance = Mock() |
| mock_ssh_client.return_value = mock_ssh_instance |
| |
| # 设置exec_command返回无效通道 |
| mock_stdin = Mock() |
| mock_stdout = Mock() |
| mock_stderr = Mock() |
| mock_stdout.channel = None # 无效通道 |
| mock_ssh_instance.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr) |
| |
| result = await self.client.connect() |
| |
| assert result is False |
| assert self.client.connected is False |
| |
| @patch('paramiko.SSHClient') |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| async def test_connect_handshake_failure(self, mock_key_loader, mock_ssh_client): |
| """测试握手失败""" |
| # 设置密钥和SSH连接成功 |
| mock_private_key = Mock() |
| mock_key_loader.return_value = mock_private_key |
| |
| mock_ssh_instance = Mock() |
| mock_ssh_client.return_value = mock_ssh_instance |
| |
| # 设置exec_command成功 |
| mock_stdin = Mock() |
| mock_stdout = Mock() |
| mock_stderr = Mock() |
| mock_channel = Mock() |
| mock_stdout.channel = mock_channel |
| mock_ssh_instance.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr) |
| |
| # Mock握手失败 |
| with patch.object(self.client, '_handshake', return_value=False): |
| result = await self.client.connect() |
| |
| assert result is False |
| assert self.client.connected is False |
| |
| async def test_disconnect_with_client(self): |
| """测试有客户端时的断开连接""" |
| # 设置连接状态 |
| mock_client = Mock() |
| self.client.client = mock_client |
| self.client.connected = True |
| |
| await self.client.disconnect() |
| |
| mock_client.close.assert_called_once() |
| assert self.client.client is None |
| assert self.client.connected is False |
| |
| async def test_disconnect_without_client(self): |
| """测试无客户端时的断开连接""" |
| # 设置无客户端状态 |
| self.client.client = None |
| self.client.connected = False |
| |
| # 断开连接不应抛出异常 |
| await self.client.disconnect() |
| |
| assert self.client.client is None |
| assert self.client.connected is False |
| |
| async def test_disconnect_exception_handling(self): |
| """测试断开连接时的异常处理""" |
| # 设置客户端但close时抛出异常 |
| mock_client = Mock() |
| mock_client.close.side_effect = Exception("断开连接失败") |
| self.client.client = mock_client |
| self.client.connected = True |
| |
| # 断开连接不应抛出异常 |
| await self.client.disconnect() |
| |
| # 在异常情况下,仍然会清理状态 |
| assert self.client.connected is False |
| # 但client可能不会设置为None,取决于实现 |
| |
| async def test_handshake_success(self): |
| """测试握手成功""" |
| # Mock发送和接收包 |
| with patch.object(self.client, '_send_packet') as mock_send: |
| with patch.object(self.client, '_receive_packet') as mock_receive: |
| # 设置接收到正确的PASS包 |
| pass_data = b"SSHOUT" + struct.pack(">H", 1) + b"\x08testuser" |
| mock_receive.return_value = (SSHOUTPacketType.PASS, pass_data) |
| |
| result = await self.client._handshake() |
| |
| # 验证发送了HELLO包 |
| hello_data = b"SSHOUT" + struct.pack(">H", 1) |
| mock_send.assert_called_once_with(SSHOUTPacketType.HELLO, hello_data) |
| |
| assert result is True |
| assert self.client.my_username == "testuser" |
| |
| async def test_handshake_wrong_packet_type(self): |
| """测试握手收到错误包类型""" |
| with patch.object(self.client, '_send_packet'): |
| with patch.object(self.client, '_receive_packet') as mock_receive: |
| # 设置接收到错误包类型 |
| mock_receive.return_value = (SSHOUTPacketType.HELLO, b"wrong") |
| |
| result = await self.client._handshake() |
| |
| assert result is False |
| |
| async def test_handshake_invalid_pass_packet_length(self): |
| """测试握手收到无效长度的PASS包""" |
| with patch.object(self.client, '_send_packet'): |
| with patch.object(self.client, '_receive_packet') as mock_receive: |
| # 设置接收到长度不足的PASS包 |
| mock_receive.return_value = (SSHOUTPacketType.PASS, b"short") |
| |
| result = await self.client._handshake() |
| |
| assert result is False |
| |
| async def test_handshake_invalid_magic(self): |
| """测试握手收到无效magic的PASS包""" |
| with patch.object(self.client, '_send_packet'): |
| with patch.object(self.client, '_receive_packet') as mock_receive: |
| # 设置接收到错误magic的PASS包 |
| invalid_data = b"WRONG!" + struct.pack(">H", 1) |
| mock_receive.return_value = (SSHOUTPacketType.PASS, invalid_data) |
| |
| result = await self.client._handshake() |
| |
| assert result is False |
| |
| async def test_handshake_invalid_version(self): |
| """测试握手收到无效版本的PASS包""" |
| with patch.object(self.client, '_send_packet'): |
| with patch.object(self.client, '_receive_packet') as mock_receive: |
| # 设置接收到错误版本的PASS包 |
| invalid_data = b"SSHOUT" + struct.pack(">H", 2) # 版本2 |
| mock_receive.return_value = (SSHOUTPacketType.PASS, invalid_data) |
| |
| result = await self.client._handshake() |
| |
| assert result is False |
| |
| async def test_handshake_exception_handling(self): |
| """测试握手过程异常处理""" |
| with patch.object(self.client, '_send_packet', side_effect=Exception("网络错误")): |
| result = await self.client._handshake() |
| |
| assert result is False |
| |
| |
| class TestSSHOUTPacketOperations: |
| """测试SSHOUT包操作""" |
| |
| def setup_method(self): |
| """测试前准备""" |
| self.client = SSHOUTApiClient( |
| hostname="test.example.com", |
| port=22333, |
| username="testuser", |
| key_path="/tmp/test_key" |
| ) |
| |
| async def test_send_packet_success(self): |
| """测试发送包成功""" |
| # 设置Mock通道 |
| mock_channel = Mock() |
| self.client.channel = mock_channel |
| |
| test_data = b"test data" |
| await self.client._send_packet(SSHOUTPacketType.HELLO, test_data) |
| |
| # 验证发送了正确的数据(根据实际实现:长度(4字节) + 类型(1字节) + 数据) |
| packet_length = len(test_data) + 1 |
| expected_packet = struct.pack(">IB", packet_length, SSHOUTPacketType.HELLO.value) + test_data |
| mock_channel.send.assert_called_once_with(expected_packet) |
| |
| async def test_send_packet_no_channel(self): |
| """测试无通道时发送包""" |
| self.client.channel = None |
| |
| with pytest.raises(Exception): |
| await self.client._send_packet(SSHOUTPacketType.HELLO, b"test") |
| |
| async def test_receive_packet_success(self): |
| """测试接收包成功""" |
| # 设置Mock通道 |
| mock_channel = Mock() |
| mock_channel.closed = False |
| mock_channel.recv_ready.return_value = True |
| self.client.channel = mock_channel |
| |
| # 模拟接收数据:包头(4字节长度) + 1字节类型 + 数据 |
| test_data = b"test data" |
| packet_length = len(test_data) + 1 |
| length_bytes = struct.pack(">I", packet_length) |
| type_bytes = struct.pack("B", SSHOUTPacketType.PASS.value) |
| |
| # Mock recv返回数据 - 模拟分批接收 |
| all_data = length_bytes + type_bytes + test_data |
| mock_channel.recv.side_effect = [ |
| length_bytes, # 首先读取长度 |
| type_bytes, # 然后读取类型 |
| test_data # 最后读取数据 |
| ] |
| |
| packet_type, data = await self.client._receive_packet() |
| |
| assert packet_type == SSHOUTPacketType.PASS |
| assert data == test_data |
| |
| async def test_receive_packet_no_channel(self): |
| """测试无通道时接收包""" |
| self.client.channel = None |
| |
| with pytest.raises(Exception): |
| await self.client._receive_packet() |
| |
| async def test_receive_packet_timeout(self): |
| """测试接收包超时""" |
| # 设置Mock通道 |
| mock_channel = Mock() |
| mock_channel.recv.side_effect = Exception("timeout") |
| self.client.channel = mock_channel |
| |
| with pytest.raises(Exception): |
| await self.client._receive_packet() |
| |
| |
| class TestSSHOUTMessageOperations: |
| """测试SSHOUT消息操作""" |
| |
| def setup_method(self): |
| """测试前准备""" |
| self.client = SSHOUTApiClient( |
| hostname="test.example.com", |
| port=22333, |
| username="testuser", |
| key_path="/tmp/test_key" |
| ) |
| |
| async def test_send_message_success(self): |
| """测试发送消息成功""" |
| self.client.connected = True |
| # 设置模拟通道 |
| mock_channel = Mock() |
| self.client.channel = mock_channel |
| |
| with patch.object(self.client, '_send_packet') as mock_send: |
| await self.client.send_message("GLOBAL", "Hello world") |
| |
| # 验证发送了正确的消息包 |
| assert mock_send.called |
| |
| async def test_send_message_not_connected(self): |
| """测试未连接时发送消息""" |
| self.client.connected = False |
| |
| result = await self.client.send_message("Hello", "GLOBAL") |
| |
| assert result is False |
| |
| async def test_send_message_exception_handling(self): |
| """测试发送消息异常处理""" |
| self.client.connected = True |
| |
| with patch.object(self.client, '_send_packet', side_effect=Exception("网络错误")): |
| result = await self.client.send_message("Hello", "GLOBAL") |
| |
| assert result is False |
| |
| |
| class TestSSHOUTAsyncTasks: |
| """测试SSHOUT异步任务""" |
| |
| def setup_method(self): |
| """测试前准备""" |
| self.client = SSHOUTApiClient( |
| hostname="test.example.com", |
| port=22333, |
| username="testuser", |
| key_path="/tmp/test_key" |
| ) |
| |
| async def test_keep_alive_task(self): |
| """测试保活任务""" |
| self.client.connected = True |
| |
| # Mock发送包函数 |
| with patch.object(self.client, '_send_packet') as mock_send: |
| # 运行短时间的保活任务 |
| task = asyncio.create_task(self.client._keep_alive()) |
| await asyncio.sleep(0.1) # 让任务运行一小段时间 |
| task.cancel() |
| |
| try: |
| await task |
| except asyncio.CancelledError: |
| pass |
| |
| # 由于保活间隔是30秒,0.1秒内不会发送任何包 |
| mock_send.assert_not_called() |
| |
| async def test_message_listener_disconnection(self): |
| """测试消息监听器断开连接处理""" |
| self.client.connected = True |
| |
| # Mock接收包抛出异常(模拟连接断开) |
| with patch.object(self.client, '_receive_packet', side_effect=Exception("连接断开")): |
| # 运行消息监听器 |
| await self.client._message_listener() |
| |
| # 验证连接状态被设置为False |
| assert self.client.connected is False |
| |
| async def test_get_online_users_success(self): |
| """测试获取在线用户成功""" |
| with patch.object(self.client, '_send_packet') as mock_send: |
| with patch.object(self.client, '_receive_packet') as mock_receive: |
| # 模拟接收在线用户包 |
| user_data = b"\x02\x04user\x05admin" # 2个用户: "user", "admin" |
| mock_receive.return_value = (SSHOUTPacketType.ONLINE_USERS_INFO, user_data) |
| |
| await self.client._get_online_users() |
| |
| # 验证发送了获取在线用户请求 |
| mock_send.assert_called_once_with(SSHOUTPacketType.GET_ONLINE_USER, b"") |
| |
| async def test_get_online_users_exception(self): |
| """测试获取在线用户异常处理""" |
| with patch.object(self.client, '_send_packet', side_effect=Exception("网络错误")): |
| # 不应抛出异常 |
| await self.client._get_online_users() |
| |
| |
| if __name__ == '__main__': |
| pytest.main([__file__]) |