blob: c331ad3628cf53f473e71bbaba6f2484daf6b5d7 [file] [log] [blame] [raw]
"""
MCP工具集成模块的全面单元测试
重点测试异常场景、错误处理和边界条件
"""
import asyncio
import pytest
import json
import logging
from unittest.mock import Mock, AsyncMock, patch, MagicMock, call
from typing import Dict, List, Optional, Any
from src.claude_agent.mcp.integration import MCPToolManager, MCPToolIntegration
class TestMCPToolManager:
"""测试MCP工具管理器的核心功能"""
@pytest.fixture
def tool_manager(self):
"""创建测试用的工具管理器"""
return MCPToolManager()
def test_initialization(self, tool_manager):
"""测试工具管理器初始化"""
assert tool_manager.tools == {}
assert tool_manager.sessions == {}
assert isinstance(tool_manager.logger, logging.Logger)
@pytest.mark.asyncio
async def test_load_mcp_server_success(self, tool_manager):
"""测试成功加载MCP服务器"""
server_name = "test_server"
command = "python"
args = ["-m", "test_mcp_server"]
# Mock MCP相关组件
mock_session = AsyncMock()
mock_tool = Mock()
mock_tool.name = "test_tool"
mock_tool.description = "测试工具"
mock_result = Mock()
mock_result.tools = [mock_tool]
mock_session.initialize.return_value = None
mock_session.list_tools.return_value = mock_result
with patch('src.claude_agent.mcp.integration.StdioServerParameters') as mock_params:
with patch('src.claude_agent.mcp.integration.stdio_client') as mock_client:
with patch('src.claude_agent.mcp.integration.ClientSession') as mock_session_cls:
# 设置异步上下文管理器
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock()))
mock_client.return_value.__aexit__ = AsyncMock(return_value=None)
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=None)
with patch.object(tool_manager.logger, 'info') as mock_log:
result = await tool_manager.load_mcp_server(server_name, command, args)
assert result is True
assert server_name in tool_manager.sessions
assert f"{server_name}.test_tool" in tool_manager.tools
mock_log.assert_called()
@pytest.mark.asyncio
async def test_load_mcp_server_connection_failure(self, tool_manager):
"""测试MCP服务器连接失败"""
server_name = "failing_server"
command = "nonexistent_command"
error_message = "连接失败"
with patch('src.claude_agent.mcp.integration.stdio_client', side_effect=Exception(error_message)):
with patch.object(tool_manager.logger, 'error') as mock_log:
result = await tool_manager.load_mcp_server(server_name, command)
assert result is False
assert server_name not in tool_manager.sessions
mock_log.assert_called_once()
assert error_message in str(mock_log.call_args)
@pytest.mark.asyncio
async def test_load_mcp_server_initialization_failure(self, tool_manager):
"""测试MCP服务器初始化失败"""
server_name = "init_failing_server"
command = "python"
mock_session = AsyncMock()
mock_session.initialize.side_effect = Exception("初始化失败")
with patch('src.claude_agent.mcp.integration.stdio_client') as mock_client:
with patch('src.claude_agent.mcp.integration.ClientSession') as mock_session_cls:
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock()))
mock_client.return_value.__aexit__ = AsyncMock(return_value=None)
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=None)
with patch.object(tool_manager.logger, 'error') as mock_log:
result = await tool_manager.load_mcp_server(server_name, command)
assert result is False
mock_log.assert_called_once()
@pytest.mark.asyncio
async def test_load_mcp_server_with_default_args(self, tool_manager):
"""测试使用默认参数加载MCP服务器"""
server_name = "default_args_server"
command = "python"
mock_session = AsyncMock()
mock_session.initialize.return_value = None
mock_session.list_tools.return_value = Mock(tools=[])
with patch('src.claude_agent.mcp.integration.StdioServerParameters') as mock_params:
with patch('src.claude_agent.mcp.integration.stdio_client') as mock_client:
with patch('src.claude_agent.mcp.integration.ClientSession') as mock_session_cls:
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock()))
mock_client.return_value.__aexit__ = AsyncMock(return_value=None)
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=None)
result = await tool_manager.load_mcp_server(server_name, command)
# 验证使用了默认的空参数列表
mock_params.assert_called_once()
call_args = mock_params.call_args[1]
assert call_args['args'] == []
@pytest.mark.asyncio
async def test_call_tool_success(self, tool_manager):
"""测试成功调用工具"""
tool_name = "test_server.test_tool"
arguments = {"param1": "value1", "param2": 42}
expected_result = {"output": "工具执行结果"}
# 设置工具信息
mock_session = AsyncMock()
mock_session.call_tool.return_value = expected_result
mock_tool_info = Mock()
mock_tool_info.name = "test_tool"
tool_manager.tools[tool_name] = {
'server': 'test_server',
'tool_info': mock_tool_info,
'session': mock_session
}
with patch.object(tool_manager.logger, 'info') as mock_log:
result = await tool_manager.call_tool(tool_name, arguments)
assert result == expected_result
mock_session.call_tool.assert_called_once_with("test_tool", arguments)
mock_log.assert_called_once()
@pytest.mark.asyncio
async def test_call_tool_nonexistent_tool(self, tool_manager):
"""测试调用不存在的工具"""
tool_name = "nonexistent_tool"
arguments = {"param": "value"}
with patch.object(tool_manager.logger, 'error') as mock_log:
result = await tool_manager.call_tool(tool_name, arguments)
assert result is None
mock_log.assert_called_once_with(f"工具不存在: {tool_name}")
@pytest.mark.asyncio
async def test_call_tool_execution_failure(self, tool_manager):
"""测试工具执行失败"""
tool_name = "test_server.failing_tool"
arguments = {"param": "value"}
error_message = "工具执行异常"
mock_session = AsyncMock()
mock_session.call_tool.side_effect = Exception(error_message)
mock_tool_info = Mock()
mock_tool_info.name = "failing_tool"
tool_manager.tools[tool_name] = {
'server': 'test_server',
'tool_info': mock_tool_info,
'session': mock_session
}
with patch.object(tool_manager.logger, 'error') as mock_log:
result = await tool_manager.call_tool(tool_name, arguments)
assert result is None
mock_log.assert_called_once()
assert error_message in str(mock_log.call_args)
@pytest.mark.asyncio
async def test_call_tool_with_invalid_arguments(self, tool_manager):
"""测试使用无效参数调用工具"""
tool_name = "test_server.test_tool"
invalid_arguments = {"invalid": object()} # object不能序列化
mock_session = AsyncMock()
mock_session.call_tool.side_effect = TypeError("参数序列化失败")
mock_tool_info = Mock()
mock_tool_info.name = "test_tool"
tool_manager.tools[tool_name] = {
'server': 'test_server',
'tool_info': mock_tool_info,
'session': mock_session
}
with patch.object(tool_manager.logger, 'error') as mock_log:
result = await tool_manager.call_tool(tool_name, invalid_arguments)
assert result is None
mock_log.assert_called_once()
def test_get_available_tools_empty(self, tool_manager):
"""测试获取空的可用工具列表"""
result = tool_manager.get_available_tools()
assert result == []
def test_get_available_tools_with_tools(self, tool_manager):
"""测试获取有工具时的可用工具列表"""
tool_manager.tools = {
"server1.tool1": {"mock": "data1"},
"server2.tool2": {"mock": "data2"},
"server1.tool3": {"mock": "data3"}
}
result = tool_manager.get_available_tools()
expected = ["server1.tool1", "server2.tool2", "server1.tool3"]
assert len(result) == 3
assert all(tool in expected for tool in result)
def test_get_tool_info_existing_tool(self, tool_manager):
"""测试获取存在工具的信息"""
tool_name = "server1.test_tool"
mock_tool_def = Mock()
mock_tool_def.name = "test_tool"
mock_tool_def.description = "测试工具描述"
mock_tool_def.inputSchema = {"type": "object", "properties": {}}
tool_manager.tools[tool_name] = {
'tool_info': mock_tool_def
}
result = tool_manager.get_tool_info(tool_name)
assert result['name'] == "test_tool"
assert result['description'] == "测试工具描述"
assert result['inputSchema'] == {"type": "object", "properties": {}}
def test_get_tool_info_nonexistent_tool(self, tool_manager):
"""测试获取不存在工具的信息"""
result = tool_manager.get_tool_info("nonexistent_tool")
assert result is None
@pytest.mark.asyncio
async def test_shutdown_with_sessions(self, tool_manager):
"""测试关闭时清理会话"""
# 添加一些模拟会话
mock_session1 = Mock()
mock_session2 = Mock()
tool_manager.sessions = {
"server1": mock_session1,
"server2": mock_session2
}
tool_manager.tools = {
"server1.tool1": {"mock": "data"},
"server2.tool2": {"mock": "data"}
}
with patch.object(tool_manager.logger, 'info') as mock_info_log:
with patch.object(tool_manager.logger, 'error') as mock_error_log:
await tool_manager.shutdown()
# 验证清理完成
assert tool_manager.sessions == {}
assert tool_manager.tools == {}
# 验证日志记录
assert mock_info_log.call_count == 2
@pytest.mark.asyncio
async def test_shutdown_with_session_close_error(self, tool_manager):
"""测试关闭会话时发生错误"""
mock_session = Mock()
tool_manager.sessions = {"error_server": mock_session}
with patch.object(tool_manager.logger, 'info') as mock_info_log:
with patch.object(tool_manager.logger, 'error') as mock_error_log:
await tool_manager.shutdown()
# 即使发生错误也应该清理
assert tool_manager.sessions == {}
assert tool_manager.tools == {}
class TestMCPToolIntegration:
"""测试MCP工具集成到Agent的桥接功能"""
@pytest.fixture
def mock_agent(self):
"""创建模拟Agent"""
agent = Mock()
agent.add_mcp_tool = Mock()
return agent
@pytest.fixture
def integration(self, mock_agent):
"""创建测试用的集成实例"""
return MCPToolIntegration(mock_agent)
def test_initialization(self, integration, mock_agent):
"""测试集成类初始化"""
assert integration.agent == mock_agent
assert isinstance(integration.tool_manager, MCPToolManager)
assert isinstance(integration.logger, logging.Logger)
@pytest.mark.asyncio
async def test_setup_default_tools_placeholder(self, integration):
"""测试设置默认工具(当前为占位符实现)"""
# 当前实现为pass,确保不会抛出异常
await integration.setup_default_tools()
# 应该正常完成,不抛出异常
@pytest.mark.asyncio
async def test_enhance_agent_with_tools_no_tools(self, integration):
"""测试在没有工具时增强Agent"""
with patch.object(integration.tool_manager, 'get_available_tools', return_value=[]):
with patch.object(integration.logger, 'info') as mock_log:
await integration.enhance_agent_with_tools()
# 应该注册了mcp_caller但没有工具描述
integration.agent.add_mcp_tool.assert_called_once()
call_args = integration.agent.add_mcp_tool.call_args
assert call_args[0][0] == "mcp_caller"
@pytest.mark.asyncio
async def test_enhance_agent_with_tools_with_tools(self, integration):
"""测试在有工具时增强Agent"""
available_tools = ["server1.tool1", "server2.tool2"]
tool_info = {
'name': 'tool1',
'description': '工具1描述',
'inputSchema': {}
}
with patch.object(integration.tool_manager, 'get_available_tools', return_value=available_tools):
with patch.object(integration.tool_manager, 'get_tool_info', return_value=tool_info):
with patch.object(integration.logger, 'info') as mock_log:
await integration.enhance_agent_with_tools()
integration.agent.add_mcp_tool.assert_called_once()
mock_log.assert_called_once()
log_message = str(mock_log.call_args)
assert "2个MCP工具" in log_message
@pytest.mark.asyncio
async def test_enhance_agent_with_tools_partial_tool_info(self, integration):
"""测试部分工具信息缺失的情况"""
available_tools = ["server1.tool1", "server2.broken_tool"]
def mock_get_tool_info(tool_name):
if tool_name == "server1.tool1":
return {'name': 'tool1', 'description': '正常工具', 'inputSchema': {}}
else:
return None # 工具信息获取失败
with patch.object(integration.tool_manager, 'get_available_tools', return_value=available_tools):
with patch.object(integration.tool_manager, 'get_tool_info', side_effect=mock_get_tool_info):
with patch.object(integration.logger, 'info') as mock_log:
await integration.enhance_agent_with_tools()
# 应该只处理有效工具
integration.agent.add_mcp_tool.assert_called_once()
@pytest.mark.asyncio
async def test_mcp_caller_function(self, integration):
"""测试生成的MCP调用函数"""
expected_result = {"success": True, "data": "结果"}
with patch.object(integration.tool_manager, 'call_tool', return_value=expected_result) as mock_call:
await integration.enhance_agent_with_tools()
# 获取注册的调用函数
call_args = integration.agent.add_mcp_tool.call_args
mcp_caller_func = call_args[0][1]
# 测试调用函数
result = await mcp_caller_func("test_tool", {"param": "value"})
assert result == expected_result
mock_call.assert_called_once_with("test_tool", {"param": "value"})
@pytest.mark.asyncio
async def test_process_tool_calls_no_calls(self, integration):
"""测试处理不包含工具调用的响应"""
agent_response = "这是一个普通的回复,没有工具调用。"
result = await integration.process_tool_calls_in_response(agent_response)
assert result == agent_response # 应该原样返回
@pytest.mark.asyncio
async def test_process_tool_calls_with_calls(self, integration):
"""测试处理包含工具调用的响应"""
agent_response = """这是回复内容。
CALL_TOOL:{"name": "test_tool", "arguments": {"param": "value"}}
更多内容。"""
tool_result = {"output": "工具执行结果"}
with patch.object(integration.tool_manager, 'call_tool', return_value=tool_result):
result = await integration.process_tool_calls_in_response(agent_response)
assert "工具调用结果:" in result
assert "test_tool" in result
assert str(tool_result) in result
@pytest.mark.asyncio
async def test_process_tool_calls_with_multiple_calls(self, integration):
"""测试处理包含多个工具调用的响应"""
agent_response = """开始处理。
CALL_TOOL:{"name": "tool1", "arguments": {"param1": "value1"}}
中间内容。
CALL_TOOL:{"name": "tool2", "arguments": {"param2": "value2"}}
结束。"""
tool_results = [
{"output": "工具1结果"},
{"output": "工具2结果"}
]
with patch.object(integration.tool_manager, 'call_tool', side_effect=tool_results):
result = await integration.process_tool_calls_in_response(agent_response)
assert "工具调用结果:" in result
assert "tool1" in result
assert "tool2" in result
assert "工具1结果" in result
assert "工具2结果" in result
@pytest.mark.asyncio
async def test_process_tool_calls_with_failed_calls(self, integration):
"""测试处理工具调用失败的情况"""
agent_response = "CALL_TOOL:{\"name\": \"failing_tool\", \"arguments\": {}}"
with patch.object(integration.tool_manager, 'call_tool', return_value=None): # 模拟失败
result = await integration.process_tool_calls_in_response(agent_response)
assert "工具调用结果:" in result
assert "failing_tool: None" in result
def test_extract_tool_calls_valid_json(self, integration):
"""测试提取有效的工具调用JSON"""
response = """普通文本
CALL_TOOL:{"name": "tool1", "arguments": {"param": "value"}}
更多文本
CALL_TOOL:{"name": "tool2", "arguments": {"count": 5}}"""
result = integration._extract_tool_calls(response)
assert len(result) == 2
assert result[0] == {"name": "tool1", "arguments": {"param": "value"}}
assert result[1] == {"name": "tool2", "arguments": {"count": 5}}
def test_extract_tool_calls_invalid_json(self, integration):
"""测试提取无效JSON的工具调用"""
response = """普通文本
CALL_TOOL:{"name": "tool1", "invalid": json}
CALL_TOOL:{"name": "tool2", "arguments": {"valid": true}}"""
with patch.object(integration.logger, 'warning') as mock_log:
result = integration._extract_tool_calls(response)
assert len(result) == 1 # 只有有效的工具调用
assert result[0] == {"name": "tool2", "arguments": {"valid": True}}
mock_log.assert_called_once()
def test_extract_tool_calls_no_calls(self, integration):
"""测试提取不包含工具调用的响应"""
response = "这是一个普通响应,没有CALL_TOOL标记。"
result = integration._extract_tool_calls(response)
assert result == []
def test_extract_tool_calls_malformed_lines(self, integration):
"""测试处理格式错误的CALL_TOOL行"""
response = """CALL_TOOL:
CALL_TOOL:
CALL_TOOL:{"valid": "call"}"""
with patch.object(integration.logger, 'warning') as mock_log:
result = integration._extract_tool_calls(response)
assert len(result) == 1
assert result[0] == {"valid": "call"}
# 应该记录两次警告(两个无效调用)
assert mock_log.call_count == 2
@pytest.mark.asyncio
async def test_shutdown(self, integration):
"""测试关闭集成"""
with patch.object(integration.tool_manager, 'shutdown') as mock_shutdown:
await integration.shutdown()
mock_shutdown.assert_called_once()
class TestMCPEdgeCasesAndErrorScenarios:
"""测试MCP集成的边界情况和错误场景"""
@pytest.fixture
def tool_manager(self):
return MCPToolManager()
@pytest.fixture
def mock_agent(self):
return Mock()
@pytest.fixture
def integration(self, mock_agent):
return MCPToolIntegration(mock_agent)
@pytest.mark.asyncio
async def test_concurrent_tool_loading(self, tool_manager):
"""测试并发加载多个MCP服务器"""
servers = [
("server1", "python", ["-m", "server1"]),
("server2", "python", ["-m", "server2"]),
("server3", "python", ["-m", "server3"])
]
# Mock成功的服务器加载
with patch.object(tool_manager, 'load_mcp_server', return_value=True) as mock_load:
tasks = [
tool_manager.load_mcp_server(name, cmd, args)
for name, cmd, args in servers
]
results = await asyncio.gather(*tasks)
assert all(results) # 所有服务器都应该成功加载
assert mock_load.call_count == 3
@pytest.mark.asyncio
async def test_concurrent_tool_calls(self, tool_manager):
"""测试并发调用多个工具"""
# 设置多个工具
tools = {}
for i in range(3):
tool_name = f"server1.tool{i}"
mock_session = AsyncMock()
mock_session.call_tool.return_value = {"result": f"工具{i}结果"}
mock_tool_info = Mock()
mock_tool_info.name = f"tool{i}"
tools[tool_name] = {
'server': 'server1',
'tool_info': mock_tool_info,
'session': mock_session
}
tool_manager.tools = tools
# 并发调用工具
tasks = [
tool_manager.call_tool(f"server1.tool{i}", {"param": f"value{i}"})
for i in range(3)
]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert all(result is not None for result in results)
@pytest.mark.asyncio
async def test_tool_call_timeout_simulation(self, tool_manager):
"""测试工具调用超时场景"""
tool_name = "server1.slow_tool"
mock_session = AsyncMock()
# 模拟超时
mock_session.call_tool.side_effect = asyncio.TimeoutError("工具调用超时")
mock_tool_info = Mock()
mock_tool_info.name = "slow_tool"
tool_manager.tools[tool_name] = {
'server': 'server1',
'tool_info': mock_tool_info,
'session': mock_session
}
with patch.object(tool_manager.logger, 'error') as mock_log:
result = await tool_manager.call_tool(tool_name, {"param": "value"})
assert result is None
mock_log.assert_called_once()
@pytest.mark.asyncio
async def test_memory_intensive_tool_operations(self, tool_manager):
"""测试内存密集型工具操作"""
tool_name = "server1.memory_tool"
# 创建大量数据
large_arguments = {"data": ["item"] * 10000}
large_result = {"output": ["result"] * 10000}
mock_session = AsyncMock()
mock_session.call_tool.return_value = large_result
mock_tool_info = Mock()
mock_tool_info.name = "memory_tool"
tool_manager.tools[tool_name] = {
'server': 'server1',
'tool_info': mock_tool_info,
'session': mock_session
}
result = await tool_manager.call_tool(tool_name, large_arguments)
assert result == large_result
mock_session.call_tool.assert_called_once_with("memory_tool", large_arguments)
def test_tool_manager_state_consistency(self, tool_manager):
"""测试工具管理器状态一致性"""
# 手动添加一些工具数据
tool_manager.tools = {
"server1.tool1": {"mock": "data1"},
"server2.tool2": {"mock": "data2"}
}
tool_manager.sessions = {
"server1": Mock(),
"server2": Mock()
}
# 获取工具列表
available_tools = tool_manager.get_available_tools()
assert len(available_tools) == 2
# 验证状态一致性
for tool_name in available_tools:
server_name = tool_name.split('.')[0]
assert server_name in tool_manager.sessions
@pytest.mark.asyncio
async def test_integration_with_broken_agent(self, tool_manager):
"""测试与损坏Agent的集成"""
broken_agent = Mock()
broken_agent.add_mcp_tool.side_effect = Exception("Agent异常")
integration = MCPToolIntegration(broken_agent)
with patch.object(integration.tool_manager, 'get_available_tools', return_value=["tool1"]):
# 应该不会导致整个集成崩溃
try:
await integration.enhance_agent_with_tools()
# 如果没有异常处理,这里会抛出异常
except Exception:
pass # 预期可能抛出异常
@pytest.mark.asyncio
async def test_malicious_tool_call_protection(self, integration):
"""测试恶意工具调用保护"""
# 尝试注入恶意JSON
malicious_response = """正常内容
CALL_TOOL:{"name": "../../../etc/passwd", "arguments": {"__proto__": {"isAdmin": true}}}
CALL_TOOL:{"name": "'; DROP TABLE users; --", "arguments": {}}"""
with patch.object(integration.tool_manager, 'call_tool', return_value=None) as mock_call:
with patch.object(integration.logger, 'warning'):
result = await integration.process_tool_calls_in_response(malicious_response)
# 应该尝试调用工具,但由于工具不存在会返回None
# 关键是不应该导致系统崩溃
assert "工具调用结果:" in result
def test_extreme_json_parsing(self, integration):
"""测试极端JSON解析情况"""
extreme_cases = [
"CALL_TOOL:" + "{" * 1000 + "}" * 1000, # 极深嵌套
"CALL_TOOL:" + '{"a":' + '"b"' * 1000 + '}', # 长字符串
"CALL_TOOL:null", # null值
"CALL_TOOL:[]", # 数组而不是对象
"CALL_TOOL:123", # 数字而不是对象
]
for case in extreme_cases:
with patch.object(integration.logger, 'warning'):
result = integration._extract_tool_calls(case)
# 应该优雅处理,不崩溃
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_tool_manager_shutdown_with_exception(self):
"""测试关闭MCP会话时的异常处理"""
from claude_agent.mcp.integration import MCPToolManager
tool_manager = MCPToolManager()
# 手动添加一个会话到sessions中
mock_session = Mock()
tool_manager.sessions["test_server"] = mock_session
# 模拟logger.info抛出异常来触发异常处理分支
with patch.object(tool_manager.logger, 'info', side_effect=Exception("日志记录失败")):
with patch.object(tool_manager.logger, 'error') as mock_log_error:
await tool_manager.shutdown()
# 验证错误日志被记录
mock_log_error.assert_called_once()
assert "关闭MCP会话失败" in mock_log_error.call_args[0][0]
# sessions应该被清空
assert len(tool_manager.sessions) == 0