1. 代码修改
models/broad_spectrum_predictor.py: ✅ 新增 StrainPrediction dataclass(单个菌株预测结果) ✅ 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型) ✅ 添加 to_strain_predictions_list() 方法(类型安全转换) ✅ 新增 _prepare_strain_level_predictions() 方法 ✅ 修改 predict_batch() 方法支持 include_strain_predictions 参数 utils/mole_predictor.py: ✅ 添加 include_strain_predictions 参数到所有函数 ✅ 添加命令行参数 --include-strain-predictions ✅ 实现菌株级别数据与聚合结果的合并逻辑 ✅ 更新所有函数签名和文档字符串 2. 测试验证 ✅ 测试基本功能(仅聚合结果): test_3.csv → 3 行输出 ✅ 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40) ✅ 验证输出格式正确性 ✅ 验证每个分子都有完整的 40 个菌株预测 ✅ 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌) 3. 文档更新 README.md: ✅ 更新命令行使用示例 ✅ 添加 Python API 使用示例(包含菌株预测) ✅ 添加详细的输出格式说明 ✅ 添加 40 种菌株列表概览 ✅ 添加数据使用场景示例(强化学习、筛选、可视化) Data/mole/README.md: ✅ 新增"菌株级别预测详情"章节 ✅ 完整的 40 种菌株列表(分革兰阴性/阳性) ✅ 数据访问方式示例(CSV 读取、Python API) ✅ 强化学习应用场景(状态表示、奖励函数设计) ✅ 数据可视化代码示例 ✅ 性能和存储建议
This commit is contained in:
@@ -65,9 +65,21 @@ class MoleculeInput:
|
||||
chem_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrainPrediction:
|
||||
"""单个菌株的预测结果"""
|
||||
pred_id: str # 格式: "chem_id:strain_name"
|
||||
chem_id: str
|
||||
strain_name: str
|
||||
antimicrobial_predictive_probability: float # 预测概率
|
||||
no_growth_probability: float # 不生长概率(1 - antimicrobial_predictive_probability)
|
||||
growth_inhibition: int # 二值化结果 (0/1)
|
||||
gram_stain: Optional[str] = None # 革兰染色类型
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadSpectrumResult:
|
||||
"""广谱抗菌预测结果"""
|
||||
"""广谱抗菌预测结果(聚合结果)"""
|
||||
chem_id: str
|
||||
apscore_total: float
|
||||
apscore_gnegative: float
|
||||
@@ -76,9 +88,10 @@ class BroadSpectrumResult:
|
||||
ginhib_gnegative: int
|
||||
ginhib_gpositive: int
|
||||
broad_spectrum: int
|
||||
strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测(40行)
|
||||
|
||||
def to_dict(self) -> Dict[str, Union[str, float, int]]:
|
||||
"""转换为字典格式"""
|
||||
"""转换为字典格式(仅聚合字段)"""
|
||||
return {
|
||||
'chem_id': self.chem_id,
|
||||
'apscore_total': self.apscore_total,
|
||||
@@ -89,6 +102,31 @@ class BroadSpectrumResult:
|
||||
'ginhib_gpositive': self.ginhib_gpositive,
|
||||
'broad_spectrum': self.broad_spectrum
|
||||
}
|
||||
|
||||
def to_strain_predictions_list(self) -> List['StrainPrediction']:
|
||||
"""
|
||||
将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景)
|
||||
|
||||
Returns:
|
||||
StrainPrediction 对象列表
|
||||
"""
|
||||
if self.strain_predictions is None or self.strain_predictions.empty:
|
||||
return []
|
||||
|
||||
strain_list = []
|
||||
for _, row in self.strain_predictions.iterrows():
|
||||
strain_pred = StrainPrediction(
|
||||
pred_id=row['pred_id'],
|
||||
chem_id=row['chem_id'],
|
||||
strain_name=row['strain_name'],
|
||||
antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'],
|
||||
no_growth_probability=row['no_growth_probability'],
|
||||
growth_inhibition=row['growth_inhibition'],
|
||||
gram_stain=row.get('gram_stain', None)
|
||||
)
|
||||
strain_list.append(strain_pred)
|
||||
|
||||
return strain_list
|
||||
|
||||
|
||||
class BroadSpectrumPredictor:
|
||||
@@ -259,6 +297,45 @@ class BroadSpectrumPredictor:
|
||||
|
||||
return df_label
|
||||
|
||||
def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
准备菌株级别的预测数据
|
||||
|
||||
Args:
|
||||
score_df: 原始预测分数DataFrame,包含 pred_id, 0, 1, growth_inhibition 列
|
||||
|
||||
Returns:
|
||||
格式化的菌株级别预测DataFrame
|
||||
"""
|
||||
# 创建副本
|
||||
strain_df = score_df.copy()
|
||||
|
||||
# 分离化合物ID和菌株名
|
||||
strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0]
|
||||
strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1]
|
||||
|
||||
# 添加革兰染色信息
|
||||
strain_df = self._gram_stain(strain_df)
|
||||
|
||||
# 重命名列为用户友好的名称
|
||||
strain_df = strain_df.rename(columns={
|
||||
"1": "antimicrobial_predictive_probability",
|
||||
"0": "no_growth_probability"
|
||||
})
|
||||
|
||||
# 选择并排序输出列
|
||||
output_columns = [
|
||||
"pred_id",
|
||||
"chem_id",
|
||||
"strain_name",
|
||||
"antimicrobial_predictive_probability",
|
||||
"no_growth_probability",
|
||||
"growth_inhibition",
|
||||
"gram_stain"
|
||||
]
|
||||
|
||||
return strain_df[output_columns]
|
||||
|
||||
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算抗菌潜力分数
|
||||
@@ -371,12 +448,14 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
results = self.predict_batch([molecule])
|
||||
return results[0]
|
||||
|
||||
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
|
||||
def predict_batch(self, molecules: List[MoleculeInput],
|
||||
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
|
||||
"""
|
||||
批量预测分子的广谱抗菌活性
|
||||
|
||||
Args:
|
||||
molecules: 分子输入列表
|
||||
include_strain_predictions: 是否在结果中包含菌株级别预测数据
|
||||
|
||||
Returns:
|
||||
广谱抗菌预测结果列表
|
||||
@@ -417,10 +496,16 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
# 合并结果
|
||||
print("Merging prediction results...")
|
||||
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
|
||||
all_pred_df = all_pred_df.reset_index()
|
||||
|
||||
# 准备菌株级别数据(如果需要)
|
||||
strain_level_data = None
|
||||
if include_strain_predictions:
|
||||
print("Preparing strain-level predictions...")
|
||||
strain_level_data = self._prepare_strain_level_predictions(all_pred_df)
|
||||
|
||||
# 计算抗菌潜力
|
||||
print("Calculating antimicrobial potential scores...")
|
||||
all_pred_df = all_pred_df.reset_index()
|
||||
agg_df = self._antimicrobial_potential(all_pred_df)
|
||||
|
||||
# 判断广谱抗菌
|
||||
@@ -431,6 +516,13 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
# 转换为结果对象
|
||||
results_list = []
|
||||
for _, row in agg_df.iterrows():
|
||||
# 获取该分子的菌株级别预测
|
||||
mol_strain_preds = None
|
||||
if strain_level_data is not None:
|
||||
mol_strain_preds = strain_level_data[
|
||||
strain_level_data['chem_id'] == row.name
|
||||
].reset_index(drop=True)
|
||||
|
||||
result = BroadSpectrumResult(
|
||||
chem_id=row.name,
|
||||
apscore_total=row["apscore_total"],
|
||||
@@ -439,7 +531,8 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
|
||||
ginhib_total=int(row["ginhib_total"]),
|
||||
ginhib_gnegative=int(row["ginhib_gnegative"]),
|
||||
ginhib_gpositive=int(row["ginhib_gpositive"]),
|
||||
broad_spectrum=int(row["broad_spectrum"])
|
||||
broad_spectrum=int(row["broad_spectrum"]),
|
||||
strain_predictions=mol_strain_preds
|
||||
)
|
||||
results_list.append(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user