first add

This commit is contained in:
2025-10-13 21:05:00 +08:00
parent c7744836e9
commit d71163df00
29 changed files with 144656 additions and 37 deletions

22
backend/Dockerfile Normal file
View File

@@ -0,0 +1,22 @@
FROM python:3.11-slim
WORKDIR /app
RUN apt-get update && apt-get install -y \
gcc \
postgresql-client \
curl \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
RUN mkdir -p /data/jobs /data/temp
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

54
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,54 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sqlmodel import SQLModel
import os
from app.core.config import settings
from app.models import * # noqa: F401,F403 - 导入以注册模型到元数据
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = SQLModel.metadata
def run_migrations_offline() -> None:
url = os.getenv("DATABASE_URL", settings.DATABASE_URL)
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = os.getenv("DATABASE_URL", settings.DATABASE_URL)
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,3 @@
__all__ = ["routes"]

View File

@@ -0,0 +1,5 @@
from . import jobs
__all__ = ["jobs"]

View File

@@ -0,0 +1,281 @@
"""任务管理 API表单+文件上传,使用 Pydantic 校验)"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from sqlmodel import Session, select, col
import logging
from ...core.database import get_session
from ...models.job import Job, JobUpdate, JobStatus, Step, StepRead
from ...schemas.job import JobCreateRequest, JobCreateResponse, JobStatusResponse, FileUploadInfo
from ...services.workspace_service import WorkspaceManager
from ...workers.pipeline import orchestrate_pipeline
logger = logging.getLogger(__name__)
router = APIRouter()
workspace_mgr = WorkspaceManager()
@router.post(
"/",
response_model=JobCreateResponse,
status_code=status.HTTP_201_CREATED,
summary="创建新的分析任务(表单+文件)",
)
async def create_job(
files: List[UploadFile] = File(...),
name: str = Form(...),
description: Optional[str] = Form(None),
sequence_type: str = Form("nucl"),
scaf_suffix: Optional[str] = Form(None),
orfs_suffix: Optional[str] = Form(None),
prot_suffix: Optional[str] = Form(None),
platform: Optional[str] = Form(None),
reads1_suffix: Optional[str] = Form(None),
reads2_suffix: Optional[str] = Form(None),
genome_size: Optional[str] = Form(None),
suffix_len: Optional[int] = Form(None),
short1: Optional[str] = Form(None),
short2: Optional[str] = Form(None),
long: Optional[str] = Form(None),
threads: int = Form(4),
update_db: bool = Form(False),
assemble_only: bool = Form(False),
session: Session = Depends(get_session),
):
"""创建 BtToxin 分析任务:接收表单参数与文件,保存到工作区后触发 Celery。"""
if not files:
raise HTTPException(status_code=400, detail="至少需要上传一个文件")
try:
params = JobCreateRequest(
name=name,
description=description,
sequence_type=sequence_type,
scaf_suffix=scaf_suffix,
orfs_suffix=orfs_suffix,
prot_suffix=prot_suffix,
platform=platform,
reads1_suffix=reads1_suffix,
reads2_suffix=reads2_suffix,
genome_size=genome_size,
suffix_len=suffix_len,
short1=short1,
short2=short2,
long=long,
threads=threads,
update_db=update_db,
assemble_only=assemble_only,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=f"参数验证失败: {e}")
job = Job(
name=params.name,
description=params.description,
sequence_type=params.sequence_type.value,
scaf_suffix=params.scaf_suffix or "",
threads=params.threads,
update_db=params.update_db,
status=JobStatus.PENDING,
)
session.add(job)
session.commit()
session.refresh(job)
workspace = workspace_mgr.create_workspace(job.id)
job.workspace_path = str(workspace["root"])
uploaded: List[FileUploadInfo] = []
for f in files:
f.file.seek(0, 2)
size = f.file.tell()
f.file.seek(0)
if size == 0:
continue
dst = workspace["inputs"] / f.filename
with dst.open("wb") as out:
out.write(await f.read())
uploaded.append(
FileUploadInfo(filename=f.filename, size=size, content_type=f.content_type, path=str(dst))
)
job.input_files = [u.model_dump() for u in uploaded]
session.add(job)
session.commit()
session.refresh(job)
cfg = {
"sequence_type": params.sequence_type.value,
"scaf_suffix": params.scaf_suffix,
"orfs_suffix": params.orfs_suffix,
"prot_suffix": params.prot_suffix,
"threads": params.threads,
"update_db": params.update_db,
"assemble_only": params.assemble_only,
"platform": params.platform.value if params.platform else None,
"reads1_suffix": params.reads1_suffix,
"reads2_suffix": params.reads2_suffix,
"suffix_len": params.suffix_len,
"genome_size": params.genome_size,
"short1": params.short1,
"short2": params.short2,
"long": params.long,
}
task = orchestrate_pipeline.delay(job.id, cfg)
job.celery_task_id = task.id
session.add(job)
session.commit()
session.refresh(job)
return JobCreateResponse(
job_id=job.id,
message="任务创建成功,正在执行分析",
uploaded_files=uploaded,
workspace_path=job.workspace_path,
celery_task_id=task.id,
)
@router.get("/", response_model=List[JobStatusResponse])
async def list_jobs(
skip: int = 0,
limit: int = 50,
status: Optional[JobStatus] = None,
session: Session = Depends(get_session),
):
"""获取任务列表"""
statement = select(Job)
if status:
statement = statement.where(Job.status == status)
statement = statement.offset(skip).limit(limit).order_by(col(Job.created_at).desc())
jobs = session.exec(statement).all()
return [
JobStatusResponse(
job_id=j.id,
name=j.name,
status=j.status.value if hasattr(j.status, "value") else j.status,
progress=j.progress,
current_step=j.current_step,
error_message=j.error_message,
created_at=j.created_at.isoformat(),
started_at=j.started_at.isoformat() if j.started_at else None,
completed_at=j.completed_at.isoformat() if j.completed_at else None,
)
for j in jobs
]
@router.get("/{job_id}", response_model=JobStatusResponse)
async def get_job(job_id: str, session: Session = Depends(get_session)):
"""获取任务详情"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return JobStatusResponse(
job_id=job.id,
name=job.name,
status=job.status.value if hasattr(job.status, "value") else job.status,
progress=job.progress,
current_step=job.current_step,
error_message=job.error_message,
created_at=job.created_at.isoformat(),
started_at=job.started_at.isoformat() if job.started_at else None,
completed_at=job.completed_at.isoformat() if job.completed_at else None,
)
@router.patch("/{job_id}", response_model=JobRead)
async def update_job(
job_id: str, job_update: JobUpdate, session: Session = Depends(get_session)
):
"""更新任务"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
update_data = job_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(job, key, value)
session.add(job)
session.commit()
session.refresh(job)
return job
@router.delete("/{job_id}")
async def delete_job(job_id: str, session: Session = Depends(get_session)):
"""删除任务"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
workspace_mgr.cleanup_workspace(job_id, keep_results=False)
session.delete(job)
session.commit()
return {"message": "Job deleted successfully"}
@router.get("/{job_id}/steps", response_model=List[StepRead])
async def get_job_steps(job_id: str, session: Session = Depends(get_session)):
"""获取任务步骤列表"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
statement = select(Step).where(Step.job_id == job_id).order_by(Step.step_order)
steps = session.exec(statement).all()
return steps
@router.get("/{job_id}/progress")
async def get_job_progress(job_id: str, session: Session = Depends(get_session)):
"""获取任务进度"""
job = session.get(Job, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
statement = select(Step).where(Step.job_id == job_id)
steps = session.exec(statement).all()
total_steps = len(steps)
completed_steps = sum(1 for s in steps if s.status == "completed")
failed_steps = sum(1 for s in steps if s.status == "failed")
return {
"job_id": job_id,
"status": job.status,
"progress": job.progress,
"current_step": job.current_step,
"total_steps": total_steps,
"completed_steps": completed_steps,
"failed_steps": failed_steps,
"steps": [
{"name": s.step_name, "status": s.status, "order": s.step_order}
for s in sorted(steps, key=lambda x: x.step_order)
],
}

View File

@@ -1,3 +1,18 @@
from celery import Celery
from .config import settings
celery_app = Celery(
"bttoxin",
broker=settings.get_celery_broker_url(),
backend=settings.get_celery_result_backend(),
)
celery_app.conf.update(
task_track_started=True,
worker_prefetch_multiplier=1,
)
"""Celery 配置"""
from celery import Celery
from ..config import settings

View File

@@ -0,0 +1,96 @@
"""应用配置"""
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""应用配置"""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
)
# ============== 应用基础配置 ==============
APP_NAME: str = "BtToxin Pipeline"
APP_VERSION: str = "1.0.0"
DEBUG: bool = False
# API 配置
API_V1_PREFIX: str = "/api/v1"
# ============== 数据库配置 ==============
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: int = 5432
POSTGRES_USER: str = "bttoxin"
POSTGRES_PASSWORD: str = "bttoxin_password"
POSTGRES_DB: str = "bttoxin_db"
@property
def DATABASE_URL(self) -> str:
"""数据库连接 URL"""
return (
f"postgresql://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}"
f"@{self.POSTGRES_HOST}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}"
)
# ============== Redis 配置 ==============
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
REDIS_DB: int = 0
REDIS_PASSWORD: Optional[str] = None
@property
def REDIS_URL(self) -> str:
"""Redis 连接 URL"""
if self.REDIS_PASSWORD:
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
return f"redis://{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
# ============== Celery 配置 ==============
CELERY_BROKER_URL: Optional[str] = None
CELERY_RESULT_BACKEND: Optional[str] = None
def get_celery_broker_url(self) -> str:
"""获取 Celery Broker URL"""
return self.CELERY_BROKER_URL or self.REDIS_URL
def get_celery_result_backend(self) -> str:
"""获取 Celery Result Backend URL"""
return self.CELERY_RESULT_BACKEND or self.REDIS_URL
# ============== 工作空间配置 ==============
WORKSPACE_BASE_PATH: str = "/data/jobs"
TEMP_BASE_PATH: str = "/data/temp"
MAX_UPLOAD_SIZE_MB: int = 500
# ============== Docker 配置 ==============
DOCKER_IMAGE: str = "quay.io/biocontainers/bttoxin_digger:1.0.10--hdfd78af_0"
DOCKER_PLATFORM: str = "linux/amd64"
# ============== S3 配置 ==============
S3_ENDPOINT: Optional[str] = None
S3_ACCESS_KEY: Optional[str] = None
S3_SECRET_KEY: Optional[str] = None
S3_BUCKET: str = "bttoxin-results"
S3_REGION: str = "us-east-1"
# ============== 任务配置 ==============
DEFAULT_THREADS: int = 4
MAX_THREADS: int = 16
TASK_TIMEOUT_SECONDS: int = 7200 # 2小时
JOB_RETENTION_DAYS: int = 30
# ============== CORS 配置 ==============
CORS_ORIGINS: list = ["http://localhost:3000", "http://localhost:5173"]
# ============== 日志配置 ==============
LOG_LEVEL: str = "INFO"
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# 创建全局配置实例
settings = Settings()

