| """ |
| 辅助工具模块的单元测试 |
| """ |
| |
| import pytest |
| import json |
| import tempfile |
| from pathlib import Path |
| from unittest.mock import patch, Mock |
| |
| from claude_agent.utils.helpers import ( |
| ConfigManager, Logger, DataValidator, |
| TextUtils, FileUtils |
| ) |
| |
| |
| class TestConfigManager: |
| """ConfigManager类的测试""" |
| |
| def test_config_manager_with_custom_path(self): |
| """测试使用自定义路径的配置管理器""" |
| with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: |
| config_data = {"test_key": "test_value"} |
| json.dump(config_data, f) |
| temp_path = f.name |
| |
| try: |
| config_manager = ConfigManager(config_path=temp_path) |
| assert config_manager.get("test_key") == "test_value" |
| finally: |
| Path(temp_path).unlink() |
| |
| def test_config_manager_default_config(self): |
| """测试默认配置""" |
| with tempfile.TemporaryDirectory() as temp_dir: |
| config_path = Path(temp_dir) / "nonexistent.json" |
| config_manager = ConfigManager(config_path=str(config_path)) |
| |
| default_config = config_manager.get_default_config() |
| assert "api_key" in default_config |
| assert "default_mode" in default_config |
| assert default_config["default_mode"] == "interactive" |
| |
| def test_config_get_set(self): |
| """测试获取和设置配置""" |
| with tempfile.TemporaryDirectory() as temp_dir: |
| config_path = Path(temp_dir) / "test_config.json" |
| config_manager = ConfigManager(config_path=str(config_path)) |
| |
| # 测试设置和获取 |
| config_manager.set("test_key", "test_value") |
| assert config_manager.get("test_key") == "test_value" |
| |
| # 测试默认值 |
| assert config_manager.get("nonexistent", "default") == "default" |
| |
| def test_config_update(self): |
| """测试批量更新配置""" |
| with tempfile.TemporaryDirectory() as temp_dir: |
| config_path = Path(temp_dir) / "test_config.json" |
| config_manager = ConfigManager(config_path=str(config_path)) |
| |
| updates = {"key1": "value1", "key2": "value2"} |
| config_manager.update(updates) |
| |
| assert config_manager.get("key1") == "value1" |
| assert config_manager.get("key2") == "value2" |
| |
| |
| class TestLogger: |
| """Logger类的测试""" |
| |
| def test_setup_logger(self): |
| """测试设置日志器""" |
| logger = Logger.setup_logger("test_logger", "DEBUG") |
| |
| assert logger.name == "test_logger" |
| assert logger.level == 10 # DEBUG级别 |
| assert len(logger.handlers) > 0 |
| |
| def test_setup_logger_default_level(self): |
| """测试设置日志器(默认级别)""" |
| logger = Logger.setup_logger("test_logger_default") |
| |
| assert logger.level == 20 # INFO级别 |
| |
| |
| class TestDataValidator: |
| """DataValidator类的测试""" |
| |
| def test_validate_api_key_valid(self): |
| """测试有效的API密钥""" |
| valid_keys = [ |
| "sk-1234567890abcdef", |
| "claude-abcdef1234567890" |
| ] |
| |
| for key in valid_keys: |
| assert DataValidator.validate_api_key(key) == True |
| |
| def test_validate_api_key_invalid(self): |
| """测试无效的API密钥""" |
| invalid_keys = [ |
| "", |
| None, |
| "short", |
| "invalid_prefix_key", |
| 123 # 非字符串 |
| ] |
| |
| for key in invalid_keys: |
| assert DataValidator.validate_api_key(key) == False |
| |
| def test_validate_task_plan_valid(self): |
| """测试有效的任务计划""" |
| valid_plan = { |
| "main_task": "测试任务", |
| "subtasks": [ |
| { |
| "id": "task_1", |
| "description": "子任务1" |
| } |
| ] |
| } |
| |
| assert DataValidator.validate_task_plan(valid_plan) == True |
| |
| def test_validate_task_plan_invalid(self): |
| """测试无效的任务计划""" |
| invalid_plans = [ |
| {}, # 缺少必要键 |
| {"main_task": "test"}, # 缺少subtasks |
| {"main_task": "test", "subtasks": "not_list"}, # subtasks不是列表 |
| {"main_task": "test", "subtasks": [{"id": "1"}]}, # 子任务缺少描述 |
| ] |
| |
| for plan in invalid_plans: |
| assert DataValidator.validate_task_plan(plan) == False |
| |
| def test_validate_tool_call_valid(self): |
| """测试有效的工具调用""" |
| valid_call = { |
| "name": "test_tool", |
| "arguments": {"param": "value"} |
| } |
| |
| assert DataValidator.validate_tool_call(valid_call) == True |
| |
| def test_validate_tool_call_invalid(self): |
| """测试无效的工具调用""" |
| invalid_calls = [ |
| {}, # 缺少必要键 |
| {"name": "test"}, # 缺少arguments |
| {"arguments": {}} # 缺少name |
| ] |
| |
| for call in invalid_calls: |
| assert DataValidator.validate_tool_call(call) == False |
| |
| |
| class TestTextUtils: |
| """TextUtils类的测试""" |
| |
| def test_truncate_text_short(self): |
| """测试截断短文本""" |
| text = "短文本" |
| result = TextUtils.truncate_text(text, max_length=100) |
| assert result == text |
| |
| def test_truncate_text_long(self): |
| """测试截断长文本""" |
| text = "这是一个很长的文本" * 10 |
| result = TextUtils.truncate_text(text, max_length=20) |
| assert len(result) <= 20 |
| assert result.endswith("...") |
| |
| def test_clean_json_string(self): |
| """测试清理JSON字符串""" |
| test_cases = [ |
| ("```json\n{\"test\": 1}\n```", '{"test": 1}'), |
| ("```\n{\"test\": 2}\n```", '{"test": 2}'), |
| ('{"test": 3}', '{"test": 3}'), |
| (" \n{\"test\": 4}\n ", '{"test": 4}') |
| ] |
| |
| for input_text, expected in test_cases: |
| result = TextUtils.clean_json_string(input_text) |
| assert result == expected |
| |
| def test_extract_code_blocks(self): |
| """测试提取代码块""" |
| text = """ |
| 一些文本 |
| ```python |
| print("hello") |
| ``` |
| 更多文本 |
| ```javascript |
| console.log("world") |
| ``` |
| """ |
| |
| # 提取Python代码块 |
| python_blocks = TextUtils.extract_code_blocks(text, "python") |
| assert len(python_blocks) == 1 |
| assert 'print("hello")' in python_blocks[0] |
| |
| # 提取所有代码块 |
| all_blocks = TextUtils.extract_code_blocks(text) |
| assert len(all_blocks) == 2 |
| |
| def test_format_error_message(self): |
| """测试格式化错误消息""" |
| error = ValueError("测试错误") |
| |
| # 不带上下文 |
| formatted = TextUtils.format_error_message(error) |
| assert "ValueError: 测试错误" in formatted |
| |
| # 带上下文 |
| formatted_with_context = TextUtils.format_error_message(error, "测试上下文") |
| assert "[测试上下文]" in formatted_with_context |
| assert "ValueError: 测试错误" in formatted_with_context |
| |
| |
| class TestFileUtils: |
| """FileUtils类的测试""" |
| |
| def test_ensure_directory(self): |
| """测试确保目录存在""" |
| with tempfile.TemporaryDirectory() as temp_dir: |
| test_path = Path(temp_dir) / "new_dir" / "sub_dir" |
| FileUtils.ensure_directory(str(test_path)) |
| assert test_path.exists() |
| assert test_path.is_dir() |
| |
| def test_read_write_json_file(self): |
| """测试读写JSON文件""" |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: |
| temp_path = f.name |
| |
| try: |
| # 写入JSON |
| test_data = {"test": "data", "number": 42} |
| FileUtils.write_json_file(temp_path, test_data) |
| |
| # 读取JSON |
| read_data = FileUtils.read_json_file(temp_path) |
| assert read_data == test_data |
| |
| finally: |
| Path(temp_path).unlink() |
| |
| def test_read_json_file_not_found(self): |
| """测试读取不存在的JSON文件""" |
| result = FileUtils.read_json_file("nonexistent.json") |
| assert result is None |
| |
| def test_get_project_root(self): |
| """测试获取项目根目录""" |
| with patch('claude_agent.utils.helpers.Path') as mock_path: |
| mock_file_path = Mock() |
| mock_file_path.parent = Mock() |
| |
| mock_path.__file__ = mock_file_path |
| mock_path.return_value = mock_file_path |
| |
| # 模拟找到setup.py |
| mock_current = Mock() |
| mock_current.parent = Mock() |
| mock_setup_exists = Mock() |
| mock_setup_exists.exists.return_value = True |
| mock_current.__truediv__ = Mock(return_value=mock_setup_exists) |
| |
| mock_file_path.parent = mock_current |
| mock_current.parent = mock_current # 防止无限循环 |
| |
| # 这个测试比较复杂,简化验证 |
| result = FileUtils.get_project_root() |
| assert result is not None |