Files
mm644706215 a56e60e9a3 first add
2025-10-16 17:21:48 +08:00

131 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FastAPI服务基于并行广谱抗菌预测API
"""
from typing import List, Union, Optional
from fastapi import FastAPI
from pydantic import BaseModel, Field
import pandas as pd
from .broad_spectrum_api import (
ParallelBroadSpectrumPredictor,
PredictionConfig,
MoleculeInput,
BroadSpectrumResult
)
# 数据模型
class MoleculeInfo(BaseModel):
smiles: str
chem_id: Optional[str] = None
class MoleculeInputRequest(BaseModel):
molecules: Optional[List[MoleculeInfo]] = None
smiles: Optional[Union[str, List[str]]] = None
chem_id: Optional[Union[str, List[str]]] = None
aggregate_scores: bool = Field(False, description="Whether to aggregate predictions")
app_threshold: float = Field(0.04374140128493309, description="Threshold for growth inhibition")
min_nkill: int = Field(10, description="Minimum strains for broad spectrum")
batch_size: int = Field(100, description="Batch size for processing")
n_workers: Optional[int] = Field(None, description="Number of worker processes")
class PredictionResponse(BaseModel):
results: List[BroadSpectrumResult]
# 创建FastAPI应用
app = FastAPI(title="Antimicrobial Prediction API",
description="API for predicting antimicrobial activity of compounds using MolE and XGBoost",
version="1.0.0")
# 初始化预测器(在实际应用中可能需要更复杂的初始化)
predictor = None
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化预测器"""
global predictor
config = PredictionConfig()
predictor = ParallelBroadSpectrumPredictor(config)
@app.post("/predict", response_model=PredictionResponse)
async def predict_antimicrobial(input_data: MoleculeInputRequest):
"""
预测化合物的抗菌活性
Args:
input_data: 包含分子信息的请求数据
Returns:
预测结果列表
"""
global predictor
if not predictor:
# 如果预测器未初始化,则创建一个
config = PredictionConfig(
batch_size=input_data.batch_size,
n_workers=input_data.n_workers
)
predictor = ParallelBroadSpectrumPredictor(config)
# 处理输入数据
if input_data.molecules:
# 直接使用MoleculeInput对象列表
molecules = [
MoleculeInput(smiles=m.smiles, chem_id=m.chem_id or f"mol{i+1}")
for i, m in enumerate(input_data.molecules)
]
elif input_data.smiles:
# 从SMILES字符串创建分子列表
if isinstance(input_data.smiles, str):
smiles_list = [input_data.smiles]
else:
smiles_list = input_data.smiles
if input_data.chem_id:
if isinstance(input_data.chem_id, str):
chem_ids = [input_data.chem_id]
else:
chem_ids = input_data.chem_id
else:
chem_ids = [f"mol{i+1}" for i in range(len(smiles_list))]
molecules = [
MoleculeInput(smiles=smiles, chem_id=chem_id)
for smiles, chem_id in zip(smiles_list, chem_ids)
]
else:
raise ValueError("Either 'molecules' or 'smiles' must be provided")
# 执行预测
try:
results = predictor.predict_batch(molecules)
return PredictionResponse(results=results)
except Exception as e:
raise RuntimeError(f"Prediction failed: {str(e)}")
@app.get("/health")
async def health_check():
"""
健康检查端点
"""
return {"status": "healthy"}
@app.get("/")
async def root():
"""
根路径提供API信息
"""
return {
"message": "Antimicrobial Prediction API",
"version": "1.0.0",
"description": "API for predicting antimicrobial activity of compounds"
}