| """ |
| MCP集成模块的单元测试 |
| """ |
| |
| import pytest |
| from unittest.mock import AsyncMock, Mock, patch |
| from claude_agent.mcp.integration import MCPToolManager, MCPToolIntegration |
| |
| |
| class TestMCPToolManager: |
| """MCPToolManager类的测试""" |
| |
| @pytest.fixture |
| def tool_manager(self): |
| """测试用的工具管理器""" |
| return MCPToolManager() |
| |
| @pytest.mark.asyncio |
| async def test_load_mcp_server_success(self, tool_manager): |
| """测试成功加载MCP服务器""" |
| with patch('claude_agent.mcp.integration.stdio_client') as mock_stdio_client, \ |
| patch('claude_agent.mcp.integration.ClientSession') as mock_session_class: |
| |
| # 模拟工具信息 |
| mock_tool = Mock() |
| mock_tool.name = "test_tool" |
| mock_tool.description = "测试工具" |
| |
| mock_session = AsyncMock() |
| mock_session.initialize.return_value = None |
| mock_session.list_tools.return_value = Mock(tools=[mock_tool]) |
| |
| mock_session_class.return_value.__aenter__.return_value = mock_session |
| mock_stdio_client.return_value.__aenter__.return_value = (Mock(), Mock()) |
| |
| result = await tool_manager.load_mcp_server("test_server", "python") |
| |
| assert result == True |
| assert "test_server.test_tool" in tool_manager.tools |
| |
| @pytest.mark.asyncio |
| async def test_load_mcp_server_failure(self, tool_manager): |
| """测试加载MCP服务器失败""" |
| with patch('claude_agent.mcp.integration.stdio_client') as mock_stdio_client: |
| mock_stdio_client.side_effect = Exception("连接失败") |
| |
| result = await tool_manager.load_mcp_server("test_server", "python") |
| |
| assert result == False |
| assert len(tool_manager.tools) == 0 |
| |
| @pytest.mark.asyncio |
| async def test_call_tool_success(self, tool_manager): |
| """测试成功调用工具""" |
| # 手动设置一个工具 |
| mock_session = AsyncMock() |
| mock_session.call_tool.return_value = "工具结果" |
| |
| mock_tool_info = Mock() |
| mock_tool_info.name = "test_tool" |
| |
| tool_manager.tools["test_server.test_tool"] = { |
| 'server': 'test_server', |
| 'tool_info': mock_tool_info, |
| 'session': mock_session |
| } |
| |
| result = await tool_manager.call_tool("test_server.test_tool", {"param": "value"}) |
| |
| assert result == "工具结果" |
| mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}) |
| |
| @pytest.mark.asyncio |
| async def test_call_tool_not_found(self, tool_manager): |
| """测试调用不存在的工具""" |
| result = await tool_manager.call_tool("nonexistent_tool", {}) |
| assert result is None |
| |
| def test_get_available_tools(self, tool_manager): |
| """测试获取可用工具列表""" |
| tool_manager.tools = { |
| "server1.tool1": {}, |
| "server2.tool2": {} |
| } |
| |
| tools = tool_manager.get_available_tools() |
| assert set(tools) == {"server1.tool1", "server2.tool2"} |
| |
| def test_get_tool_info(self, tool_manager): |
| """测试获取工具信息""" |
| mock_tool_info = Mock() |
| mock_tool_info.name = "test_tool" |
| mock_tool_info.description = "测试工具" |
| mock_tool_info.inputSchema = {"type": "object"} |
| |
| tool_manager.tools["test_tool"] = { |
| 'tool_info': mock_tool_info |
| } |
| |
| info = tool_manager.get_tool_info("test_tool") |
| |
| assert info["name"] == "test_tool" |
| assert info["description"] == "测试工具" |
| assert info["inputSchema"] == {"type": "object"} |
| |
| def test_get_tool_info_not_found(self, tool_manager): |
| """测试获取不存在工具的信息""" |
| info = tool_manager.get_tool_info("nonexistent") |
| assert info is None |
| |
| |
| class TestMCPToolIntegration: |
| """MCPToolIntegration类的测试""" |
| |
| @pytest.fixture |
| def mock_agent(self): |
| """模拟的Agent""" |
| mock = Mock() |
| mock.add_mcp_tool = Mock() |
| return mock |
| |
| @pytest.fixture |
| def integration(self, mock_agent): |
| """测试用的MCP集成""" |
| return MCPToolIntegration(mock_agent) |
| |
| @pytest.mark.asyncio |
| async def test_setup_default_tools(self, integration): |
| """测试设置默认工具""" |
| # 这个方法目前为空,所以只测试它不会抛出异常 |
| await integration.setup_default_tools() |
| # 没有异常就是成功 |
| |
| @pytest.mark.asyncio |
| async def test_enhance_agent_with_tools(self, integration, mock_agent): |
| """测试为Agent增强工具能力""" |
| integration.tool_manager.tools = { |
| "test_tool": { |
| 'tool_info': Mock(description="测试工具") |
| } |
| } |
| |
| await integration.enhance_agent_with_tools() |
| |
| # 验证工具被添加到Agent |
| mock_agent.add_mcp_tool.assert_called_once() |
| args = mock_agent.add_mcp_tool.call_args |
| assert args[0][0] == "mcp_caller" |
| |
| @pytest.mark.asyncio |
| async def test_process_tool_calls_in_response_no_calls(self, integration): |
| """测试处理没有工具调用的响应""" |
| response = "这是普通的响应,没有工具调用" |
| result = await integration.process_tool_calls_in_response(response) |
| assert result == response |
| |
| @pytest.mark.asyncio |
| async def test_process_tool_calls_in_response_with_calls(self, integration): |
| """测试处理包含工具调用的响应""" |
| integration.tool_manager.call_tool = AsyncMock(return_value="工具结果") |
| |
| response = 'CALL_TOOL:{"name": "test_tool", "arguments": {"param": "value"}}' |
| result = await integration.process_tool_calls_in_response(response) |
| |
| assert "工具调用结果:" in result |
| assert "test_tool: 工具结果" in result |
| |
| def test_extract_tool_calls(self, integration): |
| """测试提取工具调用""" |
| response = ''' |
| 这里有一些文本 |
| CALL_TOOL:{"name": "tool1", "arguments": {"a": 1}} |
| 更多文本 |
| CALL_TOOL:{"name": "tool2", "arguments": {"b": 2}} |
| ''' |
| |
| tool_calls = integration._extract_tool_calls(response) |
| |
| assert len(tool_calls) == 2 |
| assert tool_calls[0]["name"] == "tool1" |
| assert tool_calls[0]["arguments"] == {"a": 1} |
| assert tool_calls[1]["name"] == "tool2" |
| assert tool_calls[1]["arguments"] == {"b": 2} |
| |
| def test_extract_tool_calls_invalid_json(self, integration): |
| """测试提取无效JSON的工具调用""" |
| response = 'CALL_TOOL:invalid_json' |
| |
| tool_calls = integration._extract_tool_calls(response) |
| |
| assert len(tool_calls) == 0 |
| |
| @pytest.mark.asyncio |
| async def test_shutdown(self, integration): |
| """测试关闭集成""" |
| integration.tool_manager.shutdown = AsyncMock() |
| |
| await integration.shutdown() |
| |
| integration.tool_manager.shutdown.assert_called_once() |