| """ |
| Webhook服务器 |
| 使用FastAPI实现Bot消息广播服务器 |
| """ |
| |
| import asyncio |
| import logging |
| from typing import Dict, List, Set, Optional |
| from datetime import datetime, timedelta |
| import aiohttp |
| from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from contextlib import asynccontextmanager |
| |
| from .models import ( |
| BotMessage, BotRegistration, WebhookResponse, |
| MessageBroadcast, WebhookConfig |
| ) |
| |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| class BotRegistry: |
| """Bot注册管理器""" |
| |
| def __init__(self): |
| self._bots: Dict[str, BotRegistration] = {} |
| self._group_subscriptions: Dict[int, Set[str]] = {} # group_id -> bot_usernames |
| self._bot_endpoints: Dict[str, str] = {} # bot_username -> webhook_endpoint |
| |
| def register_bot(self, registration: BotRegistration) -> bool: |
| """注册Bot""" |
| try: |
| bot_username = registration.bot_username |
| self._bots[bot_username] = registration |
| |
| # 更新群组订阅 |
| for group_id in registration.subscribed_groups: |
| if group_id not in self._group_subscriptions: |
| self._group_subscriptions[group_id] = set() |
| self._group_subscriptions[group_id].add(bot_username) |
| |
| # 更新端点映射 |
| if registration.webhook_endpoint: |
| self._bot_endpoints[bot_username] = registration.webhook_endpoint |
| |
| logger.info(f"Bot {bot_username} 注册成功,订阅群组: {registration.subscribed_groups}") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Bot注册失败: {e}") |
| return False |
| |
| def unregister_bot(self, bot_username: str) -> bool: |
| """注销Bot""" |
| try: |
| if bot_username not in self._bots: |
| return False |
| |
| registration = self._bots[bot_username] |
| |
| # 移除群组订阅 |
| for group_id in registration.subscribed_groups: |
| if group_id in self._group_subscriptions: |
| self._group_subscriptions[group_id].discard(bot_username) |
| if not self._group_subscriptions[group_id]: |
| del self._group_subscriptions[group_id] |
| |
| # 移除端点映射 |
| self._bot_endpoints.pop(bot_username, None) |
| |
| # 移除Bot记录 |
| del self._bots[bot_username] |
| |
| logger.info(f"Bot {bot_username} 注销成功") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Bot注销失败: {e}") |
| return False |
| |
| def get_subscribers_for_group(self, group_id: int, exclude_bots: List[str] = None) -> List[str]: |
| """获取订阅指定群组的Bot列表""" |
| exclude_bots = exclude_bots or [] |
| subscribers = self._group_subscriptions.get(group_id, set()) |
| return [bot for bot in subscribers if bot not in exclude_bots] |
| |
| def get_bot_endpoint(self, bot_username: str) -> Optional[str]: |
| """获取Bot的Webhook端点""" |
| return self._bot_endpoints.get(bot_username) |
| |
| def update_bot_last_seen(self, bot_username: str): |
| """更新Bot最后活跃时间""" |
| if bot_username in self._bots: |
| self._bots[bot_username].last_seen = datetime.now() |
| |
| def get_bot_info(self, bot_username: str) -> Optional[BotRegistration]: |
| """获取Bot信息""" |
| return self._bots.get(bot_username) |
| |
| def get_all_bots(self) -> List[BotRegistration]: |
| """获取所有注册的Bot""" |
| return list(self._bots.values()) |
| |
| |
| class MessageDistributor: |
| """消息分发器""" |
| |
| def __init__(self, bot_registry: BotRegistry, config: WebhookConfig): |
| self.bot_registry = bot_registry |
| self.config = config |
| self.session: Optional[aiohttp.ClientSession] = None |
| |
| async def start(self): |
| """启动分发器""" |
| timeout = aiohttp.ClientTimeout( |
| total=self.config.connection_timeout, |
| connect=self.config.request_timeout |
| ) |
| self.session = aiohttp.ClientSession(timeout=timeout) |
| |
| async def stop(self): |
| """停止分发器""" |
| if self.session: |
| await self.session.close() |
| self.session = None |
| |
| async def distribute_message(self, broadcast: MessageBroadcast) -> Dict[str, bool]: |
| """分发消息到订阅的Bot""" |
| if not self.session: |
| raise RuntimeError("MessageDistributor not started") |
| |
| message = broadcast.bot_message |
| target_groups = broadcast.target_groups or [message.group_id] |
| exclude_bots = broadcast.exclude_bots or [] |
| |
| # 添加发送者Bot到排除列表 |
| if message.bot_username not in exclude_bots: |
| exclude_bots = exclude_bots + [message.bot_username] |
| |
| # 收集所有需要接收消息的Bot |
| target_bots = set() |
| for group_id in target_groups: |
| subscribers = self.bot_registry.get_subscribers_for_group(group_id, exclude_bots) |
| target_bots.update(subscribers) |
| |
| logger.info(f"分发消息给 {len(target_bots)} 个Bot: {list(target_bots)}") |
| |
| # 并发发送消息 |
| results = {} |
| if target_bots: |
| tasks = [] |
| for bot_username in target_bots: |
| task = self._send_to_bot(bot_username, message) |
| tasks.append((bot_username, task)) |
| |
| for bot_username, task in tasks: |
| try: |
| success = await task |
| results[bot_username] = success |
| except Exception as e: |
| logger.error(f"发送消息给 {bot_username} 失败: {e}") |
| results[bot_username] = False |
| |
| return results |
| |
| async def _send_to_bot(self, bot_username: str, message: BotMessage) -> bool: |
| """发送消息给指定Bot""" |
| endpoint = self.bot_registry.get_bot_endpoint(bot_username) |
| if not endpoint: |
| logger.warning(f"Bot {bot_username} 没有配置Webhook端点") |
| return False |
| |
| try: |
| # 构建请求数据 |
| payload = { |
| "message": message.model_dump(mode='json'), |
| "timestamp": datetime.now().isoformat() |
| } |
| |
| # 添加认证头 |
| headers = { |
| "Authorization": f"Bearer {self.config.auth_token}", |
| "Content-Type": "application/json" |
| } |
| |
| # 发送请求 |
| async with self.session.post(endpoint, json=payload, headers=headers) as response: |
| if response.status == 200: |
| logger.debug(f"消息成功发送给 {bot_username}") |
| self.bot_registry.update_bot_last_seen(bot_username) |
| return True |
| else: |
| logger.warning(f"发送消息给 {bot_username} 失败: HTTP {response.status}") |
| return False |
| |
| except Exception as e: |
| logger.error(f"发送消息给 {bot_username} 异常: {e}") |
| return False |
| |
| |
| class WebhookServer: |
| """Webhook服务器主类""" |
| |
| def __init__(self, config: WebhookConfig): |
| self.config = config |
| self.bot_registry = BotRegistry() |
| self.message_distributor = MessageDistributor(self.bot_registry, config) |
| self.app = FastAPI( |
| title="Bot Message Webhook Server", |
| description="Telegram Bot间消息广播服务器", |
| version="1.0.0", |
| lifespan=self._lifespan |
| ) |
| self.security = HTTPBearer() |
| self._setup_routes() |
| |
| @asynccontextmanager |
| async def _lifespan(self, app: FastAPI): |
| """应用生命周期管理""" |
| # 启动 |
| await self.message_distributor.start() |
| logger.info("Webhook服务器启动完成") |
| yield |
| # 关闭 |
| await self.message_distributor.stop() |
| logger.info("Webhook服务器关闭完成") |
| |
| def _verify_token(self, credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer())): |
| """验证Token""" |
| if credentials.credentials != self.config.auth_token: |
| raise HTTPException(status_code=401, detail="Invalid authentication token") |
| return credentials.credentials |
| |
| def _setup_routes(self): |
| """设置路由""" |
| |
| @self.app.post("/register", response_model=WebhookResponse) |
| async def register_bot( |
| registration: BotRegistration, |
| token: str = Depends(self._verify_token) |
| ): |
| """注册Bot""" |
| # 验证注册信息中的token |
| if registration.auth_token != token: |
| raise HTTPException(status_code=401, detail="Registration token mismatch") |
| |
| success = self.bot_registry.register_bot(registration) |
| if success: |
| return WebhookResponse( |
| success=True, |
| message=f"Bot {registration.bot_username} 注册成功" |
| ) |
| else: |
| raise HTTPException(status_code=400, detail="Bot registration failed") |
| |
| @self.app.post("/unregister/{bot_username}", response_model=WebhookResponse) |
| async def unregister_bot( |
| bot_username: str, |
| token: str = Depends(self._verify_token) |
| ): |
| """注销Bot""" |
| success = self.bot_registry.unregister_bot(bot_username) |
| if success: |
| return WebhookResponse( |
| success=True, |
| message=f"Bot {bot_username} 注销成功" |
| ) |
| else: |
| raise HTTPException(status_code=404, detail="Bot not found") |
| |
| @self.app.post("/broadcast", response_model=WebhookResponse) |
| async def broadcast_message( |
| broadcast: MessageBroadcast, |
| background_tasks: BackgroundTasks, |
| token: str = Depends(self._verify_token) |
| ): |
| """广播消息""" |
| # 异步分发消息 |
| background_tasks.add_task(self._handle_broadcast, broadcast) |
| |
| return WebhookResponse( |
| success=True, |
| message="消息广播请求已接收,正在处理" |
| ) |
| |
| @self.app.get("/bots", response_model=WebhookResponse) |
| async def list_bots(token: str = Depends(self._verify_token)): |
| """获取Bot列表""" |
| bots = self.bot_registry.get_all_bots() |
| return WebhookResponse( |
| success=True, |
| message=f"共有 {len(bots)} 个已注册的Bot", |
| data={"bots": [bot.model_dump(mode='json') for bot in bots]} |
| ) |
| |
| @self.app.get("/health") |
| async def health_check(): |
| """健康检查""" |
| return {"status": "healthy", "timestamp": datetime.now().isoformat()} |
| |
| async def _handle_broadcast(self, broadcast: MessageBroadcast): |
| """处理消息广播""" |
| try: |
| results = await self.message_distributor.distribute_message(broadcast) |
| success_count = sum(1 for success in results.values() if success) |
| logger.info(f"消息广播完成: {success_count}/{len(results)} 个Bot成功接收") |
| except Exception as e: |
| logger.error(f"消息广播失败: {e}") |
| |
| def start_server(self): |
| """启动服务器""" |
| import uvicorn |
| uvicorn.run( |
| self.app, |
| host=self.config.server_host, |
| port=self.config.server_port, |
| log_level="info" |
| ) |
| |
| def get_app(self) -> FastAPI: |
| """获取FastAPI应用实例""" |
| return self.app |