blob: fd9af7fa7959b75c12660dbac74284cecfd8cc9c [file] [log] [blame] [raw]
"""
MCP集成模块的单元测试
"""
import pytest
from unittest.mock import AsyncMock, Mock, patch
from claude_agent.mcp.integration import MCPToolManager, MCPToolIntegration
class TestMCPToolManager:
"""MCPToolManager类的测试"""
@pytest.fixture
def tool_manager(self):
"""测试用的工具管理器"""
return MCPToolManager()
@pytest.mark.asyncio
async def test_load_mcp_server_success(self, tool_manager):
"""测试成功加载MCP服务器"""
with patch('claude_agent.mcp.integration.stdio_client') as mock_stdio_client, \
patch('claude_agent.mcp.integration.ClientSession') as mock_session_class:
# 模拟工具信息
mock_tool = Mock()
mock_tool.name = "test_tool"
mock_tool.description = "测试工具"
mock_session = AsyncMock()
mock_session.initialize.return_value = None
mock_session.list_tools.return_value = Mock(tools=[mock_tool])
mock_session_class.return_value.__aenter__.return_value = mock_session
mock_stdio_client.return_value.__aenter__.return_value = (Mock(), Mock())
result = await tool_manager.load_mcp_server("test_server", "python")
assert result == True
assert "test_server.test_tool" in tool_manager.tools
@pytest.mark.asyncio
async def test_load_mcp_server_failure(self, tool_manager):
"""测试加载MCP服务器失败"""
with patch('claude_agent.mcp.integration.stdio_client') as mock_stdio_client:
mock_stdio_client.side_effect = Exception("连接失败")
result = await tool_manager.load_mcp_server("test_server", "python")
assert result == False
assert len(tool_manager.tools) == 0
@pytest.mark.asyncio
async def test_call_tool_success(self, tool_manager):
"""测试成功调用工具"""
# 手动设置一个工具
mock_session = AsyncMock()
mock_session.call_tool.return_value = "工具结果"
mock_tool_info = Mock()
mock_tool_info.name = "test_tool"
tool_manager.tools["test_server.test_tool"] = {
'server': 'test_server',
'tool_info': mock_tool_info,
'session': mock_session
}
result = await tool_manager.call_tool("test_server.test_tool", {"param": "value"})
assert result == "工具结果"
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"})
@pytest.mark.asyncio
async def test_call_tool_not_found(self, tool_manager):
"""测试调用不存在的工具"""
result = await tool_manager.call_tool("nonexistent_tool", {})
assert result is None
def test_get_available_tools(self, tool_manager):
"""测试获取可用工具列表"""
tool_manager.tools = {
"server1.tool1": {},
"server2.tool2": {}
}
tools = tool_manager.get_available_tools()
assert set(tools) == {"server1.tool1", "server2.tool2"}
def test_get_tool_info(self, tool_manager):
"""测试获取工具信息"""
mock_tool_info = Mock()
mock_tool_info.name = "test_tool"
mock_tool_info.description = "测试工具"
mock_tool_info.inputSchema = {"type": "object"}
tool_manager.tools["test_tool"] = {
'tool_info': mock_tool_info
}
info = tool_manager.get_tool_info("test_tool")
assert info["name"] == "test_tool"
assert info["description"] == "测试工具"
assert info["inputSchema"] == {"type": "object"}
def test_get_tool_info_not_found(self, tool_manager):
"""测试获取不存在工具的信息"""
info = tool_manager.get_tool_info("nonexistent")
assert info is None
class TestMCPToolIntegration:
"""MCPToolIntegration类的测试"""
@pytest.fixture
def mock_agent(self):
"""模拟的Agent"""
mock = Mock()
mock.add_mcp_tool = Mock()
return mock
@pytest.fixture
def integration(self, mock_agent):
"""测试用的MCP集成"""
return MCPToolIntegration(mock_agent)
@pytest.mark.asyncio
async def test_setup_default_tools(self, integration):
"""测试设置默认工具"""
# 这个方法目前为空,所以只测试它不会抛出异常
await integration.setup_default_tools()
# 没有异常就是成功
@pytest.mark.asyncio
async def test_enhance_agent_with_tools(self, integration, mock_agent):
"""测试为Agent增强工具能力"""
integration.tool_manager.tools = {
"test_tool": {
'tool_info': Mock(description="测试工具")
}
}
await integration.enhance_agent_with_tools()
# 验证工具被添加到Agent
mock_agent.add_mcp_tool.assert_called_once()
args = mock_agent.add_mcp_tool.call_args
assert args[0][0] == "mcp_caller"
@pytest.mark.asyncio
async def test_process_tool_calls_in_response_no_calls(self, integration):
"""测试处理没有工具调用的响应"""
response = "这是普通的响应,没有工具调用"
result = await integration.process_tool_calls_in_response(response)
assert result == response
@pytest.mark.asyncio
async def test_process_tool_calls_in_response_with_calls(self, integration):
"""测试处理包含工具调用的响应"""
integration.tool_manager.call_tool = AsyncMock(return_value="工具结果")
response = 'CALL_TOOL:{"name": "test_tool", "arguments": {"param": "value"}}'
result = await integration.process_tool_calls_in_response(response)
assert "工具调用结果:" in result
assert "test_tool: 工具结果" in result
def test_extract_tool_calls(self, integration):
"""测试提取工具调用"""
response = '''
这里有一些文本
CALL_TOOL:{"name": "tool1", "arguments": {"a": 1}}
更多文本
CALL_TOOL:{"name": "tool2", "arguments": {"b": 2}}
'''
tool_calls = integration._extract_tool_calls(response)
assert len(tool_calls) == 2
assert tool_calls[0]["name"] == "tool1"
assert tool_calls[0]["arguments"] == {"a": 1}
assert tool_calls[1]["name"] == "tool2"
assert tool_calls[1]["arguments"] == {"b": 2}
def test_extract_tool_calls_invalid_json(self, integration):
"""测试提取无效JSON的工具调用"""
response = 'CALL_TOOL:invalid_json'
tool_calls = integration._extract_tool_calls(response)
assert len(tool_calls) == 0
@pytest.mark.asyncio
async def test_shutdown(self, integration):
"""测试关闭集成"""
integration.tool_manager.shutdown = AsyncMock()
await integration.shutdown()
integration.tool_manager.shutdown.assert_called_once()