| """ |
| 工具包模块 |
| 提供各种辅助功能 |
| """ |
| |
| import json |
| import logging |
| import os |
| import sys |
| from typing import Dict, List, Any, Optional |
| from pathlib import Path |
| |
| |
| class ConfigManager: |
| """配置管理器""" |
| |
| def __init__(self, config_path: Optional[str] = None): |
| if config_path is None: |
| self.config_path = Path.home() / ".claude_agent" / "config.json" |
| else: |
| self.config_path = Path(config_path) |
| |
| self.config_path.parent.mkdir(parents=True, exist_ok=True) |
| self.config = self.load_config() |
| |
| def load_config(self) -> Dict[str, Any]: |
| """加载配置""" |
| if not self.config_path.exists(): |
| return self.get_default_config() |
| |
| try: |
| with open(self.config_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| except Exception as e: |
| logging.warning(f"加载配置失败,使用默认配置: {e}") |
| return self.get_default_config() |
| |
| def save_config(self): |
| """保存配置""" |
| try: |
| with open(self.config_path, 'w', encoding='utf-8') as f: |
| json.dump(self.config, f, ensure_ascii=False, indent=2) |
| except Exception as e: |
| logging.error(f"保存配置失败: {e}") |
| |
| def get_default_config(self) -> Dict[str, Any]: |
| """获取默认配置""" |
| return { |
| "api_key": "", |
| "default_mode": "interactive", |
| "max_history": 100, |
| "mcp_servers": [], |
| "log_level": "INFO", |
| "theme": "default" |
| } |
| |
| def get(self, key: str, default: Any = None) -> Any: |
| """获取配置项""" |
| return self.config.get(key, default) |
| |
| def set(self, key: str, value: Any): |
| """设置配置项""" |
| self.config[key] = value |
| self.save_config() |
| |
| def update(self, updates: Dict[str, Any]): |
| """批量更新配置""" |
| self.config.update(updates) |
| self.save_config() |
| |
| |
| class Logger: |
| """日志管理器""" |
| |
| @staticmethod |
| def setup_logger(name: str, level: str = "INFO") -> logging.Logger: |
| """设置日志器""" |
| logger = logging.getLogger(name) |
| try: |
| logger.setLevel(getattr(logging, level.upper())) |
| except AttributeError: |
| # 如果级别无效,默认使用INFO |
| logger.setLevel(logging.INFO) |
| |
| if not logger.handlers: |
| handler = logging.StreamHandler(sys.stdout) |
| formatter = logging.Formatter( |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| handler.setFormatter(formatter) |
| logger.addHandler(handler) |
| |
| return logger |
| |
| |
| class DataValidator: |
| """数据验证器""" |
| |
| @staticmethod |
| def validate_api_key(api_key: str) -> bool: |
| """验证API密钥格式""" |
| if not api_key or not isinstance(api_key, str): |
| return False |
| |
| # 简单的格式检查 |
| return len(api_key) > 5 and api_key.startswith(('sk-', 'claude-')) |
| |
| @staticmethod |
| def validate_task_plan(plan: Dict[str, Any]) -> bool: |
| """验证任务计划格式""" |
| if not isinstance(plan, dict): |
| return False |
| |
| required_keys = ['steps', 'goal'] |
| if not all(key in plan for key in required_keys): |
| return False |
| |
| if not isinstance(plan['steps'], list): |
| return False |
| |
| if not plan['goal'] or not isinstance(plan['goal'], str): |
| return False |
| |
| return True |
| |
| @staticmethod |
| def validate_tool_call(tool_call: Dict[str, Any]) -> bool: |
| """验证工具调用格式""" |
| if not isinstance(tool_call, dict): |
| return False |
| |
| required_keys = ['name', 'args'] |
| if not all(key in tool_call for key in required_keys): |
| return False |
| |
| if not tool_call['name'] or not isinstance(tool_call['name'], str): |
| return False |
| |
| return True |
| |
| |
| class TextUtils: |
| """文本处理工具""" |
| |
| @staticmethod |
| def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str: |
| """截断文本""" |
| if len(text) <= max_length: |
| return text |
| return text[:max_length - len(suffix)] + suffix |
| |
| @staticmethod |
| def clean_json_string(text: str) -> str: |
| """清理JSON字符串""" |
| import re |
| # 移除所有markdown代码块标记 |
| text = re.sub(r'```[a-zA-Z]*\n?', '', text) |
| text = re.sub(r'```', '', text) |
| return text.strip() |
| |
| @staticmethod |
| def extract_code_blocks(text: str, language: str = None) -> List[str]: |
| """提取代码块""" |
| import re |
| |
| if language: |
| # 更灵活的匹配模式 |
| pattern = f'```{language}\\s*(.*?)```' |
| else: |
| pattern = '```\\w*\\s*(.*?)```' |
| |
| matches = re.findall(pattern, text, re.DOTALL) |
| return [match.strip() for match in matches if match.strip()] |
| |
| @staticmethod |
| def format_error_message(error: Exception, context: str = "") -> str: |
| """格式化错误消息""" |
| error_type = type(error).__name__ |
| error_msg = str(error) |
| |
| if context: |
| return f"[{context}] {error_type}: {error_msg}" |
| else: |
| return f"{error_type}: {error_msg}" |
| |
| |
| class FileUtils: |
| """文件处理工具""" |
| |
| @staticmethod |
| def ensure_directory(path: str): |
| """确保目录存在""" |
| Path(path).mkdir(parents=True, exist_ok=True) |
| |
| @staticmethod |
| def read_json_file(file_path: str) -> Optional[Dict[str, Any]]: |
| """读取JSON文件""" |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| except Exception as e: |
| logging.error(f"读取JSON文件失败 {file_path}: {e}") |
| return None |
| |
| @staticmethod |
| def write_json_file(file_path: str, data: Dict[str, Any]): |
| """写入JSON文件""" |
| try: |
| with open(file_path, 'w', encoding='utf-8') as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
| except Exception as e: |
| logging.error(f"写入JSON文件失败 {file_path}: {e}") |
| |
| @staticmethod |
| def get_project_root() -> Path: |
| """获取项目根目录""" |
| current = Path(__file__).parent |
| while current.parent != current: |
| if (current / "setup.py").exists() or (current / "pyproject.toml").exists(): |
| return current |
| current = current.parent |
| return Path.cwd() |