blob: b7e66f243a3a5fe0f26cb3653a4b7cbfac7beb0e [file] [log] [blame] [raw]
"""
工具包模块
提供各种辅助功能
"""
import json
import logging
import os
import sys
from typing import Dict, List, Any, Optional
from pathlib import Path
class ConfigManager:
"""配置管理器"""
def __init__(self, config_path: Optional[str] = None):
if config_path is None:
self.config_path = Path.home() / ".claude_agent" / "config.json"
else:
self.config_path = Path(config_path)
self.config_path.parent.mkdir(parents=True, exist_ok=True)
self.config = self.load_config()
def load_config(self) -> Dict[str, Any]:
"""加载配置"""
if not self.config_path.exists():
return self.get_default_config()
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logging.warning(f"加载配置失败,使用默认配置: {e}")
return self.get_default_config()
def save_config(self):
"""保存配置"""
try:
with open(self.config_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, ensure_ascii=False, indent=2)
except Exception as e:
logging.error(f"保存配置失败: {e}")
def get_default_config(self) -> Dict[str, Any]:
"""获取默认配置"""
return {
"api_key": "",
"default_mode": "interactive",
"max_history": 100,
"mcp_servers": [],
"log_level": "INFO",
"theme": "default"
}
def get(self, key: str, default: Any = None) -> Any:
"""获取配置项"""
return self.config.get(key, default)
def set(self, key: str, value: Any):
"""设置配置项"""
self.config[key] = value
self.save_config()
def update(self, updates: Dict[str, Any]):
"""批量更新配置"""
self.config.update(updates)
self.save_config()
class Logger:
"""日志管理器"""
@staticmethod
def setup_logger(name: str, level: str = "INFO") -> logging.Logger:
"""设置日志器"""
logger = logging.getLogger(name)
try:
logger.setLevel(getattr(logging, level.upper()))
except AttributeError:
# 如果级别无效,默认使用INFO
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
class DataValidator:
"""数据验证器"""
@staticmethod
def validate_api_key(api_key: str) -> bool:
"""验证API密钥格式"""
if not api_key or not isinstance(api_key, str):
return False
# 简单的格式检查
return len(api_key) > 5 and api_key.startswith(('sk-', 'claude-'))
@staticmethod
def validate_task_plan(plan: Dict[str, Any]) -> bool:
"""验证任务计划格式"""
if not isinstance(plan, dict):
return False
required_keys = ['steps', 'goal']
if not all(key in plan for key in required_keys):
return False
if not isinstance(plan['steps'], list):
return False
if not plan['goal'] or not isinstance(plan['goal'], str):
return False
return True
@staticmethod
def validate_tool_call(tool_call: Dict[str, Any]) -> bool:
"""验证工具调用格式"""
if not isinstance(tool_call, dict):
return False
required_keys = ['name', 'args']
if not all(key in tool_call for key in required_keys):
return False
if not tool_call['name'] or not isinstance(tool_call['name'], str):
return False
return True
class TextUtils:
"""文本处理工具"""
@staticmethod
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
"""截断文本"""
if len(text) <= max_length:
return text
return text[:max_length - len(suffix)] + suffix
@staticmethod
def clean_json_string(text: str) -> str:
"""清理JSON字符串"""
import re
# 移除所有markdown代码块标记
text = re.sub(r'```[a-zA-Z]*\n?', '', text)
text = re.sub(r'```', '', text)
return text.strip()
@staticmethod
def extract_code_blocks(text: str, language: str = None) -> List[str]:
"""提取代码块"""
import re
if language:
# 更灵活的匹配模式
pattern = f'```{language}\\s*(.*?)```'
else:
pattern = '```\\w*\\s*(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return [match.strip() for match in matches if match.strip()]
@staticmethod
def format_error_message(error: Exception, context: str = "") -> str:
"""格式化错误消息"""
error_type = type(error).__name__
error_msg = str(error)
if context:
return f"[{context}] {error_type}: {error_msg}"
else:
return f"{error_type}: {error_msg}"
class FileUtils:
"""文件处理工具"""
@staticmethod
def ensure_directory(path: str):
"""确保目录存在"""
Path(path).mkdir(parents=True, exist_ok=True)
@staticmethod
def read_json_file(file_path: str) -> Optional[Dict[str, Any]]:
"""读取JSON文件"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logging.error(f"读取JSON文件失败 {file_path}: {e}")
return None
@staticmethod
def write_json_file(file_path: str, data: Dict[str, Any]):
"""写入JSON文件"""
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
logging.error(f"写入JSON文件失败 {file_path}: {e}")
@staticmethod
def get_project_root() -> Path:
"""获取项目根目录"""
current = Path(__file__).parent
while current.parent != current:
if (current / "setup.py").exists() or (current / "pyproject.toml").exists():
return current
current = current.parent
return Path.cwd()