260 lines
12 KiB
Python
260 lines
12 KiB
Python
#!/usr/bin/env python
|
||
# -*- encoding: utf-8 -*-
|
||
'''
|
||
@file :data_struct.py
|
||
@Description: :
|
||
@Date :2024/06/04 09:08:46
|
||
@Author :hotwa
|
||
@version :1.0
|
||
'''
|
||
|
||
import attrs
|
||
from typing import List
|
||
import numpy as np
|
||
import uproot
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
import os
|
||
import random
|
||
import pickle
|
||
@attrs.define
|
||
class ParticleBase:
|
||
particle_id: int = attrs.field() # unique id for each particle
|
||
part_charge: int = attrs.field() # charge of the particle
|
||
part_energy: float = attrs.field() # energy of the particle
|
||
part_px: float = attrs.field() # x-component of the momentum vector
|
||
part_py: float = attrs.field() # y-component of the momentum vector
|
||
part_pz: float = attrs.field() # z-component of the momentum vector
|
||
log_energy: float = attrs.field() # log10(part_energy)
|
||
log_pt: float = attrs.field() # log10(part_pt)
|
||
part_deta: float = attrs.field() # pseudorapidity
|
||
part_dphi: float = attrs.field() # azimuthal angle
|
||
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
||
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
||
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
||
part_d0: float = attrs.field() # tanh(d0)
|
||
part_dz: float = attrs.field() # tanh(z0)
|
||
#particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others)
|
||
particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7)
|
||
|
||
def properties_concatenated(self) -> str:
|
||
# 使用map函数将属性值转换为字符串,并保留整数和浮点数的格式
|
||
values_str = map(lambda x: f"{x}", attrs.astuple(self))
|
||
# 连接所有属性值为一个字符串
|
||
concatenated_str = ' '.join(values_str)
|
||
return concatenated_str
|
||
|
||
def attributes_to_float_list(self) -> list:
|
||
attribute_values = []
|
||
for field in attrs.fields(self.__class__): # 获取类的所有字段
|
||
value = getattr(self, field.name)
|
||
# 检查属性值是否为整数,如果是则转换为浮点数
|
||
attribute_values.append(float(value) if isinstance(value, int) else value)
|
||
return attribute_values
|
||
|
||
@attrs.define
|
||
class Jet:
|
||
jet_energy: float = attrs.field() # energy of the jet
|
||
jet_pt: float = attrs.field() # transverse momentum of the jet
|
||
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
||
label: str = attrs.field() # tpye of bb or bbbar
|
||
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
||
|
||
def __len__(self):
|
||
return len(self.particles)
|
||
|
||
def particles_concatenated(self) -> str:
|
||
# 连接所有粒子属性值为一个字符串
|
||
concatenated_str = ','.join([particle.properties_concatenated() for particle in self.particles])
|
||
return concatenated_str
|
||
def particles_attribute_list(self) ->list:
|
||
attribute_list = [particle.attributes_to_float_list() for particle in self.particles]
|
||
return attribute_list
|
||
|
||
@attrs.define
|
||
class JetSet:
|
||
jets: List[Jet]
|
||
|
||
def __len__(self):
|
||
return len(self.jets)
|
||
|
||
def jud_type(jtmp):
|
||
particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6}
|
||
max_element = max(jtmp)
|
||
idx = jtmp.index(max_element)
|
||
items = list(particle_dict.items())
|
||
return items[idx][0], items[idx][1]
|
||
|
||
def build_jetset(root_file):
|
||
with uproot.open(root_file) as f:
|
||
tree = f["tree"]
|
||
a = tree.arrays(library="pd") # pd.DataFrame
|
||
# print(a.keys())
|
||
label = Path(root_file).stem.split('_')[0]
|
||
jet_list = []
|
||
for i, j in a.iterrows():
|
||
part_pt = np.array(j['part_pt'])
|
||
jet_pt = np.array(j['jet_pt'])
|
||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||
|
||
part_energy = np.array(j['part_energy'])
|
||
jet_energy = np.array(j['jet_energy'])
|
||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||
|
||
part_deta = np.array(j['part_deta'])
|
||
part_dphi = np.array(j['part_dphi'])
|
||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||
|
||
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
||
#particle_num = len(len(j['part_pt']))
|
||
# add particles
|
||
particles = []
|
||
|
||
|
||
particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton']
|
||
part_type = []
|
||
part_pid = []
|
||
for pn in range(len(j['part_pt'])):
|
||
jtmp = []
|
||
for t in particle_list:
|
||
jtmp.append(j[t][pn])
|
||
tmp_type, tmp_pid = jud_type(jtmp)
|
||
part_type.append(tmp_type)
|
||
part_pid.append(tmp_pid)
|
||
|
||
|
||
for ii, jj, kk, ptype, pid in zip(part_logptrel, part_logerel, part_deltaR, part_type, part_pid):
|
||
particles.append(ParticleBase(
|
||
particle_id=i,
|
||
part_charge=j['part_charge'],
|
||
part_energy=j['part_energy'],
|
||
part_px=j['part_px'],
|
||
part_py=j['part_py'],
|
||
part_pz=j['part_pz'],
|
||
log_energy=np.log(j['part_energy']),
|
||
log_pt=np.log(j['part_pt']),
|
||
part_deta=j['part_deta'],
|
||
part_dphi=j['part_dphi'],
|
||
part_logptrel=ii,
|
||
part_logerel=jj,
|
||
part_deltaR=kk,
|
||
part_d0=np.tanh(j['part_d0val']),
|
||
part_dz=np.tanh(j['part_dzval']),
|
||
#particle_type=ptype, # assuming you will set this correctly
|
||
particle_pid=pid # assuming you will set this correctly
|
||
))
|
||
# add jets
|
||
jet = Jet(
|
||
jet_energy=j['jet_energy'],
|
||
jet_pt=['jet_pt'],
|
||
jet_eta=['jet_eta'],
|
||
particles=particles,
|
||
label= label
|
||
)
|
||
jet_list.append(jet)
|
||
|
||
jet_set = JetSet(jets=jet_list)
|
||
return jet_set
|
||
|
||
def preprocess(root_dir,method='float_list'):
|
||
train_jetset_list = []
|
||
val_jetset_list = []
|
||
train_ratio=0.8
|
||
# 遍历directory下的所有子目录
|
||
for root, dirs, files in os.walk(root_dir):
|
||
for dir_name in dirs: # 直接遍历dirs列表,避免二次os.walk
|
||
if dir_name in ['bb', 'bbbar']:
|
||
bb_bbbar_files = []
|
||
# 遍历当前dir_name下的所有文件
|
||
for _, _, files in os.walk(os.path.join(root, dir_name)):
|
||
for file in files:
|
||
if file.endswith('.root'):
|
||
# 构建文件的完整路径并添加到列表中
|
||
full_path = os.path.join(root, dir_name, file)
|
||
bb_bbbar_files.append(full_path)
|
||
|
||
if bb_bbbar_files: # 确保列表不为空才进行后续操作
|
||
random.shuffle(bb_bbbar_files)
|
||
split_index = int(len(bb_bbbar_files) * train_ratio)
|
||
train_files = bb_bbbar_files[:split_index]
|
||
val_files = bb_bbbar_files[split_index:]
|
||
print('len(train_files):',len(train_files))
|
||
print('len(val_files):',len(val_files))
|
||
# 限制训练集和验证集的文件数量
|
||
train_file_limit = 32
|
||
val_file_limit = 8
|
||
if method =="float_list":
|
||
train_file_count = 0
|
||
file_counter = 0
|
||
for file in train_files:
|
||
if train_file_count < train_file_limit:
|
||
train_file_count += 1
|
||
print(f"loading{file} to train pkl")
|
||
train_jetset = build_jetset(file)
|
||
for jet in tqdm(train_jetset.jets):
|
||
jet_list= jet.particles_attribute_list()
|
||
label = jet.label
|
||
file_counter += 1
|
||
filename = f"{label}_{file_counter}.pkl"
|
||
filepath = os.path.join('/data/slow/100w_pkl/train', filename)
|
||
with open(filepath, 'wb') as file:
|
||
pickle.dump(jet_list,file)
|
||
else:
|
||
break
|
||
val_file_count = 0
|
||
for file in val_files:
|
||
if val_file_count < val_file_limit:
|
||
val_file_count +=1
|
||
print(f"loading{file} to val pkl")
|
||
train_jetset = build_jetset(file)
|
||
for jet in tqdm(train_jetset.jets):
|
||
jet_list= jet.particles_attribute_list()
|
||
label = jet.label
|
||
file_counter += 1
|
||
filename = f"{label}_{file_counter}.pkl"
|
||
filepath = os.path.join('/data/slow/100w_pkl/val', filename)
|
||
with open(filepath, 'wb') as file:
|
||
pickle.dump(jet_list,file)
|
||
else:
|
||
break
|
||
else:
|
||
train_file_count = 0
|
||
file_counter = 0
|
||
for file in train_files:
|
||
if train_file_count < train_file_limit:
|
||
train_file_count += 1
|
||
print(f"loading{file}to train txt")
|
||
train_jetset = build_jetset(file)
|
||
for jet in tqdm(train_jetset.jets):
|
||
jet_str= jet.particles_concatenated()
|
||
label = jet.label
|
||
file_counter += 1
|
||
|
||
filename = f"{label}_{file_counter}.txt"
|
||
filepath = os.path.join('/data/slow/100w/train', filename)
|
||
with open(filepath, 'w') as file:
|
||
file.write(jet_str)
|
||
else:
|
||
break
|
||
val_file_count = 0
|
||
for file in val_files:
|
||
if val_file_count < val_file_limit:
|
||
val_file_count +=1
|
||
print(f"loading{file}to val txt")
|
||
train_jetset = build_jetset(file)
|
||
for jet in tqdm(train_jetset.jets):
|
||
jet_str= jet.particles_concatenated()
|
||
label = jet.label
|
||
file_counter += 1
|
||
|
||
filename = f"{label}_{file_counter}.txt"
|
||
filepath = os.path.join('/data/slow/100w/val', filename)
|
||
with open(filepath, 'w') as file:
|
||
file.write(jet_str)
|
||
else:
|
||
break
|
||
|
||
if __name__ == '__main__':
|
||
root_dir = "/data/particle_raw/slow/data_100w/n2n2higgs"
|
||
preprocess(root_dir,method='str')
|
||
|