first add
This commit is contained in:
193
convert_root_to_txt.py.bak
Normal file
193
convert_root_to_txt.py.bak
Normal file
@@ -0,0 +1,193 @@
|
||||
import enum
|
||||
import attrs
|
||||
from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import uproot
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import joblib
|
||||
import click
|
||||
|
||||
class ParticleType(enum.Enum):
|
||||
NeutralHadron = 0
|
||||
Photon = 1
|
||||
Electron = 2
|
||||
Muon = 3
|
||||
Pion = 4
|
||||
ChargedKaon = 5
|
||||
Proton = 6
|
||||
Others = 7
|
||||
|
||||
@attrs.define
|
||||
class ParticleBase:
|
||||
particle_id: int = attrs.field() # unique id for each particle
|
||||
part_charge: int = attrs.field() # charge of the particle
|
||||
part_energy: float = attrs.field() # energy of the particle
|
||||
part_px: float = attrs.field() # x-component of the momentum vector
|
||||
part_py: float = attrs.field() # y-component of the momentum vector
|
||||
part_pz: float = attrs.field() # z-component of the momentum vector
|
||||
log_energy: float = attrs.field() # log10(part_energy)
|
||||
log_pt: float = attrs.field() # log10(part_pt)
|
||||
part_deta: float = attrs.field() # pseudorapidity
|
||||
part_dphi: float = attrs.field() # azimuthal angle
|
||||
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
||||
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
||||
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
||||
part_d0: float = attrs.field() # tanh(d0)
|
||||
part_dz: float = attrs.field() # tanh(z0)
|
||||
particle_type: ParticleType = attrs.field() # type of the particle as an enum
|
||||
particle_pid: int = attrs.field() # pid of the particle
|
||||
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
|
||||
|
||||
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
|
||||
if selected_attrs is None:
|
||||
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
||||
values = [getattr(self, attr) for attr in selected_attrs]
|
||||
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
||||
return attr_sep.join(map(str, values))
|
||||
|
||||
@attrs.define
|
||||
class Jet:
|
||||
jet_energy: float = attrs.field() # energy of the jet
|
||||
jet_pt: float = attrs.field() # transverse momentum of the jet
|
||||
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
||||
label: str = attrs.field() # type of bb or bbbar
|
||||
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
||||
|
||||
def __len__(self):
|
||||
return len(self.particles)
|
||||
|
||||
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
|
||||
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
|
||||
|
||||
@attrs.define
|
||||
class JetSet:
|
||||
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
|
||||
jets: List[Jet] = attrs.field(factory=list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.jets)
|
||||
|
||||
@staticmethod
|
||||
def jud_type(jtmp):
|
||||
particle_dict = {
|
||||
0: ParticleType.NeutralHadron,
|
||||
1: ParticleType.Photon,
|
||||
2: ParticleType.Electron,
|
||||
3: ParticleType.Muon,
|
||||
4: ParticleType.Pion,
|
||||
5: ParticleType.ChargedKaon,
|
||||
6: ParticleType.Proton,
|
||||
7: ParticleType.Others
|
||||
}
|
||||
max_element = max(jtmp)
|
||||
idx = jtmp.index(max_element)
|
||||
return particle_dict.get(idx, ParticleType.Others), idx
|
||||
|
||||
@classmethod
|
||||
def build_jetset(cls, root_file: Union[str, Path]) -> "JetSet":
|
||||
print(f"Building JetSet from {root_file}")
|
||||
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
||||
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
||||
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
||||
print(f"jet type: {jets_type}")
|
||||
with uproot.open(root_file) as f:
|
||||
tree = f["tree"]
|
||||
a = tree.arrays(library="pd")
|
||||
|
||||
label = Path(root_file).stem.split('_')[0]
|
||||
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
||||
jet_list = []
|
||||
for i, j in a.iterrows():
|
||||
part_pt = np.array(j['part_pt'])
|
||||
jet_pt = np.array(j['jet_pt'])
|
||||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||||
|
||||
part_energy = np.array(j['part_energy'])
|
||||
jet_energy = np.array(j['jet_energy'])
|
||||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||||
|
||||
part_deta = np.array(j['part_deta'])
|
||||
part_dphi = np.array(j['part_dphi'])
|
||||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||||
|
||||
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
||||
particles = []
|
||||
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
||||
part_type = []
|
||||
part_pid = []
|
||||
for pn in range(len(j['part_pt'])):
|
||||
jtmp = []
|
||||
for t in particle_list:
|
||||
jtmp.append(j[t][pn])
|
||||
tmp_type, tmp_pid = cls.jud_type(jtmp)
|
||||
part_type.append(tmp_type)
|
||||
part_pid.append(tmp_pid)
|
||||
|
||||
for pn in range(len(part_type)):
|
||||
particles.append(ParticleBase(
|
||||
particle_id=i,
|
||||
part_charge=j['part_charge'][pn],
|
||||
part_energy=j['part_energy'][pn],
|
||||
part_px=j['part_px'][pn],
|
||||
part_py=j['part_py'][pn],
|
||||
part_pz=j['part_pz'][pn],
|
||||
log_energy=np.log(j['part_energy'][pn]),
|
||||
log_pt=np.log(j['part_pt'][pn]),
|
||||
part_deta=j['part_deta'][pn],
|
||||
part_dphi=j['part_dphi'][pn],
|
||||
part_logptrel=part_logptrel[pn],
|
||||
part_logerel=part_logerel[pn],
|
||||
part_deltaR=part_deltaR[pn],
|
||||
part_d0=np.tanh(j['part_d0val'][pn]),
|
||||
part_dz=np.tanh(j['part_dzval'][pn]),
|
||||
particle_type=part_type[pn],
|
||||
particle_pid=part_pid[pn],
|
||||
jet_type=jet_type
|
||||
))
|
||||
|
||||
jet = Jet(
|
||||
jet_energy=j['jet_energy'],
|
||||
jet_pt=j['jet_pt'],
|
||||
jet_eta=j['jet_eta'],
|
||||
particles=particles,
|
||||
label=label
|
||||
)
|
||||
jet_list.append(jet)
|
||||
|
||||
return cls(jets=jet_list, jets_type=jets_type)
|
||||
|
||||
def save_to_txt(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|', prefix: str = ""):
|
||||
for file_counter, jet in enumerate(tqdm(self.jets)):
|
||||
filename = f"{self.jets_type}_{file_counter}_{prefix}.txt"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
|
||||
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
|
||||
with open(filepath, 'w') as file:
|
||||
file.write(jet_data)
|
||||
|
||||
@click.command()
|
||||
@click.argument('root_dir', type=click.Path(exists=True))
|
||||
@click.argument('save_dir', type=click.Path())
|
||||
@click.option('--attr-sep', default=',', help='Separator for attributes.')
|
||||
@click.option('--part-sep', default='|', help='Separator for particles.')
|
||||
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
|
||||
def main(root_dir, save_dir, attr_sep, part_sep, selected_attrs):
|
||||
selected_attrs = selected_attrs.split(',')
|
||||
preprocess(root_dir, save_dir, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
|
||||
|
||||
def preprocess(root_dir, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
||||
root_files = list(Path(root_dir).glob('*.root'))
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
joblib.Parallel(n_jobs=-1)(
|
||||
joblib.delayed(process_file)(root_file, save_dir, selected_attrs, attr_sep, part_sep, i) for i, root_file in enumerate(root_files)
|
||||
)
|
||||
|
||||
def process_file(root_file, save_dir, selected_attrs, attr_sep, part_sep, process_id):
|
||||
prefix = f"process_{process_id}"
|
||||
jet_set = JetSet.build_jetset(root_file)
|
||||
jet_set.save_to_txt(save_dir, selected_attrs, attr_sep, part_sep, prefix)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user