Files
bttoxin-pipeline/backend/app/services/concurrency_manager.py

115 lines
3.8 KiB
Python

"""并发控制服务 - 使用 Redis 实现任务并发限制"""
import asyncio
from typing import Optional
import redis.asyncio as aioredis
from ..config import settings
import logging
logger = logging.getLogger(__name__)
class ConcurrencyManager:
"""并发控制管理器 - 使用 Redis Semaphore 实现 16 并发限制"""
MAX_CONCURRENT_TASKS = 16
REDIS_SEMAPHORE_KEY = "bttoxin:concurrency:semaphore"
REDIS_QUEUE_KEY = "bttoxin:queue:pending"
REDIS_RUNNING_KEY = "bttoxin:running:tasks"
def __init__(self):
self._redis: Optional[aioredis.Redis] = None
self._lock = asyncio.Lock()
async def get_redis(self) -> aioredis.Redis:
"""获取 Redis 连接"""
if self._redis is None:
async with self._lock:
if self._redis is None:
self._redis = await aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True
)
return self._redis
async def acquire_slot(self, job_id: str) -> bool:
"""
尝试获取执行槽位
Returns:
bool: 是否成功获取槽位
"""
redis = await self.get_redis()
# 使用 Redis Sorted Set 实现信号量
# 检查当前运行的任务数
running_count = await redis.zcard(self.REDIS_RUNNING_KEY)
if running_count < self.MAX_CONCURRENT_TASKS:
# 有可用槽位,加入运行队列
await redis.zadd(
self.REDIS_RUNNING_KEY,
{job_id: asyncio.get_event_loop().time()}
)
logger.info(f"Job {job_id} acquired slot. Running: {running_count + 1}/{self.MAX_CONCURRENT_TASKS}")
return True
else:
# 没有可用槽位,加入等待队列
await redis.rpush(self.REDIS_QUEUE_KEY, job_id)
logger.info(f"Job {job_id} queued. Position: {await self.get_queue_position(job_id)}")
return False
async def release_slot(self, job_id: str):
"""释放执行槽位"""
redis = await self.get_redis()
# 从运行队列移除
await redis.zrem(self.REDIS_RUNNING_KEY, job_id)
# 检查是否有等待的任务
next_job = await redis.lpop(self.REDIS_QUEUE_KEY)
if next_job:
# 将下一个任务加入运行队列
await redis.zadd(
self.REDIS_RUNNING_KEY,
{next_job: asyncio.get_event_loop().time()}
)
logger.info(f"Job {next_job} promoted from queue. Job {job_id} released.")
async def get_queue_position(self, job_id: str) -> int:
"""获取任务在队列中的位置"""
redis = await self.get_redis()
position = await redis.lpos(self.REDIS_QUEUE_KEY, job_id)
return position + 1 if position is not None else 0
async def get_running_count(self) -> int:
"""获取当前运行的任务数"""
redis = await self.get_redis()
return await redis.zcard(self.REDIS_RUNNING_KEY)
async def get_queue_length(self) -> int:
"""获取等待队列长度"""
redis = await self.get_redis()
return await redis.llen(self.REDIS_QUEUE_KEY)
async def close(self):
"""关闭 Redis 连接"""
if self._redis:
await self._redis.close()
self._redis = None
# 全局单例
_concurrency_manager: Optional[ConcurrencyManager] = None
_manager_lock = asyncio.Lock()
async def get_concurrency_manager() -> ConcurrencyManager:
"""获取并发管理器单例"""
global _concurrency_manager
if _concurrency_manager is None:
async with _manager_lock:
if _concurrency_manager is None:
_concurrency_manager = ConcurrencyManager()
return _concurrency_manager