This commit is contained in:
mm644706215
2025-10-16 17:26:35 +08:00
parent b1d437a06d
commit ea218a3a39
49 changed files with 694742 additions and 2 deletions

159
utils/smiles_svg_show.py Normal file
View File

@@ -0,0 +1,159 @@
import argparse
import json
import re
from datetime import datetime
from pathlib import Path
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
import boto3
# 对象存储配置信息(可随时修改)
BUCKET_NAME = "{Your_Bucket_Name}"
ACCESS_KEY = "{Your_Access_Key}"
SECRET_KEY = "{Your_Secret_Key}"
ENDPOINT_URL = "{Your_Endpoint_Url}"
S3_SVG_PREFIX = "svg_outputs/"
# 生成SVG图片并高亮
def mol_to_svg(mol, highlight_atoms=None, size=(400, 400)):
drawer = rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
drawer.SetFontSize(6)
opts = drawer.drawOptions()
opts.addAtomIndices = True
atom_colors = {}
if highlight_atoms:
for idx in highlight_atoms:
atom_colors[idx] = (1, 0, 0)
drawer.DrawMolecule(
mol,
highlightAtoms=highlight_atoms or [],
highlightAtomColors=atom_colors
)
drawer.FinishDrawing()
return drawer.GetDrawingText()
# 上传到对象存储S3兼容
# 替换原始 upload_svg_to_s3 的返回值
def upload_svg_to_s3(svg_content, object_name):
session = boto3.session.Session(
aws_access_key_id=ACCESS_KEY,
aws_secret_access_key=SECRET_KEY,
)
s3 = session.resource('s3', endpoint_url=ENDPOINT_URL)
obj = s3.Object(BUCKET_NAME, object_name)
obj.put(Body=svg_content, ContentType='image/svg+xml')
# 返回 R2.dev 公共 URL
return f"https://pub-389f446a01134875b8c7ced0572758de.r2.dev/{object_name}"
# 检测原子价态错误
def find_valence_error_atom(mol):
try:
Chem.SanitizeMol(mol)
return None
except Chem.AtomValenceException as e:
match = re.search(r'atom # (\d+)', str(e))
if match:
return int(match.group(1))
return None
# 保存和读取JSON的方法
def save_json(data, filename):
Path(filename).write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding='utf-8')
def load_json(filename):
return json.loads(Path(filename).read_text(encoding='utf-8'))
# 获取原子详细状态信息
def get_atom_status(mol, atom_idx):
atom = mol.GetAtomWithIdx(atom_idx)
mol.UpdatePropertyCache(strict=False)
connections = []
for bond in atom.GetBonds():
neighbor_idx = bond.GetOtherAtomIdx(atom_idx)
connections.append({
"connected_to": f"#{neighbor_idx} ({mol.GetAtomWithIdx(neighbor_idx).GetSymbol()})",
"bond_type": str(bond.GetBondType())
})
return {
"explicit_connections": atom.GetDegree(),
"formal_charge": atom.GetFormalCharge(),
"radical_electrons": atom.GetNumRadicalElectrons(),
"implicit_hydrogens": atom.GetNumImplicitHs(),
"explicit_hydrogens": atom.GetNumExplicitHs(),
"connections_detail": connections
}
# 主程序
def main():
parser = argparse.ArgumentParser(description="Process SMILES and optionally highlight atoms using atom index or SMARTS pattern.")
parser.add_argument('--smiles', type=str, required=True, help='SMILES string of molecule')
parser.add_argument('--atom_idx', type=int, help='Atom index to highlight')
parser.add_argument('--smarts', type=str, help='SMARTS pattern to highlight matched atoms')
parser.add_argument('--output', type=str, default="output.json", help='Output JSON filename')
parser.add_argument('--no_s3', action='store_true', help='Save SVG locally instead of S3')
args = parser.parse_args()
mol = Chem.MolFromSmiles(args.smiles, sanitize=False)
# Chem.SanitizeMol(mol) # 手动完成标准化
# Chem.MolToSmiles(mol) # canonical=True by default
error_atom_idx = find_valence_error_atom(mol)
atom_state_info = "OK" if error_atom_idx is None else f"Valence error at atom #{error_atom_idx}"
highlight_atoms = set()
if args.atom_idx is not None:
highlight_atoms.add(args.atom_idx)
if args.smarts:
patt = Chem.MolFromSmarts(args.smarts)
matches = mol.GetSubstructMatches(patt)
for match in matches:
highlight_atoms.update(match)
svg_str = mol_to_svg(mol, highlight_atoms=list(highlight_atoms))
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
svg_filename = f"molecule_{timestamp}.svg"
output_path = Path(args.output)
if not output_path.is_absolute():
output_path = Path.cwd() / output_path
if args.no_s3:
svg_path = output_path.parent / svg_filename
svg_path.write_text(svg_str, encoding='utf-8')
svg_location = str(svg_path)
else:
object_name = f"{S3_SVG_PREFIX}{svg_filename}"
svg_location = upload_svg_to_s3(svg_str, object_name)
output_data = {
"atom_state": atom_state_info,
"svg_url": svg_location,
"svg_filename": svg_filename
}
if args.atom_idx is not None:
output_data["atom_status_detail"] = get_atom_status(mol, args.atom_idx)
save_json(output_data, output_path)
print(f"Results saved to {output_path}")
if __name__ == "__main__":
main()
"""
# 自动修复键值错误
python smiles_svg_show.py --smiles "O=C1C[C@@H](O)C[C@H](O[C@H]9C[C@@](C)(OC)[C@@H](O)[C@H](C)O9)[C@@H](C)C[C@@H](C)C(=O)/C=C/[C@@H](CC)=C/[C@H](O[C@@H]9O[C@H](C)C[C@@H]([C@H]9O)N(C)C)[N@@](C)O1" --atom_idx 30
python smiles_svg_show.py --smiles "CCC1=C\[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](CC=O)OC(=O)C[C@@H](O)C[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](C)C[C@@H](C)C(=O)\C=C\1" --atom_idx 30
# smarts 匹配要求smiles正确
python smiles_svg_show.py --smiles "CCC1=C\[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](CC=O)OC(=O)C[C@@H](O)C[C@H](O[C@H]2C[C@@](C)(OC)[C@@H](O)[C@H](C)O2)[C@@H](C)C[C@@H](C)C(=O)\C=C\1" --smarts "[r16]([#8][#6](=[#8]))"
"""