159 lines
5.6 KiB
Python
159 lines
5.6 KiB
Python
import argparse
|
||
import json
|
||
import re
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from rdkit import Chem
|
||
from rdkit.Chem.Draw import rdMolDraw2D
|
||
import boto3
|
||
|
||
# 对象存储配置信息(可随时修改)
|
||
BUCKET_NAME = "{Your_Bucket_Name}"
|
||
ACCESS_KEY = "{Your_Access_Key}"
|
||
SECRET_KEY = "{Your_Secret_Key}"
|
||
ENDPOINT_URL = "{Your_Endpoint_Url}"
|
||
S3_SVG_PREFIX = "svg_outputs/"
|
||
|
||
# 生成SVG图片并高亮
|
||
def mol_to_svg(mol, highlight_atoms=None, size=(400, 400)):
|
||
drawer = rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
|
||
drawer.SetFontSize(6)
|
||
opts = drawer.drawOptions()
|
||
opts.addAtomIndices = True
|
||
|
||
atom_colors = {}
|
||
if highlight_atoms:
|
||
for idx in highlight_atoms:
|
||
atom_colors[idx] = (1, 0, 0)
|
||
|
||
drawer.DrawMolecule(
|
||
mol,
|
||
highlightAtoms=highlight_atoms or [],
|
||
highlightAtomColors=atom_colors
|
||
)
|
||
drawer.FinishDrawing()
|
||
return drawer.GetDrawingText()
|
||
|
||
# 上传到对象存储(S3兼容)
|
||
# 替换原始 upload_svg_to_s3 的返回值
|
||
def upload_svg_to_s3(svg_content, object_name):
|
||
session = boto3.session.Session(
|
||
aws_access_key_id=ACCESS_KEY,
|
||
aws_secret_access_key=SECRET_KEY,
|
||
)
|
||
s3 = session.resource('s3', endpoint_url=ENDPOINT_URL)
|
||
obj = s3.Object(BUCKET_NAME, object_name)
|
||
obj.put(Body=svg_content, ContentType='image/svg+xml')
|
||
|
||
# 返回 R2.dev 公共 URL
|
||
return f"https://pub-389f446a01134875b8c7ced0572758de.r2.dev/{object_name}"
|
||
|
||
# 检测原子价态错误
|
||
def find_valence_error_atom(mol):
|
||
try:
|
||
Chem.SanitizeMol(mol)
|
||
return None
|
||
except Chem.AtomValenceException as e:
|
||
match = re.search(r'atom # (\d+)', str(e))
|
||
if match:
|
||
return int(match.group(1))
|
||
return None
|
||
|
||
# 保存和读取JSON的方法
|
||
def save_json(data, filename):
|
||
Path(filename).write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding='utf-8')
|
||
|
||
def load_json(filename):
|
||
return json.loads(Path(filename).read_text(encoding='utf-8'))
|
||
|
||
# 获取原子详细状态信息
|
||
def get_atom_status(mol, atom_idx):
|
||
atom = mol.GetAtomWithIdx(atom_idx)
|
||
mol.UpdatePropertyCache(strict=False)
|
||
connections = []
|
||
for bond in atom.GetBonds():
|
||
neighbor_idx = bond.GetOtherAtomIdx(atom_idx)
|
||
connections.append({
|
||
"connected_to": f"#{neighbor_idx} ({mol.GetAtomWithIdx(neighbor_idx).GetSymbol()})",
|
||
"bond_type": str(bond.GetBondType())
|
||
})
|
||
|
||
return {
|
||
"explicit_connections": atom.GetDegree(),
|
||
"formal_charge": atom.GetFormalCharge(),
|
||
"radical_electrons": atom.GetNumRadicalElectrons(),
|
||
"implicit_hydrogens": atom.GetNumImplicitHs(),
|
||
"explicit_hydrogens": atom.GetNumExplicitHs(),
|
||
"connections_detail": connections
|
||
}
|
||
|
||
# 主程序
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Process SMILES and optionally highlight atoms using atom index or SMARTS pattern.")
|
||
parser.add_argument('--smiles', type=str, required=True, help='SMILES string of molecule')
|
||
parser.add_argument('--atom_idx', type=int, help='Atom index to highlight')
|
||
parser.add_argument('--smarts', type=str, help='SMARTS pattern to highlight matched atoms')
|
||
parser.add_argument('--output', type=str, default="output.json", help='Output JSON filename')
|
||
parser.add_argument('--no_s3', action='store_true', help='Save SVG locally instead of S3')
|
||
|
||
args = parser.parse_args()
|
||
|
||
mol = Chem.MolFromSmiles(args.smiles, sanitize=False)
|
||
# Chem.SanitizeMol(mol) # 手动完成标准化
|
||
# Chem.MolToSmiles(mol) # canonical=True by default
|
||
|
||
error_atom_idx = find_valence_error_atom(mol)
|
||
atom_state_info = "OK" if error_atom_idx is None else f"Valence error at atom #{error_atom_idx}"
|
||
|
||
highlight_atoms = set()
|
||
|
||
if args.atom_idx is not None:
|
||
highlight_atoms.add(args.atom_idx)
|
||
|
||
if args.smarts:
|
||
patt = Chem.MolFromSmarts(args.smarts)
|
||
matches = mol.GetSubstructMatches(patt)
|
||
for match in matches:
|
||
highlight_atoms.update(match)
|
||
|
||
svg_str = mol_to_svg(mol, highlight_atoms=list(highlight_atoms))
|
||
|
||
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||
svg_filename = f"molecule_{timestamp}.svg"
|
||
|
||
output_path = Path(args.output)
|
||
if not output_path.is_absolute():
|
||
output_path = Path.cwd() / output_path
|
||
|
||
if args.no_s3:
|
||
svg_path = output_path.parent / svg_filename
|
||
svg_path.write_text(svg_str, encoding='utf-8')
|
||
svg_location = str(svg_path)
|
||
else:
|
||
object_name = f"{S3_SVG_PREFIX}{svg_filename}"
|
||
svg_location = upload_svg_to_s3(svg_str, object_name)
|
||
|
||
output_data = {
|
||
"atom_state": atom_state_info,
|
||
"svg_url": svg_location,
|
||
"svg_filename": svg_filename
|
||
}
|
||
|
||
if args.atom_idx is not None:
|
||
output_data["atom_status_detail"] = get_atom_status(mol, args.atom_idx)
|
||
|
||
save_json(output_data, output_path)
|
||
|
||
print(f"Results saved to {output_path}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
"""
|
||
# 自动修复键值错误
|
||
python smiles_svg_show.py --smiles "O=C1C[C@@H](O)C[C@H](O[C@H]9C[C@@](C)(OC)[C@@H](O)[C@H](C)O9)[C@@H](C)C[C@@H](C)C(=O)/C=C/[C@@H](CC)=C/[C@H](O[C@@H]9O[C@H](C)C[C@@H]([C@H]9O)N(C)C)[N@@](C)O1" --atom_idx 30
|
||
|
||
python smiles_svg_show.py --smiles "CCC1=C\[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](CC=O)OC(=O)C[C@@H](O)C[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](C)C[C@@H](C)C(=O)\C=C\1" --atom_idx 30
|
||
# smarts 匹配,要求smiles正确
|
||
python smiles_svg_show.py --smiles "CCC1=C\[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](CC=O)OC(=O)C[C@@H](O)C[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](C)C[C@@H](C)C(=O)\C=C\1" --smarts "[r16]([#8][#6](=[#8]))"
|
||
""" |