251 lines
10 KiB
Python
251 lines
10 KiB
Python
import enum
|
|
import attrs
|
|
from typing import List, Optional, Union
|
|
import numpy as np
|
|
import uproot
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
import os
|
|
import pickle
|
|
import joblib
|
|
import click
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import math
|
|
|
|
class ByteDataset(Dataset):
|
|
def __init__(self, filenames, patch_size: int = 16, patch_length: int = 512):
|
|
self.patch_size = patch_size
|
|
self.patch_length = patch_length
|
|
print(f"Loading {len(filenames)} files for classification")
|
|
self.filenames = []
|
|
self.labels = {'bb': 0, 'bbbar': 1} # 映射标签到整数
|
|
|
|
for filename in tqdm(filenames):
|
|
file_size = os.path.getsize(filename)
|
|
file_size = math.ceil(file_size / self.patch_size)
|
|
ext = filename.split('.')[-1]
|
|
label = os.path.basename(filename).split('_')[0] # 使用前缀部分作为标签
|
|
label = f"{label}.{ext}"
|
|
|
|
if file_size <= self.patch_length - 2:
|
|
self.filenames.append((filename, label))
|
|
if label not in self.labels:
|
|
self.labels[label] = len(self.labels)
|
|
|
|
def __len__(self):
|
|
return len(self.filenames)
|
|
|
|
def __getitem__(self, idx):
|
|
filename, label = self.filenames[idx]
|
|
file_bytes = self.read_bytes(filename)
|
|
|
|
file_bytes = torch.tensor(file_bytes, dtype=torch.long)
|
|
label = torch.tensor(self.labels[label], dtype=torch.long)
|
|
|
|
return file_bytes, label, filename
|
|
|
|
def readbytes(self, filename):
|
|
ext = filename.split('.')[-1]
|
|
ext = bytearray(ext, 'utf-8')
|
|
ext = [byte for byte in ext][:self.patch_size]
|
|
|
|
with open(filename, 'rb') as f:
|
|
file_bytes = f.read()
|
|
|
|
bytes = []
|
|
for byte in file_bytes:
|
|
bytes.append(byte)
|
|
|
|
if len(bytes) % self.patch_size != 0:
|
|
bytes = bytes + [256] * (self.patch_size - len(bytes) % self.patch_size)
|
|
|
|
bos_patch = ext + [256] * (self.patch_size - len(ext))
|
|
bytes = bos_patch + bytes + [256] * self.patch_size
|
|
|
|
return bytes
|
|
|
|
class ParticleType(enum.Enum):
|
|
NeutralHadron = 0
|
|
Photon = 1
|
|
Electron = 2
|
|
Muon = 3
|
|
Pion = 4
|
|
ChargedKaon = 5
|
|
Proton = 6
|
|
Others = 7
|
|
|
|
@attrs.define
|
|
class ParticleBase:
|
|
particle_id: int = attrs.field() # unique id for each particle
|
|
part_charge: int = attrs.field() # charge of the particle
|
|
part_energy: float = attrs.field() # energy of the particle
|
|
part_px: float = attrs.field() # x-component of the momentum vector
|
|
part_py: float = attrs.field() # y-component of the momentum vector
|
|
part_pz: float = attrs.field() # z-component of the momentum vector
|
|
log_energy: float = attrs.field() # log10(part_energy)
|
|
log_pt: float = attrs.field() # log10(part_pt)
|
|
part_deta: float = attrs.field() # pseudorapidity
|
|
part_dphi: float = attrs.field() # azimuthal angle
|
|
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
|
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
|
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
|
part_d0: float = attrs.field() # tanh(d0)
|
|
part_dz: float = attrs.field() # tanh(z0)
|
|
particle_type: ParticleType = attrs.field() # type of the particle as an enum
|
|
particle_pid: int = attrs.field() # pid of the particle
|
|
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
|
|
|
|
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
|
|
if selected_attrs is None:
|
|
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
|
values = [getattr(self, attr) for attr in selected_attrs]
|
|
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
|
return attr_sep.join(map(str, values))
|
|
|
|
def attributes_to_float_list(self, selected_attrs: Optional[List[str]] = None) -> list:
|
|
if selected_attrs is None:
|
|
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
|
values = [getattr(self, attr) for attr in selected_attrs]
|
|
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
|
return list(map(lambda x: float(x) if isinstance(x, int) else x, values))
|
|
|
|
@attrs.define
|
|
class Jet:
|
|
jet_energy: float = attrs.field() # energy of the jet
|
|
jet_pt: float = attrs.field() # transverse momentum of the jet
|
|
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
|
label: str = attrs.field() # type of bb or bbbar
|
|
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
|
|
|
def __len__(self):
|
|
return len(self.particles)
|
|
|
|
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
|
|
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
|
|
|
|
def particles_attribute_list(self, selected_attrs: Optional[List[str]] = None) -> list:
|
|
return list(map(lambda p: p.attributes_to_float_list(selected_attrs), self.particles))
|
|
|
|
@attrs.define
|
|
class JetSet:
|
|
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
|
|
jets: List[Jet] = attrs.field(factory=list)
|
|
|
|
def __len__(self):
|
|
return len(self.jets)
|
|
|
|
@staticmethod
|
|
def jud_type(jtmp):
|
|
particle_dict = {
|
|
0: ParticleType.NeutralHadron,
|
|
1: ParticleType.Photon,
|
|
2: ParticleType.Electron,
|
|
3: ParticleType.Muon,
|
|
4: ParticleType.Pion,
|
|
5: ParticleType.ChargedKaon,
|
|
6: ParticleType.Proton,
|
|
7: ParticleType.Others
|
|
}
|
|
max_element = max(jtmp)
|
|
idx = jtmp.index(max_element)
|
|
return particle_dict.get(idx, ParticleType.Others)
|
|
|
|
@classmethod
|
|
def build_jetset(cls, root_file: Union[str, Path]) -> "JetSet":
|
|
print(f"Building JetSet from {root_file}")
|
|
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
|
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
|
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
|
with uproot.open(root_file) as f:
|
|
tree = f["tree"]
|
|
a = tree.arrays(library="pd")
|
|
|
|
label = Path(root_file).stem.split('_')[0]
|
|
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
|
jet_list = []
|
|
for i, j in a.iterrows():
|
|
part_pt = np.array(j['part_pt'])
|
|
jet_pt = np.array(j['jet_pt'])
|
|
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
|
|
|
part_energy = np.array(j['part_energy'])
|
|
jet_energy = np.array(j['jet_energy'])
|
|
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
|
|
|
part_deta = np.array(j['part_deta'])
|
|
part_dphi = np.array(j['part_dphi'])
|
|
part_deltaR = np.hypot(part_deta, part_dphi)
|
|
|
|
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
|
|
|
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
|
|
|
particles = list(map(lambda pn: ParticleBase(
|
|
particle_id=i,
|
|
part_charge=j['part_charge'],
|
|
part_energy=j['part_energy'],
|
|
part_px=j['part_px'],
|
|
part_py=j['part_py'],
|
|
part_pz=j['part_pz'],
|
|
log_energy=np.log(j['part_energy']),
|
|
log_pt=np.log(j['part_pt']),
|
|
part_deta=j['part_deta'],
|
|
part_dphi=j['part_dphi'],
|
|
part_logptrel=part_logptrel[pn],
|
|
part_logerel=part_logerel[pn],
|
|
part_deltaR=part_deltaR[pn],
|
|
part_d0=np.tanh(j['part_d0val']),
|
|
part_dz=np.tanh(j['part_dzval']),
|
|
particle_type=cls.jud_type([j[t][pn] for t in particle_list]),
|
|
particle_pid=cls.jud_type([j[t][pn] for t in particle_list]).value,
|
|
jet_type=jet_type
|
|
), range(len(j['part_pt']))))
|
|
|
|
jet = Jet(
|
|
jet_energy=j['jet_energy'],
|
|
jet_pt=j['jet_pt'],
|
|
jet_eta=j['jet_eta'],
|
|
particles=particles,
|
|
label=label
|
|
)
|
|
jet_list.append(jet)
|
|
|
|
return cls(jets=jet_list, jets_type=jets_type)
|
|
|
|
def save_to_binary(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
|
file_counter = 0
|
|
for jet in tqdm(self.jets):
|
|
filename = f"{self.jets_type}_{file_counter}.bin"
|
|
filepath = os.path.join(save_dir, filename)
|
|
|
|
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
|
|
with open(filepath, 'wb') as file:
|
|
pickle.dump(jet_data, file)
|
|
|
|
file_counter += 1
|
|
|
|
@click.command()
|
|
@click.argument('root_dir', type=click.Path(exists=True))
|
|
@click.argument('save_dir', type=click.Path())
|
|
@click.option('--attr-sep', default=',', help='Separator for attributes.')
|
|
@click.option('--part-sep', default='|', help='Separator for particles.')
|
|
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
|
|
def main(root_dir, save_dir, attr_sep, part_sep, selected_attrs):
|
|
selected_attrs = selected_attrs.split(',')
|
|
preprocess(root_dir, save_dir, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
|
|
|
|
def preprocess(root_dir, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
|
root_files = list(Path(root_dir).glob('*.root'))
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
joblib.Parallel(n_jobs=-1)(
|
|
joblib.delayed(process_file)(root_file, save_dir, selected_attrs, attr_sep, part_sep) for root_file in root_files
|
|
)
|
|
|
|
def process_file(root_file, save_dir, selected_attrs, attr_sep, part_sep):
|
|
jet_set = JetSet.build_jetset(root_file)
|
|
jet_set.save_to_binary(save_dir, selected_attrs, attr_sep, part_sep)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|