280 lines
12 KiB
Python
280 lines
12 KiB
Python
import enum
|
||
import attrs
|
||
from typing import List, Optional, Union
|
||
import numpy as np
|
||
import uproot
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
import os
|
||
import joblib
|
||
import click
|
||
|
||
class ParticleType(enum.Enum):
|
||
NeutralHadron = 0
|
||
Photon = 1
|
||
Electron = 2
|
||
Muon = 3
|
||
Pion = 4
|
||
ChargedKaon = 5
|
||
Proton = 6
|
||
Others = 7
|
||
|
||
@attrs.define
|
||
class ParticleBase:
|
||
particle_id: int = attrs.field() # unique id for each particle
|
||
part_charge: int = attrs.field() # charge of the particle
|
||
part_energy: float = attrs.field() # energy of the particle
|
||
part_px: float = attrs.field() # x-component of the momentum vector
|
||
part_py: float = attrs.field() # y-component of the momentum vector
|
||
part_pz: float = attrs.field() # z-component of the momentum vector
|
||
log_energy: float = attrs.field() # log10(part_energy)
|
||
log_pt: float = attrs.field() # log10(part_pt)
|
||
part_deta: float = attrs.field() # pseudorapidity
|
||
part_dphi: float = attrs.field() # azimuthal angle
|
||
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
||
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
||
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
||
part_d0: float = attrs.field() # tanh(d0)
|
||
part_dz: float = attrs.field() # tanh(z0)
|
||
particle_type: ParticleType = attrs.field() # type of the particle as an enum
|
||
particle_pid: int = attrs.field() # pid of the particle
|
||
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
|
||
|
||
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
|
||
if selected_attrs is None:
|
||
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
||
values = [getattr(self, attr) for attr in selected_attrs]
|
||
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
||
return attr_sep.join(map(str, values))
|
||
|
||
@attrs.define
|
||
class Jet:
|
||
jet_energy: float = attrs.field() # energy of the jet
|
||
jet_pt: float = attrs.field() # transverse momentum of the jet
|
||
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
||
label: str = attrs.field() # type of bb or bbbar
|
||
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
||
|
||
def __len__(self):
|
||
return len(self.particles)
|
||
|
||
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
|
||
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
|
||
|
||
@attrs.define
|
||
class JetSet:
|
||
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
|
||
jets: List[Jet] = attrs.field(factory=list)
|
||
|
||
def __len__(self):
|
||
return len(self.jets)
|
||
|
||
@staticmethod
|
||
def jud_type(jtmp):
|
||
particle_dict = {
|
||
0: ParticleType.NeutralHadron,
|
||
1: ParticleType.Photon,
|
||
2: ParticleType.Electron,
|
||
3: ParticleType.Muon,
|
||
4: ParticleType.Pion,
|
||
5: ParticleType.ChargedKaon,
|
||
6: ParticleType.Proton,
|
||
7: ParticleType.Others
|
||
}
|
||
max_element = max(jtmp)
|
||
idx = jtmp.index(max_element)
|
||
return particle_dict.get(idx, ParticleType.Others), idx
|
||
|
||
@classmethod
|
||
def build_jetset_fast(cls, root_file: Union[str, Path]) -> "JetSet":
|
||
print(f"Building JetSet from {root_file}")
|
||
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
||
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
||
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
||
print(f"jet type: {jets_type}")
|
||
with uproot.open(root_file) as f:
|
||
tree = f["tree"]
|
||
a = tree.arrays(library="pd")
|
||
# print(f"DataFrame structure:\n{a.head()}")
|
||
if a.empty:
|
||
raise ValueError("DataFrame is empty")
|
||
|
||
label = Path(root_file).stem.split('_')[0]
|
||
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
||
jet_list = []
|
||
for j in a.itertuples():
|
||
part_pt = np.array(j.part_pt)
|
||
jet_pt = np.array(j.jet_pt)
|
||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||
|
||
part_energy = np.array(j.part_energy)
|
||
jet_energy = np.array(j.jet_energy)
|
||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||
|
||
part_deta = np.array(j.part_deta)
|
||
part_dphi = np.array(j.part_dphi)
|
||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||
|
||
assert len(j.part_pt) == len(j.part_energy) == len(j.part_deta)
|
||
|
||
particles = []
|
||
# particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
||
part_type = []
|
||
part_pid = []
|
||
for pn in range(len(j.part_pt)):
|
||
jtmp = [j.part_isNeutralHadron[pn], j.part_isPhoton[pn], j.part_isElectron[pn], j.part_isMuon[pn], j.part_isPion[pn],
|
||
j.part_isChargedKaon[pn], j.part_isProton[pn]]
|
||
tmp_type, tmp_pid = cls.jud_type(jtmp)
|
||
part_type.append(tmp_type)
|
||
part_pid.append(tmp_pid)
|
||
|
||
bag = zip(j.part_charge, j.part_energy, j.part_px, j.part_py, j.part_pz, np.log(j.part_energy),
|
||
np.log(j.part_pt), j.part_deta, j.part_dphi, part_logptrel, part_logerel, part_deltaR,
|
||
np.tanh(j.part_d0val), np.tanh(j.part_dzval), part_type, part_pid)
|
||
#下边的代码是要对第 j 个喷注中的所有粒子做循环,将每个粒子都 存成 ParticleBase,然后 append 到 particles里,
|
||
#所以 partices 存储了 第 j 个喷注中所有粒子的信息
|
||
for num, (c, en, px, py, pz, lEn, lPt, eta, phi, ii, jj, kk, d0, dz, ptype, pid) in enumerate(bag):
|
||
particles.append(ParticleBase(
|
||
particle_id=num,
|
||
part_charge=c,
|
||
part_energy=en,
|
||
part_px=px,
|
||
part_py=py,
|
||
part_pz=pz,
|
||
log_energy=lEn,
|
||
log_pt=lPt,
|
||
part_deta=eta,
|
||
part_dphi=phi,
|
||
part_logptrel=ii,
|
||
part_logerel=jj,
|
||
part_deltaR=kk,
|
||
part_d0=d0,
|
||
part_dz=dz,
|
||
particle_type=ptype,
|
||
particle_pid=pid,
|
||
jet_type=jet_type
|
||
))
|
||
# add jets jet = 喷注,
|
||
jet = Jet(
|
||
jet_energy=j.jet_energy,
|
||
jet_pt=j.jet_pt,
|
||
jet_eta=j.jet_eta,
|
||
particles=particles,
|
||
label=label
|
||
)
|
||
jet_list.append(jet)
|
||
|
||
return cls(jets=jet_list, jets_type=jets_type)
|
||
|
||
@classmethod
|
||
def build_jetset_full(cls, root_file: Union[str, Path]) -> "JetSet":
|
||
print(f"Building JetSet from {root_file}")
|
||
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
||
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
||
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
||
print(f"jet type: {jets_type}")
|
||
with uproot.open(root_file) as f:
|
||
tree = f["tree"]
|
||
a = tree.arrays(library="pd")
|
||
|
||
label = Path(root_file).stem.split('_')[0]
|
||
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
||
jet_list = []
|
||
for i, j in a.iterrows():
|
||
part_pt = np.array(j['part_pt'])
|
||
jet_pt = np.array(j['jet_pt'])
|
||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||
|
||
part_energy = np.array(j['part_energy'])
|
||
jet_energy = np.array(j['jet_energy'])
|
||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||
|
||
part_deta = np.array(j['part_deta'])
|
||
part_dphi = np.array(j['part_dphi'])
|
||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||
|
||
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
||
particles = []
|
||
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
||
part_type = []
|
||
part_pid = []
|
||
for pn in range(len(j['part_pt'])):
|
||
jtmp = []
|
||
for t in particle_list:
|
||
jtmp.append(j[t][pn])
|
||
tmp_type, tmp_pid = cls.jud_type(jtmp)
|
||
part_type.append(tmp_type)
|
||
part_pid.append(tmp_pid)
|
||
|
||
for pn in range(len(part_type)):
|
||
particles.append(ParticleBase(
|
||
particle_id=i,
|
||
part_charge=j['part_charge'][pn],
|
||
part_energy=j['part_energy'][pn],
|
||
part_px=j['part_px'][pn],
|
||
part_py=j['part_py'][pn],
|
||
part_pz=j['part_pz'][pn],
|
||
log_energy=np.log(j['part_energy'][pn]),
|
||
log_pt=np.log(j['part_pt'][pn]),
|
||
part_deta=j['part_deta'][pn],
|
||
part_dphi=j['part_dphi'][pn],
|
||
part_logptrel=part_logptrel[pn],
|
||
part_logerel=part_logerel[pn],
|
||
part_deltaR=part_deltaR[pn],
|
||
part_d0=np.tanh(j['part_d0val'][pn]),
|
||
part_dz=np.tanh(j['part_dzval'][pn]),
|
||
particle_type=part_type[pn],
|
||
particle_pid=part_pid[pn],
|
||
jet_type=jet_type
|
||
))
|
||
|
||
jet = Jet(
|
||
jet_energy=j['jet_energy'],
|
||
jet_pt=j['jet_pt'],
|
||
jet_eta=j['jet_eta'],
|
||
particles=particles,
|
||
label=label
|
||
)
|
||
jet_list.append(jet)
|
||
|
||
return cls(jets=jet_list, jets_type=jets_type)
|
||
|
||
def save_to_txt(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|', prefix: str = ""):
|
||
for file_counter, jet in enumerate(tqdm(self.jets)):
|
||
filename = f"{self.jets_type}_{file_counter}_{prefix}.txt"
|
||
filepath = os.path.join(save_dir, filename)
|
||
|
||
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
|
||
with open(filepath, 'w') as file:
|
||
file.write(jet_data)
|
||
|
||
@click.command()
|
||
@click.argument('root_dir', type=click.Path(exists=True))
|
||
@click.argument('save_dir', type=click.Path())
|
||
@click.option('--data-type', default='fast', type=click.Choice(['fast', 'full']), help='Type of ROOT data: fast or full.')
|
||
@click.option('--attr-sep', default=',', help='Separator for attributes.')
|
||
@click.option('--part-sep', default='|', help='Separator for particles.')
|
||
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
|
||
def main(root_dir, save_dir, data_type, attr_sep, part_sep, selected_attrs):
|
||
selected_attrs = selected_attrs.split(',')
|
||
preprocess(root_dir, save_dir, data_type, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
|
||
|
||
def preprocess(root_dir, save_dir, data_type, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
||
root_files = list(Path(root_dir).glob('*.root'))
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
# process_file(root_files[0], save_dir, data_type, selected_attrs, attr_sep, part_sep, 0) # one process for test
|
||
joblib.Parallel(n_jobs=-1)(
|
||
joblib.delayed(process_file)(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, i) for i, root_file in enumerate(root_files)
|
||
)
|
||
|
||
def process_file(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, process_id):
|
||
prefix = f"process_{process_id}"
|
||
if data_type == 'fast':
|
||
jet_set = JetSet.build_jetset_fast(root_file)
|
||
else:
|
||
jet_set = JetSet.build_jetset_full(root_file)
|
||
jet_set.save_to_txt(save_dir, selected_attrs, attr_sep, part_sep, prefix)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|