Files
macrolactone-toolkit/scripts/batch_process.py
2025-11-14 20:34:58 +08:00

166 lines
5.1 KiB
Python

"""
Batch processing script for analyzing all macrolactones in the dataset.
"""
import sys
sys.path.append('..')
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import json
from rdkit import Chem
from src.fragment_cleaver import process_molecule
from src.fragment_dataclass import MoleculeFragments
def batch_process_molecules(csv_path: str, output_base_dir: str,
max_molecules: int = None):
"""
Process all molecules in the CSV file.
Args:
csv_path: Path to the CSV file containing SMILES
output_base_dir: Base directory for output
max_molecules: Maximum number of molecules to process (None for all)
"""
# Read CSV
df = pd.read_csv(csv_path)
print(f"Loaded {len(df)} molecules from {csv_path}")
if max_molecules:
df = df.head(max_molecules)
print(f"Processing first {max_molecules} molecules")
# Create output directory
output_dir = Path(output_base_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Statistics
successful = 0
failed = 0
failed_molecules = []
all_fragments = []
# Process each molecule
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing molecules"):
smiles = row['smiles']
molecule_id = row.get('IDs', f'molecule_{idx}')
try:
# Process molecule
mol_fragments = process_molecule(smiles, idx)
if mol_fragments is None or len(mol_fragments.fragments) == 0:
failed += 1
failed_molecules.append({
'index': idx,
'id': molecule_id,
'reason': 'No fragments extracted'
})
continue
# Create output directory for this molecule
mol_output_dir = output_dir / mol_fragments.parent_id
mol_output_dir.mkdir(parents=True, exist_ok=True)
# Save complete molecule fragments
mol_fragments_path = mol_output_dir / f"{mol_fragments.parent_id}_all_fragments.json"
mol_fragments.to_json_file(str(mol_fragments_path))
# Save individual fragments
for frag in mol_fragments.fragments:
frag_path = mol_output_dir / f"{frag.fragment_id}.json"
frag.to_json_file(str(frag_path))
# Collect for overall statistics
all_fragments.append({
'parent_id': frag.parent_id,
'fragment_id': frag.fragment_id,
'cleavage_position': frag.cleavage_position,
'fragment_smiles': frag.fragment_smiles,
'atom_count': frag.atom_count,
'molecular_weight': frag.molecular_weight,
'parent_smiles': frag.parent_smiles
})
successful += 1
except Exception as e:
failed += 1
failed_molecules.append({
'index': idx,
'id': molecule_id,
'error': str(e)
})
print(f"\nError processing molecule {idx} ({molecule_id}): {e}")
# Save overall statistics
stats = {
'total_molecules': len(df),
'successful': successful,
'failed': failed,
'total_fragments': len(all_fragments),
'failed_molecules': failed_molecules
}
stats_path = output_dir / 'processing_stats.json'
with open(stats_path, 'w') as f:
json.dump(stats, f, indent=2)
# Save all fragments as CSV for easy analysis
if all_fragments:
fragments_df = pd.DataFrame(all_fragments)
fragments_csv_path = output_dir / 'all_fragments.csv'
fragments_df.to_csv(fragments_csv_path, index=False)
print(f"\n✓ Saved all fragments to: {fragments_csv_path}")
# Print summary
print(f"\n{'='*60}")
print(f"PROCESSING COMPLETE")
print(f"{'='*60}")
print(f"Total molecules: {len(df)}")
print(f"Successfully processed: {successful}")
print(f"Failed: {failed}")
print(f"Total fragments extracted: {len(all_fragments)}")
print(f"{'='*60}")
print(f"\nResults saved to: {output_dir}")
print(f"Statistics saved to: {stats_path}")
return fragments_df if all_fragments else None
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Batch process macrolactones to extract side chain fragments"
)
parser.add_argument(
'--input',
type=str,
default='../ring16/temp.csv',
help='Input CSV file path'
)
parser.add_argument(
'--output',
type=str,
default='../output/fragments',
help='Output directory path'
)
parser.add_argument(
'--max',
type=int,
default=None,
help='Maximum number of molecules to process (default: all)'
)
args = parser.parse_args()
batch_process_molecules(
csv_path=args.input,
output_base_dir=args.output,
max_molecules=args.max
)