115 lines
3.8 KiB
Python
115 lines
3.8 KiB
Python
"""并发控制服务 - 使用 Redis 实现任务并发限制"""
|
|
import asyncio
|
|
from typing import Optional
|
|
import redis.asyncio as aioredis
|
|
from ..config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConcurrencyManager:
|
|
"""并发控制管理器 - 使用 Redis Semaphore 实现 16 并发限制"""
|
|
|
|
MAX_CONCURRENT_TASKS = 16
|
|
REDIS_SEMAPHORE_KEY = "bttoxin:concurrency:semaphore"
|
|
REDIS_QUEUE_KEY = "bttoxin:queue:pending"
|
|
REDIS_RUNNING_KEY = "bttoxin:running:tasks"
|
|
|
|
def __init__(self):
|
|
self._redis: Optional[aioredis.Redis] = None
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def get_redis(self) -> aioredis.Redis:
|
|
"""获取 Redis 连接"""
|
|
if self._redis is None:
|
|
async with self._lock:
|
|
if self._redis is None:
|
|
self._redis = await aioredis.from_url(
|
|
settings.REDIS_URL,
|
|
encoding="utf-8",
|
|
decode_responses=True
|
|
)
|
|
return self._redis
|
|
|
|
async def acquire_slot(self, job_id: str) -> bool:
|
|
"""
|
|
尝试获取执行槽位
|
|
|
|
Returns:
|
|
bool: 是否成功获取槽位
|
|
"""
|
|
redis = await self.get_redis()
|
|
|
|
# 使用 Redis Sorted Set 实现信号量
|
|
# 检查当前运行的任务数
|
|
running_count = await redis.zcard(self.REDIS_RUNNING_KEY)
|
|
|
|
if running_count < self.MAX_CONCURRENT_TASKS:
|
|
# 有可用槽位,加入运行队列
|
|
await redis.zadd(
|
|
self.REDIS_RUNNING_KEY,
|
|
{job_id: asyncio.get_event_loop().time()}
|
|
)
|
|
logger.info(f"Job {job_id} acquired slot. Running: {running_count + 1}/{self.MAX_CONCURRENT_TASKS}")
|
|
return True
|
|
else:
|
|
# 没有可用槽位,加入等待队列
|
|
await redis.rpush(self.REDIS_QUEUE_KEY, job_id)
|
|
logger.info(f"Job {job_id} queued. Position: {await self.get_queue_position(job_id)}")
|
|
return False
|
|
|
|
async def release_slot(self, job_id: str):
|
|
"""释放执行槽位"""
|
|
redis = await self.get_redis()
|
|
|
|
# 从运行队列移除
|
|
await redis.zrem(self.REDIS_RUNNING_KEY, job_id)
|
|
|
|
# 检查是否有等待的任务
|
|
next_job = await redis.lpop(self.REDIS_QUEUE_KEY)
|
|
if next_job:
|
|
# 将下一个任务加入运行队列
|
|
await redis.zadd(
|
|
self.REDIS_RUNNING_KEY,
|
|
{next_job: asyncio.get_event_loop().time()}
|
|
)
|
|
logger.info(f"Job {next_job} promoted from queue. Job {job_id} released.")
|
|
|
|
async def get_queue_position(self, job_id: str) -> int:
|
|
"""获取任务在队列中的位置"""
|
|
redis = await self.get_redis()
|
|
position = await redis.lpos(self.REDIS_QUEUE_KEY, job_id)
|
|
return position + 1 if position is not None else 0
|
|
|
|
async def get_running_count(self) -> int:
|
|
"""获取当前运行的任务数"""
|
|
redis = await self.get_redis()
|
|
return await redis.zcard(self.REDIS_RUNNING_KEY)
|
|
|
|
async def get_queue_length(self) -> int:
|
|
"""获取等待队列长度"""
|
|
redis = await self.get_redis()
|
|
return await redis.llen(self.REDIS_QUEUE_KEY)
|
|
|
|
async def close(self):
|
|
"""关闭 Redis 连接"""
|
|
if self._redis:
|
|
await self._redis.close()
|
|
self._redis = None
|
|
|
|
|
|
# 全局单例
|
|
_concurrency_manager: Optional[ConcurrencyManager] = None
|
|
_manager_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_concurrency_manager() -> ConcurrencyManager:
|
|
"""获取并发管理器单例"""
|
|
global _concurrency_manager
|
|
if _concurrency_manager is None:
|
|
async with _manager_lock:
|
|
if _concurrency_manager is None:
|
|
_concurrency_manager = ConcurrencyManager()
|
|
return _concurrency_manager
|