Files
particle_analyse/convert_root_to_txt.py
2024-06-10 09:09:10 +08:00

280 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import enum
import attrs
from typing import List, Optional, Union
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import joblib
import click
class ParticleType(enum.Enum):
NeutralHadron = 0
Photon = 1
Electron = 2
Muon = 3
Pion = 4
ChargedKaon = 5
Proton = 6
Others = 7
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
particle_type: ParticleType = attrs.field() # type of the particle as an enum
particle_pid: int = attrs.field() # pid of the particle
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
if selected_attrs is None:
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
values = [getattr(self, attr) for attr in selected_attrs]
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
return attr_sep.join(map(str, values))
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # type of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
@attrs.define
class JetSet:
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
jets: List[Jet] = attrs.field(factory=list)
def __len__(self):
return len(self.jets)
@staticmethod
def jud_type(jtmp):
particle_dict = {
0: ParticleType.NeutralHadron,
1: ParticleType.Photon,
2: ParticleType.Electron,
3: ParticleType.Muon,
4: ParticleType.Pion,
5: ParticleType.ChargedKaon,
6: ParticleType.Proton,
7: ParticleType.Others
}
max_element = max(jtmp)
idx = jtmp.index(max_element)
return particle_dict.get(idx, ParticleType.Others), idx
@classmethod
def build_jetset_fast(cls, root_file: Union[str, Path]) -> "JetSet":
print(f"Building JetSet from {root_file}")
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
print(f"jet type: {jets_type}")
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd")
# print(f"DataFrame structure:\n{a.head()}")
if a.empty:
raise ValueError("DataFrame is empty")
label = Path(root_file).stem.split('_')[0]
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
jet_list = []
for j in a.itertuples():
part_pt = np.array(j.part_pt)
jet_pt = np.array(j.jet_pt)
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j.part_energy)
jet_energy = np.array(j.jet_energy)
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j.part_deta)
part_dphi = np.array(j.part_dphi)
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j.part_pt) == len(j.part_energy) == len(j.part_deta)
particles = []
# particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j.part_pt)):
jtmp = [j.part_isNeutralHadron[pn], j.part_isPhoton[pn], j.part_isElectron[pn], j.part_isMuon[pn], j.part_isPion[pn],
j.part_isChargedKaon[pn], j.part_isProton[pn]]
tmp_type, tmp_pid = cls.jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
bag = zip(j.part_charge, j.part_energy, j.part_px, j.part_py, j.part_pz, np.log(j.part_energy),
np.log(j.part_pt), j.part_deta, j.part_dphi, part_logptrel, part_logerel, part_deltaR,
np.tanh(j.part_d0val), np.tanh(j.part_dzval), part_type, part_pid)
#下边的代码是要对第 j 个喷注中的所有粒子做循环,将每个粒子都 存成 ParticleBase然后 append 到 particles里
#所以 partices 存储了 第 j 个喷注中所有粒子的信息
for num, (c, en, px, py, pz, lEn, lPt, eta, phi, ii, jj, kk, d0, dz, ptype, pid) in enumerate(bag):
particles.append(ParticleBase(
particle_id=num,
part_charge=c,
part_energy=en,
part_px=px,
part_py=py,
part_pz=pz,
log_energy=lEn,
log_pt=lPt,
part_deta=eta,
part_dphi=phi,
part_logptrel=ii,
part_logerel=jj,
part_deltaR=kk,
part_d0=d0,
part_dz=dz,
particle_type=ptype,
particle_pid=pid,
jet_type=jet_type
))
# add jets jet = 喷注,
jet = Jet(
jet_energy=j.jet_energy,
jet_pt=j.jet_pt,
jet_eta=j.jet_eta,
particles=particles,
label=label
)
jet_list.append(jet)
return cls(jets=jet_list, jets_type=jets_type)
@classmethod
def build_jetset_full(cls, root_file: Union[str, Path]) -> "JetSet":
print(f"Building JetSet from {root_file}")
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
print(f"jet type: {jets_type}")
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd")
label = Path(root_file).stem.split('_')[0]
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
particles = []
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j['part_pt'])):
jtmp = []
for t in particle_list:
jtmp.append(j[t][pn])
tmp_type, tmp_pid = cls.jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
for pn in range(len(part_type)):
particles.append(ParticleBase(
particle_id=i,
part_charge=j['part_charge'][pn],
part_energy=j['part_energy'][pn],
part_px=j['part_px'][pn],
part_py=j['part_py'][pn],
part_pz=j['part_pz'][pn],
log_energy=np.log(j['part_energy'][pn]),
log_pt=np.log(j['part_pt'][pn]),
part_deta=j['part_deta'][pn],
part_dphi=j['part_dphi'][pn],
part_logptrel=part_logptrel[pn],
part_logerel=part_logerel[pn],
part_deltaR=part_deltaR[pn],
part_d0=np.tanh(j['part_d0val'][pn]),
part_dz=np.tanh(j['part_dzval'][pn]),
particle_type=part_type[pn],
particle_pid=part_pid[pn],
jet_type=jet_type
))
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=j['jet_pt'],
jet_eta=j['jet_eta'],
particles=particles,
label=label
)
jet_list.append(jet)
return cls(jets=jet_list, jets_type=jets_type)
def save_to_txt(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|', prefix: str = ""):
for file_counter, jet in enumerate(tqdm(self.jets)):
filename = f"{self.jets_type}_{file_counter}_{prefix}.txt"
filepath = os.path.join(save_dir, filename)
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
with open(filepath, 'w') as file:
file.write(jet_data)
@click.command()
@click.argument('root_dir', type=click.Path(exists=True))
@click.argument('save_dir', type=click.Path())
@click.option('--data-type', default='fast', type=click.Choice(['fast', 'full']), help='Type of ROOT data: fast or full.')
@click.option('--attr-sep', default=',', help='Separator for attributes.')
@click.option('--part-sep', default='|', help='Separator for particles.')
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
def main(root_dir, save_dir, data_type, attr_sep, part_sep, selected_attrs):
selected_attrs = selected_attrs.split(',')
preprocess(root_dir, save_dir, data_type, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
def preprocess(root_dir, save_dir, data_type, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
root_files = list(Path(root_dir).glob('*.root'))
os.makedirs(save_dir, exist_ok=True)
# process_file(root_files[0], save_dir, data_type, selected_attrs, attr_sep, part_sep, 0) # one process for test
joblib.Parallel(n_jobs=-1)(
joblib.delayed(process_file)(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, i) for i, root_file in enumerate(root_files)
)
def process_file(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, process_id):
prefix = f"process_{process_id}"
if data_type == 'fast':
jet_set = JetSet.build_jetset_fast(root_file)
else:
jet_set = JetSet.build_jetset_full(root_file)
jet_set.save_to_txt(save_dir, selected_attrs, attr_sep, part_sep, prefix)
if __name__ == "__main__":
main()