| """ |
| 核心Agent模块补充测试 |
| 提高core/agent.py的测试覆盖率 |
| V2.2 重构 - 基于 claude-agent-sdk 的简化 API |
| """ |
| |
| import pytest |
| from unittest.mock import Mock, AsyncMock, patch, MagicMock |
| import asyncio |
| from datetime import datetime |
| import json |
| |
| import sys |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "src")) |
| |
| from claude_agent.core.agent import AgentCore, ThinkingMode |
| |
| |
| class TestAgentCoreExtended: |
| """扩展的Agent Core测试""" |
| |
| @pytest.fixture |
| def agent(self): |
| """创建Agent实例""" |
| return AgentCore() |
| |
| def test_agent_initialization(self, agent): |
| """测试Agent初始化""" |
| assert agent.thinking_mode is not None |
| assert agent.conversation_history == [] |
| assert agent.max_history == 100 |
| assert agent.mcp_tools == {} |
| |
| def test_set_thinking_mode(self, agent): |
| """测试设置思考模式""" |
| agent.set_thinking_mode(ThinkingMode.YOLO) |
| assert agent.thinking_mode == ThinkingMode.YOLO |
| |
| agent.set_thinking_mode(ThinkingMode.INTERACTIVE) |
| assert agent.thinking_mode == ThinkingMode.INTERACTIVE |
| |
| def test_clear_history(self, agent): |
| """测试清空历史""" |
| # 手动添加一些历史记录 |
| agent.conversation_history.append({"role": "user", "content": "Hello"}) |
| agent.conversation_history.append({"role": "assistant", "content": "Hi there!"}) |
| assert len(agent.conversation_history) == 2 |
| |
| agent.clear_history() |
| assert len(agent.conversation_history) == 0 |
| |
| def test_get_conversation_history(self, agent): |
| """测试获取对话历史""" |
| # 手动添加历史记录 |
| agent.conversation_history.append({"role": "user", "content": "Hello"}) |
| agent.conversation_history.append({"role": "assistant", "content": "Hi there!"}) |
| |
| history = agent.get_conversation_history() |
| assert len(history) == 2 |
| assert history[0]["role"] == "user" |
| assert history[1]["role"] == "assistant" |
| # 验证返回的是副本 |
| assert history is not agent.conversation_history |
| |
| @pytest.mark.asyncio |
| async def test_process_user_input_interactive(self, agent): |
| """测试交互模式的用户输入处理""" |
| agent.set_thinking_mode(ThinkingMode.INTERACTIVE) |
| |
| with patch.object(agent, '_interactive_process', return_value="Hello! How can I help you?"): |
| response = await agent.process_user_input("Hello") |
| |
| assert isinstance(response, str) |
| assert len(response) > 0 |
| # 验证消息被添加到历史中 |
| assert len(agent.conversation_history) >= 1 # 至少有用户消息 |
| assert agent.conversation_history[0]["role"] == "user" |
| assert agent.conversation_history[0]["content"] == "Hello" |
| |
| @pytest.mark.asyncio |
| async def test_process_user_input_yolo_mode(self, agent): |
| """测试YOLO模式的用户输入处理""" |
| agent.set_thinking_mode(ThinkingMode.YOLO) |
| |
| with patch.object(agent, '_yolo_process', return_value="YOLO response"): |
| response = await agent.process_user_input("Solve this complex problem") |
| |
| assert isinstance(response, str) |
| assert len(response) > 0 |
| # 验证消息被添加到历史 |
| assert len(agent.conversation_history) >= 1 |
| assert agent.conversation_history[0]["role"] == "user" |
| |
| def test_add_mcp_tool(self, agent): |
| """测试添加MCP工具""" |
| mock_tool = Mock() |
| tool_name = "test_tool" |
| |
| agent.add_mcp_tool(tool_name, mock_tool) |
| |
| assert tool_name in agent.mcp_tools |
| assert agent.mcp_tools[tool_name] == mock_tool |
| |
| def test_clean_conversation_history(self, agent): |
| """测试清理对话历史""" |
| # 添加一些正常消息和污染消息 |
| agent.conversation_history.extend([ |
| {"role": "user", "content": "Normal message"}, |
| {"role": "assistant", "content": "A" * 10000}, # 巨大消息 |
| {"role": "user", "content": "Another normal message"} |
| ]) |
| |
| original_count = len(agent.conversation_history) |
| cleaned_count = agent.clean_conversation_history() |
| |
| assert cleaned_count > 0 # 应该清理了一些消息 |
| assert len(agent.conversation_history) < original_count |
| |
| def test_trim_conversation_history(self, agent): |
| """测试修剪对话历史""" |
| # 设置较小的最大历史限制 |
| agent.max_history = 5 |
| |
| # 添加超过限制的消息 |
| for i in range(10): |
| agent.conversation_history.append({"role": "user", "content": f"Message {i}"}) |
| |
| agent.trim_conversation_history() |
| |
| # 验证历史被修剪到限制内 |
| assert len(agent.conversation_history) <= 5 |
| # 验证保留的是最新的消息 |
| assert agent.conversation_history[-1]["content"] == "Message 9" |
| |
| @pytest.mark.asyncio |
| async def test_create_streaming_response(self, agent): |
| """测试创建流式响应""" |
| with patch.object(agent, '_interactive_process_for_streaming', return_value="Complete response"): |
| response_generator = agent.create_streaming_response("Hello") |
| |
| # 验证返回的是异步生成器 |
| assert hasattr(response_generator, '__aiter__') |
| |
| @pytest.mark.asyncio |
| async def test_interactive_process_error_handling(self, agent): |
| """测试交互模式的错误处理""" |
| with patch.object(agent, '_make_query_call', side_effect=Exception("API Error")): |
| response = await agent._interactive_process("Hello") |
| |
| # 应该返回备用响应或默认响应 |
| assert isinstance(response, str) |
| assert len(response) > 0 |
| |
| def test_fallback_response(self, agent): |
| """测试备用响应生成""" |
| error_details = "API调用失败" |
| user_input = "Hello" |
| |
| response = agent._get_fallback_response(user_input, error_details) |
| |
| assert isinstance(response, str) |
| assert len(response) > 0 |
| assert "抱歉" in response or "暂时" in response or "技术问题" in response |
| |
| def test_to_dict_serialization(self, agent): |
| """测试Agent状态序列化""" |
| # 添加一些状态 |
| agent.conversation_history.append({"role": "user", "content": "Hello"}) |
| agent.set_thinking_mode(ThinkingMode.YOLO) |
| |
| serialized = agent.to_dict() |
| |
| assert "thinking_mode" in serialized |
| assert "conversation_history" in serialized |
| assert "model" in serialized |
| assert serialized["thinking_mode"] == "yolo" |
| assert len(serialized["conversation_history"]) == 1 |
| |
| def test_from_dict_deserialization(self, agent): |
| """测试从字典反序列化Agent状态""" |
| # 创建测试数据 |
| agent_data = { |
| "thinking_mode": "yolo", |
| "conversation_history": [{"role": "user", "content": "Hello"}], |
| "model": "claude-sonnet-4-5-20250929", |
| "created_at": 1234567890, |
| "tasks": [] |
| } |
| |
| # 反序列化 |
| restored_agent = AgentCore.from_dict(agent_data) |
| |
| assert restored_agent.thinking_mode == ThinkingMode.YOLO |
| assert len(restored_agent.conversation_history) == 1 |
| assert restored_agent.conversation_history[0]["content"] == "Hello" |
| |
| def test_get_memory_summary(self, agent): |
| """测试获取记忆摘要""" |
| # 添加一些对话历史 |
| agent.conversation_history.extend([ |
| {"role": "user", "content": "Hello"}, |
| {"role": "assistant", "content": "Hi there!"}, |
| {"role": "user", "content": "How are you?"}, |
| {"role": "assistant", "content": "I'm doing well!"} |
| ]) |
| |
| summary = agent.get_memory_summary() |
| |
| assert "conversation_count" in summary |
| assert "thinking_mode" in summary |
| assert summary["conversation_count"] == 4 |
| assert summary["thinking_mode"] == "interactive" # 默认模式 |
| |
| @pytest.mark.asyncio |
| async def test_yolo_process_error_handling(self, agent): |
| """测试YOLO模式的错误处理""" |
| with patch('claude_agent.core.agent.query', side_effect=Exception("YOLO Error")): |
| response = await agent._yolo_process("Complex task") |
| |
| # 应该有合理的错误处理响应 |
| assert isinstance(response, str) |
| assert len(response) > 0 |
| |
| def test_conversation_history_management(self, agent): |
| """测试对话历史管理功能""" |
| # 测试添加短消息 |
| for i in range(5): |
| agent.conversation_history.append({"role": "user", "content": f"Short message {i}"}) |
| |
| assert len(agent.conversation_history) == 5 |
| |
| # 测试trim功能 |
| agent.max_history = 3 |
| removed = agent.trim_conversation_history() |
| assert removed == 2 |
| assert len(agent.conversation_history) == 3 |
| |
| def test_clean_polluted_history(self, agent): |
| """测试清理污染历史记录""" |
| # 添加正常消息和污染消息 |
| agent.conversation_history.extend([ |
| {"role": "user", "content": "Normal message"}, |
| {"role": "assistant", "content": "CLAUDE.md contains " + "B" * 3000}, # 包含CLAUDE.md |
| {"role": "user", "content": "Another normal message"}, |
| {"role": "assistant", "content": "x" * 6000} # 超长消息 |
| ]) |
| |
| original_count = len(agent.conversation_history) |
| cleaned_count = agent.clean_conversation_history() |
| |
| assert cleaned_count >= 2 # 应该清理至少2条污染消息 |
| assert len(agent.conversation_history) < original_count |
| |
| def test_task_dataclass_functionality(self): |
| """测试Task数据类功能""" |
| from claude_agent.core.agent import Task |
| |
| # 测试基本Task创建 |
| task = Task("test_id", "Test description") |
| assert task.id == "test_id" |
| assert task.description == "Test description" |
| assert task.status == "pending" |
| assert task.result is None |
| assert task.subtasks == [] |
| |
| # 测试带子任务的Task |
| subtask = Task("sub_id", "Subtask description") |
| task_with_subtasks = Task("parent_id", "Parent task", subtasks=[subtask]) |
| assert len(task_with_subtasks.subtasks) == 1 |
| assert task_with_subtasks.subtasks[0].id == "sub_id" |
| |
| def test_thinking_mode_enum(self): |
| """测试ThinkingMode枚举""" |
| assert ThinkingMode.INTERACTIVE.value == "interactive" |
| assert ThinkingMode.YOLO.value == "yolo" |
| |
| @pytest.mark.asyncio |
| async def test_full_streaming_workflow(self, agent): |
| """测试完整的流式工作流程""" |
| with patch.object(agent, '_interactive_process_for_streaming', return_value="Mocked streaming response"): |
| # 测试流式响应生成器 |
| async_generator = agent.create_streaming_response("Test input") |
| |
| # 验证是异步生成器 |
| assert hasattr(async_generator, '__aiter__') |
| |
| def test_edge_cases_and_error_conditions(self, agent): |
| """测试边界条件和错误情况""" |
| # 测试空历史记录的清理 |
| agent.conversation_history = [] |
| cleaned = agent.clean_conversation_history() |
| assert cleaned == 0 |
| |
| # 测试空历史记录的修剪 |
| trimmed = agent.trim_conversation_history() |
| assert trimmed == 0 |
| |
| # 测试获取空历史的摘要 |
| summary = agent.get_memory_summary() |
| assert summary["conversation_count"] == 0 |
| |
| # 测试序列化空Agent |
| serialized = agent.to_dict() |
| assert len(serialized["conversation_history"]) == 0 |
| |
| def test_mcp_tools_management(self, agent): |
| """测试MCP工具管理""" |
| # 测试添加多个工具 |
| tool1 = Mock() |
| tool2 = Mock() |
| |
| agent.add_mcp_tool("tool1", tool1) |
| agent.add_mcp_tool("tool2", tool2) |
| |
| assert len(agent.mcp_tools) == 2 |
| assert agent.mcp_tools["tool1"] == tool1 |
| assert agent.mcp_tools["tool2"] == tool2 |
| |
| # 测试工具覆盖 |
| new_tool1 = Mock() |
| agent.add_mcp_tool("tool1", new_tool1) |
| assert agent.mcp_tools["tool1"] == new_tool1 |
| |
| |
| class TestAgentStreamingFeatures: |
| """Agent流式处理功能测试""" |
| |
| @pytest.fixture |
| def agent(self): |
| """创建Agent实例""" |
| return AgentCore() |
| |
| @pytest.mark.asyncio |
| async def test_create_streaming_response_chunks(self, agent): |
| """测试流式响应的分块处理""" |
| # Mock完整响应 |
| long_response = "这是一个很长的响应。\n" * 50 + "包含多行内容\n" + "#标题\n" + "```python\ncode\n```" |
| |
| with patch.object(agent, '_interactive_process_for_streaming', return_value=long_response): |
| chunks = [] |
| async for chunk in agent.create_streaming_response("测试输入"): |
| chunks.append(chunk) |
| |
| # 验证分块 |
| assert len(chunks) > 1 |
| # 验证所有块合并后包含原始响应内容 |
| combined = "\n".join(chunks) |
| assert "这是一个很长的响应" in combined |
| assert "#标题" in combined |
| |
| @pytest.mark.asyncio |
| async def test_create_streaming_response_markdown_preservation(self, agent): |
| """测试流式响应保持Markdown格式完整性""" |
| markdown_response = """# 标题 |
| 这是正常文本 |
| ```python |
| def hello(): |
| print("Hello") |
| ``` |
| - 列表项1 |
| - 列表项2 |
| > 引用内容 |
| """ |
| |
| with patch.object(agent, '_interactive_process_for_streaming', return_value=markdown_response): |
| chunks = [] |
| async for chunk in agent.create_streaming_response("Markdown测试"): |
| chunks.append(chunk) |
| |
| combined = "\n".join(chunks) |
| # 验证Markdown元素被保留 |
| assert "# 标题" in combined |
| assert "```python" in combined |
| assert "- 列表项1" in combined |
| assert "> 引用内容" in combined |
| |
| @pytest.mark.asyncio |
| async def test_create_streaming_response_error_handling(self, agent): |
| """测试流式响应错误处理""" |
| with patch.object(agent, '_interactive_process_for_streaming', side_effect=Exception("Stream error")): |
| chunks = [] |
| async for chunk in agent.create_streaming_response("错误测试"): |
| chunks.append(chunk) |
| |
| # 验证错误被处理 |
| assert len(chunks) == 1 |
| assert "错误" in chunks[0] |
| |
| @pytest.mark.asyncio |
| async def test_interactive_process_for_streaming_success(self, agent): |
| """测试交互模式流式处理成功场景""" |
| with patch.object(agent, '_make_query_call') as mock_query: |
| async def mock_stream(): |
| yield Mock() |
| mock_query.return_value = mock_stream() |
| |
| with patch.object(agent, '_extract_response_text', return_value="Hello World!"): |
| result = await agent._interactive_process_for_streaming("测试") |
| assert result == "Hello World!" |
| |
| @pytest.mark.asyncio |
| async def test_interactive_process_for_streaming_timeout(self, agent): |
| """测试交互模式流式处理超时""" |
| with patch('asyncio.wait_for', side_effect=asyncio.TimeoutError()): |
| result = await agent._interactive_process_for_streaming("测试") |
| assert "超时" in result or "暂时" in result or "技术问题" in result |
| |
| @pytest.mark.asyncio |
| async def test_interactive_process_for_streaming_retry(self, agent): |
| """测试交互模式流式处理重试机制""" |
| # 第一次失败,第二次成功 |
| call_count = 0 |
| async def mock_failing_then_success(*args, **kwargs): |
| nonlocal call_count |
| call_count += 1 |
| if call_count == 1: |
| raise Exception("First attempt fails") |
| |
| async def success_stream(): |
| yield Mock() |
| |
| return success_stream() |
| |
| with patch.object(agent, '_make_query_call', side_effect=mock_failing_then_success): |
| with patch.object(agent, '_extract_response_text', return_value="Success on retry"): |
| result = await agent._interactive_process_for_streaming("测试") |
| assert "Success on retry" in result |
| |
| @pytest.mark.asyncio |
| async def test_yolo_process_for_streaming(self, agent): |
| """测试YOLO模式流式处理""" |
| with patch.object(agent, '_extract_response_text', return_value="YOLO流式响应测试结果"): |
| with patch('claude_agent.core.agent.query') as mock_query: |
| async def mock_stream(): |
| yield Mock() |
| mock_query.return_value = mock_stream() |
| |
| result = await agent._yolo_process_for_streaming("复杂任务") |
| assert "YOLO流式响应测试结果" in result |
| |
| @pytest.mark.asyncio |
| async def test_make_query_call(self, agent): |
| """测试查询调用""" |
| with patch('claude_agent.core.agent.query') as mock_query: |
| async def mock_response(): |
| yield Mock() |
| |
| mock_query.return_value = mock_response() |
| |
| result_stream = await agent._make_query_call("test input") |
| messages = [] |
| async for msg in result_stream: |
| messages.append(msg) |
| |
| assert len(messages) == 1 |
| |
| |
| class TestStreamInputBuffer: |
| """StreamInputBuffer测试""" |
| |
| def test_stream_input_buffer_creation(self): |
| """测试StreamInputBuffer创建""" |
| from claude_agent.core.agent import StreamInputBuffer |
| |
| buffer = StreamInputBuffer() |
| assert buffer.buffer == "" |
| assert buffer.last_update_time == 0.0 |
| assert buffer.is_complete is False |
| assert buffer.completion_signal is None |
| assert buffer.timeout_seconds == 3.0 |
| assert buffer.max_buffer_size == 10000 |
| assert buffer.delimiter == "\n\n" |
| |
| def test_stream_input_buffer_with_params(self): |
| """测试带参数的StreamInputBuffer创建""" |
| from claude_agent.core.agent import StreamInputBuffer |
| |
| buffer = StreamInputBuffer( |
| buffer="initial", |
| timeout_seconds=5.0, |
| max_buffer_size=5000, |
| delimiter="---" |
| ) |
| assert buffer.buffer == "initial" |
| assert buffer.timeout_seconds == 5.0 |
| assert buffer.max_buffer_size == 5000 |
| assert buffer.delimiter == "---" |
| |
| def test_input_completion_signal_enum(self): |
| """测试InputCompletionSignal枚举""" |
| from claude_agent.core.agent import InputCompletionSignal |
| |
| assert InputCompletionSignal.TIMEOUT.value == "timeout" |
| assert InputCompletionSignal.DELIMITER.value == "delimiter" |
| assert InputCompletionSignal.EXPLICIT.value == "explicit" |
| assert InputCompletionSignal.MAX_LENGTH.value == "max_length" |
| |
| |
| class TestAgentAdvancedFeatures: |
| """Agent高级功能测试""" |
| |
| @pytest.fixture |
| def agent(self): |
| """创建Agent实例""" |
| return AgentCore() |
| |
| def test_agent_initialization_with_custom_params(self): |
| """测试使用自定义参数初始化Agent""" |
| agent = AgentCore( |
| api_key="test_key", |
| model="custom-model", |
| max_history=50 |
| ) |
| assert agent.model == "custom-model" |
| assert agent.max_history == 50 |
| |
| @pytest.mark.asyncio |
| async def test_yolo_process_analysis_stage(self, agent): |
| """测试YOLO处理""" |
| with patch.object(agent, '_extract_response_text', return_value="这是一个复杂的需求分析"): |
| with patch('claude_agent.core.agent.query') as mock_query: |
| async def mock_stream(): |
| yield Mock() |
| mock_query.return_value = mock_stream() |
| |
| result = await agent._yolo_process("复杂任务请求") |
| assert "这是一个复杂的需求分析" in result |
| |
| @pytest.mark.asyncio |
| async def test_yolo_process_degradation_to_interactive(self, agent): |
| """测试YOLO处理降级到交互模式""" |
| # Mock YOLO处理失败,期望返回fallback响应 |
| with patch('claude_agent.core.agent.query', side_effect=Exception("YOLO processing failed")): |
| result = await agent._yolo_process("测试任务") |
| # YOLO失败后会降级到交互模式,最终也会失败,返回fallback响应 |
| assert "技术问题" in result or "失败" in result |
| |
| def test_stream_input_buffer_integration(self, agent): |
| """测试StreamInputBuffer与Agent的集成""" |
| from claude_agent.core.agent import StreamInputBuffer, InputCompletionSignal |
| |
| # 测试Agent可以使用StreamInputBuffer |
| buffer = StreamInputBuffer( |
| buffer="测试输入", |
| timeout_seconds=5.0, |
| is_complete=True, |
| completion_signal=InputCompletionSignal.DELIMITER |
| ) |
| |
| agent.stream_input_buffer = buffer |
| assert agent.stream_input_buffer.buffer == "测试输入" |
| assert agent.stream_input_buffer.is_complete is True |
| assert agent.stream_input_buffer.completion_signal == InputCompletionSignal.DELIMITER |
| |
| def test_fallback_response_variations(self, agent): |
| """测试各种fallback响应情况""" |
| # 测试API额度不足错误 |
| quota_errors = ["预扣费额度失败", "403", "剩余额度", "API Error", "/login"] |
| for error in quota_errors: |
| response = agent._get_fallback_response("测试", error) |
| assert "API服务暂时不可用" in response or "API配额不足" in response |
| |
| # 测试超时错误 |
| timeout_errors = ["超时", "timeout", "TIMEOUT"] |
| for error in timeout_errors: |
| response = agent._get_fallback_response("测试", error) |
| assert "请求超时" in response or "稍后重试" in response |
| |
| # 测试通用错误 |
| generic_error = "未知错误" |
| response = agent._get_fallback_response("测试", generic_error) |
| assert "技术问题" in response and generic_error in response |