| """ |
| MCP工具集成模块的全面单元测试 |
| 重点测试异常场景、错误处理和边界条件 |
| """ |
| |
| import asyncio |
| import pytest |
| import json |
| import logging |
| from unittest.mock import Mock, AsyncMock, patch, MagicMock, call |
| from typing import Dict, List, Optional, Any |
| |
| from src.claude_agent.mcp.integration import MCPToolManager, MCPToolIntegration |
| |
| |
| class TestMCPToolManager: |
| """测试MCP工具管理器的核心功能""" |
| |
| @pytest.fixture |
| def tool_manager(self): |
| """创建测试用的工具管理器""" |
| return MCPToolManager() |
| |
| def test_initialization(self, tool_manager): |
| """测试工具管理器初始化""" |
| assert tool_manager.tools == {} |
| assert tool_manager.sessions == {} |
| assert isinstance(tool_manager.logger, logging.Logger) |
| |
| @pytest.mark.asyncio |
| async def test_load_mcp_server_success(self, tool_manager): |
| """测试成功加载MCP服务器""" |
| server_name = "test_server" |
| command = "python" |
| args = ["-m", "test_mcp_server"] |
| |
| # Mock MCP相关组件 |
| mock_session = AsyncMock() |
| mock_tool = Mock() |
| mock_tool.name = "test_tool" |
| mock_tool.description = "测试工具" |
| |
| mock_result = Mock() |
| mock_result.tools = [mock_tool] |
| mock_session.initialize.return_value = None |
| mock_session.list_tools.return_value = mock_result |
| |
| with patch('src.claude_agent.mcp.integration.StdioServerParameters') as mock_params: |
| with patch('src.claude_agent.mcp.integration.stdio_client') as mock_client: |
| with patch('src.claude_agent.mcp.integration.ClientSession') as mock_session_cls: |
| # 设置异步上下文管理器 |
| mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock())) |
| mock_client.return_value.__aexit__ = AsyncMock(return_value=None) |
| |
| mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) |
| mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) |
| |
| with patch.object(tool_manager.logger, 'info') as mock_log: |
| result = await tool_manager.load_mcp_server(server_name, command, args) |
| |
| assert result is True |
| assert server_name in tool_manager.sessions |
| assert f"{server_name}.test_tool" in tool_manager.tools |
| mock_log.assert_called() |
| |
| @pytest.mark.asyncio |
| async def test_load_mcp_server_connection_failure(self, tool_manager): |
| """测试MCP服务器连接失败""" |
| server_name = "failing_server" |
| command = "nonexistent_command" |
| error_message = "连接失败" |
| |
| with patch('src.claude_agent.mcp.integration.stdio_client', side_effect=Exception(error_message)): |
| with patch.object(tool_manager.logger, 'error') as mock_log: |
| result = await tool_manager.load_mcp_server(server_name, command) |
| |
| assert result is False |
| assert server_name not in tool_manager.sessions |
| mock_log.assert_called_once() |
| assert error_message in str(mock_log.call_args) |
| |
| @pytest.mark.asyncio |
| async def test_load_mcp_server_initialization_failure(self, tool_manager): |
| """测试MCP服务器初始化失败""" |
| server_name = "init_failing_server" |
| command = "python" |
| |
| mock_session = AsyncMock() |
| mock_session.initialize.side_effect = Exception("初始化失败") |
| |
| with patch('src.claude_agent.mcp.integration.stdio_client') as mock_client: |
| with patch('src.claude_agent.mcp.integration.ClientSession') as mock_session_cls: |
| mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock())) |
| mock_client.return_value.__aexit__ = AsyncMock(return_value=None) |
| |
| mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) |
| mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) |
| |
| with patch.object(tool_manager.logger, 'error') as mock_log: |
| result = await tool_manager.load_mcp_server(server_name, command) |
| |
| assert result is False |
| mock_log.assert_called_once() |
| |
| @pytest.mark.asyncio |
| async def test_load_mcp_server_with_default_args(self, tool_manager): |
| """测试使用默认参数加载MCP服务器""" |
| server_name = "default_args_server" |
| command = "python" |
| |
| mock_session = AsyncMock() |
| mock_session.initialize.return_value = None |
| mock_session.list_tools.return_value = Mock(tools=[]) |
| |
| with patch('src.claude_agent.mcp.integration.StdioServerParameters') as mock_params: |
| with patch('src.claude_agent.mcp.integration.stdio_client') as mock_client: |
| with patch('src.claude_agent.mcp.integration.ClientSession') as mock_session_cls: |
| mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock())) |
| mock_client.return_value.__aexit__ = AsyncMock(return_value=None) |
| |
| mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) |
| mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) |
| |
| result = await tool_manager.load_mcp_server(server_name, command) |
| |
| # 验证使用了默认的空参数列表 |
| mock_params.assert_called_once() |
| call_args = mock_params.call_args[1] |
| assert call_args['args'] == [] |
| |
| @pytest.mark.asyncio |
| async def test_call_tool_success(self, tool_manager): |
| """测试成功调用工具""" |
| tool_name = "test_server.test_tool" |
| arguments = {"param1": "value1", "param2": 42} |
| expected_result = {"output": "工具执行结果"} |
| |
| # 设置工具信息 |
| mock_session = AsyncMock() |
| mock_session.call_tool.return_value = expected_result |
| |
| mock_tool_info = Mock() |
| mock_tool_info.name = "test_tool" |
| |
| tool_manager.tools[tool_name] = { |
| 'server': 'test_server', |
| 'tool_info': mock_tool_info, |
| 'session': mock_session |
| } |
| |
| with patch.object(tool_manager.logger, 'info') as mock_log: |
| result = await tool_manager.call_tool(tool_name, arguments) |
| |
| assert result == expected_result |
| mock_session.call_tool.assert_called_once_with("test_tool", arguments) |
| mock_log.assert_called_once() |
| |
| @pytest.mark.asyncio |
| async def test_call_tool_nonexistent_tool(self, tool_manager): |
| """测试调用不存在的工具""" |
| tool_name = "nonexistent_tool" |
| arguments = {"param": "value"} |
| |
| with patch.object(tool_manager.logger, 'error') as mock_log: |
| result = await tool_manager.call_tool(tool_name, arguments) |
| |
| assert result is None |
| mock_log.assert_called_once_with(f"工具不存在: {tool_name}") |
| |
| @pytest.mark.asyncio |
| async def test_call_tool_execution_failure(self, tool_manager): |
| """测试工具执行失败""" |
| tool_name = "test_server.failing_tool" |
| arguments = {"param": "value"} |
| error_message = "工具执行异常" |
| |
| mock_session = AsyncMock() |
| mock_session.call_tool.side_effect = Exception(error_message) |
| |
| mock_tool_info = Mock() |
| mock_tool_info.name = "failing_tool" |
| |
| tool_manager.tools[tool_name] = { |
| 'server': 'test_server', |
| 'tool_info': mock_tool_info, |
| 'session': mock_session |
| } |
| |
| with patch.object(tool_manager.logger, 'error') as mock_log: |
| result = await tool_manager.call_tool(tool_name, arguments) |
| |
| assert result is None |
| mock_log.assert_called_once() |
| assert error_message in str(mock_log.call_args) |
| |
| @pytest.mark.asyncio |
| async def test_call_tool_with_invalid_arguments(self, tool_manager): |
| """测试使用无效参数调用工具""" |
| tool_name = "test_server.test_tool" |
| invalid_arguments = {"invalid": object()} # object不能序列化 |
| |
| mock_session = AsyncMock() |
| mock_session.call_tool.side_effect = TypeError("参数序列化失败") |
| |
| mock_tool_info = Mock() |
| mock_tool_info.name = "test_tool" |
| |
| tool_manager.tools[tool_name] = { |
| 'server': 'test_server', |
| 'tool_info': mock_tool_info, |
| 'session': mock_session |
| } |
| |
| with patch.object(tool_manager.logger, 'error') as mock_log: |
| result = await tool_manager.call_tool(tool_name, invalid_arguments) |
| |
| assert result is None |
| mock_log.assert_called_once() |
| |
| def test_get_available_tools_empty(self, tool_manager): |
| """测试获取空的可用工具列表""" |
| result = tool_manager.get_available_tools() |
| assert result == [] |
| |
| def test_get_available_tools_with_tools(self, tool_manager): |
| """测试获取有工具时的可用工具列表""" |
| tool_manager.tools = { |
| "server1.tool1": {"mock": "data1"}, |
| "server2.tool2": {"mock": "data2"}, |
| "server1.tool3": {"mock": "data3"} |
| } |
| |
| result = tool_manager.get_available_tools() |
| expected = ["server1.tool1", "server2.tool2", "server1.tool3"] |
| |
| assert len(result) == 3 |
| assert all(tool in expected for tool in result) |
| |
| def test_get_tool_info_existing_tool(self, tool_manager): |
| """测试获取存在工具的信息""" |
| tool_name = "server1.test_tool" |
| |
| mock_tool_def = Mock() |
| mock_tool_def.name = "test_tool" |
| mock_tool_def.description = "测试工具描述" |
| mock_tool_def.inputSchema = {"type": "object", "properties": {}} |
| |
| tool_manager.tools[tool_name] = { |
| 'tool_info': mock_tool_def |
| } |
| |
| result = tool_manager.get_tool_info(tool_name) |
| |
| assert result['name'] == "test_tool" |
| assert result['description'] == "测试工具描述" |
| assert result['inputSchema'] == {"type": "object", "properties": {}} |
| |
| def test_get_tool_info_nonexistent_tool(self, tool_manager): |
| """测试获取不存在工具的信息""" |
| result = tool_manager.get_tool_info("nonexistent_tool") |
| assert result is None |
| |
| @pytest.mark.asyncio |
| async def test_shutdown_with_sessions(self, tool_manager): |
| """测试关闭时清理会话""" |
| # 添加一些模拟会话 |
| mock_session1 = Mock() |
| mock_session2 = Mock() |
| |
| tool_manager.sessions = { |
| "server1": mock_session1, |
| "server2": mock_session2 |
| } |
| |
| tool_manager.tools = { |
| "server1.tool1": {"mock": "data"}, |
| "server2.tool2": {"mock": "data"} |
| } |
| |
| with patch.object(tool_manager.logger, 'info') as mock_info_log: |
| with patch.object(tool_manager.logger, 'error') as mock_error_log: |
| await tool_manager.shutdown() |
| |
| # 验证清理完成 |
| assert tool_manager.sessions == {} |
| assert tool_manager.tools == {} |
| # 验证日志记录 |
| assert mock_info_log.call_count == 2 |
| |
| @pytest.mark.asyncio |
| async def test_shutdown_with_session_close_error(self, tool_manager): |
| """测试关闭会话时发生错误""" |
| mock_session = Mock() |
| tool_manager.sessions = {"error_server": mock_session} |
| |
| with patch.object(tool_manager.logger, 'info') as mock_info_log: |
| with patch.object(tool_manager.logger, 'error') as mock_error_log: |
| await tool_manager.shutdown() |
| |
| # 即使发生错误也应该清理 |
| assert tool_manager.sessions == {} |
| assert tool_manager.tools == {} |
| |
| |
| class TestMCPToolIntegration: |
| """测试MCP工具集成到Agent的桥接功能""" |
| |
| @pytest.fixture |
| def mock_agent(self): |
| """创建模拟Agent""" |
| agent = Mock() |
| agent.add_mcp_tool = Mock() |
| return agent |
| |
| @pytest.fixture |
| def integration(self, mock_agent): |
| """创建测试用的集成实例""" |
| return MCPToolIntegration(mock_agent) |
| |
| def test_initialization(self, integration, mock_agent): |
| """测试集成类初始化""" |
| assert integration.agent == mock_agent |
| assert isinstance(integration.tool_manager, MCPToolManager) |
| assert isinstance(integration.logger, logging.Logger) |
| |
| @pytest.mark.asyncio |
| async def test_setup_default_tools_placeholder(self, integration): |
| """测试设置默认工具(当前为占位符实现)""" |
| # 当前实现为pass,确保不会抛出异常 |
| await integration.setup_default_tools() |
| # 应该正常完成,不抛出异常 |
| |
| @pytest.mark.asyncio |
| async def test_enhance_agent_with_tools_no_tools(self, integration): |
| """测试在没有工具时增强Agent""" |
| with patch.object(integration.tool_manager, 'get_available_tools', return_value=[]): |
| with patch.object(integration.logger, 'info') as mock_log: |
| await integration.enhance_agent_with_tools() |
| |
| # 应该注册了mcp_caller但没有工具描述 |
| integration.agent.add_mcp_tool.assert_called_once() |
| call_args = integration.agent.add_mcp_tool.call_args |
| assert call_args[0][0] == "mcp_caller" |
| |
| @pytest.mark.asyncio |
| async def test_enhance_agent_with_tools_with_tools(self, integration): |
| """测试在有工具时增强Agent""" |
| available_tools = ["server1.tool1", "server2.tool2"] |
| tool_info = { |
| 'name': 'tool1', |
| 'description': '工具1描述', |
| 'inputSchema': {} |
| } |
| |
| with patch.object(integration.tool_manager, 'get_available_tools', return_value=available_tools): |
| with patch.object(integration.tool_manager, 'get_tool_info', return_value=tool_info): |
| with patch.object(integration.logger, 'info') as mock_log: |
| await integration.enhance_agent_with_tools() |
| |
| integration.agent.add_mcp_tool.assert_called_once() |
| mock_log.assert_called_once() |
| log_message = str(mock_log.call_args) |
| assert "2个MCP工具" in log_message |
| |
| @pytest.mark.asyncio |
| async def test_enhance_agent_with_tools_partial_tool_info(self, integration): |
| """测试部分工具信息缺失的情况""" |
| available_tools = ["server1.tool1", "server2.broken_tool"] |
| |
| def mock_get_tool_info(tool_name): |
| if tool_name == "server1.tool1": |
| return {'name': 'tool1', 'description': '正常工具', 'inputSchema': {}} |
| else: |
| return None # 工具信息获取失败 |
| |
| with patch.object(integration.tool_manager, 'get_available_tools', return_value=available_tools): |
| with patch.object(integration.tool_manager, 'get_tool_info', side_effect=mock_get_tool_info): |
| with patch.object(integration.logger, 'info') as mock_log: |
| await integration.enhance_agent_with_tools() |
| |
| # 应该只处理有效工具 |
| integration.agent.add_mcp_tool.assert_called_once() |
| |
| @pytest.mark.asyncio |
| async def test_mcp_caller_function(self, integration): |
| """测试生成的MCP调用函数""" |
| expected_result = {"success": True, "data": "结果"} |
| |
| with patch.object(integration.tool_manager, 'call_tool', return_value=expected_result) as mock_call: |
| await integration.enhance_agent_with_tools() |
| |
| # 获取注册的调用函数 |
| call_args = integration.agent.add_mcp_tool.call_args |
| mcp_caller_func = call_args[0][1] |
| |
| # 测试调用函数 |
| result = await mcp_caller_func("test_tool", {"param": "value"}) |
| |
| assert result == expected_result |
| mock_call.assert_called_once_with("test_tool", {"param": "value"}) |
| |
| @pytest.mark.asyncio |
| async def test_process_tool_calls_no_calls(self, integration): |
| """测试处理不包含工具调用的响应""" |
| agent_response = "这是一个普通的回复,没有工具调用。" |
| |
| result = await integration.process_tool_calls_in_response(agent_response) |
| |
| assert result == agent_response # 应该原样返回 |
| |
| @pytest.mark.asyncio |
| async def test_process_tool_calls_with_calls(self, integration): |
| """测试处理包含工具调用的响应""" |
| agent_response = """这是回复内容。 |
| CALL_TOOL:{"name": "test_tool", "arguments": {"param": "value"}} |
| 更多内容。""" |
| |
| tool_result = {"output": "工具执行结果"} |
| |
| with patch.object(integration.tool_manager, 'call_tool', return_value=tool_result): |
| result = await integration.process_tool_calls_in_response(agent_response) |
| |
| assert "工具调用结果:" in result |
| assert "test_tool" in result |
| assert str(tool_result) in result |
| |
| @pytest.mark.asyncio |
| async def test_process_tool_calls_with_multiple_calls(self, integration): |
| """测试处理包含多个工具调用的响应""" |
| agent_response = """开始处理。 |
| CALL_TOOL:{"name": "tool1", "arguments": {"param1": "value1"}} |
| 中间内容。 |
| CALL_TOOL:{"name": "tool2", "arguments": {"param2": "value2"}} |
| 结束。""" |
| |
| tool_results = [ |
| {"output": "工具1结果"}, |
| {"output": "工具2结果"} |
| ] |
| |
| with patch.object(integration.tool_manager, 'call_tool', side_effect=tool_results): |
| result = await integration.process_tool_calls_in_response(agent_response) |
| |
| assert "工具调用结果:" in result |
| assert "tool1" in result |
| assert "tool2" in result |
| assert "工具1结果" in result |
| assert "工具2结果" in result |
| |
| @pytest.mark.asyncio |
| async def test_process_tool_calls_with_failed_calls(self, integration): |
| """测试处理工具调用失败的情况""" |
| agent_response = "CALL_TOOL:{\"name\": \"failing_tool\", \"arguments\": {}}" |
| |
| with patch.object(integration.tool_manager, 'call_tool', return_value=None): # 模拟失败 |
| result = await integration.process_tool_calls_in_response(agent_response) |
| |
| assert "工具调用结果:" in result |
| assert "failing_tool: None" in result |
| |
| def test_extract_tool_calls_valid_json(self, integration): |
| """测试提取有效的工具调用JSON""" |
| response = """普通文本 |
| CALL_TOOL:{"name": "tool1", "arguments": {"param": "value"}} |
| 更多文本 |
| CALL_TOOL:{"name": "tool2", "arguments": {"count": 5}}""" |
| |
| result = integration._extract_tool_calls(response) |
| |
| assert len(result) == 2 |
| assert result[0] == {"name": "tool1", "arguments": {"param": "value"}} |
| assert result[1] == {"name": "tool2", "arguments": {"count": 5}} |
| |
| def test_extract_tool_calls_invalid_json(self, integration): |
| """测试提取无效JSON的工具调用""" |
| response = """普通文本 |
| CALL_TOOL:{"name": "tool1", "invalid": json} |
| CALL_TOOL:{"name": "tool2", "arguments": {"valid": true}}""" |
| |
| with patch.object(integration.logger, 'warning') as mock_log: |
| result = integration._extract_tool_calls(response) |
| |
| assert len(result) == 1 # 只有有效的工具调用 |
| assert result[0] == {"name": "tool2", "arguments": {"valid": True}} |
| mock_log.assert_called_once() |
| |
| def test_extract_tool_calls_no_calls(self, integration): |
| """测试提取不包含工具调用的响应""" |
| response = "这是一个普通响应,没有CALL_TOOL标记。" |
| |
| result = integration._extract_tool_calls(response) |
| |
| assert result == [] |
| |
| def test_extract_tool_calls_malformed_lines(self, integration): |
| """测试处理格式错误的CALL_TOOL行""" |
| response = """CALL_TOOL: |
| CALL_TOOL: |
| CALL_TOOL:{"valid": "call"}""" |
| |
| with patch.object(integration.logger, 'warning') as mock_log: |
| result = integration._extract_tool_calls(response) |
| |
| assert len(result) == 1 |
| assert result[0] == {"valid": "call"} |
| # 应该记录两次警告(两个无效调用) |
| assert mock_log.call_count == 2 |
| |
| @pytest.mark.asyncio |
| async def test_shutdown(self, integration): |
| """测试关闭集成""" |
| with patch.object(integration.tool_manager, 'shutdown') as mock_shutdown: |
| await integration.shutdown() |
| mock_shutdown.assert_called_once() |
| |
| |
| class TestMCPEdgeCasesAndErrorScenarios: |
| """测试MCP集成的边界情况和错误场景""" |
| |
| @pytest.fixture |
| def tool_manager(self): |
| return MCPToolManager() |
| |
| @pytest.fixture |
| def mock_agent(self): |
| return Mock() |
| |
| @pytest.fixture |
| def integration(self, mock_agent): |
| return MCPToolIntegration(mock_agent) |
| |
| @pytest.mark.asyncio |
| async def test_concurrent_tool_loading(self, tool_manager): |
| """测试并发加载多个MCP服务器""" |
| servers = [ |
| ("server1", "python", ["-m", "server1"]), |
| ("server2", "python", ["-m", "server2"]), |
| ("server3", "python", ["-m", "server3"]) |
| ] |
| |
| # Mock成功的服务器加载 |
| with patch.object(tool_manager, 'load_mcp_server', return_value=True) as mock_load: |
| tasks = [ |
| tool_manager.load_mcp_server(name, cmd, args) |
| for name, cmd, args in servers |
| ] |
| results = await asyncio.gather(*tasks) |
| |
| assert all(results) # 所有服务器都应该成功加载 |
| assert mock_load.call_count == 3 |
| |
| @pytest.mark.asyncio |
| async def test_concurrent_tool_calls(self, tool_manager): |
| """测试并发调用多个工具""" |
| # 设置多个工具 |
| tools = {} |
| for i in range(3): |
| tool_name = f"server1.tool{i}" |
| mock_session = AsyncMock() |
| mock_session.call_tool.return_value = {"result": f"工具{i}结果"} |
| |
| mock_tool_info = Mock() |
| mock_tool_info.name = f"tool{i}" |
| |
| tools[tool_name] = { |
| 'server': 'server1', |
| 'tool_info': mock_tool_info, |
| 'session': mock_session |
| } |
| |
| tool_manager.tools = tools |
| |
| # 并发调用工具 |
| tasks = [ |
| tool_manager.call_tool(f"server1.tool{i}", {"param": f"value{i}"}) |
| for i in range(3) |
| ] |
| results = await asyncio.gather(*tasks) |
| |
| assert len(results) == 3 |
| assert all(result is not None for result in results) |
| |
| @pytest.mark.asyncio |
| async def test_tool_call_timeout_simulation(self, tool_manager): |
| """测试工具调用超时场景""" |
| tool_name = "server1.slow_tool" |
| |
| mock_session = AsyncMock() |
| # 模拟超时 |
| mock_session.call_tool.side_effect = asyncio.TimeoutError("工具调用超时") |
| |
| mock_tool_info = Mock() |
| mock_tool_info.name = "slow_tool" |
| |
| tool_manager.tools[tool_name] = { |
| 'server': 'server1', |
| 'tool_info': mock_tool_info, |
| 'session': mock_session |
| } |
| |
| with patch.object(tool_manager.logger, 'error') as mock_log: |
| result = await tool_manager.call_tool(tool_name, {"param": "value"}) |
| |
| assert result is None |
| mock_log.assert_called_once() |
| |
| @pytest.mark.asyncio |
| async def test_memory_intensive_tool_operations(self, tool_manager): |
| """测试内存密集型工具操作""" |
| tool_name = "server1.memory_tool" |
| |
| # 创建大量数据 |
| large_arguments = {"data": ["item"] * 10000} |
| large_result = {"output": ["result"] * 10000} |
| |
| mock_session = AsyncMock() |
| mock_session.call_tool.return_value = large_result |
| |
| mock_tool_info = Mock() |
| mock_tool_info.name = "memory_tool" |
| |
| tool_manager.tools[tool_name] = { |
| 'server': 'server1', |
| 'tool_info': mock_tool_info, |
| 'session': mock_session |
| } |
| |
| result = await tool_manager.call_tool(tool_name, large_arguments) |
| |
| assert result == large_result |
| mock_session.call_tool.assert_called_once_with("memory_tool", large_arguments) |
| |
| def test_tool_manager_state_consistency(self, tool_manager): |
| """测试工具管理器状态一致性""" |
| # 手动添加一些工具数据 |
| tool_manager.tools = { |
| "server1.tool1": {"mock": "data1"}, |
| "server2.tool2": {"mock": "data2"} |
| } |
| tool_manager.sessions = { |
| "server1": Mock(), |
| "server2": Mock() |
| } |
| |
| # 获取工具列表 |
| available_tools = tool_manager.get_available_tools() |
| assert len(available_tools) == 2 |
| |
| # 验证状态一致性 |
| for tool_name in available_tools: |
| server_name = tool_name.split('.')[0] |
| assert server_name in tool_manager.sessions |
| |
| @pytest.mark.asyncio |
| async def test_integration_with_broken_agent(self, tool_manager): |
| """测试与损坏Agent的集成""" |
| broken_agent = Mock() |
| broken_agent.add_mcp_tool.side_effect = Exception("Agent异常") |
| |
| integration = MCPToolIntegration(broken_agent) |
| |
| with patch.object(integration.tool_manager, 'get_available_tools', return_value=["tool1"]): |
| # 应该不会导致整个集成崩溃 |
| try: |
| await integration.enhance_agent_with_tools() |
| # 如果没有异常处理,这里会抛出异常 |
| except Exception: |
| pass # 预期可能抛出异常 |
| |
| @pytest.mark.asyncio |
| async def test_malicious_tool_call_protection(self, integration): |
| """测试恶意工具调用保护""" |
| # 尝试注入恶意JSON |
| malicious_response = """正常内容 |
| CALL_TOOL:{"name": "../../../etc/passwd", "arguments": {"__proto__": {"isAdmin": true}}} |
| CALL_TOOL:{"name": "'; DROP TABLE users; --", "arguments": {}}""" |
| |
| with patch.object(integration.tool_manager, 'call_tool', return_value=None) as mock_call: |
| with patch.object(integration.logger, 'warning'): |
| result = await integration.process_tool_calls_in_response(malicious_response) |
| |
| # 应该尝试调用工具,但由于工具不存在会返回None |
| # 关键是不应该导致系统崩溃 |
| assert "工具调用结果:" in result |
| |
| def test_extreme_json_parsing(self, integration): |
| """测试极端JSON解析情况""" |
| extreme_cases = [ |
| "CALL_TOOL:" + "{" * 1000 + "}" * 1000, # 极深嵌套 |
| "CALL_TOOL:" + '{"a":' + '"b"' * 1000 + '}', # 长字符串 |
| "CALL_TOOL:null", # null值 |
| "CALL_TOOL:[]", # 数组而不是对象 |
| "CALL_TOOL:123", # 数字而不是对象 |
| ] |
| |
| for case in extreme_cases: |
| with patch.object(integration.logger, 'warning'): |
| result = integration._extract_tool_calls(case) |
| # 应该优雅处理,不崩溃 |
| assert isinstance(result, list) |
| |
| @pytest.mark.asyncio |
| async def test_tool_manager_shutdown_with_exception(self): |
| """测试关闭MCP会话时的异常处理""" |
| from claude_agent.mcp.integration import MCPToolManager |
| |
| tool_manager = MCPToolManager() |
| |
| # 手动添加一个会话到sessions中 |
| mock_session = Mock() |
| tool_manager.sessions["test_server"] = mock_session |
| |
| # 模拟logger.info抛出异常来触发异常处理分支 |
| with patch.object(tool_manager.logger, 'info', side_effect=Exception("日志记录失败")): |
| with patch.object(tool_manager.logger, 'error') as mock_log_error: |
| await tool_manager.shutdown() |
| |
| # 验证错误日志被记录 |
| mock_log_error.assert_called_once() |
| assert "关闭MCP会话失败" in mock_log_error.call_args[0][0] |
| |
| # sessions应该被清空 |
| assert len(tool_manager.sessions) == 0 |