blob: 59e34edfd4642574ad379e2b978d29163bbe37c2 [file] [log] [blame] [raw]
"""
SSHOUT集成功能单元测试
测试消息解析、@Claude检测、响应清理等核心功能
"""
import unittest
from unittest.mock import Mock, patch, MagicMock, AsyncMock
import asyncio
from datetime import datetime, timedelta
import sys
import os
# 添加项目根目录到Python路径
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.insert(0, os.path.join(project_root, 'src'))
from claude_agent.sshout.integration import (
SSHOUTConnection,
SSHOUTMessage,
SSHOUTIntegration
)
class TestSSHOUTMessage(unittest.TestCase):
"""测试SSHOUT消息数据结构"""
def test_message_creation(self):
"""测试消息对象创建"""
timestamp = datetime.now()
msg = SSHOUTMessage(
timestamp=timestamp,
username="testuser",
content="test message",
is_mention=True
)
self.assertEqual(msg.timestamp, timestamp)
self.assertEqual(msg.username, "testuser")
self.assertEqual(msg.content, "test message")
self.assertTrue(msg.is_mention)
def test_message_defaults(self):
"""测试消息默认值"""
timestamp = datetime.now()
msg = SSHOUTMessage(
timestamp=timestamp,
username="testuser",
content="test message"
)
self.assertFalse(msg.is_mention)
class TestSSHOUTConnection(unittest.TestCase):
"""测试SSHOUT连接管理器"""
def setUp(self):
"""设置测试环境"""
self.connection = SSHOUTConnection(
hostname="test.example.com",
port=22,
username="testuser",
key_path="/path/to/key"
)
def test_connection_initialization(self):
"""测试连接初始化"""
self.assertEqual(self.connection.hostname, "test.example.com")
self.assertEqual(self.connection.port, 22)
self.assertEqual(self.connection.username, "testuser")
self.assertEqual(self.connection.key_path, "/path/to/key")
self.assertFalse(self.connection.connected)
self.assertEqual(len(self.connection.message_callbacks), 0)
self.assertEqual(len(self.connection.mention_callbacks), 0)
def test_message_parsing_with_timestamp(self):
"""测试带时间戳的消息解析"""
test_cases = [
("[14:30:25] <user1> 大家好!", "user1", "大家好!"),
("[09:15:00] <alice> @Claude 你好", "alice", "@Claude 你好"),
("[23:59:59] <bob> 测试消息", "bob", "测试消息")
]
for line, expected_user, expected_content in test_cases:
with self.subTest(line=line):
msg = self.connection._parse_message(line)
self.assertIsNotNone(msg)
self.assertEqual(msg.username, expected_user)
self.assertEqual(msg.content, expected_content)
self.assertIsInstance(msg.timestamp, datetime)
def test_message_parsing_without_timestamp(self):
"""测试不带时间戳的消息解析"""
test_cases = [
("<user1> 大家好!", "user1", "大家好!"),
("<alice> @Claude 你好", "alice", "@Claude 你好"),
("bob: 测试消息", "bob", "测试消息")
]
for line, expected_user, expected_content in test_cases:
with self.subTest(line=line):
msg = self.connection._parse_message(line)
self.assertIsNotNone(msg)
self.assertEqual(msg.username, expected_user)
self.assertEqual(msg.content, expected_content)
def test_message_parsing_invalid_format(self):
"""测试无效格式消息解析"""
invalid_lines = [
"",
"plain text without format",
"< >",
":",
"[invalid] format"
]
for line in invalid_lines:
with self.subTest(line=line):
msg = self.connection._parse_message(line)
self.assertIsNone(msg)
def test_claude_mention_detection(self):
"""测试@Claude提及检测"""
positive_cases = [
"@Claude 你好",
"@claude 帮我一下",
"@CLAUDE 测试",
"Claude: 你在吗?",
"claude, 有问题",
"Hello @Claude!",
"hey Claude, what's up?"
]
for content in positive_cases:
with self.subTest(content=content):
result = self.connection._is_claude_mention(content)
self.assertTrue(result, f"应该检测到 '@Claude' 提及: {content}")
def test_claude_mention_detection_negative(self):
"""测试非@Claude提及"""
negative_cases = [
"普通消息,不包含提及",
"这是一条正常的聊天消息",
"Claudia is a nice name", # 相似但不同的名字
"claude在句子中间但没有标点", # 没有合适的标点
""
]
for content in negative_cases:
with self.subTest(content=content):
result = self.connection._is_claude_mention(content)
self.assertFalse(result, f"不应该检测到 '@Claude' 提及: {content}")
def test_message_history_management(self):
"""测试消息历史管理"""
# 添加消息到历史
for i in range(5):
msg = SSHOUTMessage(
timestamp=datetime.now(),
username=f"user{i}",
content=f"message {i}"
)
self.connection.message_history.append(msg)
self.assertEqual(len(self.connection.message_history), 5)
# 测试获取最近消息
recent = self.connection.get_recent_messages(3)
self.assertEqual(len(recent), 3)
self.assertEqual(recent[-1].content, "message 4") # 最新的消息
def test_context_messages_retrieval(self):
"""测试上下文消息获取"""
# 创建时间序列消息
base_time = datetime.now()
messages = []
for i in range(10):
# 使用timedelta来正确处理时间增量,避免秒数超出范围
msg_time = base_time + timedelta(seconds=i)
msg = SSHOUTMessage(
timestamp=msg_time,
username=f"user{i}",
content=f"message {i}"
)
messages.append(msg)
self.connection.message_history.append(msg)
# 获取第8条消息前的上下文
target_time = messages[7].timestamp
context = self.connection.get_context_messages(target_time, count=3)
self.assertEqual(len(context), 3)
# 应该获取到消息4、5、6(按时间顺序)
self.assertEqual(context[0].content, "message 4")
self.assertEqual(context[1].content, "message 5")
self.assertEqual(context[2].content, "message 6")
def test_callback_management(self):
"""测试回调函数管理"""
# 测试消息回调
message_callback = Mock()
self.connection.add_message_callback(message_callback)
self.assertEqual(len(self.connection.message_callbacks), 1)
# 测试提及回调
mention_callback = Mock()
self.connection.add_mention_callback(mention_callback)
self.assertEqual(len(self.connection.mention_callbacks), 1)
class TestSSHOUTIntegration(unittest.TestCase):
"""测试SSHOUT集成类"""
def setUp(self):
"""设置测试环境"""
self.mock_agent = Mock()
self.integration = SSHOUTIntegration(self.mock_agent)
def test_integration_initialization(self):
"""测试集成类初始化"""
self.assertEqual(self.integration.agent, self.mock_agent)
self.assertIsNone(self.integration.connection)
# 测试新的配置结构
self.assertIn('server', self.integration.sshout_config)
self.assertIn('ssh_key', self.integration.sshout_config)
self.assertIn('message', self.integration.sshout_config)
# 测试服务器配置
server_config = self.integration.sshout_config['server']
self.assertIn('hostname', server_config)
self.assertIn('port', server_config)
self.assertIn('username', server_config)
# 测试SSH密钥配置
ssh_config = self.integration.sshout_config['ssh_key']
self.assertIn('private_key_path', ssh_config)
def test_response_cleaning(self):
"""测试响应清理功能"""
test_cases = [
("**粗体文本**", "粗体文本"),
("*斜体文本*", "斜体文本"),
("`代码文本`", "代码文本"),
("**粗体** 和 *斜体* 混合", "粗体 和 斜体 混合"),
("单个\n换行", "单个\n换行"), # 保留单个换行
("两个\n\n换行", "两个\n\n换行"), # 保留两个换行
("三个\n\n\n换行", "三个\n\n换行"), # 压缩为两个换行
("多个\n\n\n\n\n换行", "多个\n\n换行"), # 压缩为两个换行
(" 多个 空格 测试 ", "多个 空格 测试"), # 保留行内空格,仅移除首尾空格
("行尾空格 \n 行首空格", "行尾空格\n 行首空格"), # 保留行首空格,移除行尾空格
]
for input_text, expected in test_cases:
with self.subTest(input_text=input_text):
result = self.integration._clean_response_for_sshout(input_text)
self.assertEqual(result, expected)
def test_response_length_limit(self):
"""测试响应长度限制"""
# 默认配置下(max_reply_length=0),不应截断
long_text = "这是一个很长的文本。" * 50 # 超过200字符
result = self.integration._clean_response_for_sshout(long_text)
# 验证没有被截断
self.assertEqual(len(result), len(long_text))
self.assertFalse(result.endswith("..."))
# 测试当设置了长度限制时的截断行为
self.integration.config_manager.get = Mock(return_value=50)
result_limited = self.integration._clean_response_for_sshout(long_text)
self.assertLessEqual(len(result_limited), 53) # 50 + "..."
self.assertTrue(result_limited.endswith("..."))
def test_connection_status_disconnected(self):
"""测试未连接状态"""
status = self.integration.get_connection_status()
self.assertFalse(status['connected'])
self.assertIsNone(status['server'])
self.assertEqual(status['message_count'], 0)
@patch('claude_agent.sshout.integration.SSHOUTConnection')
def test_connection_status_connected(self, mock_connection_class):
"""测试已连接状态"""
# 设置mock连接
mock_connection = Mock()
mock_connection.connected = True
mock_connection.hostname = "test.example.com"
mock_connection.port = 22333
mock_connection.message_history = [Mock() for _ in range(5)]
# Mock recent messages
mock_msg = Mock()
mock_msg.timestamp.strftime.return_value = "14:30:25"
mock_msg.username = "testuser"
mock_msg.content = "test message"
mock_connection.get_recent_messages.return_value = [mock_msg]
self.integration.connection = mock_connection
status = self.integration.get_connection_status()
self.assertTrue(status['connected'])
self.assertEqual(status['server'], "test.example.com:22333")
self.assertEqual(status['message_count'], 5)
self.assertEqual(len(status['recent_messages']), 1)
class TestSSHOUTConnectionExtended(unittest.IsolatedAsyncioTestCase):
"""SSHOUT连接的扩展测试"""
def setUp(self):
"""设置测试环境"""
self.connection = SSHOUTConnection(
hostname="test.example.com",
port=22,
username="testuser",
key_path="/path/to/key"
)
@patch('paramiko.ECDSAKey.from_private_key_file')
@patch('paramiko.SSHClient')
async def test_connect_success(self, mock_ssh_client_class, mock_key):
"""测试成功连接"""
# 设置mocks
mock_client = Mock()
mock_ssh_client_class.return_value = mock_client
mock_shell = Mock()
mock_client.invoke_shell.return_value = mock_shell
# 设置私钥mock
mock_private_key = Mock()
mock_key.return_value = mock_private_key
# 运行连接
result = await self.connection.connect()
# 验证结果
self.assertTrue(result)
self.assertTrue(self.connection.connected)
mock_client.set_missing_host_key_policy.assert_called_once()
mock_client.connect.assert_called_once_with(
hostname="test.example.com",
port=22,
username="testuser",
pkey=mock_private_key,
timeout=10
)
mock_shell.settimeout.assert_called_once_with(0.1)
@patch('paramiko.ECDSAKey.from_private_key_file')
async def test_connect_key_load_failure(self, mock_key):
"""测试私钥加载失败"""
mock_key.side_effect = Exception("Key load failed")
result = await self.connection.connect()
self.assertFalse(result)
self.assertFalse(self.connection.connected)
@patch('paramiko.ECDSAKey.from_private_key_file')
@patch('paramiko.SSHClient')
async def test_connect_ssh_failure(self, mock_ssh_client_class, mock_key):
"""测试SSH连接失败"""
mock_client = Mock()
mock_ssh_client_class.return_value = mock_client
mock_client.connect.side_effect = Exception("SSH connection failed")
result = await self.connection.connect()
self.assertFalse(result)
self.assertFalse(self.connection.connected)
async def test_disconnect_cleanup(self):
"""测试断开连接清理"""
# 模拟已连接状态
self.connection.connected = True
mock_shell = Mock()
mock_client = Mock()
self.connection.shell = mock_shell
self.connection.client = mock_client
await self.connection.disconnect()
self.assertFalse(self.connection.connected)
mock_shell.close.assert_called_once()
mock_client.close.assert_called_once()
self.assertIsNone(self.connection.shell)
self.assertIsNone(self.connection.client)
async def test_disconnect_with_exception(self):
"""测试断开连接时的异常处理"""
self.connection.connected = True
self.connection.shell = Mock()
self.connection.shell.close.side_effect = Exception("Close failed")
self.connection.client = Mock()
# 应该不抛出异常
await self.connection.disconnect()
self.assertFalse(self.connection.connected)
async def test_send_message_success(self):
"""测试成功发送消息"""
# 模拟已连接状态
self.connection.connected = True
self.connection.shell = Mock()
result = await self.connection.send_message("test message")
self.assertTrue(result)
self.connection.shell.send.assert_called_once_with("test message\n")
async def test_send_message_not_connected(self):
"""测试未连接时发送消息"""
result = await self.connection.send_message("test message")
self.assertFalse(result)
async def test_send_message_no_shell(self):
"""测试没有shell时发送消息"""
self.connection.connected = True
self.connection.shell = None
result = await self.connection.send_message("test message")
self.assertFalse(result)
async def test_send_message_exception(self):
"""测试发送消息时的异常"""
self.connection.connected = True
self.connection.shell = Mock()
self.connection.shell.send.side_effect = Exception("Send failed")
result = await self.connection.send_message("test message")
self.assertFalse(result)
def test_clean_ansi_codes(self):
"""测试ANSI码清理"""
test_cases = [
("\x1b[1;34mblue text\x1b[0m", "blue text"),
("[1;34mfake ansi[0m", "fake ansi"),
("normal text", "normal text"),
("\x1b[Kmixed\x1b[31m text\x1b[0m", "mixed text")
]
for input_text, expected in test_cases:
with self.subTest(input_text=input_text):
result = self.connection._clean_ansi_codes(input_text)
self.assertEqual(result, expected)
def test_parse_message_complex_formats(self):
"""测试复杂消息格式解析"""
test_cases = [
("[14:30:25] testuser: message with colon format", "testuser", "message with colon format"),
("\x1b[1;34m[14:30:25] <user> \x1b[0mcolored message", "user", "colored message"),
("[1;34m[14:30:25][0m user: pseudo ansi", "user", "pseudo ansi"),
("[invalid timestamp] <user> message", None, None),
("malformed line without proper format", None, None)
]
for line, expected_user, expected_content in test_cases:
with self.subTest(line=line):
msg = self.connection._parse_message(line)
if expected_user is None:
self.assertIsNone(msg)
else:
self.assertIsNotNone(msg)
self.assertEqual(msg.username, expected_user)
self.assertEqual(msg.content, expected_content)
def test_parse_message_timestamp_parsing(self):
"""测试时间戳解析"""
# 测试有效时间戳
msg = self.connection._parse_message("[14:30:25] <user> test")
self.assertIsNotNone(msg)
self.assertEqual(msg.timestamp.hour, 14)
self.assertEqual(msg.timestamp.minute, 30)
self.assertEqual(msg.timestamp.second, 25)
# 测试无效时间戳格式 - 实际上会被忽略,因为不匹配任何模式
msg = self.connection._parse_message("[invalid] <user> test")
self.assertIsNone(msg) # 实际上不会匹配任何解析模式
def test_parse_message_exception_handling(self):
"""测试消息解析异常处理"""
# 使用mock来模拟内部异常
with patch.object(self.connection, '_clean_ansi_codes', side_effect=Exception("Clean failed")):
msg = self.connection._parse_message("[14:30:25] <user> test")
self.assertIsNone(msg)
def test_mention_patterns_advanced(self):
"""测试高级@Claude检测模式"""
test_cases = [
("@Claude123", False), # 应该不匹配,因为后面有数字
("@ClaudeAI", False), # 应该不匹配,因为后面有字母
("@Claude!", True), # 应该匹配,感叹号不是字母数字
("@Claude。", True), # 应该匹配,中文标点
("Claude: 你好", True), # 冒号格式
("Claude,测试", True), # 中文逗号格式
("NotClaude:", True), # 会匹配,因为包含Claude:
]
for content, expected in test_cases:
with self.subTest(content=content):
result = self.connection._is_claude_mention(content)
self.assertEqual(result, expected, f"Content: '{content}' expected {expected}")
async def test_message_listener_loop(self):
"""测试消息监听循环"""
# 设置模拟环境
self.connection.connected = True
self.connection.shell = Mock()
self.connection.shell.recv_ready.side_effect = [True, False, False] # 第一次有数据,然后没有
self.connection.shell.recv.return_value = b"<user> test message\n"
# 模拟处理消息的方法
process_line_mock = AsyncMock()
self.connection._process_line = process_line_mock
# 创建一个快速结束的监听任务
async def quick_listener():
# 模拟一次循环
if self.connection.shell and self.connection.shell.recv_ready():
data = self.connection.shell.recv(1024).decode('utf-8', errors='ignore')
lines = data.split('\n')
for line in lines[:-1]:
if line.strip():
await self.connection._process_line(line.strip())
self.connection.connected = False # 结束循环
# 替换监听方法
original_listener = self.connection._message_listener
self.connection._message_listener = quick_listener
# 运行监听任务
await self.connection._message_listener()
# 验证
process_line_mock.assert_called_once_with("<user> test message")
# 恢复原方法
self.connection._message_listener = original_listener
async def test_message_listener_exception_handling(self):
"""测试消息监听异常处理"""
self.connection.connected = True
self.connection.shell = Mock()
self.connection.shell.recv_ready.side_effect = Exception("recv_ready failed")
# 创建一个快速结束的监听任务
async def quick_listener_with_exception():
try:
if self.connection.shell and self.connection.shell.recv_ready():
pass # 这会抛出异常
except Exception:
# 模拟异常处理
pass
self.connection.connected = False # 结束循环
# 替换监听方法
original_listener = self.connection._message_listener
self.connection._message_listener = quick_listener_with_exception
# 应该不抛出异常
await self.connection._message_listener()
# 恢复原方法
self.connection._message_listener = original_listener
async def test_process_line_with_callbacks(self):
"""测试行处理和回调调用"""
# 设置回调
message_callback = Mock()
mention_callback = AsyncMock()
self.connection.add_message_callback(message_callback)
self.connection.add_mention_callback(mention_callback)
# 处理普通消息
await self.connection._process_line("<user> normal message")
# 验证消息回调被调用
self.assertEqual(message_callback.call_count, 1)
mention_callback.assert_not_called()
# 处理@Claude提及
await self.connection._process_line("<user> @Claude help me")
# 验证提及回调被调用
self.assertEqual(message_callback.call_count, 2)
mention_callback.assert_called_once()
async def test_process_line_callback_exceptions(self):
"""测试回调函数异常处理"""
# 设置抛出异常的回调
failing_callback = Mock(side_effect=Exception("Callback failed"))
self.connection.add_message_callback(failing_callback)
# 应该不抛出异常
await self.connection._process_line("<user> test message")
# 验证回调被调用但异常被捕获
failing_callback.assert_called_once()
async def test_process_line_mention_callback_types(self):
"""测试不同类型的提及回调"""
# 同步回调
sync_callback = Mock()
# 异步回调
async_callback = AsyncMock()
self.connection.add_mention_callback(sync_callback)
self.connection.add_mention_callback(async_callback)
# 处理@Claude提及
await self.connection._process_line("<user> @Claude help")
# 验证两种回调都被调用
sync_callback.assert_called_once()
async_callback.assert_called_once()
async def test_process_line_exception(self):
"""测试处理行时的异常"""
# 模拟解析异常
with patch.object(self.connection, '_parse_message', side_effect=Exception("Parse failed")):
# 应该不抛出异常
await self.connection._process_line("<user> test")
class TestSSHOUTIntegrationExtended(unittest.IsolatedAsyncioTestCase):
"""SSHOUT集成的扩展测试"""
def setUp(self):
"""设置测试环境"""
self.mock_agent = Mock()
# 设置配置管理器mock
self.mock_config_manager = Mock()
# 设置默认配置
self.mock_config = {
'server': {
'hostname': 'test.example.com',
'port': 22333,
'username': 'testuser'
},
'ssh_key': {
'private_key_path': '/fake/key/path'
},
'message': {
'context_count': 5,
'max_reply_length': 0 # 默认无限制
},
'mention_patterns': ['@Claude', 'Claude:']
}
self.mock_config_manager.get_sshout_config.return_value = self.mock_config
self.mock_config_manager.get.side_effect = lambda key, default=None: {
'sshout.message.context_count': 5,
'sshout.message.max_reply_length': 0 # 默认无限制
}.get(key, default)
# Mock文件存在检查
with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager):
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
self.integration = SSHOUTIntegration(self.mock_agent)
def test_validate_config_missing_section(self):
"""测试配置验证 - 缺少段落"""
with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config:
mock_config_manager = Mock()
mock_get_config.return_value = mock_config_manager
# 缺少ssh_key段落
mock_config_manager.get_sshout_config.return_value = {
'server': {'hostname': 'test', 'port': 22, 'username': 'user'}
}
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
with self.assertRaises(ValueError) as cm:
SSHOUTIntegration(Mock())
self.assertIn("SSHOUT配置缺少必需的段落: ssh_key", str(cm.exception))
def test_validate_config_missing_server_key(self):
"""测试配置验证 - 缺少服务器键"""
with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config:
mock_config_manager = Mock()
mock_get_config.return_value = mock_config_manager
# 缺少port
mock_config_manager.get_sshout_config.return_value = {
'server': {'hostname': 'test', 'username': 'user'}, # 缺少port
'ssh_key': {'private_key_path': '/fake/key'}
}
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
with self.assertRaises(ValueError) as cm:
SSHOUTIntegration(Mock())
self.assertIn("SSHOUT服务器配置缺少必需的键: port", str(cm.exception))
def test_validate_config_missing_key_path(self):
"""测试配置验证 - 缺少SSH密钥路径"""
with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config:
mock_config_manager = Mock()
mock_get_config.return_value = mock_config_manager
mock_config_manager.get_sshout_config.return_value = {
'server': {'hostname': 'test', 'port': 22, 'username': 'user'},
'ssh_key': {} # 缺少private_key_path
}
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
with self.assertRaises(ValueError) as cm:
SSHOUTIntegration(Mock())
self.assertIn("SSHOUT配置缺少SSH私钥路径", str(cm.exception))
def test_validate_config_key_file_not_exists(self):
"""测试配置验证 - SSH密钥文件不存在"""
with patch('claude_agent.sshout.integration.get_config_manager') as mock_get_config:
mock_config_manager = Mock()
mock_get_config.return_value = mock_config_manager
mock_config_manager.get_sshout_config.return_value = {
'server': {'hostname': 'test', 'port': 22, 'username': 'user'},
'ssh_key': {'private_key_path': '/nonexistent/key'}
}
with patch('claude_agent.sshout.integration.os.path.exists', return_value=False):
with self.assertRaises(FileNotFoundError) as cm:
SSHOUTIntegration(Mock())
self.assertIn("SSH私钥文件不存在", str(cm.exception))
@patch('claude_agent.sshout.integration.SSHOUTConnection')
async def test_connect_to_sshout_success(self, mock_connection_class):
"""测试成功连接到SSHOUT"""
# 设置mock连接
mock_connection = AsyncMock()
mock_connection.connect.return_value = True
mock_connection_class.return_value = mock_connection
result = await self.integration.connect_to_sshout()
self.assertTrue(result)
self.assertEqual(self.integration.connection, mock_connection)
mock_connection.add_message_callback.assert_called_once()
mock_connection.add_mention_callback.assert_called_once()
mock_connection.connect.assert_called_once()
@patch('claude_agent.sshout.integration.SSHOUTConnection')
async def test_connect_to_sshout_failure(self, mock_connection_class):
"""测试连接SSHOUT失败"""
mock_connection = AsyncMock()
mock_connection.connect.return_value = False
mock_connection_class.return_value = mock_connection
result = await self.integration.connect_to_sshout()
self.assertFalse(result)
@patch('claude_agent.sshout.integration.SSHOUTConnection')
async def test_connect_to_sshout_exception(self, mock_connection_class):
"""测试连接SSHOUT异常"""
mock_connection_class.side_effect = Exception("Connection creation failed")
result = await self.integration.connect_to_sshout()
self.assertFalse(result)
def test_on_message_received(self):
"""测试普通消息接收处理"""
timestamp = datetime.now()
message = SSHOUTMessage(
timestamp=timestamp,
username="testuser",
content="test message"
)
# 应该不抛出异常
self.integration._on_message_received(message)
async def test_on_claude_mentioned_success(self):
"""测试成功处理@Claude提及"""
# 设置mock连接
mock_connection = Mock()
context_messages = [
SSHOUTMessage(datetime.now(), "user1", "context message 1"),
SSHOUTMessage(datetime.now(), "user2", "context message 2")
]
mock_connection.get_context_messages.return_value = context_messages
mock_connection.send_message = AsyncMock(return_value=True)
self.integration.connection = mock_connection
# 设置Agent mock
self.mock_agent.process_user_input = AsyncMock(return_value="Agent response")
# 创建提及消息
mention_message = SSHOUTMessage(
timestamp=datetime.now(),
username="testuser",
content="@Claude help me"
)
await self.integration._on_claude_mentioned(mention_message)
# 验证
mock_connection.get_context_messages.assert_called_once()
self.mock_agent.process_user_input.assert_called_once()
mock_connection.send_message.assert_called_once()
async def test_on_claude_mentioned_no_context(self):
"""测试处理@Claude提及 - 无上下文消息"""
# 设置mock连接
mock_connection = Mock()
mock_connection.get_context_messages.return_value = [] # 无上下文
mock_connection.send_message = AsyncMock(return_value=True)
self.integration.connection = mock_connection
# 设置Agent mock
self.mock_agent.process_user_input = AsyncMock(return_value="Agent response")
# 创建提及消息
mention_message = SSHOUTMessage(
timestamp=datetime.now(),
username="testuser",
content="@Claude help"
)
await self.integration._on_claude_mentioned(mention_message)
# 验证调用了Agent
self.mock_agent.process_user_input.assert_called_once()
# 验证提示中不包含上下文
call_args = self.mock_agent.process_user_input.call_args[0][0]
self.assertNotIn("聊天室上下文消息:", call_args)
async def test_on_claude_mentioned_send_failure(self):
"""测试@Claude提及回复发送失败"""
mock_connection = Mock()
mock_connection.get_context_messages.return_value = []
mock_connection.send_message = AsyncMock(return_value=False) # 发送失败
self.integration.connection = mock_connection
self.mock_agent.process_user_input = AsyncMock(return_value="Agent response")
mention_message = SSHOUTMessage(
timestamp=datetime.now(),
username="testuser",
content="@Claude help"
)
# 应该不抛出异常
await self.integration._on_claude_mentioned(mention_message)
mock_connection.send_message.assert_called_once()
async def test_on_claude_mentioned_no_connection(self):
"""测试@Claude提及处理 - 无连接"""
self.integration.connection = None
self.mock_agent.process_user_input = AsyncMock(return_value="Response")
mention_message = SSHOUTMessage(
timestamp=datetime.now(),
username="testuser",
content="@Claude help"
)
# 应该不抛出异常,但由于连接为None会在获取上下文时失败
await self.integration._on_claude_mentioned(mention_message)
# 不应该调用agent,因为在获取上下文消息时就失败了
self.mock_agent.process_user_input.assert_not_called()
async def test_on_claude_mentioned_exception(self):
"""测试@Claude提及处理异常"""
# 设置Agent抛出异常
self.mock_agent.process_user_input = AsyncMock(side_effect=Exception("Agent failed"))
mention_message = SSHOUTMessage(
timestamp=datetime.now(),
username="testuser",
content="@Claude help"
)
# 应该不抛出异常
await self.integration._on_claude_mentioned(mention_message)
def test_clean_response_markdown_removal(self):
"""测试响应清理 - Markdown移除"""
test_cases = [
("**bold** text", "bold text"),
("*italic* text", "italic text"),
("`code` text", "code text"),
("**bold** and *italic* together", "bold and italic together"),
("Multiple\n\nlines\nwith spaces", "Multiple\n\nlines\nwith spaces") # 保留换行
]
for input_text, expected in test_cases:
with self.subTest(input_text=input_text):
result = self.integration._clean_response_for_sshout(input_text)
self.assertEqual(result, expected)
def test_clean_response_length_limiting(self):
"""测试响应清理 - 长度限制"""
# 创建长文本
long_text = "Very long text. " * 20 # 超过200字符
# 默认配置下(max_reply_length=0),不应截断
result = self.integration._clean_response_for_sshout(long_text)
# 由于我们的清理逻辑会移除行尾空格,结果可能会稍短
self.assertGreaterEqual(len(result), len(long_text) - 5) # 允许轻微差异
self.assertFalse(result.endswith("..."))
# 设置长度限制时应该截断
self.integration.config_manager.get = Mock(return_value=50)
result_limited = self.integration._clean_response_for_sshout(long_text)
self.assertLessEqual(len(result_limited), 53) # 50 + "..."
self.assertTrue(result_limited.endswith("..."))
def test_clean_response_whitespace_normalization(self):
"""测试响应清理 - 空白字符规范化"""
test_text = " Multiple spaces and\n\nnewlines "
result = self.integration._clean_response_for_sshout(test_text)
# 保留换行和行内空格,仅移除首尾空格和行尾空格
self.assertEqual(result, "Multiple spaces and\n\nnewlines")
class TestSSHOUTAsyncMethods(unittest.IsolatedAsyncioTestCase):
"""测试SSHOUT异步方法"""
def setUp(self):
"""设置测试环境"""
# 创建mock配置管理器
self.mock_config_manager = Mock()
self.mock_config = {
'server': {
'hostname': 'test.example.com',
'port': 22333,
'username': 'testuser'
},
'ssh_key': {
'private_key_path': '/fake/key/path'
}
}
self.mock_config_manager.get_sshout_config.return_value = self.mock_config
self.mock_config_manager.get.return_value = 5
async def test_send_message_not_connected(self):
"""测试未连接时发送消息"""
with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager):
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
integration = SSHOUTIntegration(Mock())
result = await integration.send_message("test message")
self.assertFalse(result)
async def test_send_message_connected(self):
"""测试已连接时发送消息"""
with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager):
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
integration = SSHOUTIntegration(Mock())
# 设置mock连接
mock_connection = Mock()
mock_connection.connected = True
mock_connection.send_message = AsyncMock(return_value=True)
integration.connection = mock_connection
result = await integration.send_message("test message")
self.assertTrue(result)
mock_connection.send_message.assert_called_once_with("test message")
async def test_disconnect_from_sshout(self):
"""测试断开连接"""
with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager):
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
integration = SSHOUTIntegration(Mock())
mock_connection = AsyncMock()
mock_connection.disconnect = AsyncMock(return_value=None)
integration.connection = mock_connection
await integration.disconnect_from_sshout()
mock_connection.disconnect.assert_called_once()
self.assertIsNone(integration.connection)
async def test_disconnect_from_sshout_no_connection(self):
"""测试断开连接 - 无连接"""
with patch('claude_agent.sshout.integration.get_config_manager', return_value=self.mock_config_manager):
with patch('claude_agent.sshout.integration.os.path.exists', return_value=True):
integration = SSHOUTIntegration(Mock())
# 应该不抛出异常
await integration.disconnect_from_sshout()
self.assertIsNone(integration.connection)
if __name__ == '__main__':
# 运行所有测试
unittest.main(verbosity=2)