blob: 695086a454e53c3ae4c778e244839c947e907313 [file] [log] [blame] [raw]
"""
持久化存储管理器单元测试
"""
import pytest
import json
import tempfile
import shutil
import time
from pathlib import Path
from unittest.mock import patch, Mock
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "src"))
from claude_agent.storage.persistence import PersistenceManager
class TestPersistenceManager:
"""持久化存储管理器测试"""
@pytest.fixture
def temp_storage_dir(self):
"""创建临时存储目录"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
@pytest.fixture
def persistence_manager(self, temp_storage_dir):
"""创建持久化管理器实例"""
return PersistenceManager(temp_storage_dir)
@pytest.fixture
def sample_conversation_history(self):
"""示例对话历史"""
return [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi! How can I help you?"},
{"role": "user", "content": "What's the weather like?"},
{"role": "assistant", "content": "I can help you check the weather."}
]
@pytest.fixture
def sample_agent_state(self):
"""示例Agent状态"""
return {
"thinking_mode": "interactive",
"conversation_history": [
{"role": "user", "content": "Test message"},
{"role": "assistant", "content": "Test response"}
],
"tasks": [],
"model": "claude-sonnet-4-20250514",
"created_at": int(time.time())
}
def test_init(self, temp_storage_dir):
"""测试初始化"""
manager = PersistenceManager(temp_storage_dir)
assert manager.storage_dir == Path(temp_storage_dir)
assert manager.storage_dir.exists()
assert manager.conversation_file == Path(temp_storage_dir) / "conversations.json"
assert manager.agents_file == Path(temp_storage_dir) / "agents.json"
def test_save_and_load_conversation_history(self, persistence_manager, sample_conversation_history):
"""测试保存和加载对话历史"""
chat_id = "test_chat_123"
# 测试保存
result = persistence_manager.save_conversation_history(chat_id, sample_conversation_history)
assert result is True
# 测试加载
loaded_history = persistence_manager.load_conversation_history(chat_id)
assert loaded_history == sample_conversation_history
def test_save_and_load_agent_state(self, persistence_manager, sample_agent_state):
"""测试保存和加载Agent状态"""
chat_id = "test_chat_456"
# 测试保存
result = persistence_manager.save_agent_state(chat_id, sample_agent_state)
assert result is True
# 测试加载
loaded_state = persistence_manager.load_agent_state(chat_id)
assert loaded_state is not None
assert loaded_state["thinking_mode"] == sample_agent_state["thinking_mode"]
assert loaded_state["conversation_history"] == sample_agent_state["conversation_history"]
assert loaded_state["model"] == sample_agent_state["model"]
def test_load_nonexistent_conversation(self, persistence_manager):
"""测试加载不存在的对话历史"""
result = persistence_manager.load_conversation_history("nonexistent_chat")
assert result == []
def test_load_nonexistent_agent_state(self, persistence_manager):
"""测试加载不存在的Agent状态"""
result = persistence_manager.load_agent_state("nonexistent_chat")
assert result is None
def test_get_all_chat_ids(self, persistence_manager, sample_conversation_history):
"""测试获取所有聊天ID"""
chat_ids = ["chat_1", "chat_2", "chat_3"]
# 保存多个对话
for chat_id in chat_ids:
persistence_manager.save_conversation_history(chat_id, sample_conversation_history)
# 获取所有聊天ID
result = persistence_manager.get_all_chat_ids()
assert set(result) == set(chat_ids)
def test_conversation_metadata(self, persistence_manager, sample_conversation_history):
"""测试对话元数据"""
chat_id = "test_chat_metadata"
persistence_manager.save_conversation_history(chat_id, sample_conversation_history)
# 检查元数据
conversations = persistence_manager._load_json_file(persistence_manager.conversation_file)
chat_data = conversations[chat_id]
assert chat_data["message_count"] == len(sample_conversation_history)
assert "last_updated" in chat_data
assert isinstance(chat_data["last_updated"], int)
def test_agent_state_metadata(self, persistence_manager, sample_agent_state):
"""测试Agent状态元数据"""
chat_id = "test_agent_metadata"
persistence_manager.save_agent_state(chat_id, sample_agent_state)
# 检查元数据
agents = persistence_manager._load_json_file(persistence_manager.agents_file)
agent_data = agents[chat_id]
assert "last_updated" in agent_data
assert isinstance(agent_data["last_updated"], int)
def test_cleanup_old_data(self, persistence_manager, sample_conversation_history, sample_agent_state):
"""测试清理旧数据"""
# 创建一些旧数据
old_time = int(time.time()) - (31 * 24 * 3600) # 31天前
# 手动创建旧数据
old_conversations = {
"old_chat": {
"history": sample_conversation_history,
"last_updated": old_time,
"message_count": len(sample_conversation_history)
},
"new_chat": {
"history": sample_conversation_history,
"last_updated": int(time.time()),
"message_count": len(sample_conversation_history)
}
}
old_agents = {
"old_chat": {
**sample_agent_state,
"last_updated": old_time
},
"new_chat": {
**sample_agent_state,
"last_updated": int(time.time())
}
}
# 保存旧数据
persistence_manager._save_json_file(persistence_manager.conversation_file, old_conversations)
persistence_manager._save_json_file(persistence_manager.agents_file, old_agents)
# 清理30天以上的数据
cleanup_count = persistence_manager.cleanup_old_data(30)
# 验证清理结果
assert cleanup_count > 0
# 检查剩余数据
remaining_conversations = persistence_manager._load_json_file(persistence_manager.conversation_file)
remaining_agents = persistence_manager._load_json_file(persistence_manager.agents_file)
assert "old_chat" not in remaining_conversations
assert "new_chat" in remaining_conversations
assert "old_chat" not in remaining_agents
assert "new_chat" in remaining_agents
def test_get_storage_stats(self, persistence_manager, sample_conversation_history, sample_agent_state):
"""测试获取存储统计"""
# 添加一些数据
persistence_manager.save_conversation_history("chat1", sample_conversation_history)
persistence_manager.save_conversation_history("chat2", sample_conversation_history)
persistence_manager.save_agent_state("chat1", sample_agent_state)
stats = persistence_manager.get_storage_stats()
assert "total_chats" in stats
assert "total_messages" in stats
assert "total_agents" in stats
assert "storage_dir" in stats
assert "file_sizes" in stats
assert stats["total_chats"] == 2
assert stats["total_messages"] == len(sample_conversation_history) * 2
assert stats["total_agents"] == 1
def test_concurrent_access(self, persistence_manager, sample_conversation_history):
"""测试并发访问"""
import threading
results = []
def save_conversation(chat_id):
result = persistence_manager.save_conversation_history(chat_id, sample_conversation_history)
results.append(result)
# 创建多个线程同时保存
threads = []
for i in range(5):
thread = threading.Thread(target=save_conversation, args=[f"chat_{i}"])
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
# 验证所有保存都成功
assert all(results)
assert len(results) == 5
def test_json_file_corruption_handling(self, persistence_manager):
"""测试JSON文件损坏处理"""
# 创建一个损坏的JSON文件
with open(persistence_manager.conversation_file, 'w') as f:
f.write("invalid json content {")
# 尝试加载,应该返回空列表而不是抛出异常
result = persistence_manager.load_conversation_history("any_chat")
assert result == []
def test_file_permission_error_handling(self, persistence_manager, sample_conversation_history):
"""测试文件权限错误处理"""
# 模拟文件权限错误
with patch('builtins.open', side_effect=PermissionError("Permission denied")):
result = persistence_manager.save_conversation_history("test", sample_conversation_history)
assert result is False
def test_atomic_file_writing(self, persistence_manager, sample_conversation_history):
"""测试原子性文件写入"""
chat_id = "atomic_test"
# 第一次保存
result1 = persistence_manager.save_conversation_history(chat_id, sample_conversation_history)
assert result1 is True
# 验证没有临时文件残留
temp_file = persistence_manager.conversation_file.with_suffix('.tmp')
assert not temp_file.exists()
# 验证数据正确保存
loaded = persistence_manager.load_conversation_history(chat_id)
assert loaded == sample_conversation_history
@patch('claude_agent.storage.persistence.logger')
def test_logging(self, mock_logger, persistence_manager, sample_conversation_history):
"""测试日志记录"""
# 成功操作应该记录debug日志
persistence_manager.save_conversation_history("test", sample_conversation_history)
persistence_manager.load_conversation_history("test")
# 验证日志调用 - 主要是debug级别
assert mock_logger.debug.called
def test_empty_conversation_history(self, persistence_manager):
"""测试空对话历史"""
chat_id = "empty_test"
empty_history = []
result = persistence_manager.save_conversation_history(chat_id, empty_history)
assert result is True
loaded = persistence_manager.load_conversation_history(chat_id)
assert loaded == empty_history
def test_large_conversation_history(self, persistence_manager):
"""测试大量对话历史"""
chat_id = "large_test"
large_history = []
# 创建1000条消息的对话历史
for i in range(1000):
large_history.extend([
{"role": "user", "content": f"Message {i}"},
{"role": "assistant", "content": f"Response {i}"}
])
result = persistence_manager.save_conversation_history(chat_id, large_history)
assert result is True
loaded = persistence_manager.load_conversation_history(chat_id)
assert len(loaded) == len(large_history)
assert loaded == large_history
def test_unicode_content(self, persistence_manager):
"""测试Unicode内容处理"""
chat_id = "unicode_test"
unicode_history = [
{"role": "user", "content": "你好,世界!🌍"},
{"role": "assistant", "content": "Hello, 世界! こんにちは 🎌"},
{"role": "user", "content": "Émojis: 🚀🎉🔥💡"}
]
result = persistence_manager.save_conversation_history(chat_id, unicode_history)
assert result is True
loaded = persistence_manager.load_conversation_history(chat_id)
assert loaded == unicode_history
def test_get_file_size(self, persistence_manager, sample_conversation_history):
"""测试文件大小获取"""
# 保存一些数据
persistence_manager.save_conversation_history("test", sample_conversation_history)
# 获取统计信息
stats = persistence_manager.get_storage_stats()
# 验证文件大小信息
assert "file_sizes" in stats
assert "conversations" in stats["file_sizes"]
assert "agents" in stats["file_sizes"]
assert stats["file_sizes"]["conversations"] > 0
def test_multiple_updates_same_chat(self, persistence_manager):
"""测试同一聊天的多次更新"""
chat_id = "update_test"
# 第一次保存
history1 = [{"role": "user", "content": "First message"}]
persistence_manager.save_conversation_history(chat_id, history1)
# 第二次保存(更新)
history2 = [
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "First response"},
{"role": "user", "content": "Second message"}
]
persistence_manager.save_conversation_history(chat_id, history2)
# 验证最新数据
loaded = persistence_manager.load_conversation_history(chat_id)
assert loaded == history2
assert len(loaded) == 3