blob: b86decf00ef317430253eaf7479ab5df4a9e06c2 [file] [log] [blame] [raw]
"""
辅助工具模块的单元测试
"""
import pytest
import json
import tempfile
from pathlib import Path
from unittest.mock import patch, Mock
from claude_agent.utils.helpers import (
ConfigManager, Logger, DataValidator,
TextUtils, FileUtils
)
class TestConfigManager:
"""ConfigManager类的测试"""
def test_config_manager_with_custom_path(self):
"""测试使用自定义路径的配置管理器"""
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
config_data = {"test_key": "test_value"}
json.dump(config_data, f)
temp_path = f.name
try:
config_manager = ConfigManager(config_path=temp_path)
assert config_manager.get("test_key") == "test_value"
finally:
Path(temp_path).unlink()
def test_config_manager_default_config(self):
"""测试默认配置"""
with tempfile.TemporaryDirectory() as temp_dir:
config_path = Path(temp_dir) / "nonexistent.json"
config_manager = ConfigManager(config_path=str(config_path))
default_config = config_manager.get_default_config()
assert "api_key" in default_config
assert "default_mode" in default_config
assert default_config["default_mode"] == "interactive"
def test_config_get_set(self):
"""测试获取和设置配置"""
with tempfile.TemporaryDirectory() as temp_dir:
config_path = Path(temp_dir) / "test_config.json"
config_manager = ConfigManager(config_path=str(config_path))
# 测试设置和获取
config_manager.set("test_key", "test_value")
assert config_manager.get("test_key") == "test_value"
# 测试默认值
assert config_manager.get("nonexistent", "default") == "default"
def test_config_update(self):
"""测试批量更新配置"""
with tempfile.TemporaryDirectory() as temp_dir:
config_path = Path(temp_dir) / "test_config.json"
config_manager = ConfigManager(config_path=str(config_path))
updates = {"key1": "value1", "key2": "value2"}
config_manager.update(updates)
assert config_manager.get("key1") == "value1"
assert config_manager.get("key2") == "value2"
class TestLogger:
"""Logger类的测试"""
def test_setup_logger(self):
"""测试设置日志器"""
logger = Logger.setup_logger("test_logger", "DEBUG")
assert logger.name == "test_logger"
assert logger.level == 10 # DEBUG级别
assert len(logger.handlers) > 0
def test_setup_logger_default_level(self):
"""测试设置日志器(默认级别)"""
logger = Logger.setup_logger("test_logger_default")
assert logger.level == 20 # INFO级别
class TestDataValidator:
"""DataValidator类的测试"""
def test_validate_api_key_valid(self):
"""测试有效的API密钥"""
valid_keys = [
"sk-1234567890abcdef",
"claude-abcdef1234567890"
]
for key in valid_keys:
assert DataValidator.validate_api_key(key) == True
def test_validate_api_key_invalid(self):
"""测试无效的API密钥"""
invalid_keys = [
"",
None,
"short",
"invalid_prefix_key",
123 # 非字符串
]
for key in invalid_keys:
assert DataValidator.validate_api_key(key) == False
def test_validate_task_plan_valid(self):
"""测试有效的任务计划"""
valid_plan = {
"main_task": "测试任务",
"subtasks": [
{
"id": "task_1",
"description": "子任务1"
}
]
}
assert DataValidator.validate_task_plan(valid_plan) == True
def test_validate_task_plan_invalid(self):
"""测试无效的任务计划"""
invalid_plans = [
{}, # 缺少必要键
{"main_task": "test"}, # 缺少subtasks
{"main_task": "test", "subtasks": "not_list"}, # subtasks不是列表
{"main_task": "test", "subtasks": [{"id": "1"}]}, # 子任务缺少描述
]
for plan in invalid_plans:
assert DataValidator.validate_task_plan(plan) == False
def test_validate_tool_call_valid(self):
"""测试有效的工具调用"""
valid_call = {
"name": "test_tool",
"arguments": {"param": "value"}
}
assert DataValidator.validate_tool_call(valid_call) == True
def test_validate_tool_call_invalid(self):
"""测试无效的工具调用"""
invalid_calls = [
{}, # 缺少必要键
{"name": "test"}, # 缺少arguments
{"arguments": {}} # 缺少name
]
for call in invalid_calls:
assert DataValidator.validate_tool_call(call) == False
class TestTextUtils:
"""TextUtils类的测试"""
def test_truncate_text_short(self):
"""测试截断短文本"""
text = "短文本"
result = TextUtils.truncate_text(text, max_length=100)
assert result == text
def test_truncate_text_long(self):
"""测试截断长文本"""
text = "这是一个很长的文本" * 10
result = TextUtils.truncate_text(text, max_length=20)
assert len(result) <= 20
assert result.endswith("...")
def test_clean_json_string(self):
"""测试清理JSON字符串"""
test_cases = [
("```json\n{\"test\": 1}\n```", '{"test": 1}'),
("```\n{\"test\": 2}\n```", '{"test": 2}'),
('{"test": 3}', '{"test": 3}'),
(" \n{\"test\": 4}\n ", '{"test": 4}')
]
for input_text, expected in test_cases:
result = TextUtils.clean_json_string(input_text)
assert result == expected
def test_extract_code_blocks(self):
"""测试提取代码块"""
text = """
一些文本
```python
print("hello")
```
更多文本
```javascript
console.log("world")
```
"""
# 提取Python代码块
python_blocks = TextUtils.extract_code_blocks(text, "python")
assert len(python_blocks) == 1
assert 'print("hello")' in python_blocks[0]
# 提取所有代码块
all_blocks = TextUtils.extract_code_blocks(text)
assert len(all_blocks) == 2
def test_format_error_message(self):
"""测试格式化错误消息"""
error = ValueError("测试错误")
# 不带上下文
formatted = TextUtils.format_error_message(error)
assert "ValueError: 测试错误" in formatted
# 带上下文
formatted_with_context = TextUtils.format_error_message(error, "测试上下文")
assert "[测试上下文]" in formatted_with_context
assert "ValueError: 测试错误" in formatted_with_context
class TestFileUtils:
"""FileUtils类的测试"""
def test_ensure_directory(self):
"""测试确保目录存在"""
with tempfile.TemporaryDirectory() as temp_dir:
test_path = Path(temp_dir) / "new_dir" / "sub_dir"
FileUtils.ensure_directory(str(test_path))
assert test_path.exists()
assert test_path.is_dir()
def test_read_write_json_file(self):
"""测试读写JSON文件"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
temp_path = f.name
try:
# 写入JSON
test_data = {"test": "data", "number": 42}
FileUtils.write_json_file(temp_path, test_data)
# 读取JSON
read_data = FileUtils.read_json_file(temp_path)
assert read_data == test_data
finally:
Path(temp_path).unlink()
def test_read_json_file_not_found(self):
"""测试读取不存在的JSON文件"""
result = FileUtils.read_json_file("nonexistent.json")
assert result is None
def test_get_project_root(self):
"""测试获取项目根目录"""
with patch('claude_agent.utils.helpers.Path') as mock_path:
mock_file_path = Mock()
mock_file_path.parent = Mock()
mock_path.__file__ = mock_file_path
mock_path.return_value = mock_file_path
# 模拟找到setup.py
mock_current = Mock()
mock_current.parent = Mock()
mock_setup_exists = Mock()
mock_setup_exists.exists.return_value = True
mock_current.__truediv__ = Mock(return_value=mock_setup_exists)
mock_file_path.parent = mock_current
mock_current.parent = mock_current # 防止无限循环
# 这个测试比较复杂,简化验证
result = FileUtils.get_project_root()
assert result is not None