blob: 03894786cff2bfcbd95ed730bf33fefe827b3526 [file] [log] [blame] [raw]
"""
核心Agent模块补充测试
提高core/agent.py的测试覆盖率
V2.2 重构 - 基于 claude-agent-sdk 的简化 API
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import asyncio
from datetime import datetime
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "src"))
from claude_agent.core.agent import AgentCore, ThinkingMode
class TestAgentCoreExtended:
"""扩展的Agent Core测试"""
@pytest.fixture
def agent(self):
"""创建Agent实例"""
return AgentCore()
def test_agent_initialization(self, agent):
"""测试Agent初始化"""
assert agent.thinking_mode is not None
assert agent.conversation_history == []
assert agent.max_history == 100
assert agent.mcp_tools == {}
def test_set_thinking_mode(self, agent):
"""测试设置思考模式"""
agent.set_thinking_mode(ThinkingMode.YOLO)
assert agent.thinking_mode == ThinkingMode.YOLO
agent.set_thinking_mode(ThinkingMode.INTERACTIVE)
assert agent.thinking_mode == ThinkingMode.INTERACTIVE
def test_clear_history(self, agent):
"""测试清空历史"""
# 手动添加一些历史记录
agent.conversation_history.append({"role": "user", "content": "Hello"})
agent.conversation_history.append({"role": "assistant", "content": "Hi there!"})
assert len(agent.conversation_history) == 2
agent.clear_history()
assert len(agent.conversation_history) == 0
def test_get_conversation_history(self, agent):
"""测试获取对话历史"""
# 手动添加历史记录
agent.conversation_history.append({"role": "user", "content": "Hello"})
agent.conversation_history.append({"role": "assistant", "content": "Hi there!"})
history = agent.get_conversation_history()
assert len(history) == 2
assert history[0]["role"] == "user"
assert history[1]["role"] == "assistant"
# 验证返回的是副本
assert history is not agent.conversation_history
@pytest.mark.asyncio
async def test_process_user_input_interactive(self, agent):
"""测试交互模式的用户输入处理"""
agent.set_thinking_mode(ThinkingMode.INTERACTIVE)
with patch.object(agent, '_interactive_process', return_value="Hello! How can I help you?"):
response = await agent.process_user_input("Hello")
assert isinstance(response, str)
assert len(response) > 0
# 验证消息被添加到历史中
assert len(agent.conversation_history) >= 1 # 至少有用户消息
assert agent.conversation_history[0]["role"] == "user"
assert agent.conversation_history[0]["content"] == "Hello"
@pytest.mark.asyncio
async def test_process_user_input_yolo_mode(self, agent):
"""测试YOLO模式的用户输入处理"""
agent.set_thinking_mode(ThinkingMode.YOLO)
with patch.object(agent, '_yolo_process', return_value="YOLO response"):
response = await agent.process_user_input("Solve this complex problem")
assert isinstance(response, str)
assert len(response) > 0
# 验证消息被添加到历史
assert len(agent.conversation_history) >= 1
assert agent.conversation_history[0]["role"] == "user"
def test_add_mcp_tool(self, agent):
"""测试添加MCP工具"""
mock_tool = Mock()
tool_name = "test_tool"
agent.add_mcp_tool(tool_name, mock_tool)
assert tool_name in agent.mcp_tools
assert agent.mcp_tools[tool_name] == mock_tool
def test_clean_conversation_history(self, agent):
"""测试清理对话历史"""
# 添加一些正常消息和污染消息
agent.conversation_history.extend([
{"role": "user", "content": "Normal message"},
{"role": "assistant", "content": "A" * 10000}, # 巨大消息
{"role": "user", "content": "Another normal message"}
])
original_count = len(agent.conversation_history)
cleaned_count = agent.clean_conversation_history()
assert cleaned_count > 0 # 应该清理了一些消息
assert len(agent.conversation_history) < original_count
def test_trim_conversation_history(self, agent):
"""测试修剪对话历史"""
# 设置较小的最大历史限制
agent.max_history = 5
# 添加超过限制的消息
for i in range(10):
agent.conversation_history.append({"role": "user", "content": f"Message {i}"})
agent.trim_conversation_history()
# 验证历史被修剪到限制内
assert len(agent.conversation_history) <= 5
# 验证保留的是最新的消息
assert agent.conversation_history[-1]["content"] == "Message 9"
@pytest.mark.asyncio
async def test_create_streaming_response(self, agent):
"""测试创建流式响应"""
with patch.object(agent, '_interactive_process_for_streaming', return_value="Complete response"):
response_generator = agent.create_streaming_response("Hello")
# 验证返回的是异步生成器
assert hasattr(response_generator, '__aiter__')
@pytest.mark.asyncio
async def test_interactive_process_error_handling(self, agent):
"""测试交互模式的错误处理"""
with patch.object(agent, '_make_query_call', side_effect=Exception("API Error")):
response = await agent._interactive_process("Hello")
# 应该返回备用响应或默认响应
assert isinstance(response, str)
assert len(response) > 0
def test_fallback_response(self, agent):
"""测试备用响应生成"""
error_details = "API调用失败"
user_input = "Hello"
response = agent._get_fallback_response(user_input, error_details)
assert isinstance(response, str)
assert len(response) > 0
assert "抱歉" in response or "暂时" in response or "技术问题" in response
def test_to_dict_serialization(self, agent):
"""测试Agent状态序列化"""
# 添加一些状态
agent.conversation_history.append({"role": "user", "content": "Hello"})
agent.set_thinking_mode(ThinkingMode.YOLO)
serialized = agent.to_dict()
assert "thinking_mode" in serialized
assert "conversation_history" in serialized
assert "model" in serialized
assert serialized["thinking_mode"] == "yolo"
assert len(serialized["conversation_history"]) == 1
def test_from_dict_deserialization(self, agent):
"""测试从字典反序列化Agent状态"""
# 创建测试数据
agent_data = {
"thinking_mode": "yolo",
"conversation_history": [{"role": "user", "content": "Hello"}],
"model": "claude-sonnet-4-5-20250929",
"created_at": 1234567890,
"tasks": []
}
# 反序列化
restored_agent = AgentCore.from_dict(agent_data)
assert restored_agent.thinking_mode == ThinkingMode.YOLO
assert len(restored_agent.conversation_history) == 1
assert restored_agent.conversation_history[0]["content"] == "Hello"
def test_get_memory_summary(self, agent):
"""测试获取记忆摘要"""
# 添加一些对话历史
agent.conversation_history.extend([
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "I'm doing well!"}
])
summary = agent.get_memory_summary()
assert "conversation_count" in summary
assert "thinking_mode" in summary
assert summary["conversation_count"] == 4
assert summary["thinking_mode"] == "interactive" # 默认模式
@pytest.mark.asyncio
async def test_yolo_process_error_handling(self, agent):
"""测试YOLO模式的错误处理"""
with patch('claude_agent.core.agent.query', side_effect=Exception("YOLO Error")):
response = await agent._yolo_process("Complex task")
# 应该有合理的错误处理响应
assert isinstance(response, str)
assert len(response) > 0
def test_conversation_history_management(self, agent):
"""测试对话历史管理功能"""
# 测试添加短消息
for i in range(5):
agent.conversation_history.append({"role": "user", "content": f"Short message {i}"})
assert len(agent.conversation_history) == 5
# 测试trim功能
agent.max_history = 3
removed = agent.trim_conversation_history()
assert removed == 2
assert len(agent.conversation_history) == 3
def test_clean_polluted_history(self, agent):
"""测试清理污染历史记录"""
# 添加正常消息和污染消息
agent.conversation_history.extend([
{"role": "user", "content": "Normal message"},
{"role": "assistant", "content": "CLAUDE.md contains " + "B" * 3000}, # 包含CLAUDE.md
{"role": "user", "content": "Another normal message"},
{"role": "assistant", "content": "x" * 6000} # 超长消息
])
original_count = len(agent.conversation_history)
cleaned_count = agent.clean_conversation_history()
assert cleaned_count >= 2 # 应该清理至少2条污染消息
assert len(agent.conversation_history) < original_count
def test_task_dataclass_functionality(self):
"""测试Task数据类功能"""
from claude_agent.core.agent import Task
# 测试基本Task创建
task = Task("test_id", "Test description")
assert task.id == "test_id"
assert task.description == "Test description"
assert task.status == "pending"
assert task.result is None
assert task.subtasks == []
# 测试带子任务的Task
subtask = Task("sub_id", "Subtask description")
task_with_subtasks = Task("parent_id", "Parent task", subtasks=[subtask])
assert len(task_with_subtasks.subtasks) == 1
assert task_with_subtasks.subtasks[0].id == "sub_id"
def test_thinking_mode_enum(self):
"""测试ThinkingMode枚举"""
assert ThinkingMode.INTERACTIVE.value == "interactive"
assert ThinkingMode.YOLO.value == "yolo"
@pytest.mark.asyncio
async def test_full_streaming_workflow(self, agent):
"""测试完整的流式工作流程"""
with patch.object(agent, '_interactive_process_for_streaming', return_value="Mocked streaming response"):
# 测试流式响应生成器
async_generator = agent.create_streaming_response("Test input")
# 验证是异步生成器
assert hasattr(async_generator, '__aiter__')
def test_edge_cases_and_error_conditions(self, agent):
"""测试边界条件和错误情况"""
# 测试空历史记录的清理
agent.conversation_history = []
cleaned = agent.clean_conversation_history()
assert cleaned == 0
# 测试空历史记录的修剪
trimmed = agent.trim_conversation_history()
assert trimmed == 0
# 测试获取空历史的摘要
summary = agent.get_memory_summary()
assert summary["conversation_count"] == 0
# 测试序列化空Agent
serialized = agent.to_dict()
assert len(serialized["conversation_history"]) == 0
def test_mcp_tools_management(self, agent):
"""测试MCP工具管理"""
# 测试添加多个工具
tool1 = Mock()
tool2 = Mock()
agent.add_mcp_tool("tool1", tool1)
agent.add_mcp_tool("tool2", tool2)
assert len(agent.mcp_tools) == 2
assert agent.mcp_tools["tool1"] == tool1
assert agent.mcp_tools["tool2"] == tool2
# 测试工具覆盖
new_tool1 = Mock()
agent.add_mcp_tool("tool1", new_tool1)
assert agent.mcp_tools["tool1"] == new_tool1
class TestAgentStreamingFeatures:
"""Agent流式处理功能测试"""
@pytest.fixture
def agent(self):
"""创建Agent实例"""
return AgentCore()
@pytest.mark.asyncio
async def test_create_streaming_response_chunks(self, agent):
"""测试流式响应的分块处理"""
# Mock完整响应
long_response = "这是一个很长的响应。\n" * 50 + "包含多行内容\n" + "#标题\n" + "```python\ncode\n```"
with patch.object(agent, '_interactive_process_for_streaming', return_value=long_response):
chunks = []
async for chunk in agent.create_streaming_response("测试输入"):
chunks.append(chunk)
# 验证分块
assert len(chunks) > 1
# 验证所有块合并后包含原始响应内容
combined = "\n".join(chunks)
assert "这是一个很长的响应" in combined
assert "#标题" in combined
@pytest.mark.asyncio
async def test_create_streaming_response_markdown_preservation(self, agent):
"""测试流式响应保持Markdown格式完整性"""
markdown_response = """# 标题
这是正常文本
```python
def hello():
print("Hello")
```
- 列表项1
- 列表项2
> 引用内容
"""
with patch.object(agent, '_interactive_process_for_streaming', return_value=markdown_response):
chunks = []
async for chunk in agent.create_streaming_response("Markdown测试"):
chunks.append(chunk)
combined = "\n".join(chunks)
# 验证Markdown元素被保留
assert "# 标题" in combined
assert "```python" in combined
assert "- 列表项1" in combined
assert "> 引用内容" in combined
@pytest.mark.asyncio
async def test_create_streaming_response_error_handling(self, agent):
"""测试流式响应错误处理"""
with patch.object(agent, '_interactive_process_for_streaming', side_effect=Exception("Stream error")):
chunks = []
async for chunk in agent.create_streaming_response("错误测试"):
chunks.append(chunk)
# 验证错误被处理
assert len(chunks) == 1
assert "错误" in chunks[0]
@pytest.mark.asyncio
async def test_interactive_process_for_streaming_success(self, agent):
"""测试交互模式流式处理成功场景"""
with patch.object(agent, '_make_query_call') as mock_query:
async def mock_stream():
yield Mock()
mock_query.return_value = mock_stream()
with patch.object(agent, '_extract_response_text', return_value="Hello World!"):
result = await agent._interactive_process_for_streaming("测试")
assert result == "Hello World!"
@pytest.mark.asyncio
async def test_interactive_process_for_streaming_timeout(self, agent):
"""测试交互模式流式处理超时"""
with patch('asyncio.wait_for', side_effect=asyncio.TimeoutError()):
result = await agent._interactive_process_for_streaming("测试")
assert "超时" in result or "暂时" in result or "技术问题" in result
@pytest.mark.asyncio
async def test_interactive_process_for_streaming_retry(self, agent):
"""测试交互模式流式处理重试机制"""
# 第一次失败,第二次成功
call_count = 0
async def mock_failing_then_success(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("First attempt fails")
async def success_stream():
yield Mock()
return success_stream()
with patch.object(agent, '_make_query_call', side_effect=mock_failing_then_success):
with patch.object(agent, '_extract_response_text', return_value="Success on retry"):
result = await agent._interactive_process_for_streaming("测试")
assert "Success on retry" in result
@pytest.mark.asyncio
async def test_yolo_process_for_streaming(self, agent):
"""测试YOLO模式流式处理"""
with patch.object(agent, '_extract_response_text', return_value="YOLO流式响应测试结果"):
with patch('claude_agent.core.agent.query') as mock_query:
async def mock_stream():
yield Mock()
mock_query.return_value = mock_stream()
result = await agent._yolo_process_for_streaming("复杂任务")
assert "YOLO流式响应测试结果" in result
@pytest.mark.asyncio
async def test_make_query_call(self, agent):
"""测试查询调用"""
with patch('claude_agent.core.agent.query') as mock_query:
async def mock_response():
yield Mock()
mock_query.return_value = mock_response()
result_stream = await agent._make_query_call("test input")
messages = []
async for msg in result_stream:
messages.append(msg)
assert len(messages) == 1
class TestStreamInputBuffer:
"""StreamInputBuffer测试"""
def test_stream_input_buffer_creation(self):
"""测试StreamInputBuffer创建"""
from claude_agent.core.agent import StreamInputBuffer
buffer = StreamInputBuffer()
assert buffer.buffer == ""
assert buffer.last_update_time == 0.0
assert buffer.is_complete is False
assert buffer.completion_signal is None
assert buffer.timeout_seconds == 3.0
assert buffer.max_buffer_size == 10000
assert buffer.delimiter == "\n\n"
def test_stream_input_buffer_with_params(self):
"""测试带参数的StreamInputBuffer创建"""
from claude_agent.core.agent import StreamInputBuffer
buffer = StreamInputBuffer(
buffer="initial",
timeout_seconds=5.0,
max_buffer_size=5000,
delimiter="---"
)
assert buffer.buffer == "initial"
assert buffer.timeout_seconds == 5.0
assert buffer.max_buffer_size == 5000
assert buffer.delimiter == "---"
def test_input_completion_signal_enum(self):
"""测试InputCompletionSignal枚举"""
from claude_agent.core.agent import InputCompletionSignal
assert InputCompletionSignal.TIMEOUT.value == "timeout"
assert InputCompletionSignal.DELIMITER.value == "delimiter"
assert InputCompletionSignal.EXPLICIT.value == "explicit"
assert InputCompletionSignal.MAX_LENGTH.value == "max_length"
class TestAgentAdvancedFeatures:
"""Agent高级功能测试"""
@pytest.fixture
def agent(self):
"""创建Agent实例"""
return AgentCore()
def test_agent_initialization_with_custom_params(self):
"""测试使用自定义参数初始化Agent"""
agent = AgentCore(
api_key="test_key",
model="custom-model",
max_history=50
)
assert agent.model == "custom-model"
assert agent.max_history == 50
@pytest.mark.asyncio
async def test_yolo_process_analysis_stage(self, agent):
"""测试YOLO处理"""
with patch.object(agent, '_extract_response_text', return_value="这是一个复杂的需求分析"):
with patch('claude_agent.core.agent.query') as mock_query:
async def mock_stream():
yield Mock()
mock_query.return_value = mock_stream()
result = await agent._yolo_process("复杂任务请求")
assert "这是一个复杂的需求分析" in result
@pytest.mark.asyncio
async def test_yolo_process_degradation_to_interactive(self, agent):
"""测试YOLO处理降级到交互模式"""
# Mock YOLO处理失败,期望返回fallback响应
with patch('claude_agent.core.agent.query', side_effect=Exception("YOLO processing failed")):
result = await agent._yolo_process("测试任务")
# YOLO失败后会降级到交互模式,最终也会失败,返回fallback响应
assert "技术问题" in result or "失败" in result
def test_stream_input_buffer_integration(self, agent):
"""测试StreamInputBuffer与Agent的集成"""
from claude_agent.core.agent import StreamInputBuffer, InputCompletionSignal
# 测试Agent可以使用StreamInputBuffer
buffer = StreamInputBuffer(
buffer="测试输入",
timeout_seconds=5.0,
is_complete=True,
completion_signal=InputCompletionSignal.DELIMITER
)
agent.stream_input_buffer = buffer
assert agent.stream_input_buffer.buffer == "测试输入"
assert agent.stream_input_buffer.is_complete is True
assert agent.stream_input_buffer.completion_signal == InputCompletionSignal.DELIMITER
def test_fallback_response_variations(self, agent):
"""测试各种fallback响应情况"""
# 测试API额度不足错误
quota_errors = ["预扣费额度失败", "403", "剩余额度", "API Error", "/login"]
for error in quota_errors:
response = agent._get_fallback_response("测试", error)
assert "API服务暂时不可用" in response or "API配额不足" in response
# 测试超时错误
timeout_errors = ["超时", "timeout", "TIMEOUT"]
for error in timeout_errors:
response = agent._get_fallback_response("测试", error)
assert "请求超时" in response or "稍后重试" in response
# 测试通用错误
generic_error = "未知错误"
response = agent._get_fallback_response("测试", generic_error)
assert "技术问题" in response and generic_error in response