first add

This commit is contained in:
2024-06-10 09:09:10 +08:00
parent 52d5f402bc
commit 8f9fac8bd8
14 changed files with 2042 additions and 0 deletions

260
data_struct.py Normal file
View File

@@ -0,0 +1,260 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :data_struct.py
@Description: :
@Date :2024/06/04 09:08:46
@Author :hotwa
@version :1.0
'''
import attrs
from typing import List
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import random
import pickle
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
#particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others)
particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7)
def properties_concatenated(self) -> str:
# 使用map函数将属性值转换为字符串并保留整数和浮点数的格式
values_str = map(lambda x: f"{x}", attrs.astuple(self))
# 连接所有属性值为一个字符串
concatenated_str = ' '.join(values_str)
return concatenated_str
def attributes_to_float_list(self) -> list:
attribute_values = []
for field in attrs.fields(self.__class__): # 获取类的所有字段
value = getattr(self, field.name)
# 检查属性值是否为整数,如果是则转换为浮点数
attribute_values.append(float(value) if isinstance(value, int) else value)
return attribute_values
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # tpye of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self) -> str:
# 连接所有粒子属性值为一个字符串
concatenated_str = ','.join([particle.properties_concatenated() for particle in self.particles])
return concatenated_str
def particles_attribute_list(self) ->list:
attribute_list = [particle.attributes_to_float_list() for particle in self.particles]
return attribute_list
@attrs.define
class JetSet:
jets: List[Jet]
def __len__(self):
return len(self.jets)
def jud_type(jtmp):
particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6}
max_element = max(jtmp)
idx = jtmp.index(max_element)
items = list(particle_dict.items())
return items[idx][0], items[idx][1]
def build_jetset(root_file):
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd") # pd.DataFrame
# print(a.keys())
label = Path(root_file).stem.split('_')[0]
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
#particle_num = len(len(j['part_pt']))
# add particles
particles = []
particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j['part_pt'])):
jtmp = []
for t in particle_list:
jtmp.append(j[t][pn])
tmp_type, tmp_pid = jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
for ii, jj, kk, ptype, pid in zip(part_logptrel, part_logerel, part_deltaR, part_type, part_pid):
particles.append(ParticleBase(
particle_id=i,
part_charge=j['part_charge'],
part_energy=j['part_energy'],
part_px=j['part_px'],
part_py=j['part_py'],
part_pz=j['part_pz'],
log_energy=np.log(j['part_energy']),
log_pt=np.log(j['part_pt']),
part_deta=j['part_deta'],
part_dphi=j['part_dphi'],
part_logptrel=ii,
part_logerel=jj,
part_deltaR=kk,
part_d0=np.tanh(j['part_d0val']),
part_dz=np.tanh(j['part_dzval']),
#particle_type=ptype, # assuming you will set this correctly
particle_pid=pid # assuming you will set this correctly
))
# add jets
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=['jet_pt'],
jet_eta=['jet_eta'],
particles=particles,
label= label
)
jet_list.append(jet)
jet_set = JetSet(jets=jet_list)
return jet_set
def preprocess(root_dir,method='float_list'):
train_jetset_list = []
val_jetset_list = []
train_ratio=0.8
# 遍历directory下的所有子目录
for root, dirs, files in os.walk(root_dir):
for dir_name in dirs: # 直接遍历dirs列表避免二次os.walk
if dir_name in ['bb', 'bbbar']:
bb_bbbar_files = []
# 遍历当前dir_name下的所有文件
for _, _, files in os.walk(os.path.join(root, dir_name)):
for file in files:
if file.endswith('.root'):
# 构建文件的完整路径并添加到列表中
full_path = os.path.join(root, dir_name, file)
bb_bbbar_files.append(full_path)
if bb_bbbar_files: # 确保列表不为空才进行后续操作
random.shuffle(bb_bbbar_files)
split_index = int(len(bb_bbbar_files) * train_ratio)
train_files = bb_bbbar_files[:split_index]
val_files = bb_bbbar_files[split_index:]
print('len(train_files):',len(train_files))
print('len(val_files):',len(val_files))
# 限制训练集和验证集的文件数量
train_file_limit = 32
val_file_limit = 8
if method =="float_list":
train_file_count = 0
file_counter = 0
for file in train_files:
if train_file_count < train_file_limit:
train_file_count += 1
print(f"loading{file} to train pkl")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_list= jet.particles_attribute_list()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.pkl"
filepath = os.path.join('/data/slow/100w_pkl/train', filename)
with open(filepath, 'wb') as file:
pickle.dump(jet_list,file)
else:
break
val_file_count = 0
for file in val_files:
if val_file_count < val_file_limit:
val_file_count +=1
print(f"loading{file} to val pkl")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_list= jet.particles_attribute_list()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.pkl"
filepath = os.path.join('/data/slow/100w_pkl/val', filename)
with open(filepath, 'wb') as file:
pickle.dump(jet_list,file)
else:
break
else:
train_file_count = 0
file_counter = 0
for file in train_files:
if train_file_count < train_file_limit:
train_file_count += 1
print(f"loading{file}to train txt")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_str= jet.particles_concatenated()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.txt"
filepath = os.path.join('/data/slow/100w/train', filename)
with open(filepath, 'w') as file:
file.write(jet_str)
else:
break
val_file_count = 0
for file in val_files:
if val_file_count < val_file_limit:
val_file_count +=1
print(f"loading{file}to val txt")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_str= jet.particles_concatenated()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.txt"
filepath = os.path.join('/data/slow/100w/val', filename)
with open(filepath, 'w') as file:
file.write(jet_str)
else:
break
if __name__ == '__main__':
root_dir = "/data/particle_raw/slow/data_100w/n2n2higgs"
preprocess(root_dir,method='str')