Files
particle_analyse/convert_root_to_bin.py.bak
2024-06-10 09:09:10 +08:00

251 lines
10 KiB
Python

import enum
import attrs
from typing import List, Optional, Union
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import pickle
import joblib
import click
import torch
from torch.utils.data import Dataset
import math
class ByteDataset(Dataset):
def __init__(self, filenames, patch_size: int = 16, patch_length: int = 512):
self.patch_size = patch_size
self.patch_length = patch_length
print(f"Loading {len(filenames)} files for classification")
self.filenames = []
self.labels = {'bb': 0, 'bbbar': 1} # 映射标签到整数
for filename in tqdm(filenames):
file_size = os.path.getsize(filename)
file_size = math.ceil(file_size / self.patch_size)
ext = filename.split('.')[-1]
label = os.path.basename(filename).split('_')[0] # 使用前缀部分作为标签
label = f"{label}.{ext}"
if file_size <= self.patch_length - 2:
self.filenames.append((filename, label))
if label not in self.labels:
self.labels[label] = len(self.labels)
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filename, label = self.filenames[idx]
file_bytes = self.read_bytes(filename)
file_bytes = torch.tensor(file_bytes, dtype=torch.long)
label = torch.tensor(self.labels[label], dtype=torch.long)
return file_bytes, label, filename
def readbytes(self, filename):
ext = filename.split('.')[-1]
ext = bytearray(ext, 'utf-8')
ext = [byte for byte in ext][:self.patch_size]
with open(filename, 'rb') as f:
file_bytes = f.read()
bytes = []
for byte in file_bytes:
bytes.append(byte)
if len(bytes) % self.patch_size != 0:
bytes = bytes + [256] * (self.patch_size - len(bytes) % self.patch_size)
bos_patch = ext + [256] * (self.patch_size - len(ext))
bytes = bos_patch + bytes + [256] * self.patch_size
return bytes
class ParticleType(enum.Enum):
NeutralHadron = 0
Photon = 1
Electron = 2
Muon = 3
Pion = 4
ChargedKaon = 5
Proton = 6
Others = 7
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
particle_type: ParticleType = attrs.field() # type of the particle as an enum
particle_pid: int = attrs.field() # pid of the particle
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
if selected_attrs is None:
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
values = [getattr(self, attr) for attr in selected_attrs]
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
return attr_sep.join(map(str, values))
def attributes_to_float_list(self, selected_attrs: Optional[List[str]] = None) -> list:
if selected_attrs is None:
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
values = [getattr(self, attr) for attr in selected_attrs]
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
return list(map(lambda x: float(x) if isinstance(x, int) else x, values))
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # type of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
def particles_attribute_list(self, selected_attrs: Optional[List[str]] = None) -> list:
return list(map(lambda p: p.attributes_to_float_list(selected_attrs), self.particles))
@attrs.define
class JetSet:
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
jets: List[Jet] = attrs.field(factory=list)
def __len__(self):
return len(self.jets)
@staticmethod
def jud_type(jtmp):
particle_dict = {
0: ParticleType.NeutralHadron,
1: ParticleType.Photon,
2: ParticleType.Electron,
3: ParticleType.Muon,
4: ParticleType.Pion,
5: ParticleType.ChargedKaon,
6: ParticleType.Proton,
7: ParticleType.Others
}
max_element = max(jtmp)
idx = jtmp.index(max_element)
return particle_dict.get(idx, ParticleType.Others)
@classmethod
def build_jetset(cls, root_file: Union[str, Path]) -> "JetSet":
print(f"Building JetSet from {root_file}")
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd")
label = Path(root_file).stem.split('_')[0]
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
particles = list(map(lambda pn: ParticleBase(
particle_id=i,
part_charge=j['part_charge'],
part_energy=j['part_energy'],
part_px=j['part_px'],
part_py=j['part_py'],
part_pz=j['part_pz'],
log_energy=np.log(j['part_energy']),
log_pt=np.log(j['part_pt']),
part_deta=j['part_deta'],
part_dphi=j['part_dphi'],
part_logptrel=part_logptrel[pn],
part_logerel=part_logerel[pn],
part_deltaR=part_deltaR[pn],
part_d0=np.tanh(j['part_d0val']),
part_dz=np.tanh(j['part_dzval']),
particle_type=cls.jud_type([j[t][pn] for t in particle_list]),
particle_pid=cls.jud_type([j[t][pn] for t in particle_list]).value,
jet_type=jet_type
), range(len(j['part_pt']))))
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=j['jet_pt'],
jet_eta=j['jet_eta'],
particles=particles,
label=label
)
jet_list.append(jet)
return cls(jets=jet_list, jets_type=jets_type)
def save_to_binary(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
file_counter = 0
for jet in tqdm(self.jets):
filename = f"{self.jets_type}_{file_counter}.bin"
filepath = os.path.join(save_dir, filename)
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
with open(filepath, 'wb') as file:
pickle.dump(jet_data, file)
file_counter += 1
@click.command()
@click.argument('root_dir', type=click.Path(exists=True))
@click.argument('save_dir', type=click.Path())
@click.option('--attr-sep', default=',', help='Separator for attributes.')
@click.option('--part-sep', default='|', help='Separator for particles.')
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
def main(root_dir, save_dir, attr_sep, part_sep, selected_attrs):
selected_attrs = selected_attrs.split(',')
preprocess(root_dir, save_dir, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
def preprocess(root_dir, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
root_files = list(Path(root_dir).glob('*.root'))
os.makedirs(save_dir, exist_ok=True)
joblib.Parallel(n_jobs=-1)(
joblib.delayed(process_file)(root_file, save_dir, selected_attrs, attr_sep, part_sep) for root_file in root_files
)
def process_file(root_file, save_dir, selected_attrs, attr_sep, part_sep):
jet_set = JetSet.build_jetset(root_file)
jet_set.save_to_binary(save_dir, selected_attrs, attr_sep, part_sep)
if __name__ == "__main__":
main()