| """ |
| 上下文管理器 |
| 负责管理用户和群组的对话历史,确保上下文不串台 |
| """ |
| |
| import json |
| import logging |
| import threading |
| from collections import defaultdict, deque |
| from typing import Dict, List, Any, Union, Optional |
| from datetime import datetime |
| |
| from .interfaces import IContextManager |
| |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| class ContextManager(IContextManager): |
| """上下文管理器实现""" |
| |
| def __init__(self, max_history_per_chat: int = 50, storage_dir: str = None, bot_id: str = None): |
| """ |
| 初始化上下文管理器 |
| |
| Args: |
| max_history_per_chat: 每个聊天保留的最大消息数量 |
| storage_dir: 存储目录路径(保留兼容性,但不使用) |
| bot_id: Bot ID(保留兼容性,但不使用) |
| """ |
| self.max_history_per_chat = max_history_per_chat |
| self._contexts: Dict[str, deque] = defaultdict(lambda: deque(maxlen=self.max_history_per_chat)) |
| self._lock = threading.RLock() |
| |
| def add_message(self, chat_id: Union[int, str], user_id: int, message: str, is_bot: bool = False, user_info: dict = None): |
| """ |
| 添加消息到上下文 |
| |
| Args: |
| chat_id: 聊天ID |
| user_id: 用户ID |
| message: 消息内容 |
| is_bot: 是否为机器人消息 |
| user_info: 用户信息字典,包含username, first_name, last_name等 |
| """ |
| with self._lock: |
| chat_key = str(chat_id) |
| message_data = { |
| 'user_id': user_id, |
| 'message': message, |
| 'is_bot': is_bot, |
| 'timestamp': datetime.now().isoformat(), |
| 'chat_id': chat_id |
| } |
| |
| # 添加用户信息到消息数据中 |
| if user_info: |
| message_data['user_info'] = user_info |
| |
| self._contexts[chat_key].append(message_data) |
| |
| def get_context(self, chat_id: Union[int, str], limit: Optional[int] = 10) -> List[Dict[str, Any]]: |
| """ |
| 获取聊天上下文 |
| |
| Args: |
| chat_id: 聊天ID |
| limit: 获取的消息数量限制,None表示获取所有消息 |
| |
| Returns: |
| 消息列表,按时间顺序排列 |
| """ |
| with self._lock: |
| chat_key = str(chat_id) |
| context = list(self._contexts[chat_key]) |
| # 如果limit为None,返回所有消息 |
| if limit is None: |
| return context |
| # 返回最近的limit条消息 |
| return context[-limit:] if len(context) > limit else context |
| |
| def clear_context(self, chat_id: Union[int, str]): |
| """ |
| 清空指定聊天的上下文 |
| |
| Args: |
| chat_id: 聊天ID |
| """ |
| with self._lock: |
| chat_key = str(chat_id) |
| self._contexts[chat_key].clear() |
| |
| def get_chat_count(self) -> int: |
| """获取活跃聊天数量""" |
| with self._lock: |
| return len(self._contexts) |
| |
| def get_message_count(self, chat_id: Union[int, str]) -> int: |
| """获取指定聊天的消息数量""" |
| with self._lock: |
| chat_key = str(chat_id) |
| return len(self._contexts[chat_key]) |
| |
| def export_context(self, chat_id: Union[int, str]) -> str: |
| """ |
| 导出聊天上下文为JSON格式 |
| |
| Args: |
| chat_id: 聊天ID |
| |
| Returns: |
| JSON格式的上下文数据 |
| """ |
| with self._lock: |
| context = self.get_context(chat_id, limit=None) |
| return json.dumps(context, ensure_ascii=False, indent=2) |
| |
| def import_context(self, chat_id: Union[int, str], context_json: str): |
| """ |
| 从JSON导入聊天上下文 |
| |
| Args: |
| chat_id: 聊天ID |
| context_json: JSON格式的上下文数据 |
| """ |
| with self._lock: |
| try: |
| context_data = json.loads(context_json) |
| chat_key = str(chat_id) |
| self._contexts[chat_key].clear() |
| for message_data in context_data: |
| self._contexts[chat_key].append(message_data) |
| except (json.JSONDecodeError, TypeError) as e: |
| raise ValueError(f"无效的上下文JSON数据: {e}") |
| |
| def get_user_message_count(self, chat_id: Union[int, str], user_id: int) -> int: |
| """ |
| 获取指定用户在某个聊天中的消息数量 |
| |
| Args: |
| chat_id: 聊天ID |
| user_id: 用户ID |
| |
| Returns: |
| 用户消息数量 |
| """ |
| with self._lock: |
| context = self.get_context(chat_id, limit=None) |
| return sum(1 for msg in context if msg['user_id'] == user_id and not msg['is_bot']) |
| |
| def cleanup_old_chats(self, max_inactive_days: int = 7): |
| """ |
| 清理长时间不活跃的聊天记录 |
| |
| Args: |
| max_inactive_days: 最大不活跃天数 |
| """ |
| from datetime import timedelta |
| |
| cutoff_time = datetime.now() - timedelta(days=max_inactive_days) |
| |
| with self._lock: |
| chats_to_remove = [] |
| for chat_key, messages in self._contexts.items(): |
| if not messages: |
| chats_to_remove.append(chat_key) |
| continue |
| |
| # 检查最后一条消息的时间 |
| last_message = messages[-1] |
| try: |
| last_time = datetime.fromisoformat(last_message['timestamp']) |
| if last_time < cutoff_time: |
| chats_to_remove.append(chat_key) |
| except (ValueError, KeyError): |
| # 如果时间戳格式有问题,也删除 |
| chats_to_remove.append(chat_key) |
| |
| for chat_key in chats_to_remove: |
| del self._contexts[chat_key] |
| |
| if chats_to_remove: |
| logger.info(f"清理了 {len(chats_to_remove)} 个不活跃的聊天记录") |