first add

This commit is contained in:
2024-06-10 09:09:10 +08:00
parent 52d5f402bc
commit 8f9fac8bd8
14 changed files with 2042 additions and 0 deletions

9
.gitignore vendored Executable file
View File

@@ -0,0 +1,9 @@
__pycache__/
*.spec
*.pyc
*.tar.gz
dist/
build/
data/
data_100w/
save/

345
LLM.ipynb Normal file
View File

@@ -0,0 +1,345 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "074cbeba",
"metadata": {},
"outputs": [],
"source": [
"import uproot\n",
"import attrs\n",
"from typing import List\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b49ac144",
"metadata": {},
"outputs": [],
"source": [
"with uproot.open(\"./fast.root\") as f:\n",
" tree = f[\"tree\"]\n",
" a = tree.arrays(library=\"pd\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c2d94275",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['btag', 'ctag', 'gen_match', 'genpart_eta', 'genpart_phi',\n",
" 'genpart_pid', 'genpart_pt', 'is_signal', 'jet_energy', 'jet_eta',\n",
" 'jet_nparticles', 'jet_phi', 'jet_pt', 'part_charge', 'part_d0err',\n",
" 'part_d0val', 'part_deta', 'part_dphi', 'part_dzerr', 'part_dzval',\n",
" 'part_energy', 'part_pid', 'part_pt', 'part_px', 'part_py', 'part_pz',\n",
" 'part_isChargedHadron', 'part_isChargedKaon', 'part_isElectron',\n",
" 'part_isKLong', 'part_isKShort', 'part_isMuon', 'part_isNeutralHadron',\n",
" 'part_isPhoton', 'part_isPi0', 'part_isPion', 'part_isProton',\n",
" 'label_b', 'label_bb', 'label_bbar', 'label_c', 'label_cbar',\n",
" 'label_cc', 'label_d', 'label_dbar', 'label_g', 'label_gg', 'label_s',\n",
" 'label_sbar', 'label_u', 'label_ubar'],\n",
" dtype='object')\n"
]
}
],
"source": [
"print(a.keys())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "93a76758",
"metadata": {},
"outputs": [],
"source": [
"import attrs\n",
"import uproot\n",
"import numpy as np\n",
"from typing import List\n",
"@attrs.define\n",
"class ParticleBase:\n",
" part_charge: int = attrs.field() # charge of the particle\n",
" part_energy: float = attrs.field() # energy of the particle\n",
" part_px: float = attrs.field() # x-component of the momentum vector\n",
" part_py: float = attrs.field() # y-component of the momentum vector\n",
" part_pz: float = attrs.field() # z-component of the momentum vector\n",
" log_energy: float = attrs.field() # log10(part_energy)\n",
" log_pt: float = attrs.field() # log10(part_pt)\n",
" part_deta: float = attrs.field() # pseudorapidity\n",
" part_dphi: float = attrs.field() # azimuthal angle\n",
" part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))\n",
" part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))\n",
" part_deltaR: float = attrs.field() # distance between the particle and the jet\n",
" part_d0: float = attrs.field() # tanh(d0)\n",
" part_dz: float = attrs.field() # tanh(z0)\n",
" particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others)\n",
" particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7)\n",
"\n",
"@attrs.define\n",
"class Jet:\n",
" jet_b: float = attrs.field()\n",
" jet_bbar: float = attrs.field()\n",
" jet_energy: float = attrs.field() # energy of the jet\n",
" jet_pt: float = attrs.field() # transverse momentum of the jet\n",
" jet_eta: float = attrs.field() # pseudorapidity of the jet\n",
" particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet\n",
" \n",
" def __len__(self):\n",
" return len(self.particles)\n",
" \n",
"@attrs.define\n",
"class JetSet:\n",
" jets: List[Jet]\n",
" \n",
" def __len__(self):\n",
" return len(self.jets)\n",
" \n",
"def jud_type(jtmp): #这个函数用来判断每个粒子的类型每个粒子可以是electron、muon、pion 等\n",
" particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6}\n",
" max_element = max(jtmp)\n",
" idx = jtmp.index(max_element)\n",
" items = list(particle_dict.items())\n",
" return items[idx][0], items[idx][1]\n",
" \n",
"with uproot.open(\"./data/data_fast/fast_bb.root\") as f:\n",
" tree = f[\"tree\"]\n",
" a = tree.arrays(library=\"pd\")\n",
"#a里面有很多喷注我们的目标是判断每个喷注 它是 b 还是 bbar. 每个喷注里含了很多粒子。以下以 part 开头的变量都是列表,比如 part_pt它是存储了一个喷注中所有粒子的pt\n",
"#part_energy 存储了一个喷注中所有粒子的energy。\n",
"\n",
"\n",
"jet_list = []\n",
"for j in a.itertuples(): \n",
" part_pt = np.array(j.part_pt)\n",
" jet_pt = np.array(j.jet_pt)\n",
" part_logptrel = np.log(np.divide(part_pt, jet_pt))\n",
" \n",
" part_energy = np.array(j.part_energy)\n",
" jet_energy = np.array(j.jet_energy)\n",
" part_logerel = np.log(np.divide(part_energy, jet_energy))\n",
" \n",
" part_deta = np.array(j.part_deta)\n",
" part_dphi = np.array(j.part_dphi)\n",
" part_deltaR = np.hypot(part_deta, part_dphi)\n",
" \n",
" assert len(j.part_pt) == len(j.part_energy) == len(j.part_deta)\n",
"\n",
" particles = []\n",
" \n",
" \n",
" particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton']\n",
" part_type = []\n",
" part_pid = []\n",
" for pn in range(len(j.part_pt)):\n",
" jtmp = [j.part_isNeutralHadron[pn], j.part_isPhoton[pn], j.part_isElectron[pn], j.part_isMuon[pn], j.part_isPion[pn],\n",
" j.part_isChargedKaon[pn], j.part_isProton[pn]]\n",
" tmp_type, tmp_pid = jud_type(jtmp)\n",
" part_type.append(tmp_type)\n",
" part_pid.append(tmp_pid)\n",
" \n",
" bag = zip(j.part_charge, j.part_energy, j.part_px, j.part_py, j.part_pz, np.log(j.part_energy), \n",
" np.log(j.part_pt), j.part_deta, j.part_dphi, part_logptrel, part_logerel, part_deltaR, \n",
" np.tanh(j.part_d0val), np.tanh(j.part_dzval), part_type, part_pid)\n",
" \n",
" #下边的代码是要对第 j 个喷注中的所有粒子做循环,将每个粒子都 存成 ParticleBase然后 append 到 particles里\n",
" #所以 partices 存储了 第 j 个喷注中所有粒子的信息\n",
" for c, en, px, py, pz, lEn, lPt, eta, phi, ii, jj, kk, d0, dz, ptype, pid in bag:\n",
" particles.append(ParticleBase(\n",
" part_charge=c, \n",
" part_energy=en, \n",
" part_px=px, \n",
" part_py=py,\n",
" part_pz=pz, \n",
" log_energy=lEn, \n",
" log_pt=lPt,\n",
" part_deta=eta, \n",
" part_dphi=phi, \n",
" part_logptrel=ii,\n",
" part_logerel=jj, \n",
" part_deltaR=kk,\n",
" part_d0=d0, \n",
" part_dz=dz, \n",
" particle_type=ptype, # assuming you will set this correctly\n",
" particle_pid=pid # assuming you will set this correctly\n",
" ))\n",
" # add jets jet = 喷注,\n",
" jet = Jet(\n",
" jet_b=j.label_b, #如果此jet是b那么label_b = 1, 否则label_b = 0\n",
" jet_bbar=j.label_bbar, #如果此jet是bbar那么label_bbar = 1\n",
" jet_energy=j.jet_energy, #第 j 个喷注的 energy\n",
" jet_pt=j.jet_pt, # 第 j 个喷注的 pt\n",
" jet_eta=j.jet_eta, # 第 j 个喷注的 eta (是一种角度的表示)\n",
" particles=particles # 第 j 个喷注中所有的 粒子\n",
" )\n",
" jet_list.append(jet)\n",
"\n",
"jet_set1 = JetSet(jets=jet_list)\n",
"\n",
"#如上所说,每个喷注有很多粒子,你最后输入模型的是每个喷注中所有粒子的如下信息\n",
"#log_energy log_pt part_logerel part_logptrel part_deltaR part_charge part_d0 part_dz \n",
"#part_deta part_dphi particle_pid"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9995622",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b94329b2",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jet_set1.jets[0].jet_bbar"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4242ce6c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n",
"-1.0\n",
"1.0\n",
"0.0\n",
"-1.0\n",
"1.0\n",
"-1.0\n",
"0.0\n",
"-1.0\n",
"0.0\n",
"0.0\n",
"0.0\n",
"0.0\n",
"1.0\n",
"0.0\n",
"0.0\n",
"0.0\n"
]
}
],
"source": [
"for num in range(len(jet_set1.jets[0].particles)):\n",
" print(jet_set1.jets[0].particles[num].part_charge)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "6f0dc08e",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1.0\n",
"1.0\n",
"-1.0\n",
"-1.0\n",
"0.0\n",
"1.0\n",
"-1.0\n",
"1.0\n",
"0.0\n",
"0.0\n",
"0.0\n",
"0.0\n",
"-1.0\n",
"0.0\n",
"0.0\n"
]
}
],
"source": [
"for num in range(len(jet_set1.jets[1].particles)):\n",
" print(jet_set1.jets[1].particles[num].part_charge)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e6809f05",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000\n"
]
}
],
"source": [
"print(len(jet_set1.jets))"
]
},
{
"cell_type": "markdown",
"id": "67cd5f19",
"metadata": {},
"source": [
"## 与particle-transformer对比的实验设计\n",
"\n",
"使用这些属性进行与particle-transformer准确率对比bb 100w bbbar 100w\n",
"\n",
"log_energy log_pt part_logerel part_logptrel part_deltaR part_charge part_d0 part_dz part_deta part_dphi particle_pid"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

250
convert_root_to_bin.py.bak Normal file
View File

@@ -0,0 +1,250 @@
import enum
import attrs
from typing import List, Optional, Union
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import pickle
import joblib
import click
import torch
from torch.utils.data import Dataset
import math
class ByteDataset(Dataset):
def __init__(self, filenames, patch_size: int = 16, patch_length: int = 512):
self.patch_size = patch_size
self.patch_length = patch_length
print(f"Loading {len(filenames)} files for classification")
self.filenames = []
self.labels = {'bb': 0, 'bbbar': 1} # 映射标签到整数
for filename in tqdm(filenames):
file_size = os.path.getsize(filename)
file_size = math.ceil(file_size / self.patch_size)
ext = filename.split('.')[-1]
label = os.path.basename(filename).split('_')[0] # 使用前缀部分作为标签
label = f"{label}.{ext}"
if file_size <= self.patch_length - 2:
self.filenames.append((filename, label))
if label not in self.labels:
self.labels[label] = len(self.labels)
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filename, label = self.filenames[idx]
file_bytes = self.read_bytes(filename)
file_bytes = torch.tensor(file_bytes, dtype=torch.long)
label = torch.tensor(self.labels[label], dtype=torch.long)
return file_bytes, label, filename
def readbytes(self, filename):
ext = filename.split('.')[-1]
ext = bytearray(ext, 'utf-8')
ext = [byte for byte in ext][:self.patch_size]
with open(filename, 'rb') as f:
file_bytes = f.read()
bytes = []
for byte in file_bytes:
bytes.append(byte)
if len(bytes) % self.patch_size != 0:
bytes = bytes + [256] * (self.patch_size - len(bytes) % self.patch_size)
bos_patch = ext + [256] * (self.patch_size - len(ext))
bytes = bos_patch + bytes + [256] * self.patch_size
return bytes
class ParticleType(enum.Enum):
NeutralHadron = 0
Photon = 1
Electron = 2
Muon = 3
Pion = 4
ChargedKaon = 5
Proton = 6
Others = 7
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
particle_type: ParticleType = attrs.field() # type of the particle as an enum
particle_pid: int = attrs.field() # pid of the particle
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
if selected_attrs is None:
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
values = [getattr(self, attr) for attr in selected_attrs]
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
return attr_sep.join(map(str, values))
def attributes_to_float_list(self, selected_attrs: Optional[List[str]] = None) -> list:
if selected_attrs is None:
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
values = [getattr(self, attr) for attr in selected_attrs]
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
return list(map(lambda x: float(x) if isinstance(x, int) else x, values))
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # type of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
def particles_attribute_list(self, selected_attrs: Optional[List[str]] = None) -> list:
return list(map(lambda p: p.attributes_to_float_list(selected_attrs), self.particles))
@attrs.define
class JetSet:
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
jets: List[Jet] = attrs.field(factory=list)
def __len__(self):
return len(self.jets)
@staticmethod
def jud_type(jtmp):
particle_dict = {
0: ParticleType.NeutralHadron,
1: ParticleType.Photon,
2: ParticleType.Electron,
3: ParticleType.Muon,
4: ParticleType.Pion,
5: ParticleType.ChargedKaon,
6: ParticleType.Proton,
7: ParticleType.Others
}
max_element = max(jtmp)
idx = jtmp.index(max_element)
return particle_dict.get(idx, ParticleType.Others)
@classmethod
def build_jetset(cls, root_file: Union[str, Path]) -> "JetSet":
print(f"Building JetSet from {root_file}")
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd")
label = Path(root_file).stem.split('_')[0]
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
particles = list(map(lambda pn: ParticleBase(
particle_id=i,
part_charge=j['part_charge'],
part_energy=j['part_energy'],
part_px=j['part_px'],
part_py=j['part_py'],
part_pz=j['part_pz'],
log_energy=np.log(j['part_energy']),
log_pt=np.log(j['part_pt']),
part_deta=j['part_deta'],
part_dphi=j['part_dphi'],
part_logptrel=part_logptrel[pn],
part_logerel=part_logerel[pn],
part_deltaR=part_deltaR[pn],
part_d0=np.tanh(j['part_d0val']),
part_dz=np.tanh(j['part_dzval']),
particle_type=cls.jud_type([j[t][pn] for t in particle_list]),
particle_pid=cls.jud_type([j[t][pn] for t in particle_list]).value,
jet_type=jet_type
), range(len(j['part_pt']))))
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=j['jet_pt'],
jet_eta=j['jet_eta'],
particles=particles,
label=label
)
jet_list.append(jet)
return cls(jets=jet_list, jets_type=jets_type)
def save_to_binary(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
file_counter = 0
for jet in tqdm(self.jets):
filename = f"{self.jets_type}_{file_counter}.bin"
filepath = os.path.join(save_dir, filename)
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
with open(filepath, 'wb') as file:
pickle.dump(jet_data, file)
file_counter += 1
@click.command()
@click.argument('root_dir', type=click.Path(exists=True))
@click.argument('save_dir', type=click.Path())
@click.option('--attr-sep', default=',', help='Separator for attributes.')
@click.option('--part-sep', default='|', help='Separator for particles.')
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
def main(root_dir, save_dir, attr_sep, part_sep, selected_attrs):
selected_attrs = selected_attrs.split(',')
preprocess(root_dir, save_dir, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
def preprocess(root_dir, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
root_files = list(Path(root_dir).glob('*.root'))
os.makedirs(save_dir, exist_ok=True)
joblib.Parallel(n_jobs=-1)(
joblib.delayed(process_file)(root_file, save_dir, selected_attrs, attr_sep, part_sep) for root_file in root_files
)
def process_file(root_file, save_dir, selected_attrs, attr_sep, part_sep):
jet_set = JetSet.build_jetset(root_file)
jet_set.save_to_binary(save_dir, selected_attrs, attr_sep, part_sep)
if __name__ == "__main__":
main()

279
convert_root_to_txt.py Normal file
View File

@@ -0,0 +1,279 @@
import enum
import attrs
from typing import List, Optional, Union
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import joblib
import click
class ParticleType(enum.Enum):
NeutralHadron = 0
Photon = 1
Electron = 2
Muon = 3
Pion = 4
ChargedKaon = 5
Proton = 6
Others = 7
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
particle_type: ParticleType = attrs.field() # type of the particle as an enum
particle_pid: int = attrs.field() # pid of the particle
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
if selected_attrs is None:
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
values = [getattr(self, attr) for attr in selected_attrs]
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
return attr_sep.join(map(str, values))
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # type of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
@attrs.define
class JetSet:
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
jets: List[Jet] = attrs.field(factory=list)
def __len__(self):
return len(self.jets)
@staticmethod
def jud_type(jtmp):
particle_dict = {
0: ParticleType.NeutralHadron,
1: ParticleType.Photon,
2: ParticleType.Electron,
3: ParticleType.Muon,
4: ParticleType.Pion,
5: ParticleType.ChargedKaon,
6: ParticleType.Proton,
7: ParticleType.Others
}
max_element = max(jtmp)
idx = jtmp.index(max_element)
return particle_dict.get(idx, ParticleType.Others), idx
@classmethod
def build_jetset_fast(cls, root_file: Union[str, Path]) -> "JetSet":
print(f"Building JetSet from {root_file}")
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
print(f"jet type: {jets_type}")
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd")
# print(f"DataFrame structure:\n{a.head()}")
if a.empty:
raise ValueError("DataFrame is empty")
label = Path(root_file).stem.split('_')[0]
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
jet_list = []
for j in a.itertuples():
part_pt = np.array(j.part_pt)
jet_pt = np.array(j.jet_pt)
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j.part_energy)
jet_energy = np.array(j.jet_energy)
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j.part_deta)
part_dphi = np.array(j.part_dphi)
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j.part_pt) == len(j.part_energy) == len(j.part_deta)
particles = []
# particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j.part_pt)):
jtmp = [j.part_isNeutralHadron[pn], j.part_isPhoton[pn], j.part_isElectron[pn], j.part_isMuon[pn], j.part_isPion[pn],
j.part_isChargedKaon[pn], j.part_isProton[pn]]
tmp_type, tmp_pid = cls.jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
bag = zip(j.part_charge, j.part_energy, j.part_px, j.part_py, j.part_pz, np.log(j.part_energy),
np.log(j.part_pt), j.part_deta, j.part_dphi, part_logptrel, part_logerel, part_deltaR,
np.tanh(j.part_d0val), np.tanh(j.part_dzval), part_type, part_pid)
#下边的代码是要对第 j 个喷注中的所有粒子做循环,将每个粒子都 存成 ParticleBase然后 append 到 particles里
#所以 partices 存储了 第 j 个喷注中所有粒子的信息
for num, (c, en, px, py, pz, lEn, lPt, eta, phi, ii, jj, kk, d0, dz, ptype, pid) in enumerate(bag):
particles.append(ParticleBase(
particle_id=num,
part_charge=c,
part_energy=en,
part_px=px,
part_py=py,
part_pz=pz,
log_energy=lEn,
log_pt=lPt,
part_deta=eta,
part_dphi=phi,
part_logptrel=ii,
part_logerel=jj,
part_deltaR=kk,
part_d0=d0,
part_dz=dz,
particle_type=ptype,
particle_pid=pid,
jet_type=jet_type
))
# add jets jet = 喷注,
jet = Jet(
jet_energy=j.jet_energy,
jet_pt=j.jet_pt,
jet_eta=j.jet_eta,
particles=particles,
label=label
)
jet_list.append(jet)
return cls(jets=jet_list, jets_type=jets_type)
@classmethod
def build_jetset_full(cls, root_file: Union[str, Path]) -> "JetSet":
print(f"Building JetSet from {root_file}")
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
print(f"jet type: {jets_type}")
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd")
label = Path(root_file).stem.split('_')[0]
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
particles = []
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j['part_pt'])):
jtmp = []
for t in particle_list:
jtmp.append(j[t][pn])
tmp_type, tmp_pid = cls.jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
for pn in range(len(part_type)):
particles.append(ParticleBase(
particle_id=i,
part_charge=j['part_charge'][pn],
part_energy=j['part_energy'][pn],
part_px=j['part_px'][pn],
part_py=j['part_py'][pn],
part_pz=j['part_pz'][pn],
log_energy=np.log(j['part_energy'][pn]),
log_pt=np.log(j['part_pt'][pn]),
part_deta=j['part_deta'][pn],
part_dphi=j['part_dphi'][pn],
part_logptrel=part_logptrel[pn],
part_logerel=part_logerel[pn],
part_deltaR=part_deltaR[pn],
part_d0=np.tanh(j['part_d0val'][pn]),
part_dz=np.tanh(j['part_dzval'][pn]),
particle_type=part_type[pn],
particle_pid=part_pid[pn],
jet_type=jet_type
))
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=j['jet_pt'],
jet_eta=j['jet_eta'],
particles=particles,
label=label
)
jet_list.append(jet)
return cls(jets=jet_list, jets_type=jets_type)
def save_to_txt(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|', prefix: str = ""):
for file_counter, jet in enumerate(tqdm(self.jets)):
filename = f"{self.jets_type}_{file_counter}_{prefix}.txt"
filepath = os.path.join(save_dir, filename)
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
with open(filepath, 'w') as file:
file.write(jet_data)
@click.command()
@click.argument('root_dir', type=click.Path(exists=True))
@click.argument('save_dir', type=click.Path())
@click.option('--data-type', default='fast', type=click.Choice(['fast', 'full']), help='Type of ROOT data: fast or full.')
@click.option('--attr-sep', default=',', help='Separator for attributes.')
@click.option('--part-sep', default='|', help='Separator for particles.')
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
def main(root_dir, save_dir, data_type, attr_sep, part_sep, selected_attrs):
selected_attrs = selected_attrs.split(',')
preprocess(root_dir, save_dir, data_type, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
def preprocess(root_dir, save_dir, data_type, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
root_files = list(Path(root_dir).glob('*.root'))
os.makedirs(save_dir, exist_ok=True)
# process_file(root_files[0], save_dir, data_type, selected_attrs, attr_sep, part_sep, 0) # one process for test
joblib.Parallel(n_jobs=-1)(
joblib.delayed(process_file)(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, i) for i, root_file in enumerate(root_files)
)
def process_file(root_file, save_dir, data_type, selected_attrs, attr_sep, part_sep, process_id):
prefix = f"process_{process_id}"
if data_type == 'fast':
jet_set = JetSet.build_jetset_fast(root_file)
else:
jet_set = JetSet.build_jetset_full(root_file)
jet_set.save_to_txt(save_dir, selected_attrs, attr_sep, part_sep, prefix)
if __name__ == "__main__":
main()

193
convert_root_to_txt.py.bak Normal file
View File

@@ -0,0 +1,193 @@
import enum
import attrs
from typing import List, Optional, Union
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import joblib
import click
class ParticleType(enum.Enum):
NeutralHadron = 0
Photon = 1
Electron = 2
Muon = 3
Pion = 4
ChargedKaon = 5
Proton = 6
Others = 7
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
particle_type: ParticleType = attrs.field() # type of the particle as an enum
particle_pid: int = attrs.field() # pid of the particle
jet_type: str = attrs.field() # type of the jet (e.g. b, bbbar)
def properties_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',') -> str:
if selected_attrs is None:
selected_attrs = [field.name for field in attrs.fields(self) if field.name != "jet_type"]
values = [getattr(self, attr) for attr in selected_attrs]
values = [v.value if isinstance(v, ParticleType) else v for v in values] # Convert ParticleType to its numeric value
return attr_sep.join(map(str, values))
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # type of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|') -> str:
return part_sep.join(map(lambda p: p.properties_concatenated(selected_attrs, attr_sep), self.particles))
@attrs.define
class JetSet:
jets_type: str = attrs.field() # type of the jets (e.g. bb, bbbar)
jets: List[Jet] = attrs.field(factory=list)
def __len__(self):
return len(self.jets)
@staticmethod
def jud_type(jtmp):
particle_dict = {
0: ParticleType.NeutralHadron,
1: ParticleType.Photon,
2: ParticleType.Electron,
3: ParticleType.Muon,
4: ParticleType.Pion,
5: ParticleType.ChargedKaon,
6: ParticleType.Proton,
7: ParticleType.Others
}
max_element = max(jtmp)
idx = jtmp.index(max_element)
return particle_dict.get(idx, ParticleType.Others), idx
@classmethod
def build_jetset(cls, root_file: Union[str, Path]) -> "JetSet":
print(f"Building JetSet from {root_file}")
if not 'bbbar' in Path(root_file).stem and not 'bb' in Path(root_file).stem:
raise ValueError("Invalid file name, should contain 'bb' or 'bbbar'")
jets_type = "bbbar" if "bbbar" in Path(root_file).stem else "bb"
print(f"jet type: {jets_type}")
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd")
label = Path(root_file).stem.split('_')[0]
jet_type = "bbbar jet" if "bbbar" in root_file.stem else "b jet"
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
particles = []
particle_list = ['part_isNeutralHadron', 'part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion', 'part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j['part_pt'])):
jtmp = []
for t in particle_list:
jtmp.append(j[t][pn])
tmp_type, tmp_pid = cls.jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
for pn in range(len(part_type)):
particles.append(ParticleBase(
particle_id=i,
part_charge=j['part_charge'][pn],
part_energy=j['part_energy'][pn],
part_px=j['part_px'][pn],
part_py=j['part_py'][pn],
part_pz=j['part_pz'][pn],
log_energy=np.log(j['part_energy'][pn]),
log_pt=np.log(j['part_pt'][pn]),
part_deta=j['part_deta'][pn],
part_dphi=j['part_dphi'][pn],
part_logptrel=part_logptrel[pn],
part_logerel=part_logerel[pn],
part_deltaR=part_deltaR[pn],
part_d0=np.tanh(j['part_d0val'][pn]),
part_dz=np.tanh(j['part_dzval'][pn]),
particle_type=part_type[pn],
particle_pid=part_pid[pn],
jet_type=jet_type
))
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=j['jet_pt'],
jet_eta=j['jet_eta'],
particles=particles,
label=label
)
jet_list.append(jet)
return cls(jets=jet_list, jets_type=jets_type)
def save_to_txt(self, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|', prefix: str = ""):
for file_counter, jet in enumerate(tqdm(self.jets)):
filename = f"{self.jets_type}_{file_counter}_{prefix}.txt"
filepath = os.path.join(save_dir, filename)
jet_data = jet.particles_concatenated(selected_attrs, attr_sep, part_sep)
with open(filepath, 'w') as file:
file.write(jet_data)
@click.command()
@click.argument('root_dir', type=click.Path(exists=True))
@click.argument('save_dir', type=click.Path())
@click.option('--attr-sep', default=',', help='Separator for attributes.')
@click.option('--part-sep', default='|', help='Separator for particles.')
@click.option('--selected-attrs', default='particle_id,part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_type,particle_pid', help='Comma-separated list of selected attributes.')
def main(root_dir, save_dir, attr_sep, part_sep, selected_attrs):
selected_attrs = selected_attrs.split(',')
preprocess(root_dir, save_dir, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)
def preprocess(root_dir, save_dir, selected_attrs: Optional[List[str]] = None, attr_sep: str = ',', part_sep: str = '|'):
root_files = list(Path(root_dir).glob('*.root'))
os.makedirs(save_dir, exist_ok=True)
joblib.Parallel(n_jobs=-1)(
joblib.delayed(process_file)(root_file, save_dir, selected_attrs, attr_sep, part_sep, i) for i, root_file in enumerate(root_files)
)
def process_file(root_file, save_dir, selected_attrs, attr_sep, part_sep, process_id):
prefix = f"process_{process_id}"
jet_set = JetSet.build_jetset(root_file)
jet_set.save_to_txt(save_dir, selected_attrs, attr_sep, part_sep, prefix)
if __name__ == "__main__":
main()

260
data_struct.py Normal file
View File

@@ -0,0 +1,260 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :data_struct.py
@Description: :
@Date :2024/06/04 09:08:46
@Author :hotwa
@version :1.0
'''
import attrs
from typing import List
import numpy as np
import uproot
from pathlib import Path
from tqdm import tqdm
import os
import random
import pickle
@attrs.define
class ParticleBase:
particle_id: int = attrs.field() # unique id for each particle
part_charge: int = attrs.field() # charge of the particle
part_energy: float = attrs.field() # energy of the particle
part_px: float = attrs.field() # x-component of the momentum vector
part_py: float = attrs.field() # y-component of the momentum vector
part_pz: float = attrs.field() # z-component of the momentum vector
log_energy: float = attrs.field() # log10(part_energy)
log_pt: float = attrs.field() # log10(part_pt)
part_deta: float = attrs.field() # pseudorapidity
part_dphi: float = attrs.field() # azimuthal angle
part_logptrel: float = attrs.field() # log10(pt(particle)/pt(jet))
part_logerel: float = attrs.field() # log10(energy(particle)/energy(jet))
part_deltaR: float = attrs.field() # distance between the particle and the jet
part_d0: float = attrs.field() # tanh(d0)
part_dz: float = attrs.field() # tanh(z0)
#particle_type: str = attrs.field() # type of the particle (e.g. charged kaon, charged pion, proton, electron, muon, neutral hadron, photon, others)
particle_pid: int = attrs.field() # pid of the particle (e.g. 0,1,2,3,4,5,6,7)
def properties_concatenated(self) -> str:
# 使用map函数将属性值转换为字符串并保留整数和浮点数的格式
values_str = map(lambda x: f"{x}", attrs.astuple(self))
# 连接所有属性值为一个字符串
concatenated_str = ' '.join(values_str)
return concatenated_str
def attributes_to_float_list(self) -> list:
attribute_values = []
for field in attrs.fields(self.__class__): # 获取类的所有字段
value = getattr(self, field.name)
# 检查属性值是否为整数,如果是则转换为浮点数
attribute_values.append(float(value) if isinstance(value, int) else value)
return attribute_values
@attrs.define
class Jet:
jet_energy: float = attrs.field() # energy of the jet
jet_pt: float = attrs.field() # transverse momentum of the jet
jet_eta: float = attrs.field() # pseudorapidity of the jet
label: str = attrs.field() # tpye of bb or bbbar
particles: List[ParticleBase] = attrs.field(factory=list) # list of particles in the jet
def __len__(self):
return len(self.particles)
def particles_concatenated(self) -> str:
# 连接所有粒子属性值为一个字符串
concatenated_str = ','.join([particle.properties_concatenated() for particle in self.particles])
return concatenated_str
def particles_attribute_list(self) ->list:
attribute_list = [particle.attributes_to_float_list() for particle in self.particles]
return attribute_list
@attrs.define
class JetSet:
jets: List[Jet]
def __len__(self):
return len(self.jets)
def jud_type(jtmp):
particle_dict = {'NeutralHadron':0,'Photon':1, 'Electron':2, 'Muon':3, 'Pion':4,'ChargedKaon':5, 'Proton':6}
max_element = max(jtmp)
idx = jtmp.index(max_element)
items = list(particle_dict.items())
return items[idx][0], items[idx][1]
def build_jetset(root_file):
with uproot.open(root_file) as f:
tree = f["tree"]
a = tree.arrays(library="pd") # pd.DataFrame
# print(a.keys())
label = Path(root_file).stem.split('_')[0]
jet_list = []
for i, j in a.iterrows():
part_pt = np.array(j['part_pt'])
jet_pt = np.array(j['jet_pt'])
part_logptrel = np.log(np.divide(part_pt, jet_pt))
part_energy = np.array(j['part_energy'])
jet_energy = np.array(j['jet_energy'])
part_logerel = np.log(np.divide(part_energy, jet_energy))
part_deta = np.array(j['part_deta'])
part_dphi = np.array(j['part_dphi'])
part_deltaR = np.hypot(part_deta, part_dphi)
assert len(j['part_pt']) == len(j['part_energy']) == len(j['part_deta'])
#particle_num = len(len(j['part_pt']))
# add particles
particles = []
particle_list = ['part_isNeutralHadron','part_isPhoton', 'part_isElectron', 'part_isMuon', 'part_isPion','part_isChargedKaon', 'part_isProton']
part_type = []
part_pid = []
for pn in range(len(j['part_pt'])):
jtmp = []
for t in particle_list:
jtmp.append(j[t][pn])
tmp_type, tmp_pid = jud_type(jtmp)
part_type.append(tmp_type)
part_pid.append(tmp_pid)
for ii, jj, kk, ptype, pid in zip(part_logptrel, part_logerel, part_deltaR, part_type, part_pid):
particles.append(ParticleBase(
particle_id=i,
part_charge=j['part_charge'],
part_energy=j['part_energy'],
part_px=j['part_px'],
part_py=j['part_py'],
part_pz=j['part_pz'],
log_energy=np.log(j['part_energy']),
log_pt=np.log(j['part_pt']),
part_deta=j['part_deta'],
part_dphi=j['part_dphi'],
part_logptrel=ii,
part_logerel=jj,
part_deltaR=kk,
part_d0=np.tanh(j['part_d0val']),
part_dz=np.tanh(j['part_dzval']),
#particle_type=ptype, # assuming you will set this correctly
particle_pid=pid # assuming you will set this correctly
))
# add jets
jet = Jet(
jet_energy=j['jet_energy'],
jet_pt=['jet_pt'],
jet_eta=['jet_eta'],
particles=particles,
label= label
)
jet_list.append(jet)
jet_set = JetSet(jets=jet_list)
return jet_set
def preprocess(root_dir,method='float_list'):
train_jetset_list = []
val_jetset_list = []
train_ratio=0.8
# 遍历directory下的所有子目录
for root, dirs, files in os.walk(root_dir):
for dir_name in dirs: # 直接遍历dirs列表避免二次os.walk
if dir_name in ['bb', 'bbbar']:
bb_bbbar_files = []
# 遍历当前dir_name下的所有文件
for _, _, files in os.walk(os.path.join(root, dir_name)):
for file in files:
if file.endswith('.root'):
# 构建文件的完整路径并添加到列表中
full_path = os.path.join(root, dir_name, file)
bb_bbbar_files.append(full_path)
if bb_bbbar_files: # 确保列表不为空才进行后续操作
random.shuffle(bb_bbbar_files)
split_index = int(len(bb_bbbar_files) * train_ratio)
train_files = bb_bbbar_files[:split_index]
val_files = bb_bbbar_files[split_index:]
print('len(train_files):',len(train_files))
print('len(val_files):',len(val_files))
# 限制训练集和验证集的文件数量
train_file_limit = 32
val_file_limit = 8
if method =="float_list":
train_file_count = 0
file_counter = 0
for file in train_files:
if train_file_count < train_file_limit:
train_file_count += 1
print(f"loading{file} to train pkl")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_list= jet.particles_attribute_list()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.pkl"
filepath = os.path.join('/data/slow/100w_pkl/train', filename)
with open(filepath, 'wb') as file:
pickle.dump(jet_list,file)
else:
break
val_file_count = 0
for file in val_files:
if val_file_count < val_file_limit:
val_file_count +=1
print(f"loading{file} to val pkl")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_list= jet.particles_attribute_list()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.pkl"
filepath = os.path.join('/data/slow/100w_pkl/val', filename)
with open(filepath, 'wb') as file:
pickle.dump(jet_list,file)
else:
break
else:
train_file_count = 0
file_counter = 0
for file in train_files:
if train_file_count < train_file_limit:
train_file_count += 1
print(f"loading{file}to train txt")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_str= jet.particles_concatenated()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.txt"
filepath = os.path.join('/data/slow/100w/train', filename)
with open(filepath, 'w') as file:
file.write(jet_str)
else:
break
val_file_count = 0
for file in val_files:
if val_file_count < val_file_limit:
val_file_count +=1
print(f"loading{file}to val txt")
train_jetset = build_jetset(file)
for jet in tqdm(train_jetset.jets):
jet_str= jet.particles_concatenated()
label = jet.label
file_counter += 1
filename = f"{label}_{file_counter}.txt"
filepath = os.path.join('/data/slow/100w/val', filename)
with open(filepath, 'w') as file:
file.write(jet_str)
else:
break
if __name__ == '__main__':
root_dir = "/data/particle_raw/slow/data_100w/n2n2higgs"
preprocess(root_dir,method='str')

203
dataloader.py Normal file
View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :test.py
@Description: : test unit for ParticleData class
@Date :2024/05/14 16:19:01
@Author :lyzeng
@Email :pylyzeng@gmail.com
@version :1.0
'''
from pathlib import Path
from typing import List, Union
import json
from json import JSONDecodeError
from typing import List, Union
import attrs
import pickle
from tqdm import tqdm
import random
import shutil
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
@attrs.define
class Particle:
instruction: str
input: str
output: int
@attrs.define
class ParticleData:
file: Union[Path, str] = attrs.field(
converter=attrs.converters.optional(lambda v: Path(v))
)
data: List[Particle] = attrs.field(
converter=lambda data: data
)
max_length: int = attrs.field(init=False)
def __len__(self):
return len(self.data)
def __attrs_post_init__(self):
self.max_length = max([len(bytearray(d.input, 'utf-8')) for d in self.data])
@classmethod
def from_file(cls, file_path: Union[Path, str]):
# 从文件中读取序列化的数据
with open(file_path, 'rb') as f:
loaded_data = pickle.load(f)
return cls(file=file_path, data=loaded_data)
@attrs.define
class EpochLog:
epoch: int
train_acc: float
eval_acc: float
@classmethod
def from_log_lines(cls, log_lines):
epochs = []
epoch_data = {}
for line in log_lines:
if line.startswith('Epoch'):
epoch = int(line.split()[-1])
epoch_data = {'epoch': epoch}
epochs.append(epoch_data)
elif line.startswith('train_acc:'):
epoch_data['train_acc'] = float(line.split(':')[-1])
elif line.startswith('eval_acc:'):
epoch_data['eval_acc'] = float(line.split(':')[-1])
return [cls(**epoch) for epoch in epochs if 'train_acc' in epoch and 'eval_acc' in epoch]
@attrs.define
class AccuracyLogger:
epochs: List[EpochLog]
@classmethod
def from_log_file(cls, log_file: Path):
log_lines = log_file.read_text().splitlines()
epoch_logs = EpochLog.from_log_lines(log_lines)
return cls(epoch_logs)
def plot_acc_curve(self, save_path: Path = None):
epochs = [e.epoch for e in self.epochs]
train_accs = [e.train_acc for e in self.epochs]
eval_accs = [e.eval_acc for e in self.epochs]
plt.figure(figsize=(10, 6))
sns.lineplot(x=epochs, y=train_accs, label='Train Accuracy')
sns.lineplot(x=epochs, y=eval_accs, label='Eval Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Evaluation Accuracy over Epochs')
plt.legend()
plt.grid(True)
if save_path:
plt.savefig(save_path)
else:
plt.show()
def plot_acc_curve(log_file: Path, save_path: Path = Path('accuracy_curve.png')):
accuracy_logger = AccuracyLogger.from_log_file(log_file)
accuracy_logger.plot_acc_curve(save_path=save_path)
def split_train_eval_data(path: List[Path],
ext: str = 'txt', train_dir: str = 'train_data', eval_dir: str = 'eval_dir'):
# 创建目录
if not Path(train_dir).exists():
Path(train_dir).mkdir()
if not Path(eval_dir).exists():
Path(eval_dir).mkdir()
# 获取所有文件
all_files = (file for f in path for file in f.glob(f'*.{ext}'))
# 转换为列表以便于打乱和分割
all_files = list(all_files)
random.shuffle(all_files)
# 计算分割点
split_idx = int(len(all_files) * 0.8)
# 分割数据集
train_files = all_files[:split_idx]
eval_files = all_files[split_idx:]
# 复制文件到相应目录
for file in train_files:
shutil.copy(file, Path(train_dir) / file.name)
for file in eval_files:
shutil.copy(file, Path(eval_dir) / file.name)
if __name__ == '__main__':
all_file = Path('origin_data').glob('*.jsonl')
data = []
for file in all_file:
t_lines = file.read_text()[1:-1].splitlines()[1:]
for i, line in enumerate(t_lines):
if line[-1] == ',':
line = line[:-1]
particle_dict = json.loads(line)
data.append(Particle(**particle_dict))
with open('particle_data.pkl', 'wb') as f:
pickle.dump(data, f)
print(f"Total {len(data)} particles, train data account for {len(data)*0.8}, eval data account for {len(data)*0.2}")
with open('particle_data_train.pkl', 'wb') as f:
pickle.dump(data[:int(len(data)*0.8)], f)
with open('particle_data_eval.pkl', 'wb') as f:
pickle.dump(data[int(len(data)*0.2):], f)
all_files = ParticleData.from_file('/data/bgptf/particle_data.pkl')
train_files = ParticleData.from_file('/data/bgptf/particle_data_train.pkl')
eval_files = ParticleData.from_file('/data/bgptf/particle_data_eval.pkl')
test_files = ParticleData.from_file('/data/bgptf/particle_test.pkl')
p = Path('train_file_split')
if not p.exists():
p.mkdir()
for n, particle in tqdm(enumerate(train_files.data), total=len(train_files.data), desc="Writing files"):
with open(p.joinpath(f'class{particle.output}_{n}.txt'), 'w') as f:
f.write(particle.input)
# 定义源目录和目标目录
source_dir = Path('train_file_split')
train_dir = Path('bbbar_train_split')
eval_dir = Path('bbbar_eval_split')
# 创建目标目录
train_dir.mkdir(parents=True, exist_ok=True)
eval_dir.mkdir(parents=True, exist_ok=True)
# 获取所有文件路径
all_files = list(source_dir.glob('*.txt'))
# 设置划分比例,例如 80% 作为训练集20% 作为验证集
train_ratio = 0.8
# 计算训练集和验证集的文件数量
num_train_files = int(len(all_files) * train_ratio)
num_eval_files = len(all_files) - num_train_files
# 随机打乱文件列表
random.shuffle(all_files)
# 划分训练集和验证集
train_files = all_files[:num_train_files]
eval_files = all_files[num_train_files:]
# 复制文件到目标目录并显示进度条
print("正在复制训练集文件...")
for file in tqdm(train_files, desc="训练集进度"):
shutil.copy(file, train_dir / file.name)
print("正在复制验证集文件...")
for file in tqdm(eval_files, desc="验证集进度"):
shutil.copy(file, eval_dir / file.name)
print(f"训练集文件数: {len(train_files)}")
print(f"验证集文件数: {len(eval_files)}")
...

22
environment.yaml Normal file
View File

@@ -0,0 +1,22 @@
name: bgpt_data
channels:
- defaults
- conda-forge
dependencies:
- python=3.10
- click
- joblib
- pandas
- pip
- uproot
- attrs
- ipython
- ipykernel
- awkward-pandas
- importlib-metadata
- dill
- tqdm
- pytorch
- pyInstaller
pip:
- pyarmor

13
particle.py Normal file
View File

@@ -0,0 +1,13 @@
# particle.py
import enum
class ParticleType(enum.Enum):
NeutralHadron = 0
Photon = 1
Electron = 2
Muon = 3
Pion = 4
ChargedKaon = 5
Proton = 6
Others = 7

1
passwd.txt Normal file
View File

@@ -0,0 +1 @@
Xiaozhe+521

39
print.txt Normal file
View File

@@ -0,0 +1,39 @@
Building JetSet from data/full_bb.root
属性名称: particle_id, 类型: <class 'int'>, 值: 0
属性名称: part_charge, 类型: <class 'int'>, 值: [1.0, -1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
属性名称: part_energy, 类型: <class 'float'>, 值: [15.209561347961426, 13.140800476074219, 8.555957794189453, 6.927829742431641, 6.921298027038574, 6.718546390533447, 5.639996528625488, 2.734797239303589, 2.276695728302002, 2.0461738109588623, 1.7986037731170654, 1.388156771659851, 1.38234281539917, 1.0570266246795654, 0.8934926986694336, 0.8899771571159363, 0.879972517490387, 0.7406238317489624, 0.5272039175033569, 0.3686225712299347, 0.3535390794277191, 0.34719786047935486, 0.34384962916374207, 0.2649906873703003, 0.23999999463558197, 0.18479159474372864, 0.17317134141921997, 0.15941102802753448, 0.12423540651798248, 0.08461850136518478, 0.0843382179737091, 0.02829699218273163, 0.009839793667197227]
属性名称: part_px, 类型: <class 'float'>, 值: [-5.796960830688477, -5.311581611633301, -3.19583797454834, -2.326359510421753, -2.503192901611328, -2.6899032592773438, -1.925751805305481, -0.8692939281463623, -0.7592987418174744, -0.7812060713768005, -0.6791485548019409, -0.3433796167373657, -0.47946590185165405, -0.18832340836524963, -0.35665348172187805, -0.28360629081726074, -0.3264841139316559, -0.3620759844779968, -0.26313358545303345, -0.1359526515007019, -0.123598612844944, -0.06242864206433296, -0.13691917061805725, -0.08844801038503647, -0.05727635696530342, -0.08955521136522293, -0.12885883450508118, -0.10516541451215744, -0.03254406899213791, -0.015558036044239998, -0.02192043326795101, -0.008681708946824074, -0.0013048802502453327]
属性名称: part_py, 类型: <class 'float'>, 值: [-8.48299789428711, -6.931148052215576, -4.749238014221191, -3.9339864253997803, -4.476352691650391, -3.8203539848327637, -3.365067958831787, -1.7807756662368774, -1.501481056213379, -1.1750568151474, -0.9528936743736267, -0.7268454432487488, -0.833030104637146, 0.49239909648895264, -0.4624371826648712, -0.5192604064941406, -0.2649482488632202, -0.40544790029525757, -0.32859694957733154, -0.19921012222766876, -0.22155417501926422, 0.23689919710159302, -0.21035675704479218, -0.1498441994190216, -0.1331370621919632, -0.012309929355978966, 0.06338366121053696, -0.07766649127006531, -0.06364167481660843, 0.03760836273431778, -0.04934240132570267, -0.019126178696751595, 0.004166470840573311]
属性名称: part_pz, 类型: <class 'float'>, 值: [-11.213626861572266, -9.818737030029297, -6.358912944793701, -5.206402778625488, -4.645572185516357, -4.82585334777832, -4.096017360687256, -1.8796172142028809, -1.5337415933609009, -1.481818437576294, -1.3659160137176514, -1.1231404542922974, -0.9935013651847839, -0.916178286075592, -0.6762243509292603, -0.6648273468017578, -0.7730214595794678, -0.5030274987220764, -0.2853204607963562, -0.2787737548351288, -0.2462255209684372, -0.20299455523490906, -0.23502285778522491, -0.19985926151275635, -0.19129541516304016, -0.1611715406179428, -0.09677927196025848, -0.09121418744325638, -0.10161228477954865, -0.07418793439865112, -0.06479009985923767, -0.018961459398269653, -0.008818126283586025]
属性名称: log_energy, 类型: <class 'float'>, 值: [ 2.72192427 2.57572193 2.14662786 1.9355466 1.93460333 1.90487182
1.72988345 1.0060573 0.82272515 0.71597162 0.58701068 0.3279768
0.32377975 0.0554599 -0.11261712 -0.11655948 -0.1278646 -0.30026243
-0.64016787 -0.997982 -1.03976125 -1.05786046 -1.06755084 -1.3280606
-1.42711638 -1.6885266 -1.75347376 -1.83626933 -2.08557707 -2.46960234
-2.47292016 -3.56499976 -4.62132054]
属性名称: log_pt, 类型: <class 'float'>, 值: [ 2.32966824 2.16703303 1.744736 1.51959212 1.63485496 1.54165826
1.35509735 0.68391671 0.52031146 0.34432893 0.15713127 -0.21831242
-0.03961538 -0.64020631 -0.53786329 -0.52481977 -0.86639788 -0.60956524
-0.86519514 -1.42221173 -1.3716092 -1.40655069 -1.38233552 -1.74869444
-1.9314722 -2.40354098 -1.94069632 -2.03457679 -2.63833865 -3.20154184
-2.91891223 -3.86302568 -5.43390179]
属性名称: part_deta, 类型: <class 'float'>, 值: [-0.012650121003389359, 0.009468508884310722, 0.00042291259160265326, 0.01923862285912037, -0.14403046667575836, -0.05278221517801285, -0.036465417593717575, -0.11269791424274445, -0.1397673487663269, -0.040799811482429504, 0.03767024725675583, 0.17913000285625458, -0.05223162844777107, 0.362665593624115, 0.03155886009335518, 0.008969753980636597, 0.4117862284183502, -0.1295829564332962, -0.3228398263454437, 0.03021526150405407, -0.09680869430303574, -0.20235909521579742, -0.12153740972280502, 0.025453981012105942, 0.13331031799316406, 0.38491809368133545, -0.3260197937488556, -0.30641964077949524, 0.19324439764022827, 0.40427282452583313, 0.05876421928405762, -0.14629797637462616, 0.4952174425125122]
属性名称: part_dphi, 类型: <class 'float'>, 值: [-0.004700591322034597, -0.05910219997167587, 0.0024550738744437695, 0.060737334191799164, 0.0848897248506546, -0.018699318170547485, 0.07498598098754883, 0.1406451165676117, 0.1265745460987091, 0.008045745082199574, -0.02444184571504593, 0.15342673659324646, 0.07251019775867462, -2.181525468826294, -0.062189724296331406, 0.09487095475196838, -0.2942989766597748, -0.13417768478393555, -0.08044422417879105, -0.004084283020347357, 0.08590564131736755, -2.2891547679901123, 0.01777080073952675, 0.06153986230492592, 0.18849976360797882, -0.8394244313240051, -1.4331588745117188, -0.3399130403995514, 0.12207411229610443, -2.1545727252960205, 0.17670847475528717, 0.1686646193265915, -2.2433114051818848]
属性名称: part_logptrel, 类型: <class 'float'>, 值: -1.6656613686933
属性名称: part_logerel, 类型: <class 'float'>, 值: -1.6908250129173599
属性名称: part_deltaR, 类型: <class 'float'>, 值: 0.013495225829054495
属性名称: part_d0, 类型: <class 'float'>, 值: [-0.08485882 0.08362859 0. 0. -0.12567137 0.13890936
0. 0.05528848 0. 0. 0. -0.08717197
0. 0. 0. 0. 0. 0.
0.07979553 0. 0. 0.06790032 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. ]
属性名称: part_dz, 类型: <class 'float'>, 值: [ 0.16160717 0.07151473 0. 0. -0.16053929 0.09658169
0. 0.06517879 0. 0. 0. 0.11920792
0. 0. 0. 0. 0. 0.
-0.13735584 0. 0. 0.12359677 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. ]
属性名称: particle_type, 类型: <enum 'ParticleType'>, 值: ParticleType.Electron
属性名称: particle_pid, 类型: <class 'int'>, 值: 2
属性名称: jet_type, 类型: <class 'str'>, 值: b jet

397
test.ipynb Normal file
View File

@@ -0,0 +1,397 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"!rm -rf ./save/*"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building JetSet from data/data_full/full_bbbar.root\n",
"jet type: bbbar\n",
"Building JetSet from data/data_full/full_bb.root\n",
"jet type: bb\n",
"100%|███████████████████████████████████| 24921/24921 [00:06<00:00, 3850.86it/s]\n",
"100%|███████████████████████████████████| 24930/24930 [00:06<00:00, 3754.57it/s]\n"
]
}
],
"source": [
"from pathlib import Path\n",
"if not Path('save/save_full').exists():\n",
" Path('save/save_full').mkdir()\n",
"!python convert_root_to_txt.py ./data/data_full ./save/save_full --data-type full --attr-sep '|' --part-sep ';' --selected-attrs 'part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_pid'\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building JetSet from data/data_fast/fast_bbar.root\n",
"jet type: bb\n",
"Building JetSet from data/data_fast/fast_bb.root\n",
"jet type: bb\n",
"joblib.externals.loky.process_executor._RemoteTraceback: \n",
"\"\"\"\n",
"Traceback (most recent call last):\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py\", line 463, in _process_worker\n",
" r = call_item()\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py\", line 291, in __call__\n",
" return self.fn(*self.args, **self.kwargs)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 598, in __call__\n",
" return [func(*args, **kwargs)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 598, in <listcomp>\n",
" return [func(*args, **kwargs)\n",
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 271, in process_file\n",
" jet_set = JetSet.build_jetset_fast(root_file)\n",
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 102, in build_jetset_fast\n",
" for i, j in a.iterrows():\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/frame.py\", line 1541, in iterrows\n",
" for k, v in zip(self.index, self.values):\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/frame.py\", line 12651, in values\n",
" return self._mgr.as_array()\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/internals/managers.py\", line 1692, in as_array\n",
" arr = self._interleave(dtype=dtype, na_value=na_value)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/internals/managers.py\", line 1733, in _interleave\n",
" arr = blk.get_values(dtype)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/pandas/core/internals/blocks.py\", line 2253, in get_values\n",
" return np.asarray(values).reshape(self.shape)\n",
"ValueError: cannot reshape array of size 0 into shape (1,10000)\n",
"\"\"\"\n",
"\n",
"The above exception was the direct cause of the following exception:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 275, in <module>\n",
" main()\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 1157, in __call__\n",
" return self.main(*args, **kwargs)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 1078, in main\n",
" rv = self.invoke(ctx)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 1434, in invoke\n",
" return ctx.invoke(self.callback, **ctx.params)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/click/core.py\", line 783, in invoke\n",
" return __callback(*args, **kwargs)\n",
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 257, in main\n",
" preprocess(root_dir, save_dir, data_type, selected_attrs=selected_attrs, attr_sep=attr_sep, part_sep=part_sep)\n",
" File \"/home/lingyuzeng/project/bbbar_data_struct/convert_root_to_txt.py\", line 262, in preprocess\n",
" joblib.Parallel(n_jobs=-1)(\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 2007, in __call__\n",
" return output if self.return_generator else list(output)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 1650, in _get_outputs\n",
" yield from self._retrieve()\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 1754, in _retrieve\n",
" self._raise_error_fast()\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 1789, in _raise_error_fast\n",
" error_job.get_result(self.timeout)\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 745, in get_result\n",
" return self._return_or_raise()\n",
" File \"/home/lingyuzeng/micromamba/envs/bgpt_data/lib/python3.10/site-packages/joblib/parallel.py\", line 763, in _return_or_raise\n",
" raise self._result\n",
"ValueError: cannot reshape array of size 0 into shape (1,10000)\n"
]
}
],
"source": [
"from pathlib import Path\n",
"if not Path('save/save_fast').exists():\n",
" Path('save/save_fast').mkdir()\n",
"!python convert_root_to_txt.py ./data/data_fast ./save/save_fast --data-type fast --attr-sep '|' --part-sep ';' --selected-attrs 'part_charge,part_energy,part_px,part_py,part_pz,log_energy,log_pt,part_deta,part_dphi,part_logptrel,part_logerel,part_deltaR,part_d0,part_dz,particle_pid'\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from convert_root_to_txt import *\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building JetSet from data/full_bb.root\n"
]
}
],
"source": [
"jet_set = JetSet.build_jetset(Path('data/full_bb.root'))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'bb'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jet_set.jets_type"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building JetSet from data/full_bb.root\n",
"属性名称: particle_id, 类型: <class 'int'>, 值: 0\n",
"属性名称: part_charge, 类型: <class 'int'>, 值: [1.0, -1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
"属性名称: part_energy, 类型: <class 'float'>, 值: [15.209561347961426, 13.140800476074219, 8.555957794189453, 6.927829742431641, 6.921298027038574, 6.718546390533447, 5.639996528625488, 2.734797239303589, 2.276695728302002, 2.0461738109588623, 1.7986037731170654, 1.388156771659851, 1.38234281539917, 1.0570266246795654, 0.8934926986694336, 0.8899771571159363, 0.879972517490387, 0.7406238317489624, 0.5272039175033569, 0.3686225712299347, 0.3535390794277191, 0.34719786047935486, 0.34384962916374207, 0.2649906873703003, 0.23999999463558197, 0.18479159474372864, 0.17317134141921997, 0.15941102802753448, 0.12423540651798248, 0.08461850136518478, 0.0843382179737091, 0.02829699218273163, 0.009839793667197227]\n",
"属性名称: part_px, 类型: <class 'float'>, 值: [-5.796960830688477, -5.311581611633301, -3.19583797454834, -2.326359510421753, -2.503192901611328, -2.6899032592773438, -1.925751805305481, -0.8692939281463623, -0.7592987418174744, -0.7812060713768005, -0.6791485548019409, -0.3433796167373657, -0.47946590185165405, -0.18832340836524963, -0.35665348172187805, -0.28360629081726074, -0.3264841139316559, -0.3620759844779968, -0.26313358545303345, -0.1359526515007019, -0.123598612844944, -0.06242864206433296, -0.13691917061805725, -0.08844801038503647, -0.05727635696530342, -0.08955521136522293, -0.12885883450508118, -0.10516541451215744, -0.03254406899213791, -0.015558036044239998, -0.02192043326795101, -0.008681708946824074, -0.0013048802502453327]\n",
"属性名称: part_py, 类型: <class 'float'>, 值: [-8.48299789428711, -6.931148052215576, -4.749238014221191, -3.9339864253997803, -4.476352691650391, -3.8203539848327637, -3.365067958831787, -1.7807756662368774, -1.501481056213379, -1.1750568151474, -0.9528936743736267, -0.7268454432487488, -0.833030104637146, 0.49239909648895264, -0.4624371826648712, -0.5192604064941406, -0.2649482488632202, -0.40544790029525757, -0.32859694957733154, -0.19921012222766876, -0.22155417501926422, 0.23689919710159302, -0.21035675704479218, -0.1498441994190216, -0.1331370621919632, -0.012309929355978966, 0.06338366121053696, -0.07766649127006531, -0.06364167481660843, 0.03760836273431778, -0.04934240132570267, -0.019126178696751595, 0.004166470840573311]\n",
"属性名称: part_pz, 类型: <class 'float'>, 值: [-11.213626861572266, -9.818737030029297, -6.358912944793701, -5.206402778625488, -4.645572185516357, -4.82585334777832, -4.096017360687256, -1.8796172142028809, -1.5337415933609009, -1.481818437576294, -1.3659160137176514, -1.1231404542922974, -0.9935013651847839, -0.916178286075592, -0.6762243509292603, -0.6648273468017578, -0.7730214595794678, -0.5030274987220764, -0.2853204607963562, -0.2787737548351288, -0.2462255209684372, -0.20299455523490906, -0.23502285778522491, -0.19985926151275635, -0.19129541516304016, -0.1611715406179428, -0.09677927196025848, -0.09121418744325638, -0.10161228477954865, -0.07418793439865112, -0.06479009985923767, -0.018961459398269653, -0.008818126283586025]\n",
"属性名称: log_energy, 类型: <class 'float'>, 值: [ 2.72192427 2.57572193 2.14662786 1.9355466 1.93460333 1.90487182\n",
" 1.72988345 1.0060573 0.82272515 0.71597162 0.58701068 0.3279768\n",
" 0.32377975 0.0554599 -0.11261712 -0.11655948 -0.1278646 -0.30026243\n",
" -0.64016787 -0.997982 -1.03976125 -1.05786046 -1.06755084 -1.3280606\n",
" -1.42711638 -1.6885266 -1.75347376 -1.83626933 -2.08557707 -2.46960234\n",
" -2.47292016 -3.56499976 -4.62132054]\n",
"属性名称: log_pt, 类型: <class 'float'>, 值: [ 2.32966824 2.16703303 1.744736 1.51959212 1.63485496 1.54165826\n",
" 1.35509735 0.68391671 0.52031146 0.34432893 0.15713127 -0.21831242\n",
" -0.03961538 -0.64020631 -0.53786329 -0.52481977 -0.86639788 -0.60956524\n",
" -0.86519514 -1.42221173 -1.3716092 -1.40655069 -1.38233552 -1.74869444\n",
" -1.9314722 -2.40354098 -1.94069632 -2.03457679 -2.63833865 -3.20154184\n",
" -2.91891223 -3.86302568 -5.43390179]\n",
"属性名称: part_deta, 类型: <class 'float'>, 值: [-0.012650121003389359, 0.009468508884310722, 0.00042291259160265326, 0.01923862285912037, -0.14403046667575836, -0.05278221517801285, -0.036465417593717575, -0.11269791424274445, -0.1397673487663269, -0.040799811482429504, 0.03767024725675583, 0.17913000285625458, -0.05223162844777107, 0.362665593624115, 0.03155886009335518, 0.008969753980636597, 0.4117862284183502, -0.1295829564332962, -0.3228398263454437, 0.03021526150405407, -0.09680869430303574, -0.20235909521579742, -0.12153740972280502, 0.025453981012105942, 0.13331031799316406, 0.38491809368133545, -0.3260197937488556, -0.30641964077949524, 0.19324439764022827, 0.40427282452583313, 0.05876421928405762, -0.14629797637462616, 0.4952174425125122]\n",
"属性名称: part_dphi, 类型: <class 'float'>, 值: [-0.004700591322034597, -0.05910219997167587, 0.0024550738744437695, 0.060737334191799164, 0.0848897248506546, -0.018699318170547485, 0.07498598098754883, 0.1406451165676117, 0.1265745460987091, 0.008045745082199574, -0.02444184571504593, 0.15342673659324646, 0.07251019775867462, -2.181525468826294, -0.062189724296331406, 0.09487095475196838, -0.2942989766597748, -0.13417768478393555, -0.08044422417879105, -0.004084283020347357, 0.08590564131736755, -2.2891547679901123, 0.01777080073952675, 0.06153986230492592, 0.18849976360797882, -0.8394244313240051, -1.4331588745117188, -0.3399130403995514, 0.12207411229610443, -2.1545727252960205, 0.17670847475528717, 0.1686646193265915, -2.2433114051818848]\n",
"属性名称: part_logptrel, 类型: <class 'float'>, 值: -1.6656613686933\n",
"属性名称: part_logerel, 类型: <class 'float'>, 值: -1.6908250129173599\n",
"属性名称: part_deltaR, 类型: <class 'float'>, 值: 0.013495225829054495\n",
"属性名称: part_d0, 类型: <class 'float'>, 值: [-0.08485882 0.08362859 0. 0. -0.12567137 0.13890936\n",
" 0. 0.05528848 0. 0. 0. -0.08717197\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0.07979553 0. 0. 0.06790032 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. ]\n",
"属性名称: part_dz, 类型: <class 'float'>, 值: [ 0.16160717 0.07151473 0. 0. -0.16053929 0.09658169\n",
" 0. 0.06517879 0. 0. 0. 0.11920792\n",
" 0. 0. 0. 0. 0. 0.\n",
" -0.13735584 0. 0. 0.12359677 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. ]\n",
"属性名称: particle_type, 类型: <enum 'ParticleType'>, 值: ParticleType.Electron\n",
"属性名称: particle_pid, 类型: <class 'int'>, 值: 2\n",
"属性名称: jet_type, 类型: <class 'str'>, 值: b jet\n"
]
}
],
"source": [
"from convert_root_to_txt import *\n",
"from pathlib import Path\n",
"\n",
"# 构建 jet_set\n",
"jet_set = JetSet.build_jetset(Path('data/full_bb.root'))\n",
"\n",
"# 获取第一个 jet 的第一个 particle\n",
"particle = jet_set.jets[0].particles[0]\n",
"\n",
"# 打印属性名称、类型和值\n",
"for field in attrs.fields(particle.__class__):\n",
" field_name = field.name\n",
" field_type = field.type\n",
" field_value = getattr(particle, field_name)\n",
" print(f\"属性名称: {field_name}, 类型: {field_type}, 值: {field_value}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"attrs_list length: 17\n",
"Processing data/full_bbbar.root\n",
"Building JetSet from data/full_bbbar.root\n",
"Processing data/full_bb.root\n",
"Building JetSet from data/full_bb.root\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 24921/24921 [07:08<00:00, 58.13it/s]\n",
"100%|██████████| 24930/24930 [07:20<00:00, 56.56it/s]\n"
]
}
],
"source": [
"from n import *\n",
"root_dir = \"./data\"\n",
"save_dir = \"./save\"\n",
"attrs_list = ['particle_id', 'part_charge', 'part_energy', 'part_px', 'part_py', 'part_pz', 'log_energy' , 'log_pt', 'part_deta', 'part_dphi', 'part_logptrel', 'part_logerel', 'part_deltaR', 'part_d0', 'part_dz' , 'particle_type', 'particle_pid']\n",
"print(f'attrs_list length: {len(attrs_list)}')\n",
"preprocess(root_dir, save_dir, selected_attrs=attrs_list, attr_sep=',', part_sep='|')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bb file length: 24921\n",
"bbbar file length: 24930\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"print(f\"bb file length: {len(list(Path('./save').glob('bb_*.bin')))}\")\n",
"print(f\"bbbar file length: {len(list(Path('./save').glob('bbbar_*.bin')))}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('save/bb_0.bin', 'rb') as f:\n",
" res = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"str"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"198017"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(res)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import os\n",
"\n",
"def load_pickled_files(directory):\n",
" files = [f for f in os.listdir(directory) if f.endswith('.bin')]\n",
" data = []\n",
" \n",
" for file in files:\n",
" filepath = os.path.join(directory, file)\n",
" with open(filepath, 'rb') as f:\n",
" data.append(pickle.load(f))\n",
" \n",
" return data\n",
"\n",
"if __name__ == '__main__':\n",
" save_dir = \"./save\"\n",
" data = load_pickled_files(save_dir)\n",
" for i, jet in enumerate(data):\n",
" print(f\"Data from file {i}:\")\n",
" print(jet)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

17
test.log Normal file
View File

@@ -0,0 +1,17 @@
## full test log
Building JetSet from data/data_full/full_bbbar.root
jet type: bbbar
Building JetSet from data/data_full/full_bb.root
jet type: bb
100%|██████████████████████████████████████████████████████████████████████████████| 24930/24930 [00:06<00:00, 3905.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 24921/24921 [00:06<00:00, 3824.30it/s]
## fast test log
Building JetSet from data/data_fast/fast_bbar.root
jet type: bb
Building JetSet from data/data_fast/fast_bb.root
jet type: bb
100%|██████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 8766.04it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 8607.16it/s]

14
utils.py Normal file

File diff suppressed because one or more lines are too long