blob: d2a4233080e64de280afe234561b2b83a8bdf5f1 [file] [log] [blame] [raw]
"""
上下文管理器
负责管理用户和群组的对话历史,确保上下文不串台
"""
import json
import logging
import threading
from collections import defaultdict, deque
from typing import Dict, List, Any, Union, Optional
from datetime import datetime
from .interfaces import IContextManager
logger = logging.getLogger(__name__)
class ContextManager(IContextManager):
"""上下文管理器实现"""
def __init__(self, max_history_per_chat: int = 50, storage_dir: str = None, bot_id: str = None):
"""
初始化上下文管理器
Args:
max_history_per_chat: 每个聊天保留的最大消息数量
storage_dir: 存储目录路径(保留兼容性,但不使用)
bot_id: Bot ID(保留兼容性,但不使用)
"""
self.max_history_per_chat = max_history_per_chat
self._contexts: Dict[str, deque] = defaultdict(lambda: deque(maxlen=self.max_history_per_chat))
self._lock = threading.RLock()
def add_message(self, chat_id: Union[int, str], user_id: int, message: str, is_bot: bool = False, user_info: dict = None):
"""
添加消息到上下文
Args:
chat_id: 聊天ID
user_id: 用户ID
message: 消息内容
is_bot: 是否为机器人消息
user_info: 用户信息字典,包含username, first_name, last_name等
"""
with self._lock:
chat_key = str(chat_id)
message_data = {
'user_id': user_id,
'message': message,
'is_bot': is_bot,
'timestamp': datetime.now().isoformat(),
'chat_id': chat_id
}
# 添加用户信息到消息数据中
if user_info:
message_data['user_info'] = user_info
self._contexts[chat_key].append(message_data)
def get_context(self, chat_id: Union[int, str], limit: Optional[int] = 10) -> List[Dict[str, Any]]:
"""
获取聊天上下文
Args:
chat_id: 聊天ID
limit: 获取的消息数量限制,None表示获取所有消息
Returns:
消息列表,按时间顺序排列
"""
with self._lock:
chat_key = str(chat_id)
context = list(self._contexts[chat_key])
# 如果limit为None,返回所有消息
if limit is None:
return context
# 返回最近的limit条消息
return context[-limit:] if len(context) > limit else context
def clear_context(self, chat_id: Union[int, str]):
"""
清空指定聊天的上下文
Args:
chat_id: 聊天ID
"""
with self._lock:
chat_key = str(chat_id)
self._contexts[chat_key].clear()
def get_chat_count(self) -> int:
"""获取活跃聊天数量"""
with self._lock:
return len(self._contexts)
def get_message_count(self, chat_id: Union[int, str]) -> int:
"""获取指定聊天的消息数量"""
with self._lock:
chat_key = str(chat_id)
return len(self._contexts[chat_key])
def export_context(self, chat_id: Union[int, str]) -> str:
"""
导出聊天上下文为JSON格式
Args:
chat_id: 聊天ID
Returns:
JSON格式的上下文数据
"""
with self._lock:
context = self.get_context(chat_id, limit=None)
return json.dumps(context, ensure_ascii=False, indent=2)
def import_context(self, chat_id: Union[int, str], context_json: str):
"""
从JSON导入聊天上下文
Args:
chat_id: 聊天ID
context_json: JSON格式的上下文数据
"""
with self._lock:
try:
context_data = json.loads(context_json)
chat_key = str(chat_id)
self._contexts[chat_key].clear()
for message_data in context_data:
self._contexts[chat_key].append(message_data)
except (json.JSONDecodeError, TypeError) as e:
raise ValueError(f"无效的上下文JSON数据: {e}")
def get_user_message_count(self, chat_id: Union[int, str], user_id: int) -> int:
"""
获取指定用户在某个聊天中的消息数量
Args:
chat_id: 聊天ID
user_id: 用户ID
Returns:
用户消息数量
"""
with self._lock:
context = self.get_context(chat_id, limit=None)
return sum(1 for msg in context if msg['user_id'] == user_id and not msg['is_bot'])
def cleanup_old_chats(self, max_inactive_days: int = 7):
"""
清理长时间不活跃的聊天记录
Args:
max_inactive_days: 最大不活跃天数
"""
from datetime import timedelta
cutoff_time = datetime.now() - timedelta(days=max_inactive_days)
with self._lock:
chats_to_remove = []
for chat_key, messages in self._contexts.items():
if not messages:
chats_to_remove.append(chat_key)
continue
# 检查最后一条消息的时间
last_message = messages[-1]
try:
last_time = datetime.fromisoformat(last_message['timestamp'])
if last_time < cutoff_time:
chats_to_remove.append(chat_key)
except (ValueError, KeyError):
# 如果时间戳格式有问题,也删除
chats_to_remove.append(chat_key)
for chat_key in chats_to_remove:
del self._contexts[chat_key]
if chats_to_remove:
logger.info(f"清理了 {len(chats_to_remove)} 个不活跃的聊天记录")