blob: 6949a64e3b68585117e90740ae590dc1d86e55fb [file] [log] [blame] [raw]
"""
SSHOUT API客户端连接和协议测试
专注于网络连接、握手协议和错误处理的测试覆盖
"""
import pytest
import asyncio
import struct
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import paramiko
from src.claude_agent.sshout.api_client import (
SSHOUTApiClient, SSHOUTPacketType, SSHOUTMessageType, SSHOUTMessage
)
from datetime import datetime
# 设置所有异步测试为自动标记
pytestmark = pytest.mark.asyncio
class TestSSHOUTConnectionAndProtocol:
"""测试SSHOUT连接和协议功能"""
def setup_method(self):
"""测试前准备"""
self.mock_key_path = "/tmp/test_key"
self.client = SSHOUTApiClient(
hostname="test.example.com",
port=22333,
username="testuser",
key_path=self.mock_key_path,
timeout=10
)
@patch('paramiko.SSHClient')
@patch('paramiko.ECDSAKey.from_private_key_file')
async def test_connect_success_full_flow(self, mock_key_loader, mock_ssh_client):
"""测试完整连接成功流程"""
# 设置密钥加载Mock
mock_private_key = Mock()
mock_key_loader.return_value = mock_private_key
# 设置SSH客户端Mock
mock_ssh_instance = Mock()
mock_ssh_client.return_value = mock_ssh_instance
# 设置exec_command返回值
mock_stdin = Mock()
mock_stdout = Mock()
mock_stderr = Mock()
mock_channel = Mock()
mock_stdout.channel = mock_channel
mock_ssh_instance.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)
# Mock握手成功
with patch.object(self.client, '_handshake', return_value=True) as mock_handshake:
# Mock启动任务
with patch('asyncio.create_task') as mock_create_task:
# Mock获取在线用户
with patch.object(self.client, '_get_online_users') as mock_get_users:
result = await self.client.connect()
# 验证调用
mock_key_loader.assert_called_once_with(self.mock_key_path)
mock_ssh_instance.connect.assert_called_once_with(
hostname="test.example.com",
port=22333,
username="testuser",
pkey=mock_private_key,
timeout=10
)
mock_ssh_instance.exec_command.assert_called_once_with('api')
mock_handshake.assert_called_once()
assert result is True
assert self.client.connected is True
@patch('paramiko.ECDSAKey.from_private_key_file')
async def test_connect_key_loading_failure(self, mock_key_loader):
"""测试密钥加载失败"""
mock_key_loader.side_effect = Exception("密钥文件不存在")
result = await self.client.connect()
assert result is False
assert self.client.connected is False
@patch('paramiko.SSHClient')
@patch('paramiko.ECDSAKey.from_private_key_file')
async def test_connect_ssh_connection_failure(self, mock_key_loader, mock_ssh_client):
"""测试SSH连接失败"""
# 设置密钥加载成功
mock_private_key = Mock()
mock_key_loader.return_value = mock_private_key
# 设置SSH连接失败
mock_ssh_instance = Mock()
mock_ssh_client.return_value = mock_ssh_instance
mock_ssh_instance.connect.side_effect = paramiko.AuthenticationException("认证失败")
result = await self.client.connect()
assert result is False
assert self.client.connected is False
@patch('paramiko.SSHClient')
@patch('paramiko.ECDSAKey.from_private_key_file')
async def test_connect_channel_creation_failure(self, mock_key_loader, mock_ssh_client):
"""测试通道创建失败"""
# 设置密钥和SSH连接成功
mock_private_key = Mock()
mock_key_loader.return_value = mock_private_key
mock_ssh_instance = Mock()
mock_ssh_client.return_value = mock_ssh_instance
# 设置exec_command返回无效通道
mock_stdin = Mock()
mock_stdout = Mock()
mock_stderr = Mock()
mock_stdout.channel = None # 无效通道
mock_ssh_instance.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)
result = await self.client.connect()
assert result is False
assert self.client.connected is False
@patch('paramiko.SSHClient')
@patch('paramiko.ECDSAKey.from_private_key_file')
async def test_connect_handshake_failure(self, mock_key_loader, mock_ssh_client):
"""测试握手失败"""
# 设置密钥和SSH连接成功
mock_private_key = Mock()
mock_key_loader.return_value = mock_private_key
mock_ssh_instance = Mock()
mock_ssh_client.return_value = mock_ssh_instance
# 设置exec_command成功
mock_stdin = Mock()
mock_stdout = Mock()
mock_stderr = Mock()
mock_channel = Mock()
mock_stdout.channel = mock_channel
mock_ssh_instance.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)
# Mock握手失败
with patch.object(self.client, '_handshake', return_value=False):
result = await self.client.connect()
assert result is False
assert self.client.connected is False
async def test_disconnect_with_client(self):
"""测试有客户端时的断开连接"""
# 设置连接状态
mock_client = Mock()
self.client.client = mock_client
self.client.connected = True
await self.client.disconnect()
mock_client.close.assert_called_once()
assert self.client.client is None
assert self.client.connected is False
async def test_disconnect_without_client(self):
"""测试无客户端时的断开连接"""
# 设置无客户端状态
self.client.client = None
self.client.connected = False
# 断开连接不应抛出异常
await self.client.disconnect()
assert self.client.client is None
assert self.client.connected is False
async def test_disconnect_exception_handling(self):
"""测试断开连接时的异常处理"""
# 设置客户端但close时抛出异常
mock_client = Mock()
mock_client.close.side_effect = Exception("断开连接失败")
self.client.client = mock_client
self.client.connected = True
# 断开连接不应抛出异常
await self.client.disconnect()
# 在异常情况下,仍然会清理状态
assert self.client.connected is False
# 但client可能不会设置为None,取决于实现
async def test_handshake_success(self):
"""测试握手成功"""
# Mock发送和接收包
with patch.object(self.client, '_send_packet') as mock_send:
with patch.object(self.client, '_receive_packet') as mock_receive:
# 设置接收到正确的PASS包
pass_data = b"SSHOUT" + struct.pack(">H", 1) + b"\x08testuser"
mock_receive.return_value = (SSHOUTPacketType.PASS, pass_data)
result = await self.client._handshake()
# 验证发送了HELLO包
hello_data = b"SSHOUT" + struct.pack(">H", 1)
mock_send.assert_called_once_with(SSHOUTPacketType.HELLO, hello_data)
assert result is True
assert self.client.my_username == "testuser"
async def test_handshake_wrong_packet_type(self):
"""测试握手收到错误包类型"""
with patch.object(self.client, '_send_packet'):
with patch.object(self.client, '_receive_packet') as mock_receive:
# 设置接收到错误包类型
mock_receive.return_value = (SSHOUTPacketType.HELLO, b"wrong")
result = await self.client._handshake()
assert result is False
async def test_handshake_invalid_pass_packet_length(self):
"""测试握手收到无效长度的PASS包"""
with patch.object(self.client, '_send_packet'):
with patch.object(self.client, '_receive_packet') as mock_receive:
# 设置接收到长度不足的PASS包
mock_receive.return_value = (SSHOUTPacketType.PASS, b"short")
result = await self.client._handshake()
assert result is False
async def test_handshake_invalid_magic(self):
"""测试握手收到无效magic的PASS包"""
with patch.object(self.client, '_send_packet'):
with patch.object(self.client, '_receive_packet') as mock_receive:
# 设置接收到错误magic的PASS包
invalid_data = b"WRONG!" + struct.pack(">H", 1)
mock_receive.return_value = (SSHOUTPacketType.PASS, invalid_data)
result = await self.client._handshake()
assert result is False
async def test_handshake_invalid_version(self):
"""测试握手收到无效版本的PASS包"""
with patch.object(self.client, '_send_packet'):
with patch.object(self.client, '_receive_packet') as mock_receive:
# 设置接收到错误版本的PASS包
invalid_data = b"SSHOUT" + struct.pack(">H", 2) # 版本2
mock_receive.return_value = (SSHOUTPacketType.PASS, invalid_data)
result = await self.client._handshake()
assert result is False
async def test_handshake_exception_handling(self):
"""测试握手过程异常处理"""
with patch.object(self.client, '_send_packet', side_effect=Exception("网络错误")):
result = await self.client._handshake()
assert result is False
class TestSSHOUTPacketOperations:
"""测试SSHOUT包操作"""
def setup_method(self):
"""测试前准备"""
self.client = SSHOUTApiClient(
hostname="test.example.com",
port=22333,
username="testuser",
key_path="/tmp/test_key"
)
async def test_send_packet_success(self):
"""测试发送包成功"""
# 设置Mock通道
mock_channel = Mock()
self.client.channel = mock_channel
test_data = b"test data"
await self.client._send_packet(SSHOUTPacketType.HELLO, test_data)
# 验证发送了正确的数据(根据实际实现:长度(4字节) + 类型(1字节) + 数据)
packet_length = len(test_data) + 1
expected_packet = struct.pack(">IB", packet_length, SSHOUTPacketType.HELLO.value) + test_data
mock_channel.send.assert_called_once_with(expected_packet)
async def test_send_packet_no_channel(self):
"""测试无通道时发送包"""
self.client.channel = None
with pytest.raises(Exception):
await self.client._send_packet(SSHOUTPacketType.HELLO, b"test")
async def test_receive_packet_success(self):
"""测试接收包成功"""
# 设置Mock通道
mock_channel = Mock()
mock_channel.closed = False
mock_channel.recv_ready.return_value = True
self.client.channel = mock_channel
# 模拟接收数据:包头(4字节长度) + 1字节类型 + 数据
test_data = b"test data"
packet_length = len(test_data) + 1
length_bytes = struct.pack(">I", packet_length)
type_bytes = struct.pack("B", SSHOUTPacketType.PASS.value)
# Mock recv返回数据 - 模拟分批接收
all_data = length_bytes + type_bytes + test_data
mock_channel.recv.side_effect = [
length_bytes, # 首先读取长度
type_bytes, # 然后读取类型
test_data # 最后读取数据
]
packet_type, data = await self.client._receive_packet()
assert packet_type == SSHOUTPacketType.PASS
assert data == test_data
async def test_receive_packet_no_channel(self):
"""测试无通道时接收包"""
self.client.channel = None
with pytest.raises(Exception):
await self.client._receive_packet()
async def test_receive_packet_timeout(self):
"""测试接收包超时"""
# 设置Mock通道
mock_channel = Mock()
mock_channel.recv.side_effect = Exception("timeout")
self.client.channel = mock_channel
with pytest.raises(Exception):
await self.client._receive_packet()
class TestSSHOUTMessageOperations:
"""测试SSHOUT消息操作"""
def setup_method(self):
"""测试前准备"""
self.client = SSHOUTApiClient(
hostname="test.example.com",
port=22333,
username="testuser",
key_path="/tmp/test_key"
)
async def test_send_message_success(self):
"""测试发送消息成功"""
self.client.connected = True
# 设置模拟通道
mock_channel = Mock()
self.client.channel = mock_channel
with patch.object(self.client, '_send_packet') as mock_send:
await self.client.send_message("GLOBAL", "Hello world")
# 验证发送了正确的消息包
assert mock_send.called
async def test_send_message_not_connected(self):
"""测试未连接时发送消息"""
self.client.connected = False
result = await self.client.send_message("Hello", "GLOBAL")
assert result is False
async def test_send_message_exception_handling(self):
"""测试发送消息异常处理"""
self.client.connected = True
with patch.object(self.client, '_send_packet', side_effect=Exception("网络错误")):
result = await self.client.send_message("Hello", "GLOBAL")
assert result is False
class TestSSHOUTAsyncTasks:
"""测试SSHOUT异步任务"""
def setup_method(self):
"""测试前准备"""
self.client = SSHOUTApiClient(
hostname="test.example.com",
port=22333,
username="testuser",
key_path="/tmp/test_key"
)
async def test_keep_alive_task(self):
"""测试保活任务"""
self.client.connected = True
# Mock发送包函数
with patch.object(self.client, '_send_packet') as mock_send:
# 运行短时间的保活任务
task = asyncio.create_task(self.client._keep_alive())
await asyncio.sleep(0.1) # 让任务运行一小段时间
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# 由于保活间隔是30秒,0.1秒内不会发送任何包
mock_send.assert_not_called()
async def test_message_listener_disconnection(self):
"""测试消息监听器断开连接处理"""
self.client.connected = True
# Mock接收包抛出异常(模拟连接断开)
with patch.object(self.client, '_receive_packet', side_effect=Exception("连接断开")):
# 运行消息监听器
await self.client._message_listener()
# 验证连接状态被设置为False
assert self.client.connected is False
async def test_get_online_users_success(self):
"""测试获取在线用户成功"""
with patch.object(self.client, '_send_packet') as mock_send:
with patch.object(self.client, '_receive_packet') as mock_receive:
# 模拟接收在线用户包
user_data = b"\x02\x04user\x05admin" # 2个用户: "user", "admin"
mock_receive.return_value = (SSHOUTPacketType.ONLINE_USERS_INFO, user_data)
await self.client._get_online_users()
# 验证发送了获取在线用户请求
mock_send.assert_called_once_with(SSHOUTPacketType.GET_ONLINE_USER, b"")
async def test_get_online_users_exception(self):
"""测试获取在线用户异常处理"""
with patch.object(self.client, '_send_packet', side_effect=Exception("网络错误")):
# 不应抛出异常
await self.client._get_online_users()
if __name__ == '__main__':
pytest.main([__file__])