131 lines
3.8 KiB
Python
131 lines
3.8 KiB
Python
"""
|
||
FastAPI服务,基于并行广谱抗菌预测API
|
||
"""
|
||
|
||
from typing import List, Union, Optional
|
||
from fastapi import FastAPI
|
||
from pydantic import BaseModel, Field
|
||
import pandas as pd
|
||
from .broad_spectrum_api import (
|
||
ParallelBroadSpectrumPredictor,
|
||
PredictionConfig,
|
||
MoleculeInput,
|
||
BroadSpectrumResult
|
||
)
|
||
|
||
|
||
# 数据模型
|
||
class MoleculeInfo(BaseModel):
|
||
smiles: str
|
||
chem_id: Optional[str] = None
|
||
|
||
|
||
class MoleculeInputRequest(BaseModel):
|
||
molecules: Optional[List[MoleculeInfo]] = None
|
||
smiles: Optional[Union[str, List[str]]] = None
|
||
chem_id: Optional[Union[str, List[str]]] = None
|
||
aggregate_scores: bool = Field(False, description="Whether to aggregate predictions")
|
||
app_threshold: float = Field(0.04374140128493309, description="Threshold for growth inhibition")
|
||
min_nkill: int = Field(10, description="Minimum strains for broad spectrum")
|
||
batch_size: int = Field(100, description="Batch size for processing")
|
||
n_workers: Optional[int] = Field(None, description="Number of worker processes")
|
||
|
||
|
||
class PredictionResponse(BaseModel):
|
||
results: List[BroadSpectrumResult]
|
||
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(title="Antimicrobial Prediction API",
|
||
description="API for predicting antimicrobial activity of compounds using MolE and XGBoost",
|
||
version="1.0.0")
|
||
|
||
|
||
# 初始化预测器(在实际应用中可能需要更复杂的初始化)
|
||
predictor = None
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
"""应用启动时初始化预测器"""
|
||
global predictor
|
||
config = PredictionConfig()
|
||
predictor = ParallelBroadSpectrumPredictor(config)
|
||
|
||
|
||
@app.post("/predict", response_model=PredictionResponse)
|
||
async def predict_antimicrobial(input_data: MoleculeInputRequest):
|
||
"""
|
||
预测化合物的抗菌活性
|
||
|
||
Args:
|
||
input_data: 包含分子信息的请求数据
|
||
|
||
Returns:
|
||
预测结果列表
|
||
"""
|
||
global predictor
|
||
|
||
if not predictor:
|
||
# 如果预测器未初始化,则创建一个
|
||
config = PredictionConfig(
|
||
batch_size=input_data.batch_size,
|
||
n_workers=input_data.n_workers
|
||
)
|
||
predictor = ParallelBroadSpectrumPredictor(config)
|
||
|
||
# 处理输入数据
|
||
if input_data.molecules:
|
||
# 直接使用MoleculeInput对象列表
|
||
molecules = [
|
||
MoleculeInput(smiles=m.smiles, chem_id=m.chem_id or f"mol{i+1}")
|
||
for i, m in enumerate(input_data.molecules)
|
||
]
|
||
elif input_data.smiles:
|
||
# 从SMILES字符串创建分子列表
|
||
if isinstance(input_data.smiles, str):
|
||
smiles_list = [input_data.smiles]
|
||
else:
|
||
smiles_list = input_data.smiles
|
||
|
||
if input_data.chem_id:
|
||
if isinstance(input_data.chem_id, str):
|
||
chem_ids = [input_data.chem_id]
|
||
else:
|
||
chem_ids = input_data.chem_id
|
||
else:
|
||
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
|
||
|
||
molecules = [
|
||
MoleculeInput(smiles=smiles, chem_id=chem_id)
|
||
for smiles, chem_id in zip(smiles_list, chem_ids)
|
||
]
|
||
else:
|
||
raise ValueError("Either 'molecules' or 'smiles' must be provided")
|
||
|
||
# 执行预测
|
||
try:
|
||
results = predictor.predict_batch(molecules)
|
||
return PredictionResponse(results=results)
|
||
except Exception as e:
|
||
raise RuntimeError(f"Prediction failed: {str(e)}")
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""
|
||
健康检查端点
|
||
"""
|
||
return {"status": "healthy"}
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""
|
||
根路径,提供API信息
|
||
"""
|
||
return {
|
||
"message": "Antimicrobial Prediction API",
|
||
"version": "1.0.0",
|
||
"description": "API for predicting antimicrobial activity of compounds"
|
||
} |