Files
particle_analyse/data_struct.py
2024-06-10 09:09:10 +08:00

260 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :data_struct.py
@Description: :
@Date :2024/06/04 09:08:46
@Author :hotwa
@version :1.0
'''
import attrs
from typing import List
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import random
import pickle
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
#particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others)
particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7)
def properties_concatenated(self) -> str:
# 使用map函数将属性值转换为字符串并保留整数和浮点数的格式
values_str = map(lambda x: f"{x}", attrs.astuple(self))
# 连接所有属性值为一个字符串
concatenated_str = ' '.join(values_str)
return concatenated_str
def attributes_to_float_list(self) -> list:
attribute_values = []
for field in attrs.fields(self.__class__): # 获取类的所有字段
value = getattr(self, field.name)
# 检查属性值是否为整数,如果是则转换为浮点数
attribute_values.append(float(value) if isinstance(value, int) else value)
return attribute_values
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # tpye of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self) -> str:
# 连接所有粒子属性值为一个字符串
concatenated_str = ','.join([particle.properties_concatenated() for particle in self.particles])
return concatenated_str
def particles_attribute_list(self) ->list:
attribute_list = [particle.attributes_to_float_list() for particle in self.particles]
return attribute_list
@attrs.define
class JetSet:
jets: List[Jet]
def __len__(self):
return len(self.jets)
def jud_type(jtmp):
particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6}
max_element = max(jtmp)
idx = jtmp.index(max_element)
items = list(particle_dict.items())
return items[idx][0], items[idx][1]
def build_jetset(root_file):
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd") # pd.DataFrame
# print(a.keys())
label = Path(root_file).stem.split('_')[0]
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
#particle_num = len(len(j['part_pt']))
# add particles
particles = []
particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j['part_pt'])):
jtmp = []
for t in particle_list:
jtmp.append(j[t][pn])
tmp_type, tmp_pid = jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
for ii, jj, kk, ptype, pid in zip(part_logptrel, part_logerel, part_deltaR, part_type, part_pid):
particles.append(ParticleBase(
particle_id=i,
part_charge=j['part_charge'],
part_energy=j['part_energy'],
part_px=j['part_px'],
part_py=j['part_py'],
part_pz=j['part_pz'],
log_energy=np.log(j['part_energy']),
log_pt=np.log(j['part_pt']),
part_deta=j['part_deta'],
part_dphi=j['part_dphi'],
part_logptrel=ii,
part_logerel=jj,
part_deltaR=kk,
part_d0=np.tanh(j['part_d0val']),
part_dz=np.tanh(j['part_dzval']),
#particle_type=ptype, # assuming you will set this correctly
particle_pid=pid # assuming you will set this correctly
))
# add jets
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=['jet_pt'],
jet_eta=['jet_eta'],
particles=particles,
label= label
)
jet_list.append(jet)
jet_set = JetSet(jets=jet_list)
return jet_set
def preprocess(root_dir,method='float_list'):
train_jetset_list = []
val_jetset_list = []
train_ratio=0.8
# 遍历directory下的所有子目录
for root, dirs, files in os.walk(root_dir):
for dir_name in dirs: # 直接遍历dirs列表避免二次os.walk
if dir_name in ['bb', 'bbbar']:
bb_bbbar_files = []
# 遍历当前dir_name下的所有文件
for _, _, files in os.walk(os.path.join(root, dir_name)):
for file in files:
if file.endswith('.root'):
# 构建文件的完整路径并添加到列表中
full_path = os.path.join(root, dir_name, file)
bb_bbbar_files.append(full_path)
if bb_bbbar_files: # 确保列表不为空才进行后续操作
random.shuffle(bb_bbbar_files)
split_index = int(len(bb_bbbar_files) * train_ratio)
train_files = bb_bbbar_files[:split_index]
val_files = bb_bbbar_files[split_index:]
print('len(train_files):',len(train_files))
print('len(val_files):',len(val_files))
# 限制训练集和验证集的文件数量
train_file_limit = 32
val_file_limit = 8
if method =="float_list":
train_file_count = 0
file_counter = 0
for file in train_files:
if train_file_count < train_file_limit:
train_file_count += 1
print(f"loading{file} to train pkl")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_list= jet.particles_attribute_list()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.pkl"
filepath = os.path.join('/data/slow/100w_pkl/train', filename)
with open(filepath, 'wb') as file:
pickle.dump(jet_list,file)
else:
break
val_file_count = 0
for file in val_files:
if val_file_count < val_file_limit:
val_file_count +=1
print(f"loading{file} to val pkl")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_list= jet.particles_attribute_list()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.pkl"
filepath = os.path.join('/data/slow/100w_pkl/val', filename)
with open(filepath, 'wb') as file:
pickle.dump(jet_list,file)
else:
break
else:
train_file_count = 0
file_counter = 0
for file in train_files:
if train_file_count < train_file_limit:
train_file_count += 1
print(f"loading{file}to train txt")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_str= jet.particles_concatenated()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.txt"
filepath = os.path.join('/data/slow/100w/train', filename)
with open(filepath, 'w') as file:
file.write(jet_str)
else:
break
val_file_count = 0
for file in val_files:
if val_file_count < val_file_limit:
val_file_count +=1
print(f"loading{file}to val txt")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_str= jet.particles_concatenated()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.txt"
filepath = os.path.join('/data/slow/100w/val', filename)
with open(filepath, 'w') as file:
file.write(jet_str)
else:
break
if __name__ == '__main__':
root_dir = "/data/particle_raw/slow/data_100w/n2n2higgs"
preprocess(root_dir,method='str')