Files
SIME/utils/smiles_svg_show.py
mm644706215 ea218a3a39 update
2025-10-16 17:26:35 +08:00

159 lines
5.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse
import json
import re
from datetime import datetime
from pathlib import Path
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
import boto3
# 对象存储配置信息(可随时修改)
BUCKET_NAME = "{Your_Bucket_Name}"
ACCESS_KEY = "{Your_Access_Key}"
SECRET_KEY = "{Your_Secret_Key}"
ENDPOINT_URL = "{Your_Endpoint_Url}"
S3_SVG_PREFIX = "svg_outputs/"
# 生成SVG图片并高亮
def mol_to_svg(mol, highlight_atoms=None, size=(400, 400)):
drawer = rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
drawer.SetFontSize(6)
opts = drawer.drawOptions()
opts.addAtomIndices = True
atom_colors = {}
if highlight_atoms:
for idx in highlight_atoms:
atom_colors[idx] = (1, 0, 0)
drawer.DrawMolecule(
mol,
highlightAtoms=highlight_atoms or [],
highlightAtomColors=atom_colors
)
drawer.FinishDrawing()
return drawer.GetDrawingText()
# 上传到对象存储S3兼容
# 替换原始 upload_svg_to_s3 的返回值
def upload_svg_to_s3(svg_content, object_name):
session = boto3.session.Session(
aws_access_key_id=ACCESS_KEY,
aws_secret_access_key=SECRET_KEY,
)
s3 = session.resource('s3', endpoint_url=ENDPOINT_URL)
obj = s3.Object(BUCKET_NAME, object_name)
obj.put(Body=svg_content, ContentType='image/svg+xml')
# 返回 R2.dev 公共 URL
return f"https://pub-389f446a01134875b8c7ced0572758de.r2.dev/{object_name}"
# 检测原子价态错误
def find_valence_error_atom(mol):
try:
Chem.SanitizeMol(mol)
return None
except Chem.AtomValenceException as e:
match = re.search(r'atom # (\d+)', str(e))
if match:
return int(match.group(1))
return None
# 保存和读取JSON的方法
def save_json(data, filename):
Path(filename).write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding='utf-8')
def load_json(filename):
return json.loads(Path(filename).read_text(encoding='utf-8'))
# 获取原子详细状态信息
def get_atom_status(mol, atom_idx):
atom = mol.GetAtomWithIdx(atom_idx)
mol.UpdatePropertyCache(strict=False)
connections = []
for bond in atom.GetBonds():
neighbor_idx = bond.GetOtherAtomIdx(atom_idx)
connections.append({
"connected_to": f"#{neighbor_idx} ({mol.GetAtomWithIdx(neighbor_idx).GetSymbol()})",
"bond_type": str(bond.GetBondType())
})
return {
"explicit_connections": atom.GetDegree(),
"formal_charge": atom.GetFormalCharge(),
"radical_electrons": atom.GetNumRadicalElectrons(),
"implicit_hydrogens": atom.GetNumImplicitHs(),
"explicit_hydrogens": atom.GetNumExplicitHs(),
"connections_detail": connections
}
# 主程序
def main():
parser = argparse.ArgumentParser(description="Process SMILES and optionally highlight atoms using atom index or SMARTS pattern.")
parser.add_argument('--smiles', type=str, required=True, help='SMILES string of molecule')
parser.add_argument('--atom_idx', type=int, help='Atom index to highlight')
parser.add_argument('--smarts', type=str, help='SMARTS pattern to highlight matched atoms')
parser.add_argument('--output', type=str, default="output.json", help='Output JSON filename')
parser.add_argument('--no_s3', action='store_true', help='Save SVG locally instead of S3')
args = parser.parse_args()
mol = Chem.MolFromSmiles(args.smiles, sanitize=False)
# Chem.SanitizeMol(mol) # 手动完成标准化
# Chem.MolToSmiles(mol) # canonical=True by default
error_atom_idx = find_valence_error_atom(mol)
atom_state_info = "OK" if error_atom_idx is None else f"Valence error at atom #{error_atom_idx}"
highlight_atoms = set()
if args.atom_idx is not None:
highlight_atoms.add(args.atom_idx)
if args.smarts:
patt = Chem.MolFromSmarts(args.smarts)
matches = mol.GetSubstructMatches(patt)
for match in matches:
highlight_atoms.update(match)
svg_str = mol_to_svg(mol, highlight_atoms=list(highlight_atoms))
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
svg_filename = f"molecule_{timestamp}.svg"
output_path = Path(args.output)
if not output_path.is_absolute():
output_path = Path.cwd() / output_path
if args.no_s3:
svg_path = output_path.parent / svg_filename
svg_path.write_text(svg_str, encoding='utf-8')
svg_location = str(svg_path)
else:
object_name = f"{S3_SVG_PREFIX}{svg_filename}"
svg_location = upload_svg_to_s3(svg_str, object_name)
output_data = {
"atom_state": atom_state_info,
"svg_url": svg_location,
"svg_filename": svg_filename
}
if args.atom_idx is not None:
output_data["atom_status_detail"] = get_atom_status(mol, args.atom_idx)
save_json(output_data, output_path)
print(f"Results saved to {output_path}")
if __name__ == "__main__":
main()
"""
# 自动修复键值错误
python smiles_svg_show.py --smiles "O=C1C[C@@H](O)C[C@H](O[C@H]9C[C@@](C)(OC)[C@@H](O)[C@H](C)O9)[C@@H](C)C[C@@H](C)C(=O)/C=C/[C@@H](CC)=C/[C@H](O[C@@H]9O[C@H](C)C[C@@H]([C@H]9O)N(C)C)[N@@](C)O1" --atom_idx 30
python smiles_svg_show.py --smiles "CCC1=C\[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](CC=O)OC(=O)C[C@@H](O)C[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](C)C[C@@H](C)C(=O)\C=C\1" --atom_idx 30
# smarts 匹配要求smiles正确
python smiles_svg_show.py --smiles "CCC1=C\[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](CC=O)OC(=O)C[C@@H](O)C[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](C)C[C@@H](C)C(=O)\C=C\1" --smarts "[r16]([#8][#6](=[#8]))"
"""