View File

@@ -0,0 +1,50 @@
"""SQLModel 数据库配置"""
from typing import Generator
from sqlmodel import Session, create_engine
from sqlalchemy.orm import sessionmaker
from .config import settings
# 创建数据库引擎
engine = create_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
)
# 创建 SessionLocal
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
def get_session() -> Generator[Session, None, None]:
"""
获取数据库会话(依赖注入)
"""
with SessionLocal() as session:
yield session
def init_db() -> None:
"""初始化数据库(创建所有表)"""
from sqlmodel import SQLModel
from ..models.job import Job, Step, JobLog # noqa: F401
SQLModel.metadata.create_all(engine)
print("✓ Database initialized")
def drop_db() -> None:
"""删除所有表(开发用)"""
from sqlmodel import SQLModel
SQLModel.metadata.drop_all(engine)
print("✓ Database dropped")

View File

@@ -0,0 +1,14 @@
import logging
import sys
def setup_logging(level: str = "INFO", fmt: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s") -> None:
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(fmt)
handler.setFormatter(formatter)
root = logging.getLogger()
if not root.handlers:
root.addHandler(handler)
root.setLevel(level)

View File

@@ -1,3 +1,44 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .core.config import settings
from .core.database import init_db
from .api.routes import jobs as jobs_routes
from .core.logging import setup_logging
def create_app() -> FastAPI:
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
docs_url="/docs",
redoc_url="/redoc",
openapi_url=f"{settings.API_V1_PREFIX}/openapi.json",
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(jobs_routes.router, prefix=f"{settings.API_V1_PREFIX}/jobs", tags=["jobs"])
@app.on_event("startup")
def _on_startup() -> None:
setup_logging(settings.LOG_LEVEL, settings.LOG_FORMAT)
init_db()
@app.get("/healthz")
def healthz() -> dict:
return {"status": "ok"}
return app
app = create_app()
"""FastAPI 主应用"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

View File

@@ -0,0 +1,11 @@
from .job import Job, Step, JobLog, JobStatus, StepStatus
__all__ = [
"Job",
"Step",
"JobLog",
"JobStatus",
"StepStatus",
]

View File

@@ -0,0 +1,25 @@
"""基础模型"""
from datetime import datetime
from sqlmodel import SQLModel, Field
from uuid import uuid4
def generate_uuid() -> str:
"""生成 UUID"""
return str(uuid4())
class TimestampModel(SQLModel):
"""时间戳 Mixin"""
created_at: datetime = Field(
default_factory=datetime.utcnow,
nullable=False,
sa_column_kwargs={"index": True},
)
updated_at: datetime = Field(
default_factory=datetime.utcnow,
nullable=False,
sa_column_kwargs={"onupdate": datetime.utcnow},
)

View File

@@ -1,3 +1,187 @@
"""任务模型(使用 SQLModel"""
from typing import Optional, List
from datetime import datetime
from enum import Enum
from sqlmodel import SQLModel, Field, Relationship, Column
from sqlalchemy import JSON
from .base import TimestampModel, generate_uuid
class JobStatus(str, Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class StepStatus(str, Enum):
"""步骤状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
class JobBase(SQLModel):
"""Job 基础字段"""
name: str = Field(max_length=255)
description: Optional[str] = None
sequence_type: str = Field(default="nucl", max_length=20)
scaf_suffix: str = Field(default=".fna", max_length=50)
threads: int = Field(default=4, ge=1, le=32)
update_db: bool = Field(default=False)
class Job(JobBase, TimestampModel, table=True):
"""Job 数据库模型"""
__tablename__ = "jobs"
id: str = Field(
default_factory=generate_uuid,
primary_key=True,
index=True,
)
user_id: Optional[str] = Field(default=None, index=True)
status: JobStatus = Field(
default=JobStatus.PENDING,
sa_column_kwargs={"index": True},
)
input_files: List[dict] = Field(default_factory=list, sa_column=Column(JSON))
workspace_path: Optional[str] = Field(default=None, max_length=500)
result_url: Optional[str] = Field(default=None, max_length=1000)
celery_task_id: Optional[str] = Field(default=None, max_length=100, index=True)
current_step: Optional[str] = Field(default=None, max_length=100)
progress: int = Field(default=0, ge=0, le=100)
error_message: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
steps: List["Step"] = Relationship(
back_populates="job",
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
)
logs: List["JobLog"] = Relationship(
back_populates="job",
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
)
class JobCreate(JobBase):
"""创建 Job 时的请求模型"""
pass
class JobRead(JobBase):
"""读取 Job 时的响应模型"""
id: str
user_id: Optional[str]
status: JobStatus
workspace_path: Optional[str]
result_url: Optional[str]
celery_task_id: Optional[str]
current_step: Optional[str]
progress: int
error_message: Optional[str]
started_at: Optional[datetime]
completed_at: Optional[datetime]
created_at: datetime
updated_at: datetime
class JobUpdate(SQLModel):
"""更新 Job 时的请求模型"""
status: Optional[JobStatus] = None
current_step: Optional[str] = None
progress: Optional[int] = None
error_message: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class StepBase(SQLModel):
"""Step 基础字段"""
step_name: str = Field(max_length=100)
step_order: int
class Step(StepBase, table=True):
"""Step 数据库模型"""
__tablename__ = "steps"
id: Optional[int] = Field(default=None, primary_key=True)
job_id: str = Field(foreign_key="jobs.id", index=True)
status: StepStatus = Field(default=StepStatus.PENDING)
celery_task_id: Optional[str] = Field(default=None, max_length=100)
log_file: Optional[str] = Field(default=None, max_length=500)
result_data: Optional[dict] = Field(default=None, sa_column=Column(JSON))
error_message: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
duration_seconds: Optional[int] = None
job: "Job" = Relationship(back_populates="steps")
class StepRead(StepBase):
"""读取 Step 时的响应模型"""
id: int
job_id: str
status: StepStatus
celery_task_id: Optional[str]
log_file: Optional[str]
result_data: Optional[dict]
error_message: Optional[str]
started_at: Optional[datetime]
completed_at: Optional[datetime]
duration_seconds: Optional[int]
class JobLogBase(SQLModel):
"""JobLog 基础字段"""
level: str = Field(max_length=20)
message: str
step_name: Optional[str] = Field(default=None, max_length=100)
class JobLog(JobLogBase, table=True):
"""JobLog 数据库模型"""
__tablename__ = "job_logs"
id: Optional[int] = Field(default=None, primary_key=True)
job_id: str = Field(foreign_key="jobs.id", index=True)
metadata: Optional[dict] = Field(default=None, sa_column=Column(JSON))
timestamp: datetime = Field(
default_factory=datetime.utcnow,
sa_column_kwargs={"index": True},
)
job: "Job" = Relationship(back_populates="logs")
class JobLogRead(JobLogBase):
"""读取 JobLog 时的响应模型"""
id: int
job_id: str
metadata: Optional[dict]
timestamp: datetime
"""任务模型"""
from sqlalchemy import Column, String, Integer, DateTime, JSON, Enum, Text
from sqlalchemy.sql import func

View File

@@ -1,3 +1,125 @@
"""任务相关的 Pydantic Schema"""
from typing import Optional, List
from pydantic import BaseModel, Field, field_validator, model_validator
from enum import Enum
class SequenceType(str, Enum):
NUCL = "nucl"
PROT = "prot"
ORFS = "orfs"
READS = "reads"
class PlatformType(str, Enum):
ILLUMINA = "illumina"
PACBIO = "pacbio"
OXFORD = "oxford"
HYBRID = "hybrid"
class JobCreateRequest(BaseModel):
"""创建任务请求(包含各序列类型的参数)"""
# 基本信息
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
description: Optional[str] = Field(None, max_length=1000, description="任务描述")
# 序列类型
sequence_type: SequenceType = Field(default=SequenceType.NUCL, description="输入序列类型")
# nucl
scaf_suffix: Optional[str] = Field(
None, pattern=r"^\.\w+$", description="基因组文件后缀nucl", examples=[".fna", ".fasta", ".fa"]
)
# orfs
orfs_suffix: Optional[str] = Field(None, pattern=r"^\.\w+$", description="ORF 文件后缀orfs")
# prot
prot_suffix: Optional[str] = Field(None, pattern=r"^\.\w+$", description="蛋白文件后缀prot")
# reads
platform: Optional[PlatformType] = Field(None, description="测序平台reads")
reads1_suffix: Optional[str] = Field(None, description="Reads1 后缀illumina/hybrid")
reads2_suffix: Optional[str] = Field(None, description="Reads2 后缀illumina/hybrid")
genome_size: Optional[str] = Field(
None, pattern=r"^\d+(\.\d+)?[mMgG]?$", description="基因组大小估计pacbio/oxford"
)
suffix_len: Optional[int] = Field(None, ge=0, description="reads 文件后缀长度")
# hybrid 需要完整文件名
short1: Optional[str] = Field(None, description="短 reads 1 文件名(完整文件名)")
short2: Optional[str] = Field(None, description="短 reads 2 文件名(完整文件名)")
long: Optional[str] = Field(None, description="长 reads 文件名(完整文件名)")
# 执行参数
threads: int = Field(default=4, ge=1, le=32, description="线程数")
update_db: bool = Field(default=False, description="是否更新数据库")
assemble_only: bool = Field(default=False, description="仅执行组装")
@field_validator("scaf_suffix", "orfs_suffix", "prot_suffix")
@classmethod
def validate_suffix(cls, v: Optional[str]) -> Optional[str]:
if v is not None and not v.startswith("."):
raise ValueError("文件后缀必须以 . 开头")
return v
@model_validator(mode="after")
def validate_by_type(self):
if self.sequence_type == SequenceType.NUCL:
if not self.scaf_suffix:
self.scaf_suffix = ".fna"
elif self.sequence_type == SequenceType.ORFS:
if not self.orfs_suffix:
self.orfs_suffix = ".ffn"
elif self.sequence_type == SequenceType.PROT:
if not self.prot_suffix:
self.prot_suffix = ".faa"
elif self.sequence_type == SequenceType.READS:
if not self.platform:
raise ValueError("reads 类型必须指定 platform")
if self.platform == PlatformType.ILLUMINA:
if not self.reads1_suffix or not self.reads2_suffix:
raise ValueError("illumina 平台必须指定 reads1_suffix 和 reads2_suffix")
elif self.platform in [PlatformType.PACBIO, PlatformType.OXFORD]:
if not self.reads1_suffix:
raise ValueError(f"{self.platform} 平台必须指定 reads1_suffix")
if not self.genome_size:
raise ValueError(f"{self.platform} 平台必须指定 genome_size")
elif self.platform == PlatformType.HYBRID:
if not all([self.short1, self.short2, self.long]):
raise ValueError("hybrid 平台必须指定 short1, short2, long")
return self
class FileUploadInfo(BaseModel):
filename: str
size: int
content_type: Optional[str] = None
path: str
class JobCreateResponse(BaseModel):
job_id: str
message: str
uploaded_files: List[FileUploadInfo]
workspace_path: str
celery_task_id: Optional[str] = None
warnings: Optional[List[str]] = None
class JobStatusResponse(BaseModel):
job_id: str
name: str
status: str
progress: int
current_step: Optional[str] = None
error_message: Optional[str] = None
created_at: str
started_at: Optional[str] = None
completed_at: Optional[str] = None
"""任务 Schema"""
from pydantic import BaseModel
from typing import Optional, List

View File

@@ -0,0 +1,2 @@

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
from typing import Dict, List, Optional
from pathlib import Path
import docker
from ..core.config import settings
class DockerRunner:
"""通过 docker.sock 在宿主机运行容器的轻量封装。"""
def __init__(self) -> None:
# 依赖 /var/run/docker.sock 已在 docker-compose 挂载至 worker 容器
self.client = docker.from_env()
def run(
self,
image: Optional[str],
command: List[str],
workdir: Path,
mounts: Dict[Path, str],
env: Optional[Dict[str, str]] = None,
platform: Optional[str] = None,
detach: bool = False,
remove: bool = True,
) -> str:
"""
运行容器并返回容器日志(非 detach或容器 IDdetach
mounts: {host_path: container_path}
"""
image_to_use = image or settings.DOCKER_IMAGE
platform_to_use = platform or settings.DOCKER_PLATFORM
volumes = {str(host): {"bind": container, "mode": "rw"} for host, container in mounts.items()}
container = self.client.containers.run(
image=image_to_use,
command=command,
working_dir=str(workdir),
volumes=volumes,
environment=env or {},
platform=platform_to_use,
detach=detach,
remove=remove if not detach else False,
)
if detach:
return container.id
# 同步模式:等待并返回日志
result_bytes = container.logs(stream=False)
return result_bytes.decode("utf-8", errors="ignore")

View File

@@ -0,0 +1,45 @@
from pathlib import Path
from typing import List, Dict
import shutil
from fastapi import UploadFile
from ..core.config import settings
class WorkspaceManager:
def __init__(self) -> None:
self.base_path = Path(settings.WORKSPACE_BASE_PATH)
self.base_path.mkdir(parents=True, exist_ok=True)
def create_workspace(self, job_id: str) -> Dict[str, Path]:
root = self.base_path / job_id
input_dir = root / "inputs"
logs_dir = root / "logs"
results_dir = root / "results"
for d in [root, input_dir, logs_dir, results_dir]:
d.mkdir(parents=True, exist_ok=True)
return {"root": root, "inputs": input_dir, "logs": logs_dir, "results": results_dir}
def save_input_files(self, job_id: str, files: List[UploadFile]) -> List[dict]:
ws = self.create_workspace(job_id)
saved: List[dict] = []
for f in files:
dst = ws["inputs"] / f.filename
with dst.open("wb") as out:
out.write(f.file.read())
saved.append({"filename": f.filename, "path": str(dst)})
return saved
def cleanup_workspace(self, job_id: str, keep_results: bool = False) -> None:
root = self.base_path / job_id
if not root.exists():
return
if keep_results:
# 仅清理 inputs 和 logs
for name in ["inputs", "logs"]:
p = root / name
if p.exists():
shutil.rmtree(p, ignore_errors=True)
else:
shutil.rmtree(root, ignore_errors=True)

View File

@@ -0,0 +1,339 @@
"""Docker/Podman 容器管理(修正版,支持 arm64/macOS 与 linux/amd64"""
from __future__ import annotations
import os
import subprocess
import logging
import time
from pathlib import Path
from typing import Dict, Any, Optional, List
try:
import docker # type: ignore
except Exception: # 允许在无 docker SDK 环境下使用 podman fallback
docker = None # type: ignore
from ..core.config import settings
logger = logging.getLogger(__name__)
def _which(cmd: str) -> Optional[str]:
from shutil import which
return which(cmd)
class DockerContainerManager:
"""容器管理器 - 兼容 Docker 与 Podman。
优先尝试 docker SDK若不可用则回落到 podman CLI或 docker CLI
在 arm64 主机上默认以 --platform linux/amd64 运行镜像。
"""
def __init__(
self,
image: str = settings.DOCKER_IMAGE,
platform: str = settings.DOCKER_PLATFORM,
) -> None:
self.image = image
self.platform = platform
self._engine: str = "docker"
self._client = None
# 首选 docker-py 客户端(若可用)
if docker is not None:
try:
self._client = docker.from_env()
# 探测 daemon
self._client.ping()
self._engine = "docker-sdk"
except Exception as err:
logger.info(f"docker SDK 不可用,将尝试 CLI 回落: {err}")
self._client = None
# CLI 回落:优先 podman其次 docker
if self._client is None:
if _which("podman"):
self._engine = "podman-cli"
elif _which("docker"):
self._engine = "docker-cli"
else:
raise RuntimeError("未找到可用的容器引擎(需要 podman 或 docker")
self._ensure_image()
# ----------------------------- 公共方法 -----------------------------
def run_command_in_container(
self,
command: List[str],
volumes: Dict[str, Dict[str, str]],
environment: Optional[Dict[str, str]] = None,
working_dir: str = "/workspace",
name: Optional[str] = None,
detach: bool = False,
remove: bool = True,
) -> Dict[str, Any]:
"""在容器中执行命令,返回执行结果。"""
if self._engine == "docker-sdk" and self._client is not None:
return self._run_with_docker_sdk(
command, volumes, environment, working_dir, name, detach, remove
)
else:
return self._run_with_cli(
command, volumes, environment, working_dir, name, detach, remove
)
def update_database(self, log_dir: Path) -> Dict[str, Any]:
"""在容器中更新 BtToxin_Digger 数据库。"""
cmd = [
"/usr/local/env-execute",
"BtToxin_Digger",
"--update-db",
]
vols = {str(log_dir): {"bind": "/logs", "mode": "rw"}}
result = self.run_command_in_container(
command=cmd, volumes=vols, working_dir="/tmp", name=f"bttoxin_update_db_{int(time.time())}"
)
if result.get("logs"):
(log_dir / "update_db.log").write_text(result["logs"], encoding="utf-8")
return result
def validate_reads_filenames(
self,
input_dir: Path,
platform: str,
reads1_suffix: str,
reads2_suffix: Optional[str] = None,
suffix_len: int = 0,
) -> Dict[str, Any]:
files = list(input_dir.glob("*"))
if platform == "illumina":
r1 = [f for f in files if reads1_suffix and reads1_suffix in f.name]
r2 = [f for f in files if reads2_suffix and reads2_suffix in f.name]
if not r1 or not r2 or len(r1) != len(r2):
return {"valid": False, "error": "Illumina R1/R2 配对数量不匹配或缺失"}
for f1 in r1:
strain = f1.name.replace(reads1_suffix, "")
if not (input_dir / f"{strain}{reads2_suffix}").exists():
return {"valid": False, "error": f"未找到配对文件: {strain}{reads2_suffix}"}
return {
"valid": True,
"strain_count": len(r1),
"suggested_suffix_len": suffix_len or len(reads1_suffix),
}
if platform in ("pacbio", "oxford"):
r = [f for f in files if reads1_suffix and reads1_suffix in f.name]
if not r:
return {"valid": False, "error": f"未找到匹配 {reads1_suffix} 的 reads 文件"}
return {
"valid": True,
"strain_count": len(r),
"suggested_suffix_len": suffix_len or len(reads1_suffix),
}
return {"valid": True}
def run_bttoxin_digger(
self,
input_dir: Path,
output_dir: Path,
log_dir: Path,
sequence_type: str = "nucl",
scaf_suffix: str = ".fna",
threads: int = 4,
**kwargs: Any,
) -> Dict[str, Any]:
"""在容器中运行 BtToxin_Digger 主分析(工作目录挂载到 /workspace"""
command: List[str] = [
"/usr/local/env-execute",
"BtToxin_Digger",
"--SeqPath",
"/data/input",
"--SequenceType",
sequence_type,
"--threads",
str(threads),
]
if sequence_type == "nucl":
command += ["--Scaf_suffix", scaf_suffix]
elif sequence_type == "orfs":
command += ["--orfs_suffix", kwargs.get("orfs_suffix", ".ffn")]
elif sequence_type == "prot":
command += ["--prot_suffix", kwargs.get("prot_suffix", ".faa")]
elif sequence_type == "reads":
platform = kwargs.get("platform", "illumina")
command += ["--platform", platform]
if platform == "illumina":
r1 = kwargs.get("reads1_suffix", "_R1.fastq.gz")
r2 = kwargs.get("reads2_suffix", "_R2.fastq.gz")
sfx = kwargs.get("suffix_len") or len(r1)
v = self.validate_reads_filenames(input_dir, platform, r1, r2, sfx)
if not v.get("valid"):
raise ValueError(f"Reads 文件验证失败: {v.get('error')}")
sfx = v.get("suggested_suffix_len", sfx)
command += ["--reads1", r1, "--reads2", r2, "--suffix_len", str(sfx)]
elif platform in ("pacbio", "oxford"):
r = kwargs.get("reads1_suffix", ".fastq.gz")
gsize = kwargs.get("genome_size", "6.07m")
sfx = kwargs.get("suffix_len") or len(r)
v = self.validate_reads_filenames(input_dir, platform, r, None, sfx)
if not v.get("valid"):
raise ValueError(f"Reads 文件验证失败: {v.get('error')}")
sfx = v.get("suggested_suffix_len", sfx)
command += ["--reads1", r, "--genomeSize", gsize, "--suffix_len", str(sfx)]
elif platform == "hybrid":
short1 = kwargs.get("short1")
short2 = kwargs.get("short2")
long = kwargs.get("long")
if not all([short1, short2, long]):
raise ValueError("hybrid 需要 short1/short2/long 三个完整文件名")
for fn in (short1, short2, long):
if not (input_dir / fn).exists():
raise ValueError(f"文件不存在: {fn}")
command += [
"--short1",
short1,
"--short2",
short2,
"--long",
long,
"--hout",
"/workspace/Results/Assembles/Hybrid",
]
if kwargs.get("assemble_only"):
command.append("--assemble_only")
volumes = {
str(input_dir.resolve()): {"bind": "/data/input", "mode": "ro"},
str(output_dir.resolve()): {"bind": "/workspace", "mode": "rw"},
str(log_dir.resolve()): {"bind": "/data/logs", "mode": "rw"},
}
logger.info("开始 BtToxin_Digger 分析...")
result = self.run_command_in_container(
command=command,
volumes=volumes,
working_dir="/workspace",
name=f"bttoxin_digger_{int(time.time())}",
)
# 保存容器日志
logs_path = log_dir / "digger_execution.log"
if result.get("logs"):
logs_path.write_text(result["logs"], encoding="utf-8")
logger.info(f"容器日志已保存: {logs_path}")
# 验证输出
results_dir = output_dir / "Results"
if result.get("success") and results_dir.exists():
files = [f for f in results_dir.rglob("*") if f.is_file()]
result["output_files"] = len(files)
else:
result["output_files"] = 0
return result
# ----------------------------- 内部实现 -----------------------------
def _ensure_image(self) -> None:
if self._engine == "docker-sdk" and self._client is not None:
try:
self._client.images.get(self.image)
return
except Exception:
logger.info(f"拉取镜像 {self.image} (platform={self.platform}) ...")
self._client.images.pull(self.image, platform=self.platform)
else:
# CLI 模式:先尝试拉取
cli = "podman" if self._engine == "podman-cli" else "docker"
try:
subprocess.run(
[cli, "pull", "--platform", self.platform, self.image],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
except Exception as err:
logger.warning(f"{cli} pull 失败: {err}")
def _run_with_docker_sdk(
self,
command: List[str],
volumes: Dict[str, Dict[str, str]],
environment: Optional[Dict[str, str]],
working_dir: str,
name: Optional[str],
detach: bool,
remove: bool,
) -> Dict[str, Any]:
assert self._client is not None
try:
container = self._client.containers.run(
image=self.image,
command=command,
volumes=volumes,
environment=environment or {},
working_dir=working_dir,
platform=self.platform,
name=name,
detach=detach,
remove=False, # 等获取日志后再删
stdout=True,
stderr=True,
)
if detach:
return {"success": True, "container_id": container.id, "status": "running"}
exit_info = container.wait()
code = exit_info.get("StatusCode", 1)
logs = container.logs().decode("utf-8", errors="ignore")
if remove:
try:
container.remove()
except Exception:
pass
return {"success": code == 0, "exit_code": code, "logs": logs, "status": "completed" if code == 0 else "failed"}
except Exception as e:
logger.error(f"docker SDK 运行失败: {e}", exc_info=True)
return {"success": False, "error": str(e), "exit_code": -1, "status": "error"}
def _run_with_cli(
self,
command: List[str],
volumes: Dict[str, Dict[str, str]],
environment: Optional[Dict[str, str]],
working_dir: str,
name: Optional[str],
detach: bool,
remove: bool,
) -> Dict[str, Any]:
cli = "podman" if self._engine == "podman-cli" else "docker"
cmd: List[str] = [cli, "run", "--rm" if remove and not detach else ""]
cmd = [c for c in cmd if c]
cmd += ["--platform", self.platform]
if name:
cmd += ["--name", name]
for host, spec in volumes.items():
bind = spec.get("bind")
mode = spec.get("mode", "rw")
cmd += ["-v", f"{host}:{bind}:{mode}"]
for k, v in (environment or {}).items():
cmd += ["-e", f"{k}={v}"]
cmd += ["-w", working_dir, self.image]
cmd += command
try:
if detach:
# 后台运行CLI 简化返回
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
return {"success": True, "status": "running", "pid": p.pid}
else:
proc = subprocess.run(cmd, capture_output=True, text=True)
out = (proc.stdout or "") + (proc.stderr or "")
return {"success": proc.returncode == 0, "exit_code": proc.returncode, "logs": out, "status": "completed" if proc.returncode == 0 else "failed"}
except Exception as e:
logger.error(f"{cli} 运行失败: {e}", exc_info=True)
return {"success": False, "error": str(e), "exit_code": -1, "status": "error"}

View File

@@ -0,0 +1,2 @@

View File

@@ -0,0 +1,88 @@
from datetime import datetime
from pathlib import Path
from sqlmodel import Session, select
from . import __init__ # noqa: F401
from ..core.celery_app import celery_app
from ..core.database import SessionLocal
from ..models.job import Job, Step, StepStatus, JobStatus
from .steps.run_digger import run_bttoxin_digger
def _with_session():
return SessionLocal()
def _update_job(session: Session, job: Job, **changes) -> None:
for k, v in changes.items():
setattr(job, k, v)
session.add(job)
session.commit()
session.refresh(job)
def _create_or_get_step(session: Session, job_id: str, step_name: str, order: int) -> Step:
step = session.exec(select(Step).where(Step.job_id == job_id, Step.step_name == step_name)).first()
if step is None:
step = Step(job_id=job_id, step_name=step_name, step_order=order, status=StepStatus.PENDING)
session.add(step)
session.commit()
session.refresh(step)
return step
@celery_app.task(name="pipeline.orchestrate")
def orchestrate_pipeline(job_id: str, config: dict) -> dict:
with _with_session() as session:
job = session.get(Job, job_id)
if job is None:
return {"error": "job not found", "job_id": job_id}
_update_job(
session,
job,
status=JobStatus.RUNNING,
started_at=datetime.utcnow(),
current_step="run_digger",
progress=0,
)
# Step 1: run_digger
step = _create_or_get_step(session, job_id, "run_digger", 1)
step.status = StepStatus.RUNNING
step.started_at = datetime.utcnow()
session.add(step)
session.commit()
session.refresh(step)
try:
ws_root = Path(job.workspace_path or ".")
result = run_bttoxin_digger(
workspace_root=ws_root,
threads=job.threads,
sequence_type=job.sequence_type,
scaf_suffix=job.scaf_suffix,
update_db=job.update_db,
)
step.status = StepStatus.COMPLETED
step.completed_at = datetime.utcnow()
step.result_data = {"results_dir": result["results_dir"]}
session.add(step)
session.commit()
session.refresh(step)
_update_job(session, job, progress=100, status=JobStatus.COMPLETED, completed_at=datetime.utcnow())
return {"job_id": job_id, "ok": True, "step": "run_digger", "result": result}
except Exception as exc: # 仅顶层捕获,记录失败
step.status = StepStatus.FAILED
step.error_message = str(exc)
step.completed_at = datetime.utcnow()
session.add(step)
session.commit()
_update_job(session, job, status=JobStatus.FAILED, error_message=str(exc))
return {"job_id": job_id, "ok": False, "error": str(exc)}

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict
from ...services.docker_service import DockerRunner
def run_bttoxin_digger(workspace_root: Path, threads: int, sequence_type: str, scaf_suffix: str, update_db: bool) -> Dict:
runner = DockerRunner()
inputs_dir = workspace_root / "inputs"
results_dir = workspace_root / "results"
results_dir.mkdir(parents=True, exist_ok=True)
# 示例命令:根据实际 BtToxin_Digger CLI 调整
command = [
"BtToxin_Digger",
"--threads", str(threads),
"--seq_type", sequence_type,
"--scaf_suffix", scaf_suffix,
"--input", str(inputs_dir),
"--outdir", str(results_dir),
]
if update_db:
command += ["--update_db"]
logs = runner.run(
image=None,
command=command,
workdir=Path("/work"),
mounts={
workspace_root: "/work",
},
env=None,
platform=None,
detach=False,
remove=True,
)
return {
"command": command,
"logs": logs,
"results_dir": str(results_dir),
}

View File

@@ -11,8 +11,8 @@ flower==2.0.1
# 容器管理
docker==7.1.0
# 数据库
sqlalchemy==2.0.36
# 数据库SQLModel + PostgreSQL + Alembic
sqlmodel==0.0.25
alembic==1.14.0
psycopg2-binary==2.9.10
@@ -26,7 +26,7 @@ pandas==2.2.3
# 工具
pydantic==2.10.4
pydantic-settings==2.6.1
pydantic-settings==2.7.1
python-dotenv==1.0.1
aiofiles==24.1.0