blob: 8619472b2416065d3faded4fd99f5fcd631d5afb [file] [log] [blame] [raw]
"""
MCP工具集成模块
支持动态加载和调用MCP工具
"""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Any, Callable
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class MCPToolManager:
"""MCP工具管理器"""
def __init__(self):
self.tools: Dict[str, Any] = {}
self.sessions: Dict[str, ClientSession] = {}
self.logger = logging.getLogger(__name__)
async def load_mcp_server(self, server_name: str, command: str, args: List[str] = None) -> bool:
"""加载MCP服务器"""
if args is None:
args = []
try:
server_params = StdioServerParameters(
command=command,
args=args,
env=None
)
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
# 初始化会话
await session.initialize()
# 列出可用工具
result = await session.list_tools()
tools = result.tools
self.sessions[server_name] = session
# 注册工具
for tool in tools:
tool_key = f"{server_name}.{tool.name}"
self.tools[tool_key] = {
'server': server_name,
'tool_info': tool,
'session': session
}
self.logger.info(f"已注册MCP工具: {tool_key}")
return True
except Exception as e:
self.logger.error(f"加载MCP服务器失败 {server_name}: {e}")
return False
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[Any]:
"""调用MCP工具"""
if tool_name not in self.tools:
self.logger.error(f"工具不存在: {tool_name}")
return None
try:
tool_info = self.tools[tool_name]
session = tool_info['session']
tool_def = tool_info['tool_info']
# 调用工具
result = await session.call_tool(tool_def.name, arguments)
self.logger.info(f"成功调用工具: {tool_name}")
return result
except Exception as e:
self.logger.error(f"调用工具失败 {tool_name}: {e}")
return None
def get_available_tools(self) -> List[str]:
"""获取可用工具列表"""
return list(self.tools.keys())
def get_tool_info(self, tool_name: str) -> Optional[Dict]:
"""获取工具信息"""
if tool_name not in self.tools:
return None
tool_info = self.tools[tool_name]
tool_def = tool_info['tool_info']
return {
'name': tool_def.name,
'description': tool_def.description,
'inputSchema': tool_def.inputSchema
}
async def shutdown(self):
"""关闭所有MCP会话"""
for server_name, session in self.sessions.items():
try:
# 注意:这里可能需要根据实际的ClientSession API调整
# await session.close()
self.logger.info(f"已关闭MCP会话: {server_name}")
except Exception as e:
self.logger.error(f"关闭MCP会话失败 {server_name}: {e}")
self.sessions.clear()
self.tools.clear()
class MCPToolIntegration:
"""MCP工具集成到Agent的桥接类"""
def __init__(self, agent_core):
self.agent = agent_core
self.tool_manager = MCPToolManager()
self.logger = logging.getLogger(__name__)
async def setup_default_tools(self):
"""设置默认的MCP工具"""
# 这里可以配置默认的MCP工具服务器
# 示例:加载文件系统工具
# await self.tool_manager.load_mcp_server(
# "filesystem",
# "python",
# ["-m", "mcp_filesystem"]
# )
pass
async def enhance_agent_with_tools(self):
"""为Agent增强MCP工具能力"""
# 获取所有可用工具
available_tools = self.tool_manager.get_available_tools()
# 创建工具调用函数
async def call_mcp_tool(tool_name: str, arguments: Dict[str, Any]):
return await self.tool_manager.call_tool(tool_name, arguments)
# 将工具注册到Agent
self.agent.add_mcp_tool("mcp_caller", call_mcp_tool)
# 更新Agent的系统提示,让它知道有哪些工具可用
tool_descriptions = []
for tool_name in available_tools:
tool_info = self.tool_manager.get_tool_info(tool_name)
if tool_info:
tool_descriptions.append(f"- {tool_name}: {tool_info['description']}")
if tool_descriptions:
tools_prompt = f"""
你现在可以使用以下MCP工具:
{chr(10).join(tool_descriptions)}
要调用工具,请在响应中明确说明需要调用哪个工具以及参数。
"""
# 这里可以将工具信息添加到Agent的系统提示中
self.logger.info(f"已为Agent集成{len(available_tools)}个MCP工具")
async def process_tool_calls_in_response(self, agent_response: str) -> str:
"""处理Agent响应中的工具调用"""
# 这里可以解析Agent响应中的工具调用请求
# 并实际执行这些工具调用,然后将结果反馈给Agent
# 简化实现:检查是否包含工具调用标记
if "CALL_TOOL:" in agent_response:
# 提取工具调用信息
tool_calls = self._extract_tool_calls(agent_response)
# 执行工具调用
results = []
for tool_call in tool_calls:
result = await self.tool_manager.call_tool(
tool_call['name'],
tool_call['arguments']
)
results.append({
'tool': tool_call['name'],
'result': result
})
# 将结果添加到响应中
if results:
results_text = "\n\n工具调用结果:\n"
for result in results:
results_text += f"- {result['tool']}: {result['result']}\n"
agent_response += results_text
return agent_response
def _extract_tool_calls(self, response: str) -> List[Dict]:
"""从响应中提取工具调用"""
# 简化实现,实际应该更完善
tool_calls = []
lines = response.split('\n')
for line in lines:
if line.startswith("CALL_TOOL:"):
try:
tool_call_str = line[len("CALL_TOOL:"):].strip()
tool_call = json.loads(tool_call_str)
tool_calls.append(tool_call)
except json.JSONDecodeError:
self.logger.warning(f"无法解析工具调用: {line}")
return tool_calls
async def shutdown(self):
"""关闭MCP集成"""
await self.tool_manager.shutdown()