blob: e2a57447e47aef295fd44f5559aee03449ee5e42 [file] [log] [blame] [raw]
"""
Utils helpers模块单元测试
测试工具包模块的各种辅助功能
"""
import pytest
import json
import os
import tempfile
from pathlib import Path
from unittest.mock import patch, mock_open
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../src'))
from claude_agent.utils.helpers import (
ConfigManager, Logger, DataValidator,
TextUtils, FileUtils
)
class TestConfigManager:
"""测试配置管理器"""
def test_init_default_path(self):
"""测试默认路径初始化"""
config_manager = ConfigManager()
assert config_manager.config_path.name == "config.json"
def test_init_custom_path(self):
"""测试自定义路径初始化"""
custom_path = "/tmp/test_config.json"
config_manager = ConfigManager(custom_path)
assert str(config_manager.config_path) == custom_path
def test_get_default_config(self):
"""测试获取默认配置"""
config_manager = ConfigManager()
default_config = config_manager.get_default_config()
assert "default_mode" in default_config
assert "theme" in default_config
assert default_config["default_mode"] == "interactive"
def test_get_and_set_config(self):
"""测试配置的获取和设置"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump({"test_key": "test_value"}, f)
temp_path = f.name
try:
config_manager = ConfigManager(temp_path)
# 测试get
assert config_manager.get("test_key") == "test_value"
assert config_manager.get("nonexistent", "default") == "default"
# 测试set
config_manager.set("new_key", "new_value")
assert config_manager.config["new_key"] == "new_value"
finally:
os.unlink(temp_path)
def test_update_config(self):
"""测试批量更新配置"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump({}, f)
temp_path = f.name
try:
config_manager = ConfigManager(temp_path)
updates = {
"key1": "value1",
"key2": "value2"
}
config_manager.update(updates)
assert config_manager.config["key1"] == "value1"
assert config_manager.config["key2"] == "value2"
finally:
os.unlink(temp_path)
def test_load_config_file_not_exists(self):
"""测试配置文件不存在时加载默认配置"""
with tempfile.TemporaryDirectory() as temp_dir:
non_existent_path = os.path.join(temp_dir, "non_existent.json")
config_manager = ConfigManager(non_existent_path)
# 应该加载默认配置
assert "default_mode" in config_manager.config
def test_save_config_exception_handling(self):
"""测试save_config异常处理(覆盖第43-44行)"""
with patch('pathlib.Path.mkdir'), \
patch('pathlib.Path.exists', return_value=True):
config_manager = ConfigManager('/test/path.json')
# 模拟写文件时发生异常
with patch('builtins.open', side_effect=OSError("Permission denied")), \
patch('logging.error') as mock_log:
# 应该捕获异常并记录日志,不抛出
config_manager.save_config()
mock_log.assert_called()
class TestLogger:
"""测试日志器"""
def test_setup_logger_basic(self):
"""测试基本日志器设置"""
logger = Logger.setup_logger("test_logger")
assert logger.name == "test_logger"
assert logger.level == 20 # INFO level
def test_setup_logger_with_level(self):
"""测试指定级别的日志器设置"""
logger = Logger.setup_logger("test_logger_debug", "DEBUG")
assert logger.name == "test_logger_debug"
assert logger.level == 10 # DEBUG level
def test_setup_logger_invalid_level_coverage(self):
"""测试无效日志级别处理(覆盖第82-83行)"""
# 简单测试:使用无效级别应该默认为INFO
logger = Logger.setup_logger("test_invalid", "INVALID_LEVEL")
assert logger.level == 20 # INFO level
class TestDataValidator:
"""测试数据验证器"""
def test_validate_api_key_valid(self):
"""测试有效API密钥"""
valid_keys = [
"sk-1234567890abcdef",
"sk-very-long-api-key-with-many-characters",
"sk-123"
]
for key in valid_keys:
assert DataValidator.validate_api_key(key) is True
def test_validate_api_key_invalid(self):
"""测试无效API密钥"""
invalid_keys = [
"",
"sk-",
"invalid-key",
"sk",
None
]
for key in invalid_keys:
if key is not None:
assert DataValidator.validate_api_key(key) is False
def test_validate_task_plan_valid(self):
"""测试有效任务计划"""
valid_plans = [
{"steps": ["step1", "step2"], "goal": "test goal"},
{"steps": [], "goal": "empty steps"},
{"steps": ["only step"], "goal": "single step"}
]
for plan in valid_plans:
assert DataValidator.validate_task_plan(plan) is True
def test_validate_task_plan_invalid_structure(self):
"""测试validate_task_plan无效结构(覆盖第111-112行)"""
# 测试无效的任务计划结构(不是字典)
invalid_plan = "not a dict"
assert DataValidator.validate_task_plan(invalid_plan) is False
def test_validate_task_plan_invalid(self):
"""测试无效任务计划"""
invalid_plans = [
{},
{"steps": []}, # 缺少goal
{"goal": "no steps"}, # 缺少steps
{"steps": "not a list", "goal": "invalid steps"},
{"steps": [], "goal": ""}, # 空goal
None
]
for plan in invalid_plans:
if plan is not None:
assert DataValidator.validate_task_plan(plan) is False
def test_validate_tool_call_valid(self):
"""测试有效工具调用"""
valid_calls = [
{"name": "test_tool", "args": {}},
{"name": "tool_with_args", "args": {"param": "value"}},
]
for call in valid_calls:
assert DataValidator.validate_tool_call(call) is True
def test_validate_tool_call_missing_name(self):
"""测试validate_tool_call缺少name(覆盖第132-134行)"""
# 测试缺少name字段的工具调用
invalid_tool_call = {"args": {"param": "value"}}
assert DataValidator.validate_tool_call(invalid_tool_call) is False
def test_validate_tool_call_empty_name(self):
"""测试validate_tool_call空name(覆盖第136-137行)"""
# 测试name为空的情况
invalid_tool_call = {"name": "", "args": {}}
assert DataValidator.validate_tool_call(invalid_tool_call) is False
def test_validate_tool_call_invalid(self):
"""测试无效工具调用"""
invalid_calls = [
{},
{"name": "tool"}, # 缺少args
None,
"not_a_dict", # 非字典类型
123, # 非字典类型
[] # 非字典类型
]
for call in invalid_calls:
assert DataValidator.validate_tool_call(call) is False
class TestTextUtils:
"""测试文本工具"""
def test_truncate_text_basic(self):
"""测试基本文本截断"""
text = "This is a long text that needs to be truncated"
result = TextUtils.truncate_text(text, 20)
assert len(result) <= 23 # 20 + len("...")
assert result.endswith("...")
assert result.startswith("This is a long")
def test_truncate_text_short(self):
"""测试短文本不截断"""
text = "Short text"
result = TextUtils.truncate_text(text, 20)
assert result == text
def test_truncate_text_custom_suffix(self):
"""测试自定义后缀"""
text = "Long text for truncation testing"
result = TextUtils.truncate_text(text, 10, suffix="---")
assert result.endswith("---")
def test_clean_json_string(self):
"""测试JSON字符串清理"""
dirty_json = '```json\n{"key": "value"}\n```'
clean_result = TextUtils.clean_json_string(dirty_json)
assert clean_result == '{"key": "value"}'
def test_clean_json_string_multiple_blocks(self):
"""测试多个代码块清理"""
dirty_json = 'Some text ```json\n{"key": "value"}\n``` more text ```\nother block\n```'
clean_result = TextUtils.clean_json_string(dirty_json)
assert '{"key": "value"}' in clean_result
assert '```' not in clean_result
def test_extract_code_blocks_python(self):
"""测试提取Python代码块"""
text = '''
Here is some code:
```python
print("hello")
def test():
pass
```
And more text.
'''
blocks = TextUtils.extract_code_blocks(text, "python")
assert len(blocks) == 1
assert 'print("hello")' in blocks[0]
assert 'def test():' in blocks[0]
def test_extract_code_blocks_empty_result(self):
"""测试extract_code_blocks没有代码块的情况(覆盖第172-173行)"""
text_without_code = "This is just plain text without any code blocks."
result = TextUtils.extract_code_blocks(text_without_code)
assert result == []
def test_extract_code_blocks_all(self):
"""测试提取所有代码块"""
text = '''
```python
print("python")
```
```javascript
console.log("js");
```
'''
blocks = TextUtils.extract_code_blocks(text)
assert len(blocks) == 2
def test_clean_json_string_edge_case(self):
"""测试clean_json_string的边界情况"""
# 测试带空格和换行的JSON字符串
messy_json = ' {"key": "value"} \n\n'
result = TextUtils.clean_json_string(messy_json)
assert result.strip() == '{"key": "value"}'
def test_format_error_message(self):
"""测试错误信息格式化"""
error = ValueError("Test error message")
formatted = TextUtils.format_error_message(error, "test context")
assert "ValueError" in formatted
assert "Test error message" in formatted
assert "test context" in formatted
def test_format_error_message_no_context(self):
"""测试无上下文的错误信息格式化"""
error = RuntimeError("Runtime error")
formatted = TextUtils.format_error_message(error)
assert "RuntimeError" in formatted
assert "Runtime error" in formatted
class TestFileUtils:
"""测试文件工具"""
def test_ensure_directory(self):
"""测试确保目录存在"""
with tempfile.TemporaryDirectory() as temp_dir:
test_dir = os.path.join(temp_dir, "new_directory", "nested")
FileUtils.ensure_directory(test_dir)
assert os.path.exists(test_dir)
assert os.path.isdir(test_dir)
def test_read_json_file_valid(self):
"""测试读取有效JSON文件"""
test_data = {"key": "value", "number": 42}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(test_data, f)
temp_path = f.name
try:
result = FileUtils.read_json_file(temp_path)
assert result == test_data
finally:
os.unlink(temp_path)
def test_read_json_file_not_exists(self):
"""测试读取不存在的JSON文件"""
result = FileUtils.read_json_file("/nonexistent/file.json")
assert result is None
def test_read_json_file_invalid_json(self):
"""测试读取无效JSON文件"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("invalid json content")
temp_path = f.name
try:
result = FileUtils.read_json_file(temp_path)
assert result is None
finally:
os.unlink(temp_path)
def test_write_json_file_exception_handling(self):
"""测试write_json_file异常处理(覆盖第211-212行)"""
test_data = {"key": "value"}
# 模拟写文件时发生异常
with patch('builtins.open', side_effect=OSError("Permission denied")), \
patch('logging.error') as mock_log:
# 应该捕获异常并记录日志,不抛出
FileUtils.write_json_file('/test/path.json', test_data)
mock_log.assert_called()
def test_write_json_file(self):
"""测试写入JSON文件"""
test_data = {"test": "data", "array": [1, 2, 3]}
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f:
temp_path = f.name
try:
FileUtils.write_json_file(temp_path, test_data)
# 验证文件被正确写入
with open(temp_path, 'r') as f:
loaded_data = json.load(f)
assert loaded_data == test_data
finally:
os.unlink(temp_path)
def test_get_project_root_no_markers(self):
"""测试get_project_root找不到标志文件(覆盖第222行)"""
with patch('pathlib.Path.exists', return_value=False), \
patch('pathlib.Path.cwd', return_value=Path('/current/dir')):
# 当找不到任何项目标志文件时,应该返回当前目录
result = FileUtils.get_project_root()
assert result == Path('/current/dir')
def test_get_project_root(self):
"""测试获取项目根目录"""
project_root = FileUtils.get_project_root()
assert isinstance(project_root, Path)
assert project_root.exists()
assert project_root.is_dir()
def test_ensure_directory_with_path_mkdir(self):
"""测试ensure_directory的Path.mkdir调用"""
with patch('pathlib.Path.mkdir') as mock_mkdir:
# 模拟目录创建
mock_mkdir.return_value = None
FileUtils.ensure_directory('/test/path')
# 应该调用mkdir with parents=True, exist_ok=True
mock_mkdir.assert_called_with(parents=True, exist_ok=True)
if __name__ == '__main__':
pytest.main([__file__])