""" FastAPI服务,基于并行广谱抗菌预测API """ from typing import List, Union, Optional from fastapi import FastAPI from pydantic import BaseModel, Field import pandas as pd from .broad_spectrum_api import ( ParallelBroadSpectrumPredictor, PredictionConfig, MoleculeInput, BroadSpectrumResult ) # 数据模型 class MoleculeInfo(BaseModel): smiles: str chem_id: Optional[str] = None class MoleculeInputRequest(BaseModel): molecules: Optional[List[MoleculeInfo]] = None smiles: Optional[Union[str, List[str]]] = None chem_id: Optional[Union[str, List[str]]] = None aggregate_scores: bool = Field(False, description="Whether to aggregate predictions") app_threshold: float = Field(0.04374140128493309, description="Threshold for growth inhibition") min_nkill: int = Field(10, description="Minimum strains for broad spectrum") batch_size: int = Field(100, description="Batch size for processing") n_workers: Optional[int] = Field(None, description="Number of worker processes") class PredictionResponse(BaseModel): results: List[BroadSpectrumResult] # 创建FastAPI应用 app = FastAPI(title="Antimicrobial Prediction API", description="API for predicting antimicrobial activity of compounds using MolE and XGBoost", version="1.0.0") # 初始化预测器(在实际应用中可能需要更复杂的初始化) predictor = None @app.on_event("startup") async def startup_event(): """应用启动时初始化预测器""" global predictor config = PredictionConfig() predictor = ParallelBroadSpectrumPredictor(config) @app.post("/predict", response_model=PredictionResponse) async def predict_antimicrobial(input_data: MoleculeInputRequest): """ 预测化合物的抗菌活性 Args: input_data: 包含分子信息的请求数据 Returns: 预测结果列表 """ global predictor if not predictor: # 如果预测器未初始化,则创建一个 config = PredictionConfig( batch_size=input_data.batch_size, n_workers=input_data.n_workers ) predictor = ParallelBroadSpectrumPredictor(config) # 处理输入数据 if input_data.molecules: # 直接使用MoleculeInput对象列表 molecules = [ MoleculeInput(smiles=m.smiles, chem_id=m.chem_id or f"mol{i+1}") for i, m in enumerate(input_data.molecules) ] elif input_data.smiles: # 从SMILES字符串创建分子列表 if isinstance(input_data.smiles, str): smiles_list = [input_data.smiles] else: smiles_list = input_data.smiles if input_data.chem_id: if isinstance(input_data.chem_id, str): chem_ids = [input_data.chem_id] else: chem_ids = input_data.chem_id else: chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))] molecules = [ MoleculeInput(smiles=smiles, chem_id=chem_id) for smiles, chem_id in zip(smiles_list, chem_ids) ] else: raise ValueError("Either 'molecules' or 'smiles' must be provided") # 执行预测 try: results = predictor.predict_batch(molecules) return PredictionResponse(results=results) except Exception as e: raise RuntimeError(f"Prediction failed: {str(e)}") @app.get("/health") async def health_check(): """ 健康检查端点 """ return {"status": "healthy"} @app.get("/") async def root(): """ 根路径,提供API信息 """ return { "message": "Antimicrobial Prediction API", "version": "1.0.0", "description": "API for predicting antimicrobial activity of compounds" }