| """ |
| 配置管理模块 |
| 支持TOML配置文件的加载和管理 |
| 支持环境变量覆盖配置(Docker友好) |
| """ |
| |
| import os |
| import sys |
| import json |
| from typing import Dict, Any, Optional, List |
| from pathlib import Path |
| |
| # 使用tomli库读取TOML文件(Python 3.11+包含在标准库中) |
| if sys.version_info >= (3, 11): |
| import tomllib |
| else: |
| import tomli as tomllib |
| |
| |
| class ConfigManager: |
| """配置管理器""" |
| |
| def __init__(self, config_name: str = "default"): |
| """ |
| 初始化配置管理器 |
| |
| Args: |
| config_name: 配置文件名 (不含.toml扩展名) |
| """ |
| self.config_name = config_name |
| self._config: Dict[str, Any] = {} |
| self._project_root = self._get_project_root() |
| self._config_dir = os.path.join(self._project_root, "configs") |
| |
| # 加载配置 |
| self._load_config() |
| |
| def _get_project_root(self) -> str: |
| """获取项目根目录""" |
| current_file = os.path.abspath(__file__) |
| # 从 src/claude_agent/utils/config.py 回到项目根目录 |
| return os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file)))) |
| |
| def _load_config(self): |
| """加载配置文件""" |
| config_file = os.path.join(self._config_dir, f"{self.config_name}.toml") |
| |
| if not os.path.exists(config_file): |
| raise FileNotFoundError(f"配置文件不存在: {config_file}") |
| |
| try: |
| with open(config_file, 'rb') as f: |
| self._config = tomllib.load(f) |
| except Exception as e: |
| raise ValueError(f"配置文件格式错误: {e}") |
| |
| def get(self, key_path: str, default: Any = None) -> Any: |
| """ |
| 获取配置值 |
| |
| Args: |
| key_path: 配置键路径,使用点号分隔,如 'sshout.server.hostname' |
| default: 默认值 |
| |
| Returns: |
| 配置值 |
| """ |
| keys = key_path.split('.') |
| value = self._config |
| |
| try: |
| for key in keys: |
| value = value[key] |
| return value |
| except (KeyError, TypeError): |
| return default |
| |
| def get_section(self, section: str) -> Dict[str, Any]: |
| """ |
| 获取配置段落 |
| |
| Args: |
| section: 段落名称,如 'sshout' |
| |
| Returns: |
| 配置段落字典 |
| """ |
| return self._config.get(section, {}) |
| |
| def get_sshout_config(self) -> Dict[str, Any]: |
| """获取SSHOUT配置""" |
| sshout_config = self.get_section('sshout') |
| |
| # 处理SSH密钥路径,转换为绝对路径 |
| if 'ssh_key' in sshout_config and 'private_key_path' in sshout_config['ssh_key']: |
| key_path = sshout_config['ssh_key']['private_key_path'] |
| if not os.path.isabs(key_path): |
| # 相对路径转换为绝对路径 |
| sshout_config['ssh_key']['private_key_path'] = os.path.join( |
| self._project_root, key_path |
| ) |
| |
| return sshout_config |
| |
| def get_agent_config(self) -> Dict[str, Any]: |
| """获取Agent配置,支持环境变量覆盖""" |
| agent_config = self.get_section('agent').copy() |
| |
| # 环境变量覆盖(优先级最高) |
| # ANTHROPIC_API_KEY - Claude API密钥(通过os.environ自动使用) |
| # ANTHROPIC_BASE_URL - Claude API基础URL |
| if 'ANTHROPIC_BASE_URL' in os.environ: |
| agent_config['api_base_url'] = os.environ['ANTHROPIC_BASE_URL'] |
| |
| return agent_config |
| |
| def get_cli_config(self) -> Dict[str, Any]: |
| """获取CLI配置""" |
| return self.get_section('cli') |
| |
| def get_mcp_config(self) -> Dict[str, Any]: |
| """获取MCP配置""" |
| return self.get_section('mcp') |
| |
| def get_telegram_config(self) -> Dict[str, Any]: |
| """获取Telegram配置,支持环境变量覆盖""" |
| telegram_config = self.get_section('telegram').copy() |
| |
| # 环境变量覆盖(优先级最高) |
| if 'TELEGRAM_BOT_TOKEN' in os.environ: |
| telegram_config['bot_token'] = os.environ['TELEGRAM_BOT_TOKEN'] |
| |
| if 'TELEGRAM_ALLOWED_USERS' in os.environ: |
| try: |
| users = json.loads(os.environ['TELEGRAM_ALLOWED_USERS']) |
| telegram_config['allowed_users'] = users |
| except json.JSONDecodeError: |
| pass # 保持配置文件中的值 |
| |
| if 'TELEGRAM_ALLOWED_GROUPS' in os.environ: |
| try: |
| groups = json.loads(os.environ['TELEGRAM_ALLOWED_GROUPS']) |
| telegram_config['allowed_groups'] = groups |
| except json.JSONDecodeError: |
| pass # 保持配置文件中的值 |
| |
| # 处理临时目录路径,转换为绝对路径 |
| if 'files' in telegram_config and 'temp_dir' in telegram_config['files']: |
| temp_dir = telegram_config['files']['temp_dir'] |
| if not os.path.isabs(temp_dir): |
| # 相对路径转换为绝对路径 |
| telegram_config['files']['temp_dir'] = os.path.join( |
| self._project_root, temp_dir |
| ) |
| |
| return telegram_config |
| |
| def get_logging_config(self) -> Dict[str, Any]: |
| """获取日志配置""" |
| return self.get_section('logging') |
| |
| def get_webhook_config(self) -> Dict[str, Any]: |
| """获取Webhook配置,支持环境变量覆盖""" |
| webhook_config = self.get_section('webhook').copy() |
| |
| # 环境变量覆盖(优先级最高) |
| if 'WEBHOOK_SERVER_URL' in os.environ: |
| webhook_config['server_url'] = os.environ['WEBHOOK_SERVER_URL'] |
| |
| if 'WEBHOOK_AUTH_TOKEN' in os.environ: |
| webhook_config['auth_token'] = os.environ['WEBHOOK_AUTH_TOKEN'] |
| |
| if 'WEBHOOK_HOST' in os.environ and 'server' in webhook_config: |
| webhook_config['server']['host'] = os.environ['WEBHOOK_HOST'] |
| |
| if 'WEBHOOK_PORT' in os.environ and 'server' in webhook_config: |
| try: |
| webhook_config['server']['port'] = int(os.environ['WEBHOOK_PORT']) |
| except ValueError: |
| pass # 保持配置文件中的值 |
| |
| return webhook_config |
| |
| def reload(self): |
| """重新加载配置""" |
| self._load_config() |
| |
| @property |
| def project_root(self) -> str: |
| """获取项目根目录路径""" |
| return self._project_root |
| |
| |
| # 全局配置实例 |
| _config_manager: Optional[ConfigManager] = None |
| |
| |
| def get_config_manager(config_name: str = None) -> ConfigManager: |
| """ |
| 获取全局配置管理器实例 |
| |
| Args: |
| config_name: 配置文件名,如果为None则使用已有实例或默认配置 |
| |
| Returns: |
| ConfigManager实例 |
| """ |
| global _config_manager |
| |
| if _config_manager is None or config_name is not None: |
| config_name = config_name or os.getenv('CLAUDE_CONFIG', 'default') |
| _config_manager = ConfigManager(config_name) |
| |
| return _config_manager |
| |
| |
| def get_config(key_path: str, default: Any = None) -> Any: |
| """ |
| 获取配置值的便捷函数 |
| |
| Args: |
| key_path: 配置键路径 |
| default: 默认值 |
| |
| Returns: |
| 配置值 |
| """ |
| return get_config_manager().get(key_path, default) |