| """ |
| SSHOUT集成功能单元测试 |
| 测试消息解析、@Claude检测、响应清理等核心功能 |
| """ |
| |
| import unittest |
| from unittest.mock import Mock, patch, MagicMock, AsyncMock |
| import asyncio |
| from datetime import datetime, timedelta |
| import sys |
| import os |
| |
| # 添加项目根目录到Python路径 |
| project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
| sys.path.insert(0, os.path.join(project_root, 'src')) |
| |
| from claude_agent.sshout.integration import ( |
| SSHOUTConnection, |
| SSHOUTMessage, |
| SSHOUTIntegration |
| ) |
| |
| |
| class TestSSHOUTMessage(unittest.TestCase): |
| """测试SSHOUT消息数据结构""" |
| |
| def test_message_creation(self): |
| """测试消息对象创建""" |
| timestamp = datetime.now() |
| msg = SSHOUTMessage( |
| timestamp=timestamp, |
| username="testuser", |
| content="test message", |
| is_mention=True |
| ) |
| |
| self.assertEqual(msg.timestamp, timestamp) |
| self.assertEqual(msg.username, "testuser") |
| self.assertEqual(msg.content, "test message") |
| self.assertTrue(msg.is_mention) |
| |
| def test_message_defaults(self): |
| """测试消息默认值""" |
| timestamp = datetime.now() |
| msg = SSHOUTMessage( |
| timestamp=timestamp, |
| username="testuser", |
| content="test message" |
| ) |
| |
| self.assertFalse(msg.is_mention) |
| |
| |
| class TestSSHOUTConnection(unittest.TestCase): |
| """测试SSHOUT连接管理器""" |
| |
| def setUp(self): |
| """设置测试环境""" |
| self.connection = SSHOUTConnection( |
| hostname="test.example.com", |
| port=22, |
| username="testuser", |
| key_path="/path/to/key" |
| ) |
| |
| def test_connection_initialization(self): |
| """测试连接初始化""" |
| self.assertEqual(self.connection.hostname, "test.example.com") |
| self.assertEqual(self.connection.port, 22) |
| self.assertEqual(self.connection.username, "testuser") |
| self.assertEqual(self.connection.key_path, "/path/to/key") |
| self.assertFalse(self.connection.connected) |
| self.assertEqual(len(self.connection.message_callbacks), 0) |
| self.assertEqual(len(self.connection.mention_callbacks), 0) |
| |
| def test_message_parsing_with_timestamp(self): |
| """测试带时间戳的消息解析""" |
| test_cases = [ |
| ("[14:30:25] <user1> 大家好!", "user1", "大家好!"), |
| ("[09:15:00] <alice> @Claude 你好", "alice", "@Claude 你好"), |
| ("[23:59:59] <bob> 测试消息", "bob", "测试消息") |
| ] |
| |
| for line, expected_user, expected_content in test_cases: |
| with self.subTest(line=line): |
| msg = self.connection._parse_message(line) |
| self.assertIsNotNone(msg) |
| self.assertEqual(msg.username, expected_user) |
| self.assertEqual(msg.content, expected_content) |
| self.assertIsInstance(msg.timestamp, datetime) |
| |
| def test_message_parsing_without_timestamp(self): |
| """测试不带时间戳的消息解析""" |
| test_cases = [ |
| ("<user1> 大家好!", "user1", "大家好!"), |
| ("<alice> @Claude 你好", "alice", "@Claude 你好"), |
| ("bob: 测试消息", "bob", "测试消息") |
| ] |
| |
| for line, expected_user, expected_content in test_cases: |
| with self.subTest(line=line): |
| msg = self.connection._parse_message(line) |
| self.assertIsNotNone(msg) |
| self.assertEqual(msg.username, expected_user) |
| self.assertEqual(msg.content, expected_content) |
| |
| def test_message_parsing_invalid_format(self): |
| """测试无效格式消息解析""" |
| invalid_lines = [ |
| "", |
| "plain text without format", |
| "< >", |
| ":", |
| "[invalid] format" |
| ] |
| |
| for line in invalid_lines: |
| with self.subTest(line=line): |
| msg = self.connection._parse_message(line) |
| self.assertIsNone(msg) |
| |
| def test_claude_mention_detection(self): |
| """测试@Claude提及检测""" |
| positive_cases = [ |
| "@Claude 你好", |
| "@claude 帮我一下", |
| "@CLAUDE 测试", |
| "Claude: 你在吗?", |
| "claude, 有问题", |
| "Hello @Claude!", |
| "hey Claude, what's up?" |
| ] |
| |
| for content in positive_cases: |
| with self.subTest(content=content): |
| result = self.connection._is_claude_mention(content) |
| self.assertTrue(result, f"应该检测到 '@Claude' 提及: {content}") |
| |
| def test_claude_mention_detection_negative(self): |
| """测试非@Claude提及""" |
| negative_cases = [ |
| "普通消息,不包含提及", |
| "这是一条正常的聊天消息", |
| "Claudia is a nice name", # 相似但不同的名字 |
| "claude在句子中间但没有标点", # 没有合适的标点 |
| "" |
| ] |
| |
| for content in negative_cases: |
| with self.subTest(content=content): |
| result = self.connection._is_claude_mention(content) |
| self.assertFalse(result, f"不应该检测到 '@Claude' 提及: {content}") |
| |
| def test_message_history_management(self): |
| """测试消息历史管理""" |
| # 添加消息到历史 |
| for i in range(5): |
| msg = SSHOUTMessage( |
| timestamp=datetime.now(), |
| username=f"user{i}", |
| content=f"message {i}" |
| ) |
| self.connection.message_history.append(msg) |
| |
| self.assertEqual(len(self.connection.message_history), 5) |
| |
| # 测试获取最近消息 |
| recent = self.connection.get_recent_messages(3) |
| self.assertEqual(len(recent), 3) |
| self.assertEqual(recent[-1].content, "message 4") # 最新的消息 |
| |
| def test_context_messages_retrieval(self): |
| """测试上下文消息获取""" |
| # 创建时间序列消息 |
| base_time = datetime.now() |
| messages = [] |
| |
| for i in range(10): |
| # 使用timedelta来正确处理时间增量,避免秒数超出范围 |
| msg_time = base_time + timedelta(seconds=i) |
| msg = SSHOUTMessage( |
| timestamp=msg_time, |
| username=f"user{i}", |
| content=f"message {i}" |
| ) |
| messages.append(msg) |
| self.connection.message_history.append(msg) |
| |
| # 获取第8条消息前的上下文 |
| target_time = messages[7].timestamp |
| context = self.connection.get_context_messages(target_time, count=3) |
| |
| self.assertEqual(len(context), 3) |
| # 应该获取到消息4、5、6(按时间顺序) |
| self.assertEqual(context[0].content, "message 4") |
| self.assertEqual(context[1].content, "message 5") |
| self.assertEqual(context[2].content, "message 6") |
| |
| def test_callback_management(self): |
| """测试回调函数管理""" |
| # 测试消息回调 |
| message_callback = Mock() |
| self.connection.add_message_callback(message_callback) |
| self.assertEqual(len(self.connection.message_callbacks), 1) |
| |
| # 测试提及回调 |
| mention_callback = Mock() |
| self.connection.add_mention_callback(mention_callback) |
| self.assertEqual(len(self.connection.mention_callbacks), 1) |
| |
| |
| class TestSSHOUTIntegration(unittest.TestCase): |
| """测试SSHOUT集成类""" |
| |
| def setUp(self): |
| """设置测试环境""" |
| self.mock_agent = Mock() |
| self.integration = SSHOUTIntegration(self.mock_agent) |
| |
| def test_integration_initialization(self): |
| """测试集成类初始化""" |
| self.assertEqual(self.integration.agent, self.mock_agent) |
| self.assertIsNone(self.integration.connection) |
| |
| # 测试新的配置结构 |
| self.assertIn('server', self.integration.sshout_config) |
| self.assertIn('ssh_key', self.integration.sshout_config) |
| self.assertIn('message', self.integration.sshout_config) |
| |
| # 测试服务器配置 |
| server_config = self.integration.sshout_config['server'] |
| self.assertIn('hostname', server_config) |
| self.assertIn('port', server_config) |
| self.assertIn('username', server_config) |
| |
| # 测试SSH密钥配置 |
| ssh_config = self.integration.sshout_config['ssh_key'] |
| self.assertIn('private_key_path', ssh_config) |
| |
| def test_response_cleaning(self): |
| """测试响应清理功能""" |
| test_cases = [ |
| ("**粗体文本**", "粗体文本"), |
| ("*斜体文本*", "斜体文本"), |
| ("`代码文本`", "代码文本"), |
| ("**粗体** 和 *斜体* 混合", "粗体 和 斜体 混合"), |
| ("单个\n换行", "单个\n换行"), # 保留单个换行 |
| ("两个\n\n换行", "两个\n\n换行"), # 保留两个换行 |
| ("三个\n\n\n换行", "三个\n\n换行"), # 压缩为两个换行 |
| ("多个\n\n\n\n\n换行", "多个\n\n换行"), # 压缩为两个换行 |
| (" 多个 空格 测试 ", "多个 空格 测试"), # 保留行内空格,仅移除首尾空格 |
| ("行尾空格 \n 行首空格", "行尾空格\n 行首空格"), # 保留行首空格,移除行尾空格 |
| ] |
| |
| for input_text, expected in test_cases: |
| with self.subTest(input_text=input_text): |
| result = self.integration._clean_response_for_sshout(input_text) |
| self.assertEqual(result, expected) |
| |
| def test_response_length_limit(self): |
| """测试响应长度限制""" |
| # 默认配置下(max_reply_length=0),不应截断 |
| long_text = "这是一个很长的文本。" * 50 # 超过200字符 |
| result = self.integration._clean_response_for_sshout(long_text) |
| |
| # 验证没有被截断 |
| self.assertEqual(len(result), len(long_text)) |
| self.assertFalse(result.endswith("...")) |
| |
| # 测试当设置了长度限制时的截断行为 |
| self.integration.config_manager.get = Mock(return_value=50) |
| result_limited = self.integration._clean_response_for_sshout(long_text) |
| self.assertLessEqual(len(result_limited), 53) # 50 + "..." |
| self.assertTrue(result_limited.endswith("...")) |
| |
| def test_connection_status_disconnected(self): |
| """测试未连接状态""" |
| status = self.integration.get_connection_status() |
| |
| self.assertFalse(status['connected']) |
| self.assertIsNone(status['server']) |
| self.assertEqual(status['message_count'], 0) |
| |
| @patch('claude_agent.sshout.integration.SSHOUTConnection') |
| def test_connection_status_connected(self, mock_connection_class): |
| """测试已连接状态""" |
| # 设置mock连接 |
| mock_connection = Mock() |
| mock_connection.connected = True |
| mock_connection.hostname = "test.example.com" |
| mock_connection.port = 22333 |
| mock_connection.message_history = [Mock() for _ in range(5)] |
| |
| # Mock recent messages |
| mock_msg = Mock() |
| mock_msg.timestamp.strftime.return_value = "14:30:25" |
| mock_msg.username = "testuser" |
| mock_msg.content = "test message" |
| mock_connection.get_recent_messages.return_value = [mock_msg] |
| |
| self.integration.connection = mock_connection |
| |
| status = self.integration.get_connection_status() |
| |
| self.assertTrue(status['connected']) |
| self.assertEqual(status['server'], "test.example.com:22333") |
| self.assertEqual(status['message_count'], 5) |
| self.assertEqual(len(status['recent_messages']), 1) |
| |
| |
| class TestSSHOUTConnectionExtended(unittest.IsolatedAsyncioTestCase): |
| """SSHOUT连接的扩展测试""" |
| |
| def setUp(self): |
| """设置测试环境""" |
| self.connection = SSHOUTConnection( |
| hostname="test.example.com", |
| port=22, |
| username="testuser", |
| key_path="/path/to/key" |
| ) |
| |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| @patch('paramiko.SSHClient') |
| async def test_connect_success(self, mock_ssh_client_class, mock_key): |
| """测试成功连接""" |
| # 设置mocks |
| mock_client = Mock() |
| mock_ssh_client_class.return_value = mock_client |
| mock_shell = Mock() |
| mock_client.invoke_shell.return_value = mock_shell |
| |
| # 设置私钥mock |
| mock_private_key = Mock() |
| mock_key.return_value = mock_private_key |
| |
| # 运行连接 |
| result = await self.connection.connect() |
| |
| # 验证结果 |
| self.assertTrue(result) |
| self.assertTrue(self.connection.connected) |
| mock_client.set_missing_host_key_policy.assert_called_once() |
| mock_client.connect.assert_called_once_with( |
| hostname="test.example.com", |
| port=22, |
| username="testuser", |
| pkey=mock_private_key, |
| timeout=10 |
| ) |
| mock_shell.settimeout.assert_called_once_with(0.1) |
| |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| async def test_connect_key_load_failure(self, mock_key): |
| """测试私钥加载失败""" |
| mock_key.side_effect = Exception("Key load failed") |
| |
| result = await self.connection.connect() |
| |
| self.assertFalse(result) |
| self.assertFalse(self.connection.connected) |
| |
| @patch('paramiko.ECDSAKey.from_private_key_file') |
| @patch('paramiko.SSHClient') |
| async def test_connect_ssh_failure(self, mock_ssh_client_class, mock_key): |
| """测试SSH连接失败""" |
| mock_client = Mock() |
| mock_ssh_client_class.return_value = mock_client |
| mock_client.connect.side_effect = Exception("SSH connection failed") |
| |
| result = await self.connection.connect() |
| |
| self.assertFalse(result) |
| self.assertFalse(self.connection.connected) |
| |
| async def test_disconnect_cleanup(self): |
| """测试断开连接清理""" |
| # 模拟已连接状态 |
| self.connection.connected = True |
| mock_shell = Mock() |
| mock_client = Mock() |
| self.connection.shell = mock_shell |
| self.connection.client = mock_client |
| |
| await self.connection.disconnect() |
| |
| self.assertFalse(self.connection.connected) |
| mock_shell.close.assert_called_once() |
| mock_client.close.assert_called_once() |
| self.assertIsNone(self.connection.shell) |
| self.assertIsNone(self.connection.client) |
| |
| async def test_disconnect_with_exception(self): |
| """测试断开连接时的异常处理""" |
| self.connection.connected = True |
| self.connection.shell = Mock() |
| self.connection.shell.close.side_effect = Exception("Close failed") |
| self.connection.client = Mock() |
| |
| # 应该不抛出异常 |
| await self.connection.disconnect() |
| self.assertFalse(self.connection.connected) |
| |
| async def test_send_message_success(self): |
| """测试成功发送消息""" |
| # 模拟已连接状态 |
| self.connection.connected = True |
| self.connection.shell = Mock() |
| |
| result = await self.connection.send_message("test message") |
| |
| self.assertTrue(result) |
| self.connection.shell.send.assert_called_once_with("test message\n") |
| |
| async def test_send_message_not_connected(self): |
| """测试未连接时发送消息""" |
| result = await self.connection.send_message("test message") |
| self.assertFalse(result) |
| |
| async def test_send_message_no_shell(self): |
| """测试没有shell时发送消息""" |
| self.connection.connected = True |
| self.connection.shell = None |
| |
| result = await self.connection.send_message("test message") |
| self.assertFalse(result) |
| |
| async def test_send_message_exception(self): |
| """测试发送消息时的异常""" |
| self.connection.connected = True |
| self.connection.shell = Mock() |
| self.connection.shell.send.side_effect = Exception("Send failed") |
| |
| result = await self.connection.send_message("test message") |
| self.assertFalse(result) |
| |
| def test_clean_ansi_codes(self): |
| """测试ANSI码清理""" |
| test_cases = [ |
| ("\x1b[1;34mblue text\x1b[0m", "blue text"), |
| ("[1;34mfake ansi[0m", "fake ansi"), |
| ("normal text", "normal text"), |
| ("\x1b[Kmixed\x1b[31m text\x1b[0m", "mixed text") |
| ] |
| |
| for input_text, expected in test_cases: |
| with self.subTest(input_text=input_text): |
| result = self.connection._clean_ansi_codes(input_text) |
| self.assertEqual(result, expected) |
| |
| def test_parse_message_complex_formats(self): |
| """测试复杂消息格式解析""" |
| test_cases = [ |
| ("[14:30:25] testuser: message with colon format", "testuser", "message with colon format"), |
| ("\x1b[1;34m[14:30:25] <user> \x1b[0mcolored message", "user", "colored message"), |
| ("[1;34m[14:30:25][0m user: pseudo ansi", "user", "pseudo ansi"), |
| ("[invalid timestamp] <user> message", None, None), |
| ("malformed line without proper format", None, None) |
| ] |
| |
| for line, expected_user, expected_content in test_cases: |
| with self.subTest(line=line): |
| msg = self.connection._parse_message(line) |
| if expected_user is None: |
| self.assertIsNone(msg) |
| else: |
| self.assertIsNotNone(msg) |
| self.assertEqual(msg.username, expected_user) |
| self.assertEqual(msg.content, expected_content) |
| |
| def test_parse_message_timestamp_parsing(self): |
| """测试时间戳解析""" |
| # 测试有效时间戳 |
| msg = self.connection._parse_message("[14:30:25] <user> test") |
| self.assertIsNotNone(msg) |
| self.assertEqual(msg.timestamp.hour, 14) |
| self.assertEqual(msg.timestamp.minute, 30) |
| self.assertEqual(msg.timestamp.second, 25) |
| |
| # 测试无效时间戳格式 - 实际上会被忽略,因为不匹配任何模式 |
| msg = self.connection._parse_message("[invalid] <user> test") |
| self.assertIsNone(msg) # 实际上不会匹配任何解析模式 |
| |
| def test_parse_message_exception_handling(self): |
| """测试消息解析异常处理""" |
| # 使用mock来模拟内部异常 |
| with patch.object(self.connection, '_clean_ansi_codes', side_effect=Exception("Clean failed")): |
| msg = self.connection._parse_message("[14:30:25] <user> test") |
| self.assertIsNone(msg) |
| |
| def test_mention_patterns_advanced(self): |
| """测试高级@Claude检测模式""" |
| test_cases = [ |
| ("@Claude123", False), # 应该不匹配,因为后面有数字 |
| ("@ClaudeAI", False), # 应该不匹配,因为后面有字母 |
| ("@Claude!", True), # 应该匹配,感叹号不是字母数字 |
| ("@Claude。", True), # 应该匹配,中文标点 |
| ("Claude: 你好", True), # 冒号格式 |
| ("Claude,测试", True), # 中文逗号格式 |
| ("NotClaude:", True), # 会匹配,因为包含Claude: |
| ] |
| |
| for content, expected in test_cases: |
| with self.subTest(content=content): |
| result = self.connection._is_claude_mention(content) |
| self.assertEqual(result, expected, f"Content: '{content}' expected {expected}") |
| |
| async def test_message_listener_loop(self): |
| """测试消息监听循环""" |
| # 设置模拟环境 |
| self.connection.connected = True |
| self.connection.shell = Mock() |
| self.connection.shell.recv_ready.side_effect = [True, False, False] # 第一次有数据,然后没有 |
| self.connection.shell.recv.return_value = b"<user> test message\n" |
| |
| # 模拟处理消息的方法 |
| process_line_mock = AsyncMock() |
| self.connection._process_line = process_line_mock |
| |
| # 创建一个快速结束的监听任务 |
| async def quick_listener(): |
| # 模拟一次循环 |
| if self.connection.shell and self.connection.shell.recv_ready(): |
| data = self.connection.shell.recv(1024).decode('utf-8', errors='ignore') |
| lines = data.split('\n') |
| for line in lines[:-1]: |
| if line.strip(): |
| await self.connection._process_line(line.strip()) |
| self.connection.connected = False # 结束循环 |
| |
| # 替换监听方法 |
| original_listener = self.connection._message_listener |
| self.connection._message_listener = quick_listener |
| |
| # 运行监听任务 |
| await self.connection._message_listener() |
| |
| # 验证 |
| process_line_mock.assert_called_once_with("<user> test message") |
| |
| # 恢复原方法 |
| self.connection._message_listener = original_listener |
| |
| async def test_message_listener_exception_handling(self): |
| """测试消息监听异常处理""" |
| self.connection.connected = True |
| self.connection.shell = Mock() |
| self.connection.shell.recv_ready.side_effect = Exception("recv_ready failed") |
| |
| # 创建一个快速结束的监听任务 |
| async def quick_listener_with_exception(): |
| try: |
| if self.connection.shell and self.connection.shell.recv_ready(): |
| pass # 这会抛出异常 |
| except Exception: |
| # 模拟异常处理 |
| pass |
| self.connection.connected = False # 结束循环 |
| |
| # 替换监听方法 |
| original_listener = self.connection._message_listener |
| self.connection._message_listener = quick_listener_with_exception |
| |
| # 应该不抛出异常 |
| await self.connection._message_listener() |
| |
| # 恢复原方法 |
| self.connection._message_listener = original_listener |
| |
| async def test_process_line_with_callbacks(self): |
| """测试行处理和回调调用""" |
| # 设置回调 |
| message_callback = Mock() |
| mention_callback = AsyncMock() |
| self.connection.add_message_callback(message_callback) |
| self.connection.add_mention_callback(mention_callback) |
| |
| # 处理普通消息 |
| await self.connection._process_line("<user> normal message") |
| |
| # 验证消息回调被调用 |
| self.assertEqual(message_callback.call_count, 1) |
| mention_callback.assert_not_called() |
| |
| # 处理@Claude提及 |
| await self.connection._process_line("<user> @Claude help me") |
| |
| # 验证提及回调被调用 |
| self.assertEqual(message_callback.call_count, 2) |
| mention_callback.assert_called_once() |
| |
| async def test_process_line_callback_exceptions(self): |
| """测试回调函数异常处理""" |
| # 设置抛出异常的回调 |
| failing_callback = Mock(side_effect=Exception("Callback failed")) |
| self.connection.add_message_callback(failing_callback) |
| |
| # 应该不抛出异常 |
| await self.connection._process_line("<user> test message") |
| |
| # 验证回调被调用但异常被捕获 |
| failing_callback.assert_called_once() |
| |
| async def test_process_line_mention_callback_types(self): |
| """测试不同类型的提及回调""" |
| # 同步回调 |
| sync_callback = Mock() |
| # 异步回调 |
| async_callback = AsyncMock() |
| |
| self.connection.add_mention_callback(sync_callback) |
| self.connection.add_mention_callback(async_callback) |
| |
| # 处理@Claude提及 |
| await self.connection._process_line("<user> @Claude help") |
| |
| # 验证两种回调都被调用 |
| sync_callback.assert_called_once() |
| async_callback.assert_called_once() |
| |
| async def test_process_line_exception(self): |
| """测试处理行时的异常""" |
| # 模拟解析异常 |
| with patch.object(self.connection, '_parse_message', side_effect=Exception("Parse failed")): |
| # 应该不抛出异常 |
| await self.connection._process_line("<user> test") |
| |
| |
| class TestSSHOUTIntegrationExtended(unittest.IsolatedAsyncioTestCase): |
| """SSHOUT集成的扩展测试""" |
| |
| def setUp(self): |
| """设置测试环境""" |
| self.mock_agent = Mock() |
| |
| # 设置配置管理器mock |
| self.mock_config_manager = Mock() |
| |
| # 设置默认配置 |
| self.mock_config = { |
| 'server': { |
| 'hostname': 'test.example.com', |
| 'port': 22333, |
| 'username': 'testuser' |
| }, |
| 'ssh_key': { |
| 'private_key_path': '/fake/key/path' |
| }, |
| 'message': { |
| 'context_count': 5, |
| 'max_reply_length': 0 # 默认无限制 |
| }, |
| 'mention_patterns': ['@Claude', 'Claude:'] |
| } |
| |
| self.mock_config_manager.get_sshout_config.return_value = self.mock_config |
| self.mock_config_manager.get.side_effect = lambda key, default=None: { |
| 'sshout.message.context_count': 5, |
| 'sshout.message.max_reply_length': 0 # 默认无限制 |
| }.get(key, default) |
| |
| # Mock文件存在检查 |
| with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager): |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| self.integration = SSHOUTIntegration(self.mock_agent) |
| |
| def test_validate_config_missing_section(self): |
| """测试配置验证 - 缺少段落""" |
| with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config: |
| mock_config_manager = Mock() |
| mock_get_config.return_value = mock_config_manager |
| |
| # 缺少ssh_key段落 |
| mock_config_manager.get_sshout_config.return_value = { |
| 'server': {'hostname': 'test', 'port': 22, 'username': 'user'} |
| } |
| |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| with self.assertRaises(ValueError) as cm: |
| SSHOUTIntegration(Mock()) |
| |
| self.assertIn("SSHOUT配置缺少必需的段落: ssh_key", str(cm.exception)) |
| |
| def test_validate_config_missing_server_key(self): |
| """测试配置验证 - 缺少服务器键""" |
| with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config: |
| mock_config_manager = Mock() |
| mock_get_config.return_value = mock_config_manager |
| |
| # 缺少port |
| mock_config_manager.get_sshout_config.return_value = { |
| 'server': {'hostname': 'test', 'username': 'user'}, # 缺少port |
| 'ssh_key': {'private_key_path': '/fake/key'} |
| } |
| |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| with self.assertRaises(ValueError) as cm: |
| SSHOUTIntegration(Mock()) |
| |
| self.assertIn("SSHOUT服务器配置缺少必需的键: port", str(cm.exception)) |
| |
| def test_validate_config_missing_key_path(self): |
| """测试配置验证 - 缺少SSH密钥路径""" |
| with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config: |
| mock_config_manager = Mock() |
| mock_get_config.return_value = mock_config_manager |
| |
| mock_config_manager.get_sshout_config.return_value = { |
| 'server': {'hostname': 'test', 'port': 22, 'username': 'user'}, |
| 'ssh_key': {} # 缺少private_key_path |
| } |
| |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| with self.assertRaises(ValueError) as cm: |
| SSHOUTIntegration(Mock()) |
| |
| self.assertIn("SSHOUT配置缺少SSH私钥路径", str(cm.exception)) |
| |
| def test_validate_config_key_file_not_exists(self): |
| """测试配置验证 - SSH密钥文件不存在""" |
| with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config: |
| mock_config_manager = Mock() |
| mock_get_config.return_value = mock_config_manager |
| |
| mock_config_manager.get_sshout_config.return_value = { |
| 'server': {'hostname': 'test', 'port': 22, 'username': 'user'}, |
| 'ssh_key': {'private_key_path': '/nonexistent/key'} |
| } |
| |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=False): |
| with self.assertRaises(FileNotFoundError) as cm: |
| SSHOUTIntegration(Mock()) |
| |
| self.assertIn("SSH私钥文件不存在", str(cm.exception)) |
| |
| @patch('claude_agent.sshout.integration.SSHOUTConnection') |
| async def test_connect_to_sshout_success(self, mock_connection_class): |
| """测试成功连接到SSHOUT""" |
| # 设置mock连接 |
| mock_connection = AsyncMock() |
| mock_connection.connect.return_value = True |
| mock_connection_class.return_value = mock_connection |
| |
| result = await self.integration.connect_to_sshout() |
| |
| self.assertTrue(result) |
| self.assertEqual(self.integration.connection, mock_connection) |
| mock_connection.add_message_callback.assert_called_once() |
| mock_connection.add_mention_callback.assert_called_once() |
| mock_connection.connect.assert_called_once() |
| |
| @patch('claude_agent.sshout.integration.SSHOUTConnection') |
| async def test_connect_to_sshout_failure(self, mock_connection_class): |
| """测试连接SSHOUT失败""" |
| mock_connection = AsyncMock() |
| mock_connection.connect.return_value = False |
| mock_connection_class.return_value = mock_connection |
| |
| result = await self.integration.connect_to_sshout() |
| |
| self.assertFalse(result) |
| |
| @patch('claude_agent.sshout.integration.SSHOUTConnection') |
| async def test_connect_to_sshout_exception(self, mock_connection_class): |
| """测试连接SSHOUT异常""" |
| mock_connection_class.side_effect = Exception("Connection creation failed") |
| |
| result = await self.integration.connect_to_sshout() |
| |
| self.assertFalse(result) |
| |
| def test_on_message_received(self): |
| """测试普通消息接收处理""" |
| timestamp = datetime.now() |
| message = SSHOUTMessage( |
| timestamp=timestamp, |
| username="testuser", |
| content="test message" |
| ) |
| |
| # 应该不抛出异常 |
| self.integration._on_message_received(message) |
| |
| async def test_on_claude_mentioned_success(self): |
| """测试成功处理@Claude提及""" |
| # 设置mock连接 |
| mock_connection = Mock() |
| context_messages = [ |
| SSHOUTMessage(datetime.now(), "user1", "context message 1"), |
| SSHOUTMessage(datetime.now(), "user2", "context message 2") |
| ] |
| mock_connection.get_context_messages.return_value = context_messages |
| mock_connection.send_message = AsyncMock(return_value=True) |
| self.integration.connection = mock_connection |
| |
| # 设置Agent mock |
| self.mock_agent.process_user_input = AsyncMock(return_value="Agent response") |
| |
| # 创建提及消息 |
| mention_message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| username="testuser", |
| content="@Claude help me" |
| ) |
| |
| await self.integration._on_claude_mentioned(mention_message) |
| |
| # 验证 |
| mock_connection.get_context_messages.assert_called_once() |
| self.mock_agent.process_user_input.assert_called_once() |
| mock_connection.send_message.assert_called_once() |
| |
| async def test_on_claude_mentioned_no_context(self): |
| """测试处理@Claude提及 - 无上下文消息""" |
| # 设置mock连接 |
| mock_connection = Mock() |
| mock_connection.get_context_messages.return_value = [] # 无上下文 |
| mock_connection.send_message = AsyncMock(return_value=True) |
| self.integration.connection = mock_connection |
| |
| # 设置Agent mock |
| self.mock_agent.process_user_input = AsyncMock(return_value="Agent response") |
| |
| # 创建提及消息 |
| mention_message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| username="testuser", |
| content="@Claude help" |
| ) |
| |
| await self.integration._on_claude_mentioned(mention_message) |
| |
| # 验证调用了Agent |
| self.mock_agent.process_user_input.assert_called_once() |
| # 验证提示中不包含上下文 |
| call_args = self.mock_agent.process_user_input.call_args[0][0] |
| self.assertNotIn("聊天室上下文消息:", call_args) |
| |
| async def test_on_claude_mentioned_send_failure(self): |
| """测试@Claude提及回复发送失败""" |
| mock_connection = Mock() |
| mock_connection.get_context_messages.return_value = [] |
| mock_connection.send_message = AsyncMock(return_value=False) # 发送失败 |
| self.integration.connection = mock_connection |
| |
| self.mock_agent.process_user_input = AsyncMock(return_value="Agent response") |
| |
| mention_message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| username="testuser", |
| content="@Claude help" |
| ) |
| |
| # 应该不抛出异常 |
| await self.integration._on_claude_mentioned(mention_message) |
| mock_connection.send_message.assert_called_once() |
| |
| async def test_on_claude_mentioned_no_connection(self): |
| """测试@Claude提及处理 - 无连接""" |
| self.integration.connection = None |
| self.mock_agent.process_user_input = AsyncMock(return_value="Response") |
| |
| mention_message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| username="testuser", |
| content="@Claude help" |
| ) |
| |
| # 应该不抛出异常,但由于连接为None会在获取上下文时失败 |
| await self.integration._on_claude_mentioned(mention_message) |
| # 不应该调用agent,因为在获取上下文消息时就失败了 |
| self.mock_agent.process_user_input.assert_not_called() |
| |
| async def test_on_claude_mentioned_exception(self): |
| """测试@Claude提及处理异常""" |
| # 设置Agent抛出异常 |
| self.mock_agent.process_user_input = AsyncMock(side_effect=Exception("Agent failed")) |
| |
| mention_message = SSHOUTMessage( |
| timestamp=datetime.now(), |
| username="testuser", |
| content="@Claude help" |
| ) |
| |
| # 应该不抛出异常 |
| await self.integration._on_claude_mentioned(mention_message) |
| |
| def test_clean_response_markdown_removal(self): |
| """测试响应清理 - Markdown移除""" |
| test_cases = [ |
| ("**bold** text", "bold text"), |
| ("*italic* text", "italic text"), |
| ("`code` text", "code text"), |
| ("**bold** and *italic* together", "bold and italic together"), |
| ("Multiple\n\nlines\nwith spaces", "Multiple\n\nlines\nwith spaces") # 保留换行 |
| ] |
| |
| for input_text, expected in test_cases: |
| with self.subTest(input_text=input_text): |
| result = self.integration._clean_response_for_sshout(input_text) |
| self.assertEqual(result, expected) |
| |
| def test_clean_response_length_limiting(self): |
| """测试响应清理 - 长度限制""" |
| # 创建长文本 |
| long_text = "Very long text. " * 20 # 超过200字符 |
| |
| # 默认配置下(max_reply_length=0),不应截断 |
| result = self.integration._clean_response_for_sshout(long_text) |
| # 由于我们的清理逻辑会移除行尾空格,结果可能会稍短 |
| self.assertGreaterEqual(len(result), len(long_text) - 5) # 允许轻微差异 |
| self.assertFalse(result.endswith("...")) |
| |
| # 设置长度限制时应该截断 |
| self.integration.config_manager.get = Mock(return_value=50) |
| result_limited = self.integration._clean_response_for_sshout(long_text) |
| self.assertLessEqual(len(result_limited), 53) # 50 + "..." |
| self.assertTrue(result_limited.endswith("...")) |
| |
| def test_clean_response_whitespace_normalization(self): |
| """测试响应清理 - 空白字符规范化""" |
| test_text = " Multiple spaces and\n\nnewlines " |
| result = self.integration._clean_response_for_sshout(test_text) |
| |
| # 保留换行和行内空格,仅移除首尾空格和行尾空格 |
| self.assertEqual(result, "Multiple spaces and\n\nnewlines") |
| |
| |
| class TestSSHOUTAsyncMethods(unittest.IsolatedAsyncioTestCase): |
| """测试SSHOUT异步方法""" |
| |
| def setUp(self): |
| """设置测试环境""" |
| # 创建mock配置管理器 |
| self.mock_config_manager = Mock() |
| self.mock_config = { |
| 'server': { |
| 'hostname': 'test.example.com', |
| 'port': 22333, |
| 'username': 'testuser' |
| }, |
| 'ssh_key': { |
| 'private_key_path': '/fake/key/path' |
| } |
| } |
| self.mock_config_manager.get_sshout_config.return_value = self.mock_config |
| self.mock_config_manager.get.return_value = 5 |
| |
| async def test_send_message_not_connected(self): |
| """测试未连接时发送消息""" |
| with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager): |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| integration = SSHOUTIntegration(Mock()) |
| result = await integration.send_message("test message") |
| self.assertFalse(result) |
| |
| async def test_send_message_connected(self): |
| """测试已连接时发送消息""" |
| with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager): |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| integration = SSHOUTIntegration(Mock()) |
| |
| # 设置mock连接 |
| mock_connection = Mock() |
| mock_connection.connected = True |
| mock_connection.send_message = AsyncMock(return_value=True) |
| integration.connection = mock_connection |
| |
| result = await integration.send_message("test message") |
| self.assertTrue(result) |
| mock_connection.send_message.assert_called_once_with("test message") |
| |
| async def test_disconnect_from_sshout(self): |
| """测试断开连接""" |
| with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager): |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| integration = SSHOUTIntegration(Mock()) |
| |
| mock_connection = AsyncMock() |
| mock_connection.disconnect = AsyncMock(return_value=None) |
| integration.connection = mock_connection |
| |
| await integration.disconnect_from_sshout() |
| |
| mock_connection.disconnect.assert_called_once() |
| self.assertIsNone(integration.connection) |
| |
| async def test_disconnect_from_sshout_no_connection(self): |
| """测试断开连接 - 无连接""" |
| with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager): |
| with patch('claude_agent.sshout.integration.os.path.exists', return_value=True): |
| integration = SSHOUTIntegration(Mock()) |
| |
| # 应该不抛出异常 |
| await integration.disconnect_from_sshout() |
| self.assertIsNone(integration.connection) |
| |
| |
| if __name__ == '__main__': |
| # 运行所有测试 |
| unittest.main(verbosity=2) |