1. 代码修改

models/broad_spectrum_predictor.py:
 新增 StrainPrediction dataclass(单个菌株预测结果)
 更新 BroadSpectrumResult 添加 strain_predictions 字段(pandas.DataFrame 类型)
 添加 to_strain_predictions_list() 方法(类型安全转换)
 新增 _prepare_strain_level_predictions() 方法
 修改 predict_batch() 方法支持 include_strain_predictions 参数
utils/mole_predictor.py:
 添加 include_strain_predictions 参数到所有函数
 添加命令行参数 --include-strain-predictions
 实现菌株级别数据与聚合结果的合并逻辑
 更新所有函数签名和文档字符串
2. 测试验证
 测试基本功能(仅聚合结果): test_3.csv → 3 行输出
 测试菌株级别预测功能: test_3.csv → 120 行输出(3 × 40)
 验证输出格式正确性
 验证每个分子都有完整的 40 个菌株预测
 验证革兰染色信息正确(18 个阴性菌 + 22 个阳性菌)
3. 文档更新
README.md:
 更新命令行使用示例
 添加 Python API 使用示例(包含菌株预测)
 添加详细的输出格式说明
 添加 40 种菌株列表概览
 添加数据使用场景示例(强化学习、筛选、可视化)
Data/mole/README.md:
 新增"菌株级别预测详情"章节
 完整的 40 种菌株列表(分革兰阴性/阳性)
 数据访问方式示例(CSV 读取、Python API)
 强化学习应用场景(状态表示、奖励函数设计)
 数据可视化代码示例
 性能和存储建议
This commit is contained in:
2025-10-17 16:46:04 +08:00
parent 62e0f3d6aa
commit 34102cf459
5 changed files with 716 additions and 21 deletions

View File

@@ -65,9 +65,21 @@ class MoleculeInput:
chem_id: Optional[str] = None
@dataclass
class StrainPrediction:
"""单个菌株的预测结果"""
pred_id: str # 格式: "chem_id:strain_name"
chem_id: str
strain_name: str
antimicrobial_predictive_probability: float # 预测概率
no_growth_probability: float # 不生长概率1 - antimicrobial_predictive_probability
growth_inhibition: int # 二值化结果 (0/1)
gram_stain: Optional[str] = None # 革兰染色类型
@dataclass
class BroadSpectrumResult:
"""广谱抗菌预测结果"""
"""广谱抗菌预测结果(聚合结果)"""
chem_id: str
apscore_total: float
apscore_gnegative: float
@@ -76,9 +88,10 @@ class BroadSpectrumResult:
ginhib_gnegative: int
ginhib_gpositive: int
broad_spectrum: int
strain_predictions: Optional[pd.DataFrame] = None # 菌株级别预测40行
def to_dict(self) -> Dict[str, Union[str, float, int]]:
"""转换为字典格式"""
"""转换为字典格式(仅聚合字段)"""
return {
'chem_id': self.chem_id,
'apscore_total': self.apscore_total,
@@ -89,6 +102,31 @@ class BroadSpectrumResult:
'ginhib_gpositive': self.ginhib_gpositive,
'broad_spectrum': self.broad_spectrum
}
def to_strain_predictions_list(self) -> List['StrainPrediction']:
"""
将 DataFrame 转换为 StrainPrediction 列表(用于类型安全场景)
Returns:
StrainPrediction 对象列表
"""
if self.strain_predictions is None or self.strain_predictions.empty:
return []
strain_list = []
for _, row in self.strain_predictions.iterrows():
strain_pred = StrainPrediction(
pred_id=row['pred_id'],
chem_id=row['chem_id'],
strain_name=row['strain_name'],
antimicrobial_predictive_probability=row['antimicrobial_predictive_probability'],
no_growth_probability=row['no_growth_probability'],
growth_inhibition=row['growth_inhibition'],
gram_stain=row.get('gram_stain', None)
)
strain_list.append(strain_pred)
return strain_list
class BroadSpectrumPredictor:
@@ -259,6 +297,45 @@ class BroadSpectrumPredictor:
return df_label
def _prepare_strain_level_predictions(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
准备菌株级别的预测数据
Args:
score_df: 原始预测分数DataFrame包含 pred_id, 0, 1, growth_inhibition 列
Returns:
格式化的菌株级别预测DataFrame
"""
# 创建副本
strain_df = score_df.copy()
# 分离化合物ID和菌株名
strain_df["chem_id"] = strain_df["pred_id"].str.split(":", expand=True)[0]
strain_df["strain_name"] = strain_df["pred_id"].str.split(":", expand=True)[1]
# 添加革兰染色信息
strain_df = self._gram_stain(strain_df)
# 重命名列为用户友好的名称
strain_df = strain_df.rename(columns={
"1": "antimicrobial_predictive_probability",
"0": "no_growth_probability"
})
# 选择并排序输出列
output_columns = [
"pred_id",
"chem_id",
"strain_name",
"antimicrobial_predictive_probability",
"no_growth_probability",
"growth_inhibition",
"gram_stain"
]
return strain_df[output_columns]
def _antimicrobial_potential(self, score_df: pd.DataFrame) -> pd.DataFrame:
"""
计算抗菌潜力分数
@@ -371,12 +448,14 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
results = self.predict_batch([molecule])
return results[0]
def predict_batch(self, molecules: List[MoleculeInput]) -> List[BroadSpectrumResult]:
def predict_batch(self, molecules: List[MoleculeInput],
include_strain_predictions: bool = False) -> List[BroadSpectrumResult]:
"""
批量预测分子的广谱抗菌活性
Args:
molecules: 分子输入列表
include_strain_predictions: 是否在结果中包含菌株级别预测数据
Returns:
广谱抗菌预测结果列表
@@ -417,10 +496,16 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
# 合并结果
print("Merging prediction results...")
all_pred_df = pd.concat([results[i] for i in sorted(results.keys())])
all_pred_df = all_pred_df.reset_index()
# 准备菌株级别数据(如果需要)
strain_level_data = None
if include_strain_predictions:
print("Preparing strain-level predictions...")
strain_level_data = self._prepare_strain_level_predictions(all_pred_df)
# 计算抗菌潜力
print("Calculating antimicrobial potential scores...")
all_pred_df = all_pred_df.reset_index()
agg_df = self._antimicrobial_potential(all_pred_df)
# 判断广谱抗菌
@@ -431,6 +516,13 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
# 转换为结果对象
results_list = []
for _, row in agg_df.iterrows():
# 获取该分子的菌株级别预测
mol_strain_preds = None
if strain_level_data is not None:
mol_strain_preds = strain_level_data[
strain_level_data['chem_id'] == row.name
].reset_index(drop=True)
result = BroadSpectrumResult(
chem_id=row.name,
apscore_total=row["apscore_total"],
@@ -439,7 +531,8 @@ class ParallelBroadSpectrumPredictor(BroadSpectrumPredictor):
ginhib_total=int(row["ginhib_total"]),
ginhib_gnegative=int(row["ginhib_gnegative"]),
ginhib_gpositive=int(row["ginhib_gpositive"]),
broad_spectrum=int(row["broad_spectrum"])
broad_spectrum=int(row["broad_spectrum"]),
strain_predictions=mol_strain_preds
)
results_list.append(result)