"""并发控制服务 - 使用 Redis 实现任务并发限制""" import asyncio from typing import Optional import redis.asyncio as aioredis from ..config import settings import logging logger = logging.getLogger(__name__) class ConcurrencyManager: """并发控制管理器 - 使用 Redis Semaphore 实现 16 并发限制""" MAX_CONCURRENT_TASKS = 16 REDIS_SEMAPHORE_KEY = "bttoxin:concurrency:semaphore" REDIS_QUEUE_KEY = "bttoxin:queue:pending" REDIS_RUNNING_KEY = "bttoxin:running:tasks" def __init__(self): self._redis: Optional[aioredis.Redis] = None self._lock = asyncio.Lock() async def get_redis(self) -> aioredis.Redis: """获取 Redis 连接""" if self._redis is None: async with self._lock: if self._redis is None: self._redis = await aioredis.from_url( settings.REDIS_URL, encoding="utf-8", decode_responses=True ) return self._redis async def acquire_slot(self, job_id: str) -> bool: """ 尝试获取执行槽位 Returns: bool: 是否成功获取槽位 """ redis = await self.get_redis() # 使用 Redis Sorted Set 实现信号量 # 检查当前运行的任务数 running_count = await redis.zcard(self.REDIS_RUNNING_KEY) if running_count < self.MAX_CONCURRENT_TASKS: # 有可用槽位,加入运行队列 await redis.zadd( self.REDIS_RUNNING_KEY, {job_id: asyncio.get_event_loop().time()} ) logger.info(f"Job {job_id} acquired slot. Running: {running_count + 1}/{self.MAX_CONCURRENT_TASKS}") return True else: # 没有可用槽位,加入等待队列 await redis.rpush(self.REDIS_QUEUE_KEY, job_id) logger.info(f"Job {job_id} queued. Position: {await self.get_queue_position(job_id)}") return False async def release_slot(self, job_id: str): """释放执行槽位""" redis = await self.get_redis() # 从运行队列移除 await redis.zrem(self.REDIS_RUNNING_KEY, job_id) # 检查是否有等待的任务 next_job = await redis.lpop(self.REDIS_QUEUE_KEY) if next_job: # 将下一个任务加入运行队列 await redis.zadd( self.REDIS_RUNNING_KEY, {next_job: asyncio.get_event_loop().time()} ) logger.info(f"Job {next_job} promoted from queue. Job {job_id} released.") async def get_queue_position(self, job_id: str) -> int: """获取任务在队列中的位置""" redis = await self.get_redis() position = await redis.lpos(self.REDIS_QUEUE_KEY, job_id) return position + 1 if position is not None else 0 async def get_running_count(self) -> int: """获取当前运行的任务数""" redis = await self.get_redis() return await redis.zcard(self.REDIS_RUNNING_KEY) async def get_queue_length(self) -> int: """获取等待队列长度""" redis = await self.get_redis() return await redis.llen(self.REDIS_QUEUE_KEY) async def close(self): """关闭 Redis 连接""" if self._redis: await self._redis.close() self._redis = None # 全局单例 _concurrency_manager: Optional[ConcurrencyManager] = None _manager_lock = asyncio.Lock() async def get_concurrency_manager() -> ConcurrencyManager: """获取并发管理器单例""" global _concurrency_manager if _concurrency_manager is None: async with _manager_lock: if _concurrency_manager is None: _concurrency_manager = ConcurrencyManager() return _concurrency_manager