blob: 7db1db660a62995bb3e284ca31343a753205fe57 [file] [log] [blame] [raw]
"""
Webhook服务器
使用FastAPI实现Bot消息广播服务器
"""
import asyncio
import logging
from typing import Dict, List, Set, Optional
from datetime import datetime, timedelta
import aiohttp
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from contextlib import asynccontextmanager
from .models import (
BotMessage, BotRegistration, WebhookResponse,
MessageBroadcast, WebhookConfig
)
logger = logging.getLogger(__name__)
class BotRegistry:
"""Bot注册管理器"""
def __init__(self):
self._bots: Dict[str, BotRegistration] = {}
self._group_subscriptions: Dict[int, Set[str]] = {} # group_id -> bot_usernames
self._bot_endpoints: Dict[str, str] = {} # bot_username -> webhook_endpoint
def register_bot(self, registration: BotRegistration) -> bool:
"""注册Bot"""
try:
bot_username = registration.bot_username
self._bots[bot_username] = registration
# 更新群组订阅
for group_id in registration.subscribed_groups:
if group_id not in self._group_subscriptions:
self._group_subscriptions[group_id] = set()
self._group_subscriptions[group_id].add(bot_username)
# 更新端点映射
if registration.webhook_endpoint:
self._bot_endpoints[bot_username] = registration.webhook_endpoint
logger.info(f"Bot {bot_username} 注册成功,订阅群组: {registration.subscribed_groups}")
return True
except Exception as e:
logger.error(f"Bot注册失败: {e}")
return False
def unregister_bot(self, bot_username: str) -> bool:
"""注销Bot"""
try:
if bot_username not in self._bots:
return False
registration = self._bots[bot_username]
# 移除群组订阅
for group_id in registration.subscribed_groups:
if group_id in self._group_subscriptions:
self._group_subscriptions[group_id].discard(bot_username)
if not self._group_subscriptions[group_id]:
del self._group_subscriptions[group_id]
# 移除端点映射
self._bot_endpoints.pop(bot_username, None)
# 移除Bot记录
del self._bots[bot_username]
logger.info(f"Bot {bot_username} 注销成功")
return True
except Exception as e:
logger.error(f"Bot注销失败: {e}")
return False
def get_subscribers_for_group(self, group_id: int, exclude_bots: List[str] = None) -> List[str]:
"""获取订阅指定群组的Bot列表"""
exclude_bots = exclude_bots or []
subscribers = self._group_subscriptions.get(group_id, set())
return [bot for bot in subscribers if bot not in exclude_bots]
def get_bot_endpoint(self, bot_username: str) -> Optional[str]:
"""获取Bot的Webhook端点"""
return self._bot_endpoints.get(bot_username)
def update_bot_last_seen(self, bot_username: str):
"""更新Bot最后活跃时间"""
if bot_username in self._bots:
self._bots[bot_username].last_seen = datetime.now()
def get_bot_info(self, bot_username: str) -> Optional[BotRegistration]:
"""获取Bot信息"""
return self._bots.get(bot_username)
def get_all_bots(self) -> List[BotRegistration]:
"""获取所有注册的Bot"""
return list(self._bots.values())
class MessageDistributor:
"""消息分发器"""
def __init__(self, bot_registry: BotRegistry, config: WebhookConfig):
self.bot_registry = bot_registry
self.config = config
self.session: Optional[aiohttp.ClientSession] = None
async def start(self):
"""启动分发器"""
timeout = aiohttp.ClientTimeout(
total=self.config.connection_timeout,
connect=self.config.request_timeout
)
self.session = aiohttp.ClientSession(timeout=timeout)
async def stop(self):
"""停止分发器"""
if self.session:
await self.session.close()
self.session = None
async def distribute_message(self, broadcast: MessageBroadcast) -> Dict[str, bool]:
"""分发消息到订阅的Bot"""
if not self.session:
raise RuntimeError("MessageDistributor not started")
message = broadcast.bot_message
target_groups = broadcast.target_groups or [message.group_id]
exclude_bots = broadcast.exclude_bots or []
# 添加发送者Bot到排除列表
if message.bot_username not in exclude_bots:
exclude_bots = exclude_bots + [message.bot_username]
# 收集所有需要接收消息的Bot
target_bots = set()
for group_id in target_groups:
subscribers = self.bot_registry.get_subscribers_for_group(group_id, exclude_bots)
target_bots.update(subscribers)
logger.info(f"分发消息给 {len(target_bots)} 个Bot: {list(target_bots)}")
# 并发发送消息
results = {}
if target_bots:
tasks = []
for bot_username in target_bots:
task = self._send_to_bot(bot_username, message)
tasks.append((bot_username, task))
for bot_username, task in tasks:
try:
success = await task
results[bot_username] = success
except Exception as e:
logger.error(f"发送消息给 {bot_username} 失败: {e}")
results[bot_username] = False
return results
async def _send_to_bot(self, bot_username: str, message: BotMessage) -> bool:
"""发送消息给指定Bot"""
endpoint = self.bot_registry.get_bot_endpoint(bot_username)
if not endpoint:
logger.warning(f"Bot {bot_username} 没有配置Webhook端点")
return False
try:
# 构建请求数据
payload = {
"message": message.model_dump(mode='json'),
"timestamp": datetime.now().isoformat()
}
# 添加认证头
headers = {
"Authorization": f"Bearer {self.config.auth_token}",
"Content-Type": "application/json"
}
# 发送请求
async with self.session.post(endpoint, json=payload, headers=headers) as response:
if response.status == 200:
logger.debug(f"消息成功发送给 {bot_username}")
self.bot_registry.update_bot_last_seen(bot_username)
return True
else:
logger.warning(f"发送消息给 {bot_username} 失败: HTTP {response.status}")
return False
except Exception as e:
logger.error(f"发送消息给 {bot_username} 异常: {e}")
return False
class WebhookServer:
"""Webhook服务器主类"""
def __init__(self, config: WebhookConfig):
self.config = config
self.bot_registry = BotRegistry()
self.message_distributor = MessageDistributor(self.bot_registry, config)
self.app = FastAPI(
title="Bot Message Webhook Server",
description="Telegram Bot间消息广播服务器",
version="1.0.0",
lifespan=self._lifespan
)
self.security = HTTPBearer()
self._setup_routes()
@asynccontextmanager
async def _lifespan(self, app: FastAPI):
"""应用生命周期管理"""
# 启动
await self.message_distributor.start()
logger.info("Webhook服务器启动完成")
yield
# 关闭
await self.message_distributor.stop()
logger.info("Webhook服务器关闭完成")
def _verify_token(self, credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
"""验证Token"""
if credentials.credentials != self.config.auth_token:
raise HTTPException(status_code=401, detail="Invalid authentication token")
return credentials.credentials
def _setup_routes(self):
"""设置路由"""
@self.app.post("/register", response_model=WebhookResponse)
async def register_bot(
registration: BotRegistration,
token: str = Depends(self._verify_token)
):
"""注册Bot"""
# 验证注册信息中的token
if registration.auth_token != token:
raise HTTPException(status_code=401, detail="Registration token mismatch")
success = self.bot_registry.register_bot(registration)
if success:
return WebhookResponse(
success=True,
message=f"Bot {registration.bot_username} 注册成功"
)
else:
raise HTTPException(status_code=400, detail="Bot registration failed")
@self.app.post("/unregister/{bot_username}", response_model=WebhookResponse)
async def unregister_bot(
bot_username: str,
token: str = Depends(self._verify_token)
):
"""注销Bot"""
success = self.bot_registry.unregister_bot(bot_username)
if success:
return WebhookResponse(
success=True,
message=f"Bot {bot_username} 注销成功"
)
else:
raise HTTPException(status_code=404, detail="Bot not found")
@self.app.post("/broadcast", response_model=WebhookResponse)
async def broadcast_message(
broadcast: MessageBroadcast,
background_tasks: BackgroundTasks,
token: str = Depends(self._verify_token)
):
"""广播消息"""
# 异步分发消息
background_tasks.add_task(self._handle_broadcast, broadcast)
return WebhookResponse(
success=True,
message="消息广播请求已接收,正在处理"
)
@self.app.get("/bots", response_model=WebhookResponse)
async def list_bots(token: str = Depends(self._verify_token)):
"""获取Bot列表"""
bots = self.bot_registry.get_all_bots()
return WebhookResponse(
success=True,
message=f"共有 {len(bots)} 个已注册的Bot",
data={"bots": [bot.model_dump(mode='json') for bot in bots]}
)
@self.app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
async def _handle_broadcast(self, broadcast: MessageBroadcast):
"""处理消息广播"""
try:
results = await self.message_distributor.distribute_message(broadcast)
success_count = sum(1 for success in results.values() if success)
logger.info(f"消息广播完成: {success_count}/{len(results)} 个Bot成功接收")
except Exception as e:
logger.error(f"消息广播失败: {e}")
def start_server(self):
"""启动服务器"""
import uvicorn
uvicorn.run(
self.app,
host=self.config.server_host,
port=self.config.server_port,
log_level="info"
)
def get_app(self) -> FastAPI:
"""获取FastAPI应用实例"""
return self.app