blob: e1ad1a54541b7aa3a7e151b5495096d69c932f16 [file] [log] [blame] [raw]
"""
配置管理模块
支持TOML配置文件的加载和管理
支持环境变量覆盖配置(Docker友好)
"""
import os
import sys
import json
from typing import Dict, Any, Optional, List
from pathlib import Path
# 使用tomli库读取TOML文件(Python 3.11+包含在标准库中)
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
class ConfigManager:
"""配置管理器"""
def __init__(self, config_name: str = "default"):
"""
初始化配置管理器
Args:
config_name: 配置文件名 (不含.toml扩展名)
"""
self.config_name = config_name
self._config: Dict[str, Any] = {}
self._project_root = self._get_project_root()
self._config_dir = os.path.join(self._project_root, "configs")
# 加载配置
self._load_config()
def _get_project_root(self) -> str:
"""获取项目根目录"""
current_file = os.path.abspath(__file__)
# 从 src/claude_agent/utils/config.py 回到项目根目录
return os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
def _load_config(self):
"""加载配置文件"""
config_file = os.path.join(self._config_dir, f"{self.config_name}.toml")
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
try:
with open(config_file, 'rb') as f:
self._config = tomllib.load(f)
except Exception as e:
raise ValueError(f"配置文件格式错误: {e}")
def get(self, key_path: str, default: Any = None) -> Any:
"""
获取配置值
Args:
key_path: 配置键路径,使用点号分隔,如 'sshout.server.hostname'
default: 默认值
Returns:
配置值
"""
keys = key_path.split('.')
value = self._config
try:
for key in keys:
value = value[key]
return value
except (KeyError, TypeError):
return default
def get_section(self, section: str) -> Dict[str, Any]:
"""
获取配置段落
Args:
section: 段落名称,如 'sshout'
Returns:
配置段落字典
"""
return self._config.get(section, {})
def get_sshout_config(self) -> Dict[str, Any]:
"""获取SSHOUT配置"""
sshout_config = self.get_section('sshout')
# 处理SSH密钥路径,转换为绝对路径
if 'ssh_key' in sshout_config and 'private_key_path' in sshout_config['ssh_key']:
key_path = sshout_config['ssh_key']['private_key_path']
if not os.path.isabs(key_path):
# 相对路径转换为绝对路径
sshout_config['ssh_key']['private_key_path'] = os.path.join(
self._project_root, key_path
)
return sshout_config
def get_agent_config(self) -> Dict[str, Any]:
"""获取Agent配置,支持环境变量覆盖"""
agent_config = self.get_section('agent').copy()
# 环境变量覆盖(优先级最高)
# ANTHROPIC_API_KEY - Claude API密钥(通过os.environ自动使用)
# ANTHROPIC_BASE_URL - Claude API基础URL
if 'ANTHROPIC_BASE_URL' in os.environ:
agent_config['api_base_url'] = os.environ['ANTHROPIC_BASE_URL']
return agent_config
def get_cli_config(self) -> Dict[str, Any]:
"""获取CLI配置"""
return self.get_section('cli')
def get_mcp_config(self) -> Dict[str, Any]:
"""获取MCP配置"""
return self.get_section('mcp')
def get_telegram_config(self) -> Dict[str, Any]:
"""获取Telegram配置,支持环境变量覆盖"""
telegram_config = self.get_section('telegram').copy()
# 环境变量覆盖(优先级最高)
if 'TELEGRAM_BOT_TOKEN' in os.environ:
telegram_config['bot_token'] = os.environ['TELEGRAM_BOT_TOKEN']
if 'TELEGRAM_ALLOWED_USERS' in os.environ:
try:
users = json.loads(os.environ['TELEGRAM_ALLOWED_USERS'])
telegram_config['allowed_users'] = users
except json.JSONDecodeError:
pass # 保持配置文件中的值
if 'TELEGRAM_ALLOWED_GROUPS' in os.environ:
try:
groups = json.loads(os.environ['TELEGRAM_ALLOWED_GROUPS'])
telegram_config['allowed_groups'] = groups
except json.JSONDecodeError:
pass # 保持配置文件中的值
# 处理临时目录路径,转换为绝对路径
if 'files' in telegram_config and 'temp_dir' in telegram_config['files']:
temp_dir = telegram_config['files']['temp_dir']
if not os.path.isabs(temp_dir):
# 相对路径转换为绝对路径
telegram_config['files']['temp_dir'] = os.path.join(
self._project_root, temp_dir
)
return telegram_config
def get_logging_config(self) -> Dict[str, Any]:
"""获取日志配置"""
return self.get_section('logging')
def get_webhook_config(self) -> Dict[str, Any]:
"""获取Webhook配置,支持环境变量覆盖"""
webhook_config = self.get_section('webhook').copy()
# 环境变量覆盖(优先级最高)
if 'WEBHOOK_SERVER_URL' in os.environ:
webhook_config['server_url'] = os.environ['WEBHOOK_SERVER_URL']
if 'WEBHOOK_AUTH_TOKEN' in os.environ:
webhook_config['auth_token'] = os.environ['WEBHOOK_AUTH_TOKEN']
if 'WEBHOOK_HOST' in os.environ and 'server' in webhook_config:
webhook_config['server']['host'] = os.environ['WEBHOOK_HOST']
if 'WEBHOOK_PORT' in os.environ and 'server' in webhook_config:
try:
webhook_config['server']['port'] = int(os.environ['WEBHOOK_PORT'])
except ValueError:
pass # 保持配置文件中的值
return webhook_config
def reload(self):
"""重新加载配置"""
self._load_config()
@property
def project_root(self) -> str:
"""获取项目根目录路径"""
return self._project_root
# 全局配置实例
_config_manager: Optional[ConfigManager] = None
def get_config_manager(config_name: str = None) -> ConfigManager:
"""
获取全局配置管理器实例
Args:
config_name: 配置文件名,如果为None则使用已有实例或默认配置
Returns:
ConfigManager实例
"""
global _config_manager
if _config_manager is None or config_name is not None:
config_name = config_name or os.getenv('CLAUDE_CONFIG', 'default')
_config_manager = ConfigManager(config_name)
return _config_manager
def get_config(key_path: str, default: Any = None) -> Any:
"""
获取配置值的便捷函数
Args:
key_path: 配置键路径
default: 默认值
Returns:
配置值
"""
return get_config_manager().get(key_path, default)