#!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @file :data_struct.py @Description: : @Date :2024/06/04 09:08:46 @Author :hotwa @version :1.0 ''' import attrs from typing import List import numpy as np import uproot from pathlib import Path from tqdm import tqdm import os import random import pickle @attrs.define class ParticleBase: particle_id: int = attrs.field() # unique id for each particle part_charge: int = attrs.field() # charge of the particle part_energy: float = attrs.field() # energy of the particle part_px: float = attrs.field() # x-component of the momentum vector part_py: float = attrs.field() # y-component of the momentum vector part_pz: float = attrs.field() # z-component of the momentum vector log_energy: float = attrs.field() # log10(part_energy) log_pt: float = attrs.field() # log10(part_pt) part_deta: float = attrs.field() # pseudorapidity part_dphi: float = attrs.field() # azimuthal angle part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet)) part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet)) part_deltaR: float = attrs.field() # distance between the particle and the jet part_d0: float = attrs.field() # tanh(d0) part_dz: float = attrs.field() # tanh(z0) #particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others) particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7) def properties_concatenated(self) -> str: # 使用map函数将属性值转换为字符串,并保留整数和浮点数的格式 values_str = map(lambda x: f"{x}", attrs.astuple(self)) # 连接所有属性值为一个字符串 concatenated_str = ' '.join(values_str) return concatenated_str def attributes_to_float_list(self) -> list: attribute_values = [] for field in attrs.fields(self.__class__): # 获取类的所有字段 value = getattr(self, field.name) # 检查属性值是否为整数,如果是则转换为浮点数 attribute_values.append(float(value) if isinstance(value, int) else value) return attribute_values @attrs.define class Jet: jet_energy: float = attrs.field() # energy of the jet jet_pt: float = attrs.field() # transverse momentum of the jet jet_eta: float = attrs.field() # pseudorapidity of the jet label: str = attrs.field() # tpye of bb or bbbar particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet def __len__(self): return len(self.particles) def particles_concatenated(self) -> str: # 连接所有粒子属性值为一个字符串 concatenated_str = ','.join([particle.properties_concatenated() for particle in self.particles]) return concatenated_str def particles_attribute_list(self) ->list: attribute_list = [particle.attributes_to_float_list() for particle in self.particles] return attribute_list @attrs.define class JetSet: jets: List[Jet] def __len__(self): return len(self.jets) def jud_type(jtmp): particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6} max_element = max(jtmp) idx = jtmp.index(max_element) items = list(particle_dict.items()) return items[idx][0], items[idx][1] def build_jetset(root_file): with uproot.open(root_file) as f: tree = f["tree"] a = tree.arrays(library="pd") # pd.DataFrame # print(a.keys()) label = Path(root_file).stem.split('_')[0] jet_list = [] for i, j in a.iterrows(): part_pt = np.array(j['part_pt']) jet_pt = np.array(j['jet_pt']) part_logptrel = np.log(np.divide(part_pt, jet_pt)) part_energy = np.array(j['part_energy']) jet_energy = np.array(j['jet_energy']) part_logerel = np.log(np.divide(part_energy, jet_energy)) part_deta = np.array(j['part_deta']) part_dphi = np.array(j['part_dphi']) part_deltaR = np.hypot(part_deta, part_dphi) assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta']) #particle_num = len(len(j['part_pt'])) # add particles particles = [] particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton'] part_type = [] part_pid = [] for pn in range(len(j['part_pt'])): jtmp = [] for t in particle_list: jtmp.append(j[t][pn]) tmp_type, tmp_pid = jud_type(jtmp) part_type.append(tmp_type) part_pid.append(tmp_pid) for ii, jj, kk, ptype, pid in zip(part_logptrel, part_logerel, part_deltaR, part_type, part_pid): particles.append(ParticleBase( particle_id=i, part_charge=j['part_charge'], part_energy=j['part_energy'], part_px=j['part_px'], part_py=j['part_py'], part_pz=j['part_pz'], log_energy=np.log(j['part_energy']), log_pt=np.log(j['part_pt']), part_deta=j['part_deta'], part_dphi=j['part_dphi'], part_logptrel=ii, part_logerel=jj, part_deltaR=kk, part_d0=np.tanh(j['part_d0val']), part_dz=np.tanh(j['part_dzval']), #particle_type=ptype, # assuming you will set this correctly particle_pid=pid # assuming you will set this correctly )) # add jets jet = Jet( jet_energy=j['jet_energy'], jet_pt=['jet_pt'], jet_eta=['jet_eta'], particles=particles, label= label ) jet_list.append(jet) jet_set = JetSet(jets=jet_list) return jet_set def preprocess(root_dir,method='float_list'): train_jetset_list = [] val_jetset_list = [] train_ratio=0.8 # 遍历directory下的所有子目录 for root, dirs, files in os.walk(root_dir): for dir_name in dirs: # 直接遍历dirs列表,避免二次os.walk if dir_name in ['bb', 'bbbar']: bb_bbbar_files = [] # 遍历当前dir_name下的所有文件 for _, _, files in os.walk(os.path.join(root, dir_name)): for file in files: if file.endswith('.root'): # 构建文件的完整路径并添加到列表中 full_path = os.path.join(root, dir_name, file) bb_bbbar_files.append(full_path) if bb_bbbar_files: # 确保列表不为空才进行后续操作 random.shuffle(bb_bbbar_files) split_index = int(len(bb_bbbar_files) * train_ratio) train_files = bb_bbbar_files[:split_index] val_files = bb_bbbar_files[split_index:] print('len(train_files):',len(train_files)) print('len(val_files):',len(val_files)) # 限制训练集和验证集的文件数量 train_file_limit = 32 val_file_limit = 8 if method =="float_list": train_file_count = 0 file_counter = 0 for file in train_files: if train_file_count < train_file_limit: train_file_count += 1 print(f"loading{file} to train pkl") train_jetset = build_jetset(file) for jet in tqdm(train_jetset.jets): jet_list= jet.particles_attribute_list() label = jet.label file_counter += 1 filename = f"{label}_{file_counter}.pkl" filepath = os.path.join('/data/slow/100w_pkl/train', filename) with open(filepath, 'wb') as file: pickle.dump(jet_list,file) else: break val_file_count = 0 for file in val_files: if val_file_count < val_file_limit: val_file_count +=1 print(f"loading{file} to val pkl") train_jetset = build_jetset(file) for jet in tqdm(train_jetset.jets): jet_list= jet.particles_attribute_list() label = jet.label file_counter += 1 filename = f"{label}_{file_counter}.pkl" filepath = os.path.join('/data/slow/100w_pkl/val', filename) with open(filepath, 'wb') as file: pickle.dump(jet_list,file) else: break else: train_file_count = 0 file_counter = 0 for file in train_files: if train_file_count < train_file_limit: train_file_count += 1 print(f"loading{file}to train txt") train_jetset = build_jetset(file) for jet in tqdm(train_jetset.jets): jet_str= jet.particles_concatenated() label = jet.label file_counter += 1 filename = f"{label}_{file_counter}.txt" filepath = os.path.join('/data/slow/100w/train', filename) with open(filepath, 'w') as file: file.write(jet_str) else: break val_file_count = 0 for file in val_files: if val_file_count < val_file_limit: val_file_count +=1 print(f"loading{file}to val txt") train_jetset = build_jetset(file) for jet in tqdm(train_jetset.jets): jet_str= jet.particles_concatenated() label = jet.label file_counter += 1 filename = f"{label}_{file_counter}.txt" filepath = os.path.join('/data/slow/100w/val', filename) with open(filepath, 'w') as file: file.write(jet_str) else: break if __name__ == '__main__': root_dir = "/data/particle_raw/slow/data_100w/n2n2higgs" preprocess(root_dir,method='str')