Files
rdscripts/scripts/gen_sdf_parallel.py
2025-07-31 13:18:32 +08:00

171 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
console = Console()
def is_valid_sdf(sdf_path):
try:
# 尝试读取SDF文件为mol对象
suppl = Chem.SDMolSupplier(str(sdf_path), sanitize=False)
mols = [mol for mol in suppl if mol is not None]
return len(mols) > 0
except Exception:
return False
def smiles_to_3d_sdf(identifier, smiles, props, sdf_path, max_attempts=10):
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return identifier, False, "SMILES解析失败"
mol = Chem.AddHs(mol)
params = AllChem.ETKDGv3()
last_error = ""
for attempt in range(max_attempts):
try:
status = AllChem.EmbedMolecule(mol, params)
if status == 0:
AllChem.UFFOptimizeMolecule(mol)
if props:
for k, v in props.items():
mol.SetProp(str(k), str(v))
writer = Chem.SDWriter(str(sdf_path))
writer.write(mol)
writer.close()
return identifier, True, f"成功(第{attempt+1}次)"
except Exception as e:
last_error = f"3D生成异常: {e}"
continue
return identifier, False, last_error if last_error else f"3D构象生成失败已重试{max_attempts}次)"
except Exception as e:
return identifier, False, "其它异常: " + traceback.format_exc(limit=1)
def smiles_to_3d_sdf_tuple(args):
return smiles_to_3d_sdf(*args)
def batch_csv_to_3d_sdf_parallel(csv_path, output_dir, smiles_col, id_col, n_jobs=4, max_attempts=10):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_csv(csv_path, sep=',', dtype=str)
tasks = []
skipped = []
for idx, row in df.iterrows():
smiles = row[smiles_col]
identifier = row[id_col]
props = row.to_dict()
sdf_file = output_dir / f"{identifier}.sdf"
if sdf_file.exists():
if is_valid_sdf(sdf_file):
# SDF存在且可读跳过
skipped.append(identifier)
continue
else:
# SDF存在但不可读认为损坏先删除
try:
sdf_file.unlink()
console.print(f"[red]⚡发现损坏SDF文件 {sdf_file.name},已删除,准备重新生成[/red]")
except Exception as e:
console.print(f"[bold magenta]❗无法删除损坏SDF: {sdf_file.name}, {e}[/]")
tasks.append((identifier, smiles, props, sdf_file, max_attempts))
console.rule(f"[bold green]共 {len(df)} 个分子,{len(skipped)} 个已存在且有效,{len(tasks)} 个待处理(使用 {n_jobs} 并行进程)[/]")
results = []
if tasks:
with ProcessPoolExecutor(max_workers=n_jobs) as executor:
future_to_identifier = {executor.submit(smiles_to_3d_sdf_tuple, task): task[0] for task in tasks}
for i, future in enumerate(as_completed(future_to_identifier), 1):
identifier, success, msg = future.result()
results.append((identifier, success, msg))
if success:
console.print(f"[bold green]✅ [{identifier}] 处理成功。[/][dim]{msg}[/]")
else:
if "SMILES解析失败" in msg:
console.print(f"[bold red]❌ [{identifier}] SMILES解析失败: {msg}[/]")
elif "3D" in msg:
console.print(f"[yellow]⚠️ [{identifier}] 3D构象生成失败: {msg}[/]")
else:
console.print(f"[magenta]❗ [{identifier}] 其它错误: {msg}[/]")
# 分类失败原因
failed = [r for r in results if not r[1]]
succeed = [r for r in results if r[1]]
failed_smiles = [r for r in failed if "SMILES解析失败" in r[2]]
failed_3d = [r for r in failed if ("3D" in r[2]) and ("SMILES解析失败" not in r[2])]
failed_other = [r for r in failed if r not in failed_smiles and r not in failed_3d]
# 展示 summary
table = Table(title="处理结果统计", show_lines=True)
table.add_column("状态", justify="center", style="cyan")
table.add_column("数量", justify="center")
table.add_row("成功", str(len(succeed)))
table.add_row("已跳过已存在有效SDF", str(len(skipped)))
table.add_row("SMILES解析失败", str(len(failed_smiles)))
table.add_row("3D构象失败", str(len(failed_3d)))
table.add_row("其它失败", str(len(failed_other)))
console.print(table)
# 输出详细失败信息
if failed:
fail_file = output_dir / "failed_smiles.txt"
with open(fail_file, "w", encoding="utf-8") as f:
for identifier, _, msg in failed:
f.write(f"{identifier}\t{msg}\n")
# 分类面板
if failed_smiles:
console.print(
Panel(
f"SMILES解析失败: [yellow]{', '.join([r[0] for r in failed_smiles])}[/yellow]",
title="[bold red]SMILES解析失败分子[/bold red]",
style="red"
)
)
if failed_3d:
console.print(
Panel(
f"3D构象失败: [yellow]{', '.join([r[0] for r in failed_3d])}[/yellow]",
title="[bold yellow]3D构象生成失败分子[/bold yellow]",
style="yellow"
)
)
if failed_other:
console.print(
Panel(
f"其它异常: [yellow]{', '.join([r[0] for r in failed_other])}[/yellow]",
title="[bold magenta]其它失败分子[/bold magenta]",
style="magenta"
)
)
console.print(
Panel(
f"共 [red]{len(failed)}[/red] 个分子失败,详情见: [bold]{fail_file.resolve()}[/bold]",
title="[bold red]失败分子统计[/bold red]",
style="red"
)
)
else:
console.print(Panel("[bold green]全部分子处理成功或已跳过![/bold green]", style="green"))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--csv', type=str, required=True, help='csv文件路径')
parser.add_argument('--outdir', type=str, default='./sdf_files', help='SDF输出文件夹')
parser.add_argument('--smiles_col', type=str, default='canonical_smiles', help='SMILES列名')
parser.add_argument('--id_col', type=str, default='identifier', help='标识符列名')
parser.add_argument('--n_jobs', type=int, default=4, help='并行进程数')
parser.add_argument('--max_attempts', type=int, default=10, help='最大尝试次数')
args = parser.parse_args()
batch_csv_to_3d_sdf_parallel(
args.csv, args.outdir, args.smiles_col, args.id_col,
n_jobs=args.n_jobs, max_attempts=args.max_attempts
)
# use example:
# python gen_sdf_parallel.py --csv coconut_data_info.csv --outdir ./sdf_files --n_jobs 8 --max_attempts 10 --smiles_col canonical_smiles --id_col identifier