| """ |
| Utils helpers模块单元测试 |
| 测试工具包模块的各种辅助功能 |
| """ |
| |
| import pytest |
| import json |
| import os |
| import tempfile |
| from pathlib import Path |
| from unittest.mock import patch, mock_open |
| |
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../src')) |
| |
| from claude_agent.utils.helpers import ( |
| ConfigManager, Logger, DataValidator, |
| TextUtils, FileUtils |
| ) |
| |
| |
| class TestConfigManager: |
| """测试配置管理器""" |
| |
| def test_init_default_path(self): |
| """测试默认路径初始化""" |
| config_manager = ConfigManager() |
| assert config_manager.config_path.name == "config.json" |
| |
| def test_init_custom_path(self): |
| """测试自定义路径初始化""" |
| custom_path = "/tmp/test_config.json" |
| config_manager = ConfigManager(custom_path) |
| assert str(config_manager.config_path) == custom_path |
| |
| def test_get_default_config(self): |
| """测试获取默认配置""" |
| config_manager = ConfigManager() |
| default_config = config_manager.get_default_config() |
| |
| assert "default_mode" in default_config |
| assert "theme" in default_config |
| assert default_config["default_mode"] == "interactive" |
| |
| def test_get_and_set_config(self): |
| """测试配置的获取和设置""" |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: |
| json.dump({"test_key": "test_value"}, f) |
| temp_path = f.name |
| |
| try: |
| config_manager = ConfigManager(temp_path) |
| |
| # 测试get |
| assert config_manager.get("test_key") == "test_value" |
| assert config_manager.get("nonexistent", "default") == "default" |
| |
| # 测试set |
| config_manager.set("new_key", "new_value") |
| assert config_manager.config["new_key"] == "new_value" |
| |
| finally: |
| os.unlink(temp_path) |
| |
| def test_update_config(self): |
| """测试批量更新配置""" |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: |
| json.dump({}, f) |
| temp_path = f.name |
| |
| try: |
| config_manager = ConfigManager(temp_path) |
| |
| updates = { |
| "key1": "value1", |
| "key2": "value2" |
| } |
| config_manager.update(updates) |
| |
| assert config_manager.config["key1"] == "value1" |
| assert config_manager.config["key2"] == "value2" |
| |
| finally: |
| os.unlink(temp_path) |
| |
| def test_load_config_file_not_exists(self): |
| """测试配置文件不存在时加载默认配置""" |
| with tempfile.TemporaryDirectory() as temp_dir: |
| non_existent_path = os.path.join(temp_dir, "non_existent.json") |
| config_manager = ConfigManager(non_existent_path) |
| |
| # 应该加载默认配置 |
| assert "default_mode" in config_manager.config |
| |
| def test_save_config_exception_handling(self): |
| """测试save_config异常处理(覆盖第43-44行)""" |
| with patch('pathlib.Path.mkdir'), \ |
| patch('pathlib.Path.exists', return_value=True): |
| config_manager = ConfigManager('/test/path.json') |
| |
| # 模拟写文件时发生异常 |
| with patch('builtins.open', side_effect=OSError("Permission denied")), \ |
| patch('logging.error') as mock_log: |
| # 应该捕获异常并记录日志,不抛出 |
| config_manager.save_config() |
| mock_log.assert_called() |
| |
| |
| class TestLogger: |
| """测试日志器""" |
| |
| def test_setup_logger_basic(self): |
| """测试基本日志器设置""" |
| logger = Logger.setup_logger("test_logger") |
| |
| assert logger.name == "test_logger" |
| assert logger.level == 20 # INFO level |
| |
| def test_setup_logger_with_level(self): |
| """测试指定级别的日志器设置""" |
| logger = Logger.setup_logger("test_logger_debug", "DEBUG") |
| |
| assert logger.name == "test_logger_debug" |
| assert logger.level == 10 # DEBUG level |
| |
| def test_setup_logger_invalid_level_coverage(self): |
| """测试无效日志级别处理(覆盖第82-83行)""" |
| # 简单测试:使用无效级别应该默认为INFO |
| logger = Logger.setup_logger("test_invalid", "INVALID_LEVEL") |
| assert logger.level == 20 # INFO level |
| |
| |
| class TestDataValidator: |
| """测试数据验证器""" |
| |
| def test_validate_api_key_valid(self): |
| """测试有效API密钥""" |
| valid_keys = [ |
| "sk-1234567890abcdef", |
| "sk-very-long-api-key-with-many-characters", |
| "sk-123" |
| ] |
| |
| for key in valid_keys: |
| assert DataValidator.validate_api_key(key) is True |
| |
| def test_validate_api_key_invalid(self): |
| """测试无效API密钥""" |
| invalid_keys = [ |
| "", |
| "sk-", |
| "invalid-key", |
| "sk", |
| None |
| ] |
| |
| for key in invalid_keys: |
| if key is not None: |
| assert DataValidator.validate_api_key(key) is False |
| |
| def test_validate_task_plan_valid(self): |
| """测试有效任务计划""" |
| valid_plans = [ |
| {"steps": ["step1", "step2"], "goal": "test goal"}, |
| {"steps": [], "goal": "empty steps"}, |
| {"steps": ["only step"], "goal": "single step"} |
| ] |
| |
| for plan in valid_plans: |
| assert DataValidator.validate_task_plan(plan) is True |
| |
| def test_validate_task_plan_invalid_structure(self): |
| """测试validate_task_plan无效结构(覆盖第111-112行)""" |
| # 测试无效的任务计划结构(不是字典) |
| invalid_plan = "not a dict" |
| assert DataValidator.validate_task_plan(invalid_plan) is False |
| |
| def test_validate_task_plan_invalid(self): |
| """测试无效任务计划""" |
| invalid_plans = [ |
| {}, |
| {"steps": []}, # 缺少goal |
| {"goal": "no steps"}, # 缺少steps |
| {"steps": "not a list", "goal": "invalid steps"}, |
| {"steps": [], "goal": ""}, # 空goal |
| None |
| ] |
| |
| for plan in invalid_plans: |
| if plan is not None: |
| assert DataValidator.validate_task_plan(plan) is False |
| |
| def test_validate_tool_call_valid(self): |
| """测试有效工具调用""" |
| valid_calls = [ |
| {"name": "test_tool", "args": {}}, |
| {"name": "tool_with_args", "args": {"param": "value"}}, |
| ] |
| |
| for call in valid_calls: |
| assert DataValidator.validate_tool_call(call) is True |
| |
| def test_validate_tool_call_missing_name(self): |
| """测试validate_tool_call缺少name(覆盖第132-134行)""" |
| # 测试缺少name字段的工具调用 |
| invalid_tool_call = {"args": {"param": "value"}} |
| assert DataValidator.validate_tool_call(invalid_tool_call) is False |
| |
| def test_validate_tool_call_empty_name(self): |
| """测试validate_tool_call空name(覆盖第136-137行)""" |
| # 测试name为空的情况 |
| invalid_tool_call = {"name": "", "args": {}} |
| assert DataValidator.validate_tool_call(invalid_tool_call) is False |
| |
| def test_validate_tool_call_invalid(self): |
| """测试无效工具调用""" |
| invalid_calls = [ |
| {}, |
| {"name": "tool"}, # 缺少args |
| None, |
| "not_a_dict", # 非字典类型 |
| 123, # 非字典类型 |
| [] # 非字典类型 |
| ] |
| |
| for call in invalid_calls: |
| assert DataValidator.validate_tool_call(call) is False |
| |
| |
| class TestTextUtils: |
| """测试文本工具""" |
| |
| def test_truncate_text_basic(self): |
| """测试基本文本截断""" |
| text = "This is a long text that needs to be truncated" |
| result = TextUtils.truncate_text(text, 20) |
| |
| assert len(result) <= 23 # 20 + len("...") |
| assert result.endswith("...") |
| assert result.startswith("This is a long") |
| |
| def test_truncate_text_short(self): |
| """测试短文本不截断""" |
| text = "Short text" |
| result = TextUtils.truncate_text(text, 20) |
| |
| assert result == text |
| |
| def test_truncate_text_custom_suffix(self): |
| """测试自定义后缀""" |
| text = "Long text for truncation testing" |
| result = TextUtils.truncate_text(text, 10, suffix="---") |
| |
| assert result.endswith("---") |
| |
| def test_clean_json_string(self): |
| """测试JSON字符串清理""" |
| dirty_json = '```json\n{"key": "value"}\n```' |
| clean_result = TextUtils.clean_json_string(dirty_json) |
| |
| assert clean_result == '{"key": "value"}' |
| |
| def test_clean_json_string_multiple_blocks(self): |
| """测试多个代码块清理""" |
| dirty_json = 'Some text ```json\n{"key": "value"}\n``` more text ```\nother block\n```' |
| clean_result = TextUtils.clean_json_string(dirty_json) |
| |
| assert '{"key": "value"}' in clean_result |
| assert '```' not in clean_result |
| |
| def test_extract_code_blocks_python(self): |
| """测试提取Python代码块""" |
| text = ''' |
| Here is some code: |
| ```python |
| print("hello") |
| def test(): |
| pass |
| ``` |
| And more text. |
| ''' |
| |
| blocks = TextUtils.extract_code_blocks(text, "python") |
| assert len(blocks) == 1 |
| assert 'print("hello")' in blocks[0] |
| assert 'def test():' in blocks[0] |
| |
| def test_extract_code_blocks_empty_result(self): |
| """测试extract_code_blocks没有代码块的情况(覆盖第172-173行)""" |
| text_without_code = "This is just plain text without any code blocks." |
| result = TextUtils.extract_code_blocks(text_without_code) |
| assert result == [] |
| |
| def test_extract_code_blocks_all(self): |
| """测试提取所有代码块""" |
| text = ''' |
| ```python |
| print("python") |
| ``` |
| ```javascript |
| console.log("js"); |
| ``` |
| ''' |
| |
| blocks = TextUtils.extract_code_blocks(text) |
| assert len(blocks) == 2 |
| |
| def test_clean_json_string_edge_case(self): |
| """测试clean_json_string的边界情况""" |
| # 测试带空格和换行的JSON字符串 |
| messy_json = ' {"key": "value"} \n\n' |
| result = TextUtils.clean_json_string(messy_json) |
| assert result.strip() == '{"key": "value"}' |
| |
| def test_format_error_message(self): |
| """测试错误信息格式化""" |
| error = ValueError("Test error message") |
| formatted = TextUtils.format_error_message(error, "test context") |
| |
| assert "ValueError" in formatted |
| assert "Test error message" in formatted |
| assert "test context" in formatted |
| |
| def test_format_error_message_no_context(self): |
| """测试无上下文的错误信息格式化""" |
| error = RuntimeError("Runtime error") |
| formatted = TextUtils.format_error_message(error) |
| |
| assert "RuntimeError" in formatted |
| assert "Runtime error" in formatted |
| |
| |
| class TestFileUtils: |
| """测试文件工具""" |
| |
| def test_ensure_directory(self): |
| """测试确保目录存在""" |
| with tempfile.TemporaryDirectory() as temp_dir: |
| test_dir = os.path.join(temp_dir, "new_directory", "nested") |
| |
| FileUtils.ensure_directory(test_dir) |
| assert os.path.exists(test_dir) |
| assert os.path.isdir(test_dir) |
| |
| def test_read_json_file_valid(self): |
| """测试读取有效JSON文件""" |
| test_data = {"key": "value", "number": 42} |
| |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: |
| json.dump(test_data, f) |
| temp_path = f.name |
| |
| try: |
| result = FileUtils.read_json_file(temp_path) |
| assert result == test_data |
| finally: |
| os.unlink(temp_path) |
| |
| def test_read_json_file_not_exists(self): |
| """测试读取不存在的JSON文件""" |
| result = FileUtils.read_json_file("/nonexistent/file.json") |
| assert result is None |
| |
| def test_read_json_file_invalid_json(self): |
| """测试读取无效JSON文件""" |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: |
| f.write("invalid json content") |
| temp_path = f.name |
| |
| try: |
| result = FileUtils.read_json_file(temp_path) |
| assert result is None |
| finally: |
| os.unlink(temp_path) |
| |
| def test_write_json_file_exception_handling(self): |
| """测试write_json_file异常处理(覆盖第211-212行)""" |
| test_data = {"key": "value"} |
| |
| # 模拟写文件时发生异常 |
| with patch('builtins.open', side_effect=OSError("Permission denied")), \ |
| patch('logging.error') as mock_log: |
| # 应该捕获异常并记录日志,不抛出 |
| FileUtils.write_json_file('/test/path.json', test_data) |
| mock_log.assert_called() |
| |
| def test_write_json_file(self): |
| """测试写入JSON文件""" |
| test_data = {"test": "data", "array": [1, 2, 3]} |
| |
| with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: |
| temp_path = f.name |
| |
| try: |
| FileUtils.write_json_file(temp_path, test_data) |
| |
| # 验证文件被正确写入 |
| with open(temp_path, 'r') as f: |
| loaded_data = json.load(f) |
| |
| assert loaded_data == test_data |
| finally: |
| os.unlink(temp_path) |
| |
| def test_get_project_root_no_markers(self): |
| """测试get_project_root找不到标志文件(覆盖第222行)""" |
| with patch('pathlib.Path.exists', return_value=False), \ |
| patch('pathlib.Path.cwd', return_value=Path('/current/dir')): |
| |
| # 当找不到任何项目标志文件时,应该返回当前目录 |
| result = FileUtils.get_project_root() |
| assert result == Path('/current/dir') |
| |
| def test_get_project_root(self): |
| """测试获取项目根目录""" |
| project_root = FileUtils.get_project_root() |
| |
| assert isinstance(project_root, Path) |
| assert project_root.exists() |
| assert project_root.is_dir() |
| |
| def test_ensure_directory_with_path_mkdir(self): |
| """测试ensure_directory的Path.mkdir调用""" |
| with patch('pathlib.Path.mkdir') as mock_mkdir: |
| # 模拟目录创建 |
| mock_mkdir.return_value = None |
| |
| FileUtils.ensure_directory('/test/path') |
| # 应该调用mkdir with parents=True, exist_ok=True |
| mock_mkdir.assert_called_with(parents=True, exist_ok=True) |
| |
| |
| if __name__ == '__main__': |
| pytest.main([__file__]) |