first add
This commit is contained in:
9
.gitignore
vendored
Executable file
9
.gitignore
vendored
Executable file
@@ -0,0 +1,9 @@
|
||||
__pycache__/
|
||||
*.spec
|
||||
*.pyc
|
||||
*.tar.gz
|
||||
dist/
|
||||
build/
|
||||
data/
|
||||
data_100w/
|
||||
save/
|
||||
345
LLM.ipynb
Normal file
345
LLM.ipynb
Normal file
@@ -0,0 +1,345 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "074cbeba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import uproot\n",
|
||||
"import attrs\n",
|
||||
"from typing import List\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b49ac144",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with uproot.open(\"./fast.root\") as f:\n",
|
||||
" tree = f[\"tree\"]\n",
|
||||
" a = tree.arrays(library=\"pd\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c2d94275",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Index(['btag', 'ctag', 'gen_match', 'genpart_eta', 'genpart_phi',\n",
|
||||
" 'genpart_pid', 'genpart_pt', 'is_signal', 'jet_energy', 'jet_eta',\n",
|
||||
" 'jet_nparticles', 'jet_phi', 'jet_pt', 'part_charge', 'part_d0err',\n",
|
||||
" 'part_d0val', 'part_deta', 'part_dphi', 'part_dzerr', 'part_dzval',\n",
|
||||
" 'part_energy', 'part_pid', 'part_pt', 'part_px', 'part_py', 'part_pz',\n",
|
||||
" 'part_isChargedHadron', 'part_isChargedKaon', 'part_isElectron',\n",
|
||||
" 'part_isKLong', 'part_isKShort', 'part_isMuon', 'part_isNeutralHadron',\n",
|
||||
" 'part_isPhoton', 'part_isPi0', 'part_isPion', 'part_isProton',\n",
|
||||
" 'label_b', 'label_bb', 'label_bbar', 'label_c', 'label_cbar',\n",
|
||||
" 'label_cc', 'label_d', 'label_dbar', 'label_g', 'label_gg', 'label_s',\n",
|
||||
" 'label_sbar', 'label_u', 'label_ubar'],\n",
|
||||
" dtype='object')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(a.keys())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "93a76758",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import attrs\n",
|
||||
"import uproot\n",
|
||||
"import numpy as np\n",
|
||||
"from typing import List\n",
|
||||
"@attrs.define\n",
|
||||
"class ParticleBase:\n",
|
||||
" part_charge: int = attrs.field() # charge of the particle\n",
|
||||
" part_energy: float = attrs.field() # energy of the particle\n",
|
||||
" part_px: float = attrs.field() # x-component of the momentum vector\n",
|
||||
" part_py: float = attrs.field() # y-component of the momentum vector\n",
|
||||
" part_pz: float = attrs.field() # z-component of the momentum vector\n",
|
||||
" log_energy: float = attrs.field() # log10(part_energy)\n",
|
||||
" log_pt: float = attrs.field() # log10(part_pt)\n",
|
||||
" part_deta: float = attrs.field() # pseudorapidity\n",
|
||||
" part_dphi: float = attrs.field() # azimuthal angle\n",
|
||||
" part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))\n",
|
||||
" part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))\n",
|
||||
" part_deltaR: float = attrs.field() # distance between the particle and the jet\n",
|
||||
" part_d0: float = attrs.field() # tanh(d0)\n",
|
||||
" part_dz: float = attrs.field() # tanh(z0)\n",
|
||||
" particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others)\n",
|
||||
" particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7)\n",
|
||||
"\n",
|
||||
"@attrs.define\n",
|
||||
"class Jet:\n",
|
||||
" jet_b: float = attrs.field()\n",
|
||||
" jet_bbar: float = attrs.field()\n",
|
||||
" jet_energy: float = attrs.field() # energy of the jet\n",
|
||||
" jet_pt: float = attrs.field() # transverse momentum of the jet\n",
|
||||
" jet_eta: float = attrs.field() # pseudorapidity of the jet\n",
|
||||
" particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.particles)\n",
|
||||
" \n",
|
||||
"@attrs.define\n",
|
||||
"class JetSet:\n",
|
||||
" jets: List[Jet]\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.jets)\n",
|
||||
" \n",
|
||||
"def jud_type(jtmp): #这个函数用来判断每个粒子的类型,每个粒子可以是electron、muon、pion 等\n",
|
||||
" particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6}\n",
|
||||
" max_element = max(jtmp)\n",
|
||||
" idx = jtmp.index(max_element)\n",
|
||||
" items = list(particle_dict.items())\n",
|
||||
" return items[idx][0], items[idx][1]\n",
|
||||
" \n",
|
||||
"with uproot.open(\"./data/data_fast/fast_bb.root\") as f:\n",
|
||||
" tree = f[\"tree\"]\n",
|
||||
" a = tree.arrays(library=\"pd\")\n",
|
||||
"#a里面有很多喷注,我们的目标是判断每个喷注 它是 b 还是 bbar. 每个喷注里含了很多粒子。以下以 part 开头的变量都是列表,比如 part_pt,它是存储了一个喷注中所有粒子的pt\n",
|
||||
"#part_energy 存储了一个喷注中所有粒子的energy。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"jet_list = []\n",
|
||||
"for j in a.itertuples(): \n",
|
||||
" part_pt = np.array(j.part_pt)\n",
|
||||
" jet_pt = np.array(j.jet_pt)\n",
|
||||
" part_logptrel = np.log(np.divide(part_pt, jet_pt))\n",
|
||||
" \n",
|
||||
" part_energy = np.array(j.part_energy)\n",
|
||||
" jet_energy = np.array(j.jet_energy)\n",
|
||||
" part_logerel = np.log(np.divide(part_energy, jet_energy))\n",
|
||||
" \n",
|
||||
" part_deta = np.array(j.part_deta)\n",
|
||||
" part_dphi = np.array(j.part_dphi)\n",
|
||||
" part_deltaR = np.hypot(part_deta, part_dphi)\n",
|
||||
" \n",
|
||||
" assert len(j.part_pt) == len(j.part_energy) == len(j.part_deta)\n",
|
||||
"\n",
|
||||
" particles = []\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton']\n",
|
||||
" part_type = []\n",
|
||||
" part_pid = []\n",
|
||||
" for pn in range(len(j.part_pt)):\n",
|
||||
" jtmp = [j.part_isNeutralHadron[pn], j.part_isPhoton[pn], j.part_isElectron[pn], j.part_isMuon[pn], j.part_isPion[pn],\n",
|
||||
" j.part_isChargedKaon[pn], j.part_isProton[pn]]\n",
|
||||
" tmp_type, tmp_pid = jud_type(jtmp)\n",
|
||||
" part_type.append(tmp_type)\n",
|
||||
" part_pid.append(tmp_pid)\n",
|
||||
" \n",
|
||||
" bag = zip(j.part_charge, j.part_energy, j.part_px, j.part_py, j.part_pz, np.log(j.part_energy), \n",
|
||||
" np.log(j.part_pt), j.part_deta, j.part_dphi, part_logptrel, part_logerel, part_deltaR, \n",
|
||||
" np.tanh(j.part_d0val), np.tanh(j.part_dzval), part_type, part_pid)\n",
|
||||
" \n",
|
||||
" #下边的代码是要对第 j 个喷注中的所有粒子做循环,将每个粒子都 存成 ParticleBase,然后 append 到 particles里,\n",
|
||||
" #所以 partices 存储了 第 j 个喷注中所有粒子的信息\n",
|
||||
" for c, en, px, py, pz, lEn, lPt, eta, phi, ii, jj, kk, d0, dz, ptype, pid in bag:\n",
|
||||
" particles.append(ParticleBase(\n",
|
||||
" part_charge=c, \n",
|
||||
" part_energy=en, \n",
|
||||
" part_px=px, \n",
|
||||
" part_py=py,\n",
|
||||
" part_pz=pz, \n",
|
||||
" log_energy=lEn, \n",
|
||||
" log_pt=lPt,\n",
|
||||
" part_deta=eta, \n",
|
||||
" part_dphi=phi, \n",
|
||||
" part_logptrel=ii,\n",
|
||||
" part_logerel=jj, \n",
|
||||
" part_deltaR=kk,\n",
|
||||
" part_d0=d0, \n",
|
||||
" part_dz=dz, \n",
|
||||
" particle_type=ptype, # assuming you will set this correctly\n",
|
||||
" particle_pid=pid # assuming you will set this correctly\n",
|
||||
" ))\n",
|
||||
" # add jets jet = 喷注,\n",
|
||||
" jet = Jet(\n",
|
||||
" jet_b=j.label_b, #如果此jet是b,那么label_b = 1, 否则label_b = 0\n",
|
||||
" jet_bbar=j.label_bbar, #如果此jet是bbar,那么label_bbar = 1\n",
|
||||
" jet_energy=j.jet_energy, #第 j 个喷注的 energy\n",
|
||||
" jet_pt=j.jet_pt, # 第 j 个喷注的 pt\n",
|
||||
" jet_eta=j.jet_eta, # 第 j 个喷注的 eta (是一种角度的表示)\n",
|
||||
" particles=particles # 第 j 个喷注中所有的 粒子\n",
|
||||
" )\n",
|
||||
" jet_list.append(jet)\n",
|
||||
"\n",
|
||||
"jet_set1 = JetSet(jets=jet_list)\n",
|
||||
"\n",
|
||||
"#如上所说,每个喷注有很多粒子,你最后输入模型的是每个喷注中所有粒子的如下信息\n",
|
||||
"#log_energy log_pt part_logerel part_logptrel part_deltaR part_charge part_d0 part_dz \n",
|
||||
"#part_deta part_dphi particle_pid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c9995622",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "b94329b2",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"False"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jet_set1.jets[0].jet_bbar"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "4242ce6c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.0\n",
|
||||
"-1.0\n",
|
||||
"1.0\n",
|
||||
"0.0\n",
|
||||
"-1.0\n",
|
||||
"1.0\n",
|
||||
"-1.0\n",
|
||||
"0.0\n",
|
||||
"-1.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"1.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for num in range(len(jet_set1.jets[0].particles)):\n",
|
||||
" print(jet_set1.jets[0].particles[num].part_charge)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "6f0dc08e",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-1.0\n",
|
||||
"1.0\n",
|
||||
"-1.0\n",
|
||||
"-1.0\n",
|
||||
"0.0\n",
|
||||
"1.0\n",
|
||||
"-1.0\n",
|
||||
"1.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"-1.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for num in range(len(jet_set1.jets[1].particles)):\n",
|
||||
" print(jet_set1.jets[1].particles[num].part_charge)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "e6809f05",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"10000\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(len(jet_set1.jets))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "67cd5f19",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 与particle-transformer对比的实验设计\n",
|
||||
"\n",
|
||||
"使用这些属性进行与particle-transformer准确率对比,bb 100w bbbar 100w\n",
|
||||
"\n",
|
||||
"log_energy log_pt part_logerel part_logptrel part_deltaR part_charge part_d0 part_dz part_deta part_dphi particle_pid"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
250
convert_root_to_bin.py.bak
Normal file
250
convert_root_to_bin.py.bak
Normal file
@@ -0,0 +1,250 @@
|
||||
import enum
|
||||
import attrs
|
||||
from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import uproot
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import pickle
|
||||
import joblib
|
||||
import click
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import math
|
||||
|
||||
class ByteDataset(Dataset):
|
||||
def __init__(self, filenames, patch_size: int = 16, patch_length: int = 512):
|
||||
self.patch_size = patch_size
|
||||
self.patch_length = patch_length
|
||||
print(f"Loading {len(filenames)} files for classification")
|
||||
self.filenames = []
|
||||
self.labels = {'bb': 0, 'bbbar': 1} # 映射标签到整数
|
||||
|
||||
for filename in tqdm(filenames):
|
||||
file_size = os.path.getsize(filename)
|
||||
file_size = math.ceil(file_size / self.patch_size)
|
||||
ext = filename.split('.')[-1]
|
||||
label = os.path.basename(filename).split('_')[0] # 使用前缀部分作为标签
|
||||
label = f"{label}.{ext}"
|
||||
|
||||
if file_size <= self.patch_length - 2:
|
||||
self.filenames.append((filename, label))
|
||||
if label not in self.labels:
|
||||
self.labels[label] = len(self.labels)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
filename, label = self.filenames[idx]
|
||||
file_bytes = self.read_bytes(filename)
|
||||
|
||||
file_bytes = torch.tensor(file_bytes, dtype=torch.long)
|
||||
label = torch.tensor(self.labels[label], dtype=torch.long)
|
||||
|
||||
return file_bytes, label, filename
|
||||
|
||||
def readbytes(self, filename):
|
||||
ext = filename.split('.')[-1]
|
||||
ext = bytearray(ext, 'utf-8')
|
||||
ext = [byte for byte in ext][:self.patch_size]
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
bytes = []
|
||||
for byte in file_bytes:
|
||||
bytes.append(byte)
|
||||
|
||||
if len(bytes) % self.patch_size != 0:
|
||||
bytes = bytes + [256] * (self.patch_size - len(bytes) % self.patch_size)
|
||||
|
||||
bos_patch = ext + [256] * (self.patch_size - len(ext))
|
||||
bytes = bos_patch + bytes + [256] * self.patch_size
|
||||
|
||||
return bytes
|
||||
|
||||
class ParticleType(enum.Enum):
|
||||
NeutralHadron = 0
|
||||
Photon = 1
|
||||
Electron = 2
|
||||
Muon = 3
|
||||
Pion = 4
|
||||
ChargedKaon = 5
|
||||
Proton = 6
|
||||
Others = 7
|
||||
|
||||
@attrs.define
|
||||
class ParticleBase:
|
||||
particle_id: int = attrs.field() # unique id for each particle
|
||||
part_charge: int = attrs.field() # charge of the particle
|
||||
part_energy: float = attrs.field() # energy of the particle
|
||||
part_px: float = attrs.field() # x-component of the momentum vector
|
||||
part_py: float = attrs.field() # y-component of the momentum vector
|
||||
part_pz: float = attrs.field() # z-component of the momentum vector
|
||||
log_energy: float = attrs.field() # log10(part_energy)
|
||||
log_pt: float = attrs.field() # log10(part_pt)
|
||||
part_deta: float = attrs.field() # pseudorapidity
|
||||
part_dphi: float = attrs.field() # azimuthal angle
|
||||
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
||||
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
||||
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
||||
part_d0: float = attrs.field() # tanh(d0)
|
||||
part_dz: float = attrs.field() # tanh(z0)
|
||||
particle_type: ParticleType = attrs.field() # type of the particle as an enum
|
||||
particle_pid: int = attrs.field() # pid of the particle
|
||||
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
|
||||
|
||||
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
|
||||
if selected_attrs is None:
|
||||
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
||||
values = [getattr(self, attr) for attr in selected_attrs]
|
||||
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
||||
return attr_sep.join(map(str, values))
|
||||
|
||||
def attributes_to_float_list(self, selected_attrs: Optional[List[str]] = None) -> list:
|
||||
if selected_attrs is None:
|
||||
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
||||
values = [getattr(self, attr) for attr in selected_attrs]
|
||||
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
||||
return list(map(lambda x: float(x) if isinstance(x, int) else x, values))
|
||||
|
||||
@attrs.define
|
||||
class Jet:
|
||||
jet_energy: float = attrs.field() # energy of the jet
|
||||
jet_pt: float = attrs.field() # transverse momentum of the jet
|
||||
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
||||
label: str = attrs.field() # type of bb or bbbar
|
||||
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
||||
|
||||
def __len__(self):
|
||||
return len(self.particles)
|
||||
|
||||
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
|
||||
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
|
||||
|
||||
def particles_attribute_list(self, selected_attrs: Optional[List[str]] = None) -> list:
|
||||
return list(map(lambda p: p.attributes_to_float_list(selected_attrs), self.particles))
|
||||
|
||||
@attrs.define
|
||||
class JetSet:
|
||||
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
|
||||
jets: List[Jet] = attrs.field(factory=list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.jets)
|
||||
|
||||
@staticmethod
|
||||
def jud_type(jtmp):
|
||||
particle_dict = {
|
||||
0: ParticleType.NeutralHadron,
|
||||
1: ParticleType.Photon,
|
||||
2: ParticleType.Electron,
|
||||
3: ParticleType.Muon,
|
||||
4: ParticleType.Pion,
|
||||
5: ParticleType.ChargedKaon,
|
||||
6: ParticleType.Proton,
|
||||
7: ParticleType.Others
|
||||
}
|
||||
max_element = max(jtmp)
|
||||
idx = jtmp.index(max_element)
|
||||
return particle_dict.get(idx, ParticleType.Others)
|
||||
|
||||
@classmethod
|
||||
def build_jetset(cls, root_file: Union[str, Path]) -> "JetSet":
|
||||
print(f"Building JetSet from {root_file}")
|
||||
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
||||
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
||||
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
||||
with uproot.open(root_file) as f:
|
||||
tree = f["tree"]
|
||||
a = tree.arrays(library="pd")
|
||||
|
||||
label = Path(root_file).stem.split('_')[0]
|
||||
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
||||
jet_list = []
|
||||
for i, j in a.iterrows():
|
||||
part_pt = np.array(j['part_pt'])
|
||||
jet_pt = np.array(j['jet_pt'])
|
||||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||||
|
||||
part_energy = np.array(j['part_energy'])
|
||||
jet_energy = np.array(j['jet_energy'])
|
||||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||||
|
||||
part_deta = np.array(j['part_deta'])
|
||||
part_dphi = np.array(j['part_dphi'])
|
||||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||||
|
||||
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
||||
|
||||
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
||||
|
||||
particles = list(map(lambda pn: ParticleBase(
|
||||
particle_id=i,
|
||||
part_charge=j['part_charge'],
|
||||
part_energy=j['part_energy'],
|
||||
part_px=j['part_px'],
|
||||
part_py=j['part_py'],
|
||||
part_pz=j['part_pz'],
|
||||
log_energy=np.log(j['part_energy']),
|
||||
log_pt=np.log(j['part_pt']),
|
||||
part_deta=j['part_deta'],
|
||||
part_dphi=j['part_dphi'],
|
||||
part_logptrel=part_logptrel[pn],
|
||||
part_logerel=part_logerel[pn],
|
||||
part_deltaR=part_deltaR[pn],
|
||||
part_d0=np.tanh(j['part_d0val']),
|
||||
part_dz=np.tanh(j['part_dzval']),
|
||||
particle_type=cls.jud_type([j[t][pn] for t in particle_list]),
|
||||
particle_pid=cls.jud_type([j[t][pn] for t in particle_list]).value,
|
||||
jet_type=jet_type
|
||||
), range(len(j['part_pt']))))
|
||||
|
||||
jet = Jet(
|
||||
jet_energy=j['jet_energy'],
|
||||
jet_pt=j['jet_pt'],
|
||||
jet_eta=j['jet_eta'],
|
||||
particles=particles,
|
||||
label=label
|
||||
)
|
||||
jet_list.append(jet)
|
||||
|
||||
return cls(jets=jet_list, jets_type=jets_type)
|
||||
|
||||
def save_to_binary(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
||||
file_counter = 0
|
||||
for jet in tqdm(self.jets):
|
||||
filename = f"{self.jets_type}_{file_counter}.bin"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
|
||||
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
|
||||
with open(filepath, 'wb') as file:
|
||||
pickle.dump(jet_data, file)
|
||||
|
||||
file_counter += 1
|
||||
|
||||
@click.command()
|
||||
@click.argument('root_dir', type=click.Path(exists=True))
|
||||
@click.argument('save_dir', type=click.Path())
|
||||
@click.option('--attr-sep', default=',', help='Separator for attributes.')
|
||||
@click.option('--part-sep', default='|', help='Separator for particles.')
|
||||
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
|
||||
def main(root_dir, save_dir, attr_sep, part_sep, selected_attrs):
|
||||
selected_attrs = selected_attrs.split(',')
|
||||
preprocess(root_dir, save_dir, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
|
||||
|
||||
def preprocess(root_dir, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
||||
root_files = list(Path(root_dir).glob('*.root'))
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
joblib.Parallel(n_jobs=-1)(
|
||||
joblib.delayed(process_file)(root_file, save_dir, selected_attrs, attr_sep, part_sep) for root_file in root_files
|
||||
)
|
||||
|
||||
def process_file(root_file, save_dir, selected_attrs, attr_sep, part_sep):
|
||||
jet_set = JetSet.build_jetset(root_file)
|
||||
jet_set.save_to_binary(save_dir, selected_attrs, attr_sep, part_sep)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
279
convert_root_to_txt.py
Normal file
279
convert_root_to_txt.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import enum
|
||||
import attrs
|
||||
from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import uproot
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import joblib
|
||||
import click
|
||||
|
||||
class ParticleType(enum.Enum):
|
||||
NeutralHadron = 0
|
||||
Photon = 1
|
||||
Electron = 2
|
||||
Muon = 3
|
||||
Pion = 4
|
||||
ChargedKaon = 5
|
||||
Proton = 6
|
||||
Others = 7
|
||||
|
||||
@attrs.define
|
||||
class ParticleBase:
|
||||
particle_id: int = attrs.field() # unique id for each particle
|
||||
part_charge: int = attrs.field() # charge of the particle
|
||||
part_energy: float = attrs.field() # energy of the particle
|
||||
part_px: float = attrs.field() # x-component of the momentum vector
|
||||
part_py: float = attrs.field() # y-component of the momentum vector
|
||||
part_pz: float = attrs.field() # z-component of the momentum vector
|
||||
log_energy: float = attrs.field() # log10(part_energy)
|
||||
log_pt: float = attrs.field() # log10(part_pt)
|
||||
part_deta: float = attrs.field() # pseudorapidity
|
||||
part_dphi: float = attrs.field() # azimuthal angle
|
||||
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
||||
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
||||
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
||||
part_d0: float = attrs.field() # tanh(d0)
|
||||
part_dz: float = attrs.field() # tanh(z0)
|
||||
particle_type: ParticleType = attrs.field() # type of the particle as an enum
|
||||
particle_pid: int = attrs.field() # pid of the particle
|
||||
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
|
||||
|
||||
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
|
||||
if selected_attrs is None:
|
||||
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
||||
values = [getattr(self, attr) for attr in selected_attrs]
|
||||
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
||||
return attr_sep.join(map(str, values))
|
||||
|
||||
@attrs.define
|
||||
class Jet:
|
||||
jet_energy: float = attrs.field() # energy of the jet
|
||||
jet_pt: float = attrs.field() # transverse momentum of the jet
|
||||
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
||||
label: str = attrs.field() # type of bb or bbbar
|
||||
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
||||
|
||||
def __len__(self):
|
||||
return len(self.particles)
|
||||
|
||||
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
|
||||
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
|
||||
|
||||
@attrs.define
|
||||
class JetSet:
|
||||
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
|
||||
jets: List[Jet] = attrs.field(factory=list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.jets)
|
||||
|
||||
@staticmethod
|
||||
def jud_type(jtmp):
|
||||
particle_dict = {
|
||||
0: ParticleType.NeutralHadron,
|
||||
1: ParticleType.Photon,
|
||||
2: ParticleType.Electron,
|
||||
3: ParticleType.Muon,
|
||||
4: ParticleType.Pion,
|
||||
5: ParticleType.ChargedKaon,
|
||||
6: ParticleType.Proton,
|
||||
7: ParticleType.Others
|
||||
}
|
||||
max_element = max(jtmp)
|
||||
idx = jtmp.index(max_element)
|
||||
return particle_dict.get(idx, ParticleType.Others), idx
|
||||
|
||||
@classmethod
|
||||
def build_jetset_fast(cls, root_file: Union[str, Path]) -> "JetSet":
|
||||
print(f"Building JetSet from {root_file}")
|
||||
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
||||
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
||||
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
||||
print(f"jet type: {jets_type}")
|
||||
with uproot.open(root_file) as f:
|
||||
tree = f["tree"]
|
||||
a = tree.arrays(library="pd")
|
||||
# print(f"DataFrame structure:\n{a.head()}")
|
||||
if a.empty:
|
||||
raise ValueError("DataFrame is empty")
|
||||
|
||||
label = Path(root_file).stem.split('_')[0]
|
||||
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
||||
jet_list = []
|
||||
for j in a.itertuples():
|
||||
part_pt = np.array(j.part_pt)
|
||||
jet_pt = np.array(j.jet_pt)
|
||||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||||
|
||||
part_energy = np.array(j.part_energy)
|
||||
jet_energy = np.array(j.jet_energy)
|
||||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||||
|
||||
part_deta = np.array(j.part_deta)
|
||||
part_dphi = np.array(j.part_dphi)
|
||||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||||
|
||||
assert len(j.part_pt) == len(j.part_energy) == len(j.part_deta)
|
||||
|
||||
particles = []
|
||||
# particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
||||
part_type = []
|
||||
part_pid = []
|
||||
for pn in range(len(j.part_pt)):
|
||||
jtmp = [j.part_isNeutralHadron[pn], j.part_isPhoton[pn], j.part_isElectron[pn], j.part_isMuon[pn], j.part_isPion[pn],
|
||||
j.part_isChargedKaon[pn], j.part_isProton[pn]]
|
||||
tmp_type, tmp_pid = cls.jud_type(jtmp)
|
||||
part_type.append(tmp_type)
|
||||
part_pid.append(tmp_pid)
|
||||
|
||||
bag = zip(j.part_charge, j.part_energy, j.part_px, j.part_py, j.part_pz, np.log(j.part_energy),
|
||||
np.log(j.part_pt), j.part_deta, j.part_dphi, part_logptrel, part_logerel, part_deltaR,
|
||||
np.tanh(j.part_d0val), np.tanh(j.part_dzval), part_type, part_pid)
|
||||
#下边的代码是要对第 j 个喷注中的所有粒子做循环,将每个粒子都 存成 ParticleBase,然后 append 到 particles里,
|
||||
#所以 partices 存储了 第 j 个喷注中所有粒子的信息
|
||||
for num, (c, en, px, py, pz, lEn, lPt, eta, phi, ii, jj, kk, d0, dz, ptype, pid) in enumerate(bag):
|
||||
particles.append(ParticleBase(
|
||||
particle_id=num,
|
||||
part_charge=c,
|
||||
part_energy=en,
|
||||
part_px=px,
|
||||
part_py=py,
|
||||
part_pz=pz,
|
||||
log_energy=lEn,
|
||||
log_pt=lPt,
|
||||
part_deta=eta,
|
||||
part_dphi=phi,
|
||||
part_logptrel=ii,
|
||||
part_logerel=jj,
|
||||
part_deltaR=kk,
|
||||
part_d0=d0,
|
||||
part_dz=dz,
|
||||
particle_type=ptype,
|
||||
particle_pid=pid,
|
||||
jet_type=jet_type
|
||||
))
|
||||
# add jets jet = 喷注,
|
||||
jet = Jet(
|
||||
jet_energy=j.jet_energy,
|
||||
jet_pt=j.jet_pt,
|
||||
jet_eta=j.jet_eta,
|
||||
particles=particles,
|
||||
label=label
|
||||
)
|
||||
jet_list.append(jet)
|
||||
|
||||
return cls(jets=jet_list, jets_type=jets_type)
|
||||
|
||||
@classmethod
|
||||
def build_jetset_full(cls, root_file: Union[str, Path]) -> "JetSet":
|
||||
print(f"Building JetSet from {root_file}")
|
||||
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
||||
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
||||
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
||||
print(f"jet type: {jets_type}")
|
||||
with uproot.open(root_file) as f:
|
||||
tree = f["tree"]
|
||||
a = tree.arrays(library="pd")
|
||||
|
||||
label = Path(root_file).stem.split('_')[0]
|
||||
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
||||
jet_list = []
|
||||
for i, j in a.iterrows():
|
||||
part_pt = np.array(j['part_pt'])
|
||||
jet_pt = np.array(j['jet_pt'])
|
||||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||||
|
||||
part_energy = np.array(j['part_energy'])
|
||||
jet_energy = np.array(j['jet_energy'])
|
||||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||||
|
||||
part_deta = np.array(j['part_deta'])
|
||||
part_dphi = np.array(j['part_dphi'])
|
||||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||||
|
||||
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
||||
particles = []
|
||||
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
||||
part_type = []
|
||||
part_pid = []
|
||||
for pn in range(len(j['part_pt'])):
|
||||
jtmp = []
|
||||
for t in particle_list:
|
||||
jtmp.append(j[t][pn])
|
||||
tmp_type, tmp_pid = cls.jud_type(jtmp)
|
||||
part_type.append(tmp_type)
|
||||
part_pid.append(tmp_pid)
|
||||
|
||||
for pn in range(len(part_type)):
|
||||
particles.append(ParticleBase(
|
||||
particle_id=i,
|
||||
part_charge=j['part_charge'][pn],
|
||||
part_energy=j['part_energy'][pn],
|
||||
part_px=j['part_px'][pn],
|
||||
part_py=j['part_py'][pn],
|
||||
part_pz=j['part_pz'][pn],
|
||||
log_energy=np.log(j['part_energy'][pn]),
|
||||
log_pt=np.log(j['part_pt'][pn]),
|
||||
part_deta=j['part_deta'][pn],
|
||||
part_dphi=j['part_dphi'][pn],
|
||||
part_logptrel=part_logptrel[pn],
|
||||
part_logerel=part_logerel[pn],
|
||||
part_deltaR=part_deltaR[pn],
|
||||
part_d0=np.tanh(j['part_d0val'][pn]),
|
||||
part_dz=np.tanh(j['part_dzval'][pn]),
|
||||
particle_type=part_type[pn],
|
||||
particle_pid=part_pid[pn],
|
||||
jet_type=jet_type
|
||||
))
|
||||
|
||||
jet = Jet(
|
||||
jet_energy=j['jet_energy'],
|
||||
jet_pt=j['jet_pt'],
|
||||
jet_eta=j['jet_eta'],
|
||||
particles=particles,
|
||||
label=label
|
||||
)
|
||||
jet_list.append(jet)
|
||||
|
||||
return cls(jets=jet_list, jets_type=jets_type)
|
||||
|
||||
def save_to_txt(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|', prefix: str = ""):
|
||||
for file_counter, jet in enumerate(tqdm(self.jets)):
|
||||
filename = f"{self.jets_type}_{file_counter}_{prefix}.txt"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
|
||||
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
|
||||
with open(filepath, 'w') as file:
|
||||
file.write(jet_data)
|
||||
|
||||
@click.command()
|
||||
@click.argument('root_dir', type=click.Path(exists=True))
|
||||
@click.argument('save_dir', type=click.Path())
|
||||
@click.option('--data-type', default='fast', type=click.Choice(['fast', 'full']), help='Type of ROOT data: fast or full.')
|
||||
@click.option('--attr-sep', default=',', help='Separator for attributes.')
|
||||
@click.option('--part-sep', default='|', help='Separator for particles.')
|
||||
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
|
||||
def main(root_dir, save_dir, data_type, attr_sep, part_sep, selected_attrs):
|
||||
selected_attrs = selected_attrs.split(',')
|
||||
preprocess(root_dir, save_dir, data_type, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
|
||||
|
||||
def preprocess(root_dir, save_dir, data_type, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
||||
root_files = list(Path(root_dir).glob('*.root'))
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# process_file(root_files[0], save_dir, data_type, selected_attrs, attr_sep, part_sep, 0) # one process for test
|
||||
joblib.Parallel(n_jobs=-1)(
|
||||
joblib.delayed(process_file)(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, i) for i, root_file in enumerate(root_files)
|
||||
)
|
||||
|
||||
def process_file(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, process_id):
|
||||
prefix = f"process_{process_id}"
|
||||
if data_type == 'fast':
|
||||
jet_set = JetSet.build_jetset_fast(root_file)
|
||||
else:
|
||||
jet_set = JetSet.build_jetset_full(root_file)
|
||||
jet_set.save_to_txt(save_dir, selected_attrs, attr_sep, part_sep, prefix)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
193
convert_root_to_txt.py.bak
Normal file
193
convert_root_to_txt.py.bak
Normal file
@@ -0,0 +1,193 @@
|
||||
import enum
|
||||
import attrs
|
||||
from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import uproot
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import joblib
|
||||
import click
|
||||
|
||||
class ParticleType(enum.Enum):
|
||||
NeutralHadron = 0
|
||||
Photon = 1
|
||||
Electron = 2
|
||||
Muon = 3
|
||||
Pion = 4
|
||||
ChargedKaon = 5
|
||||
Proton = 6
|
||||
Others = 7
|
||||
|
||||
@attrs.define
|
||||
class ParticleBase:
|
||||
particle_id: int = attrs.field() # unique id for each particle
|
||||
part_charge: int = attrs.field() # charge of the particle
|
||||
part_energy: float = attrs.field() # energy of the particle
|
||||
part_px: float = attrs.field() # x-component of the momentum vector
|
||||
part_py: float = attrs.field() # y-component of the momentum vector
|
||||
part_pz: float = attrs.field() # z-component of the momentum vector
|
||||
log_energy: float = attrs.field() # log10(part_energy)
|
||||
log_pt: float = attrs.field() # log10(part_pt)
|
||||
part_deta: float = attrs.field() # pseudorapidity
|
||||
part_dphi: float = attrs.field() # azimuthal angle
|
||||
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
||||
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
||||
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
||||
part_d0: float = attrs.field() # tanh(d0)
|
||||
part_dz: float = attrs.field() # tanh(z0)
|
||||
particle_type: ParticleType = attrs.field() # type of the particle as an enum
|
||||
particle_pid: int = attrs.field() # pid of the particle
|
||||
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
|
||||
|
||||
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
|
||||
if selected_attrs is None:
|
||||
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
|
||||
values = [getattr(self, attr) for attr in selected_attrs]
|
||||
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
|
||||
return attr_sep.join(map(str, values))
|
||||
|
||||
@attrs.define
|
||||
class Jet:
|
||||
jet_energy: float = attrs.field() # energy of the jet
|
||||
jet_pt: float = attrs.field() # transverse momentum of the jet
|
||||
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
||||
label: str = attrs.field() # type of bb or bbbar
|
||||
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
||||
|
||||
def __len__(self):
|
||||
return len(self.particles)
|
||||
|
||||
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
|
||||
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
|
||||
|
||||
@attrs.define
|
||||
class JetSet:
|
||||
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
|
||||
jets: List[Jet] = attrs.field(factory=list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.jets)
|
||||
|
||||
@staticmethod
|
||||
def jud_type(jtmp):
|
||||
particle_dict = {
|
||||
0: ParticleType.NeutralHadron,
|
||||
1: ParticleType.Photon,
|
||||
2: ParticleType.Electron,
|
||||
3: ParticleType.Muon,
|
||||
4: ParticleType.Pion,
|
||||
5: ParticleType.ChargedKaon,
|
||||
6: ParticleType.Proton,
|
||||
7: ParticleType.Others
|
||||
}
|
||||
max_element = max(jtmp)
|
||||
idx = jtmp.index(max_element)
|
||||
return particle_dict.get(idx, ParticleType.Others), idx
|
||||
|
||||
@classmethod
|
||||
def build_jetset(cls, root_file: Union[str, Path]) -> "JetSet":
|
||||
print(f"Building JetSet from {root_file}")
|
||||
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
|
||||
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
|
||||
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
|
||||
print(f"jet type: {jets_type}")
|
||||
with uproot.open(root_file) as f:
|
||||
tree = f["tree"]
|
||||
a = tree.arrays(library="pd")
|
||||
|
||||
label = Path(root_file).stem.split('_')[0]
|
||||
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
|
||||
jet_list = []
|
||||
for i, j in a.iterrows():
|
||||
part_pt = np.array(j['part_pt'])
|
||||
jet_pt = np.array(j['jet_pt'])
|
||||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||||
|
||||
part_energy = np.array(j['part_energy'])
|
||||
jet_energy = np.array(j['jet_energy'])
|
||||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||||
|
||||
part_deta = np.array(j['part_deta'])
|
||||
part_dphi = np.array(j['part_dphi'])
|
||||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||||
|
||||
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
||||
particles = []
|
||||
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
|
||||
part_type = []
|
||||
part_pid = []
|
||||
for pn in range(len(j['part_pt'])):
|
||||
jtmp = []
|
||||
for t in particle_list:
|
||||
jtmp.append(j[t][pn])
|
||||
tmp_type, tmp_pid = cls.jud_type(jtmp)
|
||||
part_type.append(tmp_type)
|
||||
part_pid.append(tmp_pid)
|
||||
|
||||
for pn in range(len(part_type)):
|
||||
particles.append(ParticleBase(
|
||||
particle_id=i,
|
||||
part_charge=j['part_charge'][pn],
|
||||
part_energy=j['part_energy'][pn],
|
||||
part_px=j['part_px'][pn],
|
||||
part_py=j['part_py'][pn],
|
||||
part_pz=j['part_pz'][pn],
|
||||
log_energy=np.log(j['part_energy'][pn]),
|
||||
log_pt=np.log(j['part_pt'][pn]),
|
||||
part_deta=j['part_deta'][pn],
|
||||
part_dphi=j['part_dphi'][pn],
|
||||
part_logptrel=part_logptrel[pn],
|
||||
part_logerel=part_logerel[pn],
|
||||
part_deltaR=part_deltaR[pn],
|
||||
part_d0=np.tanh(j['part_d0val'][pn]),
|
||||
part_dz=np.tanh(j['part_dzval'][pn]),
|
||||
particle_type=part_type[pn],
|
||||
particle_pid=part_pid[pn],
|
||||
jet_type=jet_type
|
||||
))
|
||||
|
||||
jet = Jet(
|
||||
jet_energy=j['jet_energy'],
|
||||
jet_pt=j['jet_pt'],
|
||||
jet_eta=j['jet_eta'],
|
||||
particles=particles,
|
||||
label=label
|
||||
)
|
||||
jet_list.append(jet)
|
||||
|
||||
return cls(jets=jet_list, jets_type=jets_type)
|
||||
|
||||
def save_to_txt(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|', prefix: str = ""):
|
||||
for file_counter, jet in enumerate(tqdm(self.jets)):
|
||||
filename = f"{self.jets_type}_{file_counter}_{prefix}.txt"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
|
||||
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
|
||||
with open(filepath, 'w') as file:
|
||||
file.write(jet_data)
|
||||
|
||||
@click.command()
|
||||
@click.argument('root_dir', type=click.Path(exists=True))
|
||||
@click.argument('save_dir', type=click.Path())
|
||||
@click.option('--attr-sep', default=',', help='Separator for attributes.')
|
||||
@click.option('--part-sep', default='|', help='Separator for particles.')
|
||||
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
|
||||
def main(root_dir, save_dir, attr_sep, part_sep, selected_attrs):
|
||||
selected_attrs = selected_attrs.split(',')
|
||||
preprocess(root_dir, save_dir, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
|
||||
|
||||
def preprocess(root_dir, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
|
||||
root_files = list(Path(root_dir).glob('*.root'))
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
joblib.Parallel(n_jobs=-1)(
|
||||
joblib.delayed(process_file)(root_file, save_dir, selected_attrs, attr_sep, part_sep, i) for i, root_file in enumerate(root_files)
|
||||
)
|
||||
|
||||
def process_file(root_file, save_dir, selected_attrs, attr_sep, part_sep, process_id):
|
||||
prefix = f"process_{process_id}"
|
||||
jet_set = JetSet.build_jetset(root_file)
|
||||
jet_set.save_to_txt(save_dir, selected_attrs, attr_sep, part_sep, prefix)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
260
data_struct.py
Normal file
260
data_struct.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@file :data_struct.py
|
||||
@Description: :
|
||||
@Date :2024/06/04 09:08:46
|
||||
@Author :hotwa
|
||||
@version :1.0
|
||||
'''
|
||||
|
||||
import attrs
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import uproot
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import random
|
||||
import pickle
|
||||
@attrs.define
|
||||
class ParticleBase:
|
||||
particle_id: int = attrs.field() # unique id for each particle
|
||||
part_charge: int = attrs.field() # charge of the particle
|
||||
part_energy: float = attrs.field() # energy of the particle
|
||||
part_px: float = attrs.field() # x-component of the momentum vector
|
||||
part_py: float = attrs.field() # y-component of the momentum vector
|
||||
part_pz: float = attrs.field() # z-component of the momentum vector
|
||||
log_energy: float = attrs.field() # log10(part_energy)
|
||||
log_pt: float = attrs.field() # log10(part_pt)
|
||||
part_deta: float = attrs.field() # pseudorapidity
|
||||
part_dphi: float = attrs.field() # azimuthal angle
|
||||
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
|
||||
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
|
||||
part_deltaR: float = attrs.field() # distance between the particle and the jet
|
||||
part_d0: float = attrs.field() # tanh(d0)
|
||||
part_dz: float = attrs.field() # tanh(z0)
|
||||
#particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others)
|
||||
particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7)
|
||||
|
||||
def properties_concatenated(self) -> str:
|
||||
# 使用map函数将属性值转换为字符串,并保留整数和浮点数的格式
|
||||
values_str = map(lambda x: f"{x}", attrs.astuple(self))
|
||||
# 连接所有属性值为一个字符串
|
||||
concatenated_str = ' '.join(values_str)
|
||||
return concatenated_str
|
||||
|
||||
def attributes_to_float_list(self) -> list:
|
||||
attribute_values = []
|
||||
for field in attrs.fields(self.__class__): # 获取类的所有字段
|
||||
value = getattr(self, field.name)
|
||||
# 检查属性值是否为整数,如果是则转换为浮点数
|
||||
attribute_values.append(float(value) if isinstance(value, int) else value)
|
||||
return attribute_values
|
||||
|
||||
@attrs.define
|
||||
class Jet:
|
||||
jet_energy: float = attrs.field() # energy of the jet
|
||||
jet_pt: float = attrs.field() # transverse momentum of the jet
|
||||
jet_eta: float = attrs.field() # pseudorapidity of the jet
|
||||
label: str = attrs.field() # tpye of bb or bbbar
|
||||
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
|
||||
|
||||
def __len__(self):
|
||||
return len(self.particles)
|
||||
|
||||
def particles_concatenated(self) -> str:
|
||||
# 连接所有粒子属性值为一个字符串
|
||||
concatenated_str = ','.join([particle.properties_concatenated() for particle in self.particles])
|
||||
return concatenated_str
|
||||
def particles_attribute_list(self) ->list:
|
||||
attribute_list = [particle.attributes_to_float_list() for particle in self.particles]
|
||||
return attribute_list
|
||||
|
||||
@attrs.define
|
||||
class JetSet:
|
||||
jets: List[Jet]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.jets)
|
||||
|
||||
def jud_type(jtmp):
|
||||
particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6}
|
||||
max_element = max(jtmp)
|
||||
idx = jtmp.index(max_element)
|
||||
items = list(particle_dict.items())
|
||||
return items[idx][0], items[idx][1]
|
||||
|
||||
def build_jetset(root_file):
|
||||
with uproot.open(root_file) as f:
|
||||
tree = f["tree"]
|
||||
a = tree.arrays(library="pd") # pd.DataFrame
|
||||
# print(a.keys())
|
||||
label = Path(root_file).stem.split('_')[0]
|
||||
jet_list = []
|
||||
for i, j in a.iterrows():
|
||||
part_pt = np.array(j['part_pt'])
|
||||
jet_pt = np.array(j['jet_pt'])
|
||||
part_logptrel = np.log(np.divide(part_pt, jet_pt))
|
||||
|
||||
part_energy = np.array(j['part_energy'])
|
||||
jet_energy = np.array(j['jet_energy'])
|
||||
part_logerel = np.log(np.divide(part_energy, jet_energy))
|
||||
|
||||
part_deta = np.array(j['part_deta'])
|
||||
part_dphi = np.array(j['part_dphi'])
|
||||
part_deltaR = np.hypot(part_deta, part_dphi)
|
||||
|
||||
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
|
||||
#particle_num = len(len(j['part_pt']))
|
||||
# add particles
|
||||
particles = []
|
||||
|
||||
|
||||
particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton']
|
||||
part_type = []
|
||||
part_pid = []
|
||||
for pn in range(len(j['part_pt'])):
|
||||
jtmp = []
|
||||
for t in particle_list:
|
||||
jtmp.append(j[t][pn])
|
||||
tmp_type, tmp_pid = jud_type(jtmp)
|
||||
part_type.append(tmp_type)
|
||||
part_pid.append(tmp_pid)
|
||||
|
||||
|
||||
for ii, jj, kk, ptype, pid in zip(part_logptrel, part_logerel, part_deltaR, part_type, part_pid):
|
||||
particles.append(ParticleBase(
|
||||
particle_id=i,
|
||||
part_charge=j['part_charge'],
|
||||
part_energy=j['part_energy'],
|
||||
part_px=j['part_px'],
|
||||
part_py=j['part_py'],
|
||||
part_pz=j['part_pz'],
|
||||
log_energy=np.log(j['part_energy']),
|
||||
log_pt=np.log(j['part_pt']),
|
||||
part_deta=j['part_deta'],
|
||||
part_dphi=j['part_dphi'],
|
||||
part_logptrel=ii,
|
||||
part_logerel=jj,
|
||||
part_deltaR=kk,
|
||||
part_d0=np.tanh(j['part_d0val']),
|
||||
part_dz=np.tanh(j['part_dzval']),
|
||||
#particle_type=ptype, # assuming you will set this correctly
|
||||
particle_pid=pid # assuming you will set this correctly
|
||||
))
|
||||
# add jets
|
||||
jet = Jet(
|
||||
jet_energy=j['jet_energy'],
|
||||
jet_pt=['jet_pt'],
|
||||
jet_eta=['jet_eta'],
|
||||
particles=particles,
|
||||
label= label
|
||||
)
|
||||
jet_list.append(jet)
|
||||
|
||||
jet_set = JetSet(jets=jet_list)
|
||||
return jet_set
|
||||
|
||||
def preprocess(root_dir,method='float_list'):
|
||||
train_jetset_list = []
|
||||
val_jetset_list = []
|
||||
train_ratio=0.8
|
||||
# 遍历directory下的所有子目录
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
for dir_name in dirs: # 直接遍历dirs列表,避免二次os.walk
|
||||
if dir_name in ['bb', 'bbbar']:
|
||||
bb_bbbar_files = []
|
||||
# 遍历当前dir_name下的所有文件
|
||||
for _, _, files in os.walk(os.path.join(root, dir_name)):
|
||||
for file in files:
|
||||
if file.endswith('.root'):
|
||||
# 构建文件的完整路径并添加到列表中
|
||||
full_path = os.path.join(root, dir_name, file)
|
||||
bb_bbbar_files.append(full_path)
|
||||
|
||||
if bb_bbbar_files: # 确保列表不为空才进行后续操作
|
||||
random.shuffle(bb_bbbar_files)
|
||||
split_index = int(len(bb_bbbar_files) * train_ratio)
|
||||
train_files = bb_bbbar_files[:split_index]
|
||||
val_files = bb_bbbar_files[split_index:]
|
||||
print('len(train_files):',len(train_files))
|
||||
print('len(val_files):',len(val_files))
|
||||
# 限制训练集和验证集的文件数量
|
||||
train_file_limit = 32
|
||||
val_file_limit = 8
|
||||
if method =="float_list":
|
||||
train_file_count = 0
|
||||
file_counter = 0
|
||||
for file in train_files:
|
||||
if train_file_count < train_file_limit:
|
||||
train_file_count += 1
|
||||
print(f"loading{file} to train pkl")
|
||||
train_jetset = build_jetset(file)
|
||||
for jet in tqdm(train_jetset.jets):
|
||||
jet_list= jet.particles_attribute_list()
|
||||
label = jet.label
|
||||
file_counter += 1
|
||||
filename = f"{label}_{file_counter}.pkl"
|
||||
filepath = os.path.join('/data/slow/100w_pkl/train', filename)
|
||||
with open(filepath, 'wb') as file:
|
||||
pickle.dump(jet_list,file)
|
||||
else:
|
||||
break
|
||||
val_file_count = 0
|
||||
for file in val_files:
|
||||
if val_file_count < val_file_limit:
|
||||
val_file_count +=1
|
||||
print(f"loading{file} to val pkl")
|
||||
train_jetset = build_jetset(file)
|
||||
for jet in tqdm(train_jetset.jets):
|
||||
jet_list= jet.particles_attribute_list()
|
||||
label = jet.label
|
||||
file_counter += 1
|
||||
filename = f"{label}_{file_counter}.pkl"
|
||||
filepath = os.path.join('/data/slow/100w_pkl/val', filename)
|
||||
with open(filepath, 'wb') as file:
|
||||
pickle.dump(jet_list,file)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
train_file_count = 0
|
||||
file_counter = 0
|
||||
for file in train_files:
|
||||
if train_file_count < train_file_limit:
|
||||
train_file_count += 1
|
||||
print(f"loading{file}to train txt")
|
||||
train_jetset = build_jetset(file)
|
||||
for jet in tqdm(train_jetset.jets):
|
||||
jet_str= jet.particles_concatenated()
|
||||
label = jet.label
|
||||
file_counter += 1
|
||||
|
||||
filename = f"{label}_{file_counter}.txt"
|
||||
filepath = os.path.join('/data/slow/100w/train', filename)
|
||||
with open(filepath, 'w') as file:
|
||||
file.write(jet_str)
|
||||
else:
|
||||
break
|
||||
val_file_count = 0
|
||||
for file in val_files:
|
||||
if val_file_count < val_file_limit:
|
||||
val_file_count +=1
|
||||
print(f"loading{file}to val txt")
|
||||
train_jetset = build_jetset(file)
|
||||
for jet in tqdm(train_jetset.jets):
|
||||
jet_str= jet.particles_concatenated()
|
||||
label = jet.label
|
||||
file_counter += 1
|
||||
|
||||
filename = f"{label}_{file_counter}.txt"
|
||||
filepath = os.path.join('/data/slow/100w/val', filename)
|
||||
with open(filepath, 'w') as file:
|
||||
file.write(jet_str)
|
||||
else:
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = "/data/particle_raw/slow/data_100w/n2n2higgs"
|
||||
preprocess(root_dir,method='str')
|
||||
|
||||
203
dataloader.py
Normal file
203
dataloader.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@file :test.py
|
||||
@Description: : test unit for ParticleData class
|
||||
@Date :2024/05/14 16:19:01
|
||||
@Author :lyzeng
|
||||
@Email :pylyzeng@gmail.com
|
||||
@version :1.0
|
||||
'''
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
import attrs
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@attrs.define
|
||||
class Particle:
|
||||
instruction: str
|
||||
input: str
|
||||
output: int
|
||||
|
||||
@attrs.define
|
||||
class ParticleData:
|
||||
file: Union[Path, str] = attrs.field(
|
||||
converter=attrs.converters.optional(lambda v: Path(v))
|
||||
)
|
||||
data: List[Particle] = attrs.field(
|
||||
converter=lambda data: data
|
||||
)
|
||||
max_length: int = attrs.field(init=False)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.max_length = max([len(bytearray(d.input, 'utf-8')) for d in self.data])
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path: Union[Path, str]):
|
||||
# 从文件中读取序列化的数据
|
||||
with open(file_path, 'rb') as f:
|
||||
loaded_data = pickle.load(f)
|
||||
return cls(file=file_path, data=loaded_data)
|
||||
|
||||
|
||||
@attrs.define
|
||||
class EpochLog:
|
||||
epoch: int
|
||||
train_acc: float
|
||||
eval_acc: float
|
||||
|
||||
@classmethod
|
||||
def from_log_lines(cls, log_lines):
|
||||
epochs = []
|
||||
epoch_data = {}
|
||||
for line in log_lines:
|
||||
if line.startswith('Epoch'):
|
||||
epoch = int(line.split()[-1])
|
||||
epoch_data = {'epoch': epoch}
|
||||
epochs.append(epoch_data)
|
||||
elif line.startswith('train_acc:'):
|
||||
epoch_data['train_acc'] = float(line.split(':')[-1])
|
||||
elif line.startswith('eval_acc:'):
|
||||
epoch_data['eval_acc'] = float(line.split(':')[-1])
|
||||
return [cls(**epoch) for epoch in epochs if 'train_acc' in epoch and 'eval_acc' in epoch]
|
||||
|
||||
@attrs.define
|
||||
class AccuracyLogger:
|
||||
epochs: List[EpochLog]
|
||||
|
||||
@classmethod
|
||||
def from_log_file(cls, log_file: Path):
|
||||
log_lines = log_file.read_text().splitlines()
|
||||
epoch_logs = EpochLog.from_log_lines(log_lines)
|
||||
return cls(epoch_logs)
|
||||
|
||||
def plot_acc_curve(self, save_path: Path = None):
|
||||
epochs = [e.epoch for e in self.epochs]
|
||||
train_accs = [e.train_acc for e in self.epochs]
|
||||
eval_accs = [e.eval_acc for e in self.epochs]
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
sns.lineplot(x=epochs, y=train_accs, label='Train Accuracy')
|
||||
sns.lineplot(x=epochs, y=eval_accs, label='Eval Accuracy')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Accuracy')
|
||||
plt.title('Training and Evaluation Accuracy over Epochs')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_acc_curve(log_file: Path, save_path: Path = Path('accuracy_curve.png')):
|
||||
accuracy_logger = AccuracyLogger.from_log_file(log_file)
|
||||
accuracy_logger.plot_acc_curve(save_path=save_path)
|
||||
|
||||
def split_train_eval_data(path: List[Path],
|
||||
ext: str = 'txt', train_dir: str = 'train_data', eval_dir: str = 'eval_dir'):
|
||||
# 创建目录
|
||||
if not Path(train_dir).exists():
|
||||
Path(train_dir).mkdir()
|
||||
|
||||
if not Path(eval_dir).exists():
|
||||
Path(eval_dir).mkdir()
|
||||
# 获取所有文件
|
||||
all_files = (file for f in path for file in f.glob(f'*.{ext}'))
|
||||
# 转换为列表以便于打乱和分割
|
||||
all_files = list(all_files)
|
||||
random.shuffle(all_files)
|
||||
# 计算分割点
|
||||
split_idx = int(len(all_files) * 0.8)
|
||||
# 分割数据集
|
||||
train_files = all_files[:split_idx]
|
||||
eval_files = all_files[split_idx:]
|
||||
# 复制文件到相应目录
|
||||
for file in train_files:
|
||||
shutil.copy(file, Path(train_dir) / file.name)
|
||||
for file in eval_files:
|
||||
shutil.copy(file, Path(eval_dir) / file.name)
|
||||
|
||||
if __name__ == '__main__':
|
||||
all_file = Path('origin_data').glob('*.jsonl')
|
||||
data = []
|
||||
for file in all_file:
|
||||
t_lines = file.read_text()[1:-1].splitlines()[1:]
|
||||
for i, line in enumerate(t_lines):
|
||||
if line[-1] == ',':
|
||||
line = line[:-1]
|
||||
particle_dict = json.loads(line)
|
||||
data.append(Particle(**particle_dict))
|
||||
|
||||
with open('particle_data.pkl', 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
print(f"Total {len(data)} particles, train data account for {len(data)*0.8}, eval data account for {len(data)*0.2}")
|
||||
|
||||
with open('particle_data_train.pkl', 'wb') as f:
|
||||
pickle.dump(data[:int(len(data)*0.8)], f)
|
||||
|
||||
with open('particle_data_eval.pkl', 'wb') as f:
|
||||
pickle.dump(data[int(len(data)*0.2):], f)
|
||||
|
||||
all_files = ParticleData.from_file('/data/bgptf/particle_data.pkl')
|
||||
train_files = ParticleData.from_file('/data/bgptf/particle_data_train.pkl')
|
||||
eval_files = ParticleData.from_file('/data/bgptf/particle_data_eval.pkl')
|
||||
test_files = ParticleData.from_file('/data/bgptf/particle_test.pkl')
|
||||
p = Path('train_file_split')
|
||||
if not p.exists():
|
||||
p.mkdir()
|
||||
for n, particle in tqdm(enumerate(train_files.data), total=len(train_files.data), desc="Writing files"):
|
||||
with open(p.joinpath(f'class{particle.output}_{n}.txt'), 'w') as f:
|
||||
f.write(particle.input)
|
||||
|
||||
# 定义源目录和目标目录
|
||||
source_dir = Path('train_file_split')
|
||||
train_dir = Path('bbbar_train_split')
|
||||
eval_dir = Path('bbbar_eval_split')
|
||||
|
||||
# 创建目标目录
|
||||
train_dir.mkdir(parents=True, exist_ok=True)
|
||||
eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取所有文件路径
|
||||
all_files = list(source_dir.glob('*.txt'))
|
||||
|
||||
# 设置划分比例,例如 80% 作为训练集,20% 作为验证集
|
||||
train_ratio = 0.8
|
||||
|
||||
# 计算训练集和验证集的文件数量
|
||||
num_train_files = int(len(all_files) * train_ratio)
|
||||
num_eval_files = len(all_files) - num_train_files
|
||||
|
||||
# 随机打乱文件列表
|
||||
random.shuffle(all_files)
|
||||
|
||||
# 划分训练集和验证集
|
||||
train_files = all_files[:num_train_files]
|
||||
eval_files = all_files[num_train_files:]
|
||||
|
||||
# 复制文件到目标目录并显示进度条
|
||||
print("正在复制训练集文件...")
|
||||
for file in tqdm(train_files, desc="训练集进度"):
|
||||
shutil.copy(file, train_dir / file.name)
|
||||
|
||||
print("正在复制验证集文件...")
|
||||
for file in tqdm(eval_files, desc="验证集进度"):
|
||||
shutil.copy(file, eval_dir / file.name)
|
||||
|
||||
print(f"训练集文件数: {len(train_files)}")
|
||||
print(f"验证集文件数: {len(eval_files)}")
|
||||
...
|
||||
22
environment.yaml
Normal file
22
environment.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
name: bgpt_data
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.10
|
||||
- click
|
||||
- joblib
|
||||
- pandas
|
||||
- pip
|
||||
- uproot
|
||||
- attrs
|
||||
- ipython
|
||||
- ipykernel
|
||||
- awkward-pandas
|
||||
- importlib-metadata
|
||||
- dill
|
||||
- tqdm
|
||||
- pytorch
|
||||
- pyInstaller
|
||||
pip:
|
||||
- pyarmor
|
||||
13
particle.py
Normal file
13
particle.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# particle.py
|
||||
|
||||
import enum
|
||||
|
||||
class ParticleType(enum.Enum):
|
||||
NeutralHadron = 0
|
||||
Photon = 1
|
||||
Electron = 2
|
||||
Muon = 3
|
||||
Pion = 4
|
||||
ChargedKaon = 5
|
||||
Proton = 6
|
||||
Others = 7
|
||||
1
passwd.txt
Normal file
1
passwd.txt
Normal file
@@ -0,0 +1 @@
|
||||
Xiaozhe+521
|
||||
39
print.txt
Normal file
39
print.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
Building JetSet from data/full_bb.root
|
||||
属性名称: particle_id, 类型: <class 'int'>, 值: 0
|
||||
属性名称: part_charge, 类型: <class 'int'>, 值: [1.0, -1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
属性名称: part_energy, 类型: <class 'float'>, 值: [15.209561347961426, 13.140800476074219, 8.555957794189453, 6.927829742431641, 6.921298027038574, 6.718546390533447, 5.639996528625488, 2.734797239303589, 2.276695728302002, 2.0461738109588623, 1.7986037731170654, 1.388156771659851, 1.38234281539917, 1.0570266246795654, 0.8934926986694336, 0.8899771571159363, 0.879972517490387, 0.7406238317489624, 0.5272039175033569, 0.3686225712299347, 0.3535390794277191, 0.34719786047935486, 0.34384962916374207, 0.2649906873703003, 0.23999999463558197, 0.18479159474372864, 0.17317134141921997, 0.15941102802753448, 0.12423540651798248, 0.08461850136518478, 0.0843382179737091, 0.02829699218273163, 0.009839793667197227]
|
||||
属性名称: part_px, 类型: <class 'float'>, 值: [-5.796960830688477, -5.311581611633301, -3.19583797454834, -2.326359510421753, -2.503192901611328, -2.6899032592773438, -1.925751805305481, -0.8692939281463623, -0.7592987418174744, -0.7812060713768005, -0.6791485548019409, -0.3433796167373657, -0.47946590185165405, -0.18832340836524963, -0.35665348172187805, -0.28360629081726074, -0.3264841139316559, -0.3620759844779968, -0.26313358545303345, -0.1359526515007019, -0.123598612844944, -0.06242864206433296, -0.13691917061805725, -0.08844801038503647, -0.05727635696530342, -0.08955521136522293, -0.12885883450508118, -0.10516541451215744, -0.03254406899213791, -0.015558036044239998, -0.02192043326795101, -0.008681708946824074, -0.0013048802502453327]
|
||||
属性名称: part_py, 类型: <class 'float'>, 值: [-8.48299789428711, -6.931148052215576, -4.749238014221191, -3.9339864253997803, -4.476352691650391, -3.8203539848327637, -3.365067958831787, -1.7807756662368774, -1.501481056213379, -1.1750568151474, -0.9528936743736267, -0.7268454432487488, -0.833030104637146, 0.49239909648895264, -0.4624371826648712, -0.5192604064941406, -0.2649482488632202, -0.40544790029525757, -0.32859694957733154, -0.19921012222766876, -0.22155417501926422, 0.23689919710159302, -0.21035675704479218, -0.1498441994190216, -0.1331370621919632, -0.012309929355978966, 0.06338366121053696, -0.07766649127006531, -0.06364167481660843, 0.03760836273431778, -0.04934240132570267, -0.019126178696751595, 0.004166470840573311]
|
||||
属性名称: part_pz, 类型: <class 'float'>, 值: [-11.213626861572266, -9.818737030029297, -6.358912944793701, -5.206402778625488, -4.645572185516357, -4.82585334777832, -4.096017360687256, -1.8796172142028809, -1.5337415933609009, -1.481818437576294, -1.3659160137176514, -1.1231404542922974, -0.9935013651847839, -0.916178286075592, -0.6762243509292603, -0.6648273468017578, -0.7730214595794678, -0.5030274987220764, -0.2853204607963562, -0.2787737548351288, -0.2462255209684372, -0.20299455523490906, -0.23502285778522491, -0.19985926151275635, -0.19129541516304016, -0.1611715406179428, -0.09677927196025848, -0.09121418744325638, -0.10161228477954865, -0.07418793439865112, -0.06479009985923767, -0.018961459398269653, -0.008818126283586025]
|
||||
属性名称: log_energy, 类型: <class 'float'>, 值: [ 2.72192427 2.57572193 2.14662786 1.9355466 1.93460333 1.90487182
|
||||
1.72988345 1.0060573 0.82272515 0.71597162 0.58701068 0.3279768
|
||||
0.32377975 0.0554599 -0.11261712 -0.11655948 -0.1278646 -0.30026243
|
||||
-0.64016787 -0.997982 -1.03976125 -1.05786046 -1.06755084 -1.3280606
|
||||
-1.42711638 -1.6885266 -1.75347376 -1.83626933 -2.08557707 -2.46960234
|
||||
-2.47292016 -3.56499976 -4.62132054]
|
||||
属性名称: log_pt, 类型: <class 'float'>, 值: [ 2.32966824 2.16703303 1.744736 1.51959212 1.63485496 1.54165826
|
||||
1.35509735 0.68391671 0.52031146 0.34432893 0.15713127 -0.21831242
|
||||
-0.03961538 -0.64020631 -0.53786329 -0.52481977 -0.86639788 -0.60956524
|
||||
-0.86519514 -1.42221173 -1.3716092 -1.40655069 -1.38233552 -1.74869444
|
||||
-1.9314722 -2.40354098 -1.94069632 -2.03457679 -2.63833865 -3.20154184
|
||||
-2.91891223 -3.86302568 -5.43390179]
|
||||
属性名称: part_deta, 类型: <class 'float'>, 值: [-0.012650121003389359, 0.009468508884310722, 0.00042291259160265326, 0.01923862285912037, -0.14403046667575836, -0.05278221517801285, -0.036465417593717575, -0.11269791424274445, -0.1397673487663269, -0.040799811482429504, 0.03767024725675583, 0.17913000285625458, -0.05223162844777107, 0.362665593624115, 0.03155886009335518, 0.008969753980636597, 0.4117862284183502, -0.1295829564332962, -0.3228398263454437, 0.03021526150405407, -0.09680869430303574, -0.20235909521579742, -0.12153740972280502, 0.025453981012105942, 0.13331031799316406, 0.38491809368133545, -0.3260197937488556, -0.30641964077949524, 0.19324439764022827, 0.40427282452583313, 0.05876421928405762, -0.14629797637462616, 0.4952174425125122]
|
||||
属性名称: part_dphi, 类型: <class 'float'>, 值: [-0.004700591322034597, -0.05910219997167587, 0.0024550738744437695, 0.060737334191799164, 0.0848897248506546, -0.018699318170547485, 0.07498598098754883, 0.1406451165676117, 0.1265745460987091, 0.008045745082199574, -0.02444184571504593, 0.15342673659324646, 0.07251019775867462, -2.181525468826294, -0.062189724296331406, 0.09487095475196838, -0.2942989766597748, -0.13417768478393555, -0.08044422417879105, -0.004084283020347357, 0.08590564131736755, -2.2891547679901123, 0.01777080073952675, 0.06153986230492592, 0.18849976360797882, -0.8394244313240051, -1.4331588745117188, -0.3399130403995514, 0.12207411229610443, -2.1545727252960205, 0.17670847475528717, 0.1686646193265915, -2.2433114051818848]
|
||||
属性名称: part_logptrel, 类型: <class 'float'>, 值: -1.6656613686933
|
||||
属性名称: part_logerel, 类型: <class 'float'>, 值: -1.6908250129173599
|
||||
属性名称: part_deltaR, 类型: <class 'float'>, 值: 0.013495225829054495
|
||||
属性名称: part_d0, 类型: <class 'float'>, 值: [-0.08485882 0.08362859 0. 0. -0.12567137 0.13890936
|
||||
0. 0.05528848 0. 0. 0. -0.08717197
|
||||
0. 0. 0. 0. 0. 0.
|
||||
0.07979553 0. 0. 0.06790032 0. 0.
|
||||
0. 0. 0. 0. 0. 0.
|
||||
0. 0. 0. ]
|
||||
属性名称: part_dz, 类型: <class 'float'>, 值: [ 0.16160717 0.07151473 0. 0. -0.16053929 0.09658169
|
||||
0. 0.06517879 0. 0. 0. 0.11920792
|
||||
0. 0. 0. 0. 0. 0.
|
||||
-0.13735584 0. 0. 0.12359677 0. 0.
|
||||
0. 0. 0. 0. 0. 0.
|
||||
0. 0. 0. ]
|
||||
属性名称: particle_type, 类型: <enum 'ParticleType'>, 值: ParticleType.Electron
|
||||
属性名称: particle_pid, 类型: <class 'int'>, 值: 2
|
||||
属性名称: jet_type, 类型: <class 'str'>, 值: b jet
|
||||
397
test.ipynb
Normal file
397
test.ipynb
Normal file
@@ -0,0 +1,397 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!rm -rf ./save/*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building JetSet from data/data_full/full_bbbar.root\n",
|
||||
"jet type: bbbar\n",
|
||||
"Building JetSet from data/data_full/full_bb.root\n",
|
||||
"jet type: bb\n",
|
||||
"100%|███████████████████████████████████| 24921/24921 [00:06<00:00, 3850.86it/s]\n",
|
||||
"100%|███████████████████████████████████| 24930/24930 [00:06<00:00, 3754.57it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"if not Path('save/save_full').exists():\n",
|
||||
" Path('save/save_full').mkdir()\n",
|
||||
"!python convert_root_to_txt.py ./data/data_full ./save/save_full --data-type full --attr-sep '|' --part-sep ';' --selected-attrs 'part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_pid'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building JetSet from data/data_fast/fast_bbar.root\n",
|
||||
"jet type: bb\n",
|
||||
"Building JetSet from data/data_fast/fast_bb.root\n",
|
||||
"jet type: bb\n",
|
||||
"joblib.externals.loky.process_executor._RemoteTraceback: \n",
|
||||
"\"\"\"\n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py\", line 463, in _process_worker\n",
|
||||
" r = call_item()\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py\", line 291, in __call__\n",
|
||||
" return self.fn(*self.args, **self.kwargs)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 598, in __call__\n",
|
||||
" return [func(*args, **kwargs)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 598, in <listcomp>\n",
|
||||
" return [func(*args, **kwargs)\n",
|
||||
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 271, in process_file\n",
|
||||
" jet_set = JetSet.build_jetset_fast(root_file)\n",
|
||||
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 102, in build_jetset_fast\n",
|
||||
" for i, j in a.iterrows():\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/frame.py\", line 1541, in iterrows\n",
|
||||
" for k, v in zip(self.index, self.values):\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/frame.py\", line 12651, in values\n",
|
||||
" return self._mgr.as_array()\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/internals/managers.py\", line 1692, in as_array\n",
|
||||
" arr = self._interleave(dtype=dtype, na_value=na_value)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/internals/managers.py\", line 1733, in _interleave\n",
|
||||
" arr = blk.get_values(dtype)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/internals/blocks.py\", line 2253, in get_values\n",
|
||||
" return np.asarray(values).reshape(self.shape)\n",
|
||||
"ValueError: cannot reshape array of size 0 into shape (1,10000)\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"The above exception was the direct cause of the following exception:\n",
|
||||
"\n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 275, in <module>\n",
|
||||
" main()\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 1157, in __call__\n",
|
||||
" return self.main(*args, **kwargs)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 1078, in main\n",
|
||||
" rv = self.invoke(ctx)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 1434, in invoke\n",
|
||||
" return ctx.invoke(self.callback, **ctx.params)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 783, in invoke\n",
|
||||
" return __callback(*args, **kwargs)\n",
|
||||
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 257, in main\n",
|
||||
" preprocess(root_dir, save_dir, data_type, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)\n",
|
||||
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 262, in preprocess\n",
|
||||
" joblib.Parallel(n_jobs=-1)(\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 2007, in __call__\n",
|
||||
" return output if self.return_generator else list(output)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 1650, in _get_outputs\n",
|
||||
" yield from self._retrieve()\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 1754, in _retrieve\n",
|
||||
" self._raise_error_fast()\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 1789, in _raise_error_fast\n",
|
||||
" error_job.get_result(self.timeout)\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 745, in get_result\n",
|
||||
" return self._return_or_raise()\n",
|
||||
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 763, in _return_or_raise\n",
|
||||
" raise self._result\n",
|
||||
"ValueError: cannot reshape array of size 0 into shape (1,10000)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"if not Path('save/save_fast').exists():\n",
|
||||
" Path('save/save_fast').mkdir()\n",
|
||||
"!python convert_root_to_txt.py ./data/data_fast ./save/save_fast --data-type fast --attr-sep '|' --part-sep ';' --selected-attrs 'part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_pid'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from convert_root_to_txt import *\n",
|
||||
"from pathlib import Path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building JetSet from data/full_bb.root\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jet_set = JetSet.build_jetset(Path('data/full_bb.root'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'bb'"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jet_set.jets_type"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building JetSet from data/full_bb.root\n",
|
||||
"属性名称: particle_id, 类型: <class 'int'>, 值: 0\n",
|
||||
"属性名称: part_charge, 类型: <class 'int'>, 值: [1.0, -1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
|
||||
"属性名称: part_energy, 类型: <class 'float'>, 值: [15.209561347961426, 13.140800476074219, 8.555957794189453, 6.927829742431641, 6.921298027038574, 6.718546390533447, 5.639996528625488, 2.734797239303589, 2.276695728302002, 2.0461738109588623, 1.7986037731170654, 1.388156771659851, 1.38234281539917, 1.0570266246795654, 0.8934926986694336, 0.8899771571159363, 0.879972517490387, 0.7406238317489624, 0.5272039175033569, 0.3686225712299347, 0.3535390794277191, 0.34719786047935486, 0.34384962916374207, 0.2649906873703003, 0.23999999463558197, 0.18479159474372864, 0.17317134141921997, 0.15941102802753448, 0.12423540651798248, 0.08461850136518478, 0.0843382179737091, 0.02829699218273163, 0.009839793667197227]\n",
|
||||
"属性名称: part_px, 类型: <class 'float'>, 值: [-5.796960830688477, -5.311581611633301, -3.19583797454834, -2.326359510421753, -2.503192901611328, -2.6899032592773438, -1.925751805305481, -0.8692939281463623, -0.7592987418174744, -0.7812060713768005, -0.6791485548019409, -0.3433796167373657, -0.47946590185165405, -0.18832340836524963, -0.35665348172187805, -0.28360629081726074, -0.3264841139316559, -0.3620759844779968, -0.26313358545303345, -0.1359526515007019, -0.123598612844944, -0.06242864206433296, -0.13691917061805725, -0.08844801038503647, -0.05727635696530342, -0.08955521136522293, -0.12885883450508118, -0.10516541451215744, -0.03254406899213791, -0.015558036044239998, -0.02192043326795101, -0.008681708946824074, -0.0013048802502453327]\n",
|
||||
"属性名称: part_py, 类型: <class 'float'>, 值: [-8.48299789428711, -6.931148052215576, -4.749238014221191, -3.9339864253997803, -4.476352691650391, -3.8203539848327637, -3.365067958831787, -1.7807756662368774, -1.501481056213379, -1.1750568151474, -0.9528936743736267, -0.7268454432487488, -0.833030104637146, 0.49239909648895264, -0.4624371826648712, -0.5192604064941406, -0.2649482488632202, -0.40544790029525757, -0.32859694957733154, -0.19921012222766876, -0.22155417501926422, 0.23689919710159302, -0.21035675704479218, -0.1498441994190216, -0.1331370621919632, -0.012309929355978966, 0.06338366121053696, -0.07766649127006531, -0.06364167481660843, 0.03760836273431778, -0.04934240132570267, -0.019126178696751595, 0.004166470840573311]\n",
|
||||
"属性名称: part_pz, 类型: <class 'float'>, 值: [-11.213626861572266, -9.818737030029297, -6.358912944793701, -5.206402778625488, -4.645572185516357, -4.82585334777832, -4.096017360687256, -1.8796172142028809, -1.5337415933609009, -1.481818437576294, -1.3659160137176514, -1.1231404542922974, -0.9935013651847839, -0.916178286075592, -0.6762243509292603, -0.6648273468017578, -0.7730214595794678, -0.5030274987220764, -0.2853204607963562, -0.2787737548351288, -0.2462255209684372, -0.20299455523490906, -0.23502285778522491, -0.19985926151275635, -0.19129541516304016, -0.1611715406179428, -0.09677927196025848, -0.09121418744325638, -0.10161228477954865, -0.07418793439865112, -0.06479009985923767, -0.018961459398269653, -0.008818126283586025]\n",
|
||||
"属性名称: log_energy, 类型: <class 'float'>, 值: [ 2.72192427 2.57572193 2.14662786 1.9355466 1.93460333 1.90487182\n",
|
||||
" 1.72988345 1.0060573 0.82272515 0.71597162 0.58701068 0.3279768\n",
|
||||
" 0.32377975 0.0554599 -0.11261712 -0.11655948 -0.1278646 -0.30026243\n",
|
||||
" -0.64016787 -0.997982 -1.03976125 -1.05786046 -1.06755084 -1.3280606\n",
|
||||
" -1.42711638 -1.6885266 -1.75347376 -1.83626933 -2.08557707 -2.46960234\n",
|
||||
" -2.47292016 -3.56499976 -4.62132054]\n",
|
||||
"属性名称: log_pt, 类型: <class 'float'>, 值: [ 2.32966824 2.16703303 1.744736 1.51959212 1.63485496 1.54165826\n",
|
||||
" 1.35509735 0.68391671 0.52031146 0.34432893 0.15713127 -0.21831242\n",
|
||||
" -0.03961538 -0.64020631 -0.53786329 -0.52481977 -0.86639788 -0.60956524\n",
|
||||
" -0.86519514 -1.42221173 -1.3716092 -1.40655069 -1.38233552 -1.74869444\n",
|
||||
" -1.9314722 -2.40354098 -1.94069632 -2.03457679 -2.63833865 -3.20154184\n",
|
||||
" -2.91891223 -3.86302568 -5.43390179]\n",
|
||||
"属性名称: part_deta, 类型: <class 'float'>, 值: [-0.012650121003389359, 0.009468508884310722, 0.00042291259160265326, 0.01923862285912037, -0.14403046667575836, -0.05278221517801285, -0.036465417593717575, -0.11269791424274445, -0.1397673487663269, -0.040799811482429504, 0.03767024725675583, 0.17913000285625458, -0.05223162844777107, 0.362665593624115, 0.03155886009335518, 0.008969753980636597, 0.4117862284183502, -0.1295829564332962, -0.3228398263454437, 0.03021526150405407, -0.09680869430303574, -0.20235909521579742, -0.12153740972280502, 0.025453981012105942, 0.13331031799316406, 0.38491809368133545, -0.3260197937488556, -0.30641964077949524, 0.19324439764022827, 0.40427282452583313, 0.05876421928405762, -0.14629797637462616, 0.4952174425125122]\n",
|
||||
"属性名称: part_dphi, 类型: <class 'float'>, 值: [-0.004700591322034597, -0.05910219997167587, 0.0024550738744437695, 0.060737334191799164, 0.0848897248506546, -0.018699318170547485, 0.07498598098754883, 0.1406451165676117, 0.1265745460987091, 0.008045745082199574, -0.02444184571504593, 0.15342673659324646, 0.07251019775867462, -2.181525468826294, -0.062189724296331406, 0.09487095475196838, -0.2942989766597748, -0.13417768478393555, -0.08044422417879105, -0.004084283020347357, 0.08590564131736755, -2.2891547679901123, 0.01777080073952675, 0.06153986230492592, 0.18849976360797882, -0.8394244313240051, -1.4331588745117188, -0.3399130403995514, 0.12207411229610443, -2.1545727252960205, 0.17670847475528717, 0.1686646193265915, -2.2433114051818848]\n",
|
||||
"属性名称: part_logptrel, 类型: <class 'float'>, 值: -1.6656613686933\n",
|
||||
"属性名称: part_logerel, 类型: <class 'float'>, 值: -1.6908250129173599\n",
|
||||
"属性名称: part_deltaR, 类型: <class 'float'>, 值: 0.013495225829054495\n",
|
||||
"属性名称: part_d0, 类型: <class 'float'>, 值: [-0.08485882 0.08362859 0. 0. -0.12567137 0.13890936\n",
|
||||
" 0. 0.05528848 0. 0. 0. -0.08717197\n",
|
||||
" 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0.07979553 0. 0. 0.06790032 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. ]\n",
|
||||
"属性名称: part_dz, 类型: <class 'float'>, 值: [ 0.16160717 0.07151473 0. 0. -0.16053929 0.09658169\n",
|
||||
" 0. 0.06517879 0. 0. 0. 0.11920792\n",
|
||||
" 0. 0. 0. 0. 0. 0.\n",
|
||||
" -0.13735584 0. 0. 0.12359677 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. ]\n",
|
||||
"属性名称: particle_type, 类型: <enum 'ParticleType'>, 值: ParticleType.Electron\n",
|
||||
"属性名称: particle_pid, 类型: <class 'int'>, 值: 2\n",
|
||||
"属性名称: jet_type, 类型: <class 'str'>, 值: b jet\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from convert_root_to_txt import *\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"# 构建 jet_set\n",
|
||||
"jet_set = JetSet.build_jetset(Path('data/full_bb.root'))\n",
|
||||
"\n",
|
||||
"# 获取第一个 jet 的第一个 particle\n",
|
||||
"particle = jet_set.jets[0].particles[0]\n",
|
||||
"\n",
|
||||
"# 打印属性名称、类型和值\n",
|
||||
"for field in attrs.fields(particle.__class__):\n",
|
||||
" field_name = field.name\n",
|
||||
" field_type = field.type\n",
|
||||
" field_value = getattr(particle, field_name)\n",
|
||||
" print(f\"属性名称: {field_name}, 类型: {field_type}, 值: {field_value}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"attrs_list length: 17\n",
|
||||
"Processing data/full_bbbar.root\n",
|
||||
"Building JetSet from data/full_bbbar.root\n",
|
||||
"Processing data/full_bb.root\n",
|
||||
"Building JetSet from data/full_bb.root\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 24921/24921 [07:08<00:00, 58.13it/s]\n",
|
||||
"100%|██████████| 24930/24930 [07:20<00:00, 56.56it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from n import *\n",
|
||||
"root_dir = \"./data\"\n",
|
||||
"save_dir = \"./save\"\n",
|
||||
"attrs_list = ['particle_id', 'part_charge', 'part_energy', 'part_px', 'part_py', 'part_pz', 'log_energy' , 'log_pt', 'part_deta', 'part_dphi', 'part_logptrel', 'part_logerel', 'part_deltaR', 'part_d0', 'part_dz' , 'particle_type', 'particle_pid']\n",
|
||||
"print(f'attrs_list length: {len(attrs_list)}')\n",
|
||||
"preprocess(root_dir, save_dir, selected_attrs=attrs_list, attr_sep=',', part_sep='|')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"bb file length: 24921\n",
|
||||
"bbbar file length: 24930\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"print(f\"bb file length: {len(list(Path('./save').glob('bb_*.bin')))}\")\n",
|
||||
"print(f\"bbbar file length: {len(list(Path('./save').glob('bbbar_*.bin')))}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"with open('save/bb_0.bin', 'rb') as f:\n",
|
||||
" res = pickle.load(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"str"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"type(res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"198017"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"def load_pickled_files(directory):\n",
|
||||
" files = [f for f in os.listdir(directory) if f.endswith('.bin')]\n",
|
||||
" data = []\n",
|
||||
" \n",
|
||||
" for file in files:\n",
|
||||
" filepath = os.path.join(directory, file)\n",
|
||||
" with open(filepath, 'rb') as f:\n",
|
||||
" data.append(pickle.load(f))\n",
|
||||
" \n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"if __name__ == '__main__':\n",
|
||||
" save_dir = \"./save\"\n",
|
||||
" data = load_pickled_files(save_dir)\n",
|
||||
" for i, jet in enumerate(data):\n",
|
||||
" print(f\"Data from file {i}:\")\n",
|
||||
" print(jet)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
17
test.log
Normal file
17
test.log
Normal file
@@ -0,0 +1,17 @@
|
||||
## full test log
|
||||
|
||||
Building JetSet from data/data_full/full_bbbar.root
|
||||
jet type: bbbar
|
||||
Building JetSet from data/data_full/full_bb.root
|
||||
jet type: bb
|
||||
100%|██████████████████████████████████████████████████████████████████████████████| 24930/24930 [00:06<00:00, 3905.34it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████| 24921/24921 [00:06<00:00, 3824.30it/s]
|
||||
|
||||
## fast test log
|
||||
|
||||
Building JetSet from data/data_fast/fast_bbar.root
|
||||
jet type: bb
|
||||
Building JetSet from data/data_fast/fast_bb.root
|
||||
jet type: bb
|
||||
100%|██████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 8766.04it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 8607.16it/s]
|
||||
Reference in New Issue
Block a user