Files
particle_analyse/dataloader.py
2024-06-10 09:09:10 +08:00

203 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :test.py
@Description: : test unit for ParticleData class
@Date :2024/05/14 16:19:01
@Author :lyzeng
@Email :pylyzeng@gmail.com
@version :1.0
'''
from pathlib import Path
from typing import List, Union
import json
from json import JSONDecodeError
from typing import List, Union
import attrs
import pickle
from tqdm import tqdm
import random
import shutil
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
@attrs.define
class Particle:
instruction: str
input: str
output: int
@attrs.define
class ParticleData:
file: Union[Path, str] = attrs.field(
converter=attrs.converters.optional(lambda v: Path(v))
)
data: List[Particle] = attrs.field(
converter=lambda data: data
)
max_length: int = attrs.field(init=False)
def __len__(self):
return len(self.data)
def __attrs_post_init__(self):
self.max_length = max([len(bytearray(d.input, 'utf-8')) for d in self.data])
@classmethod
def from_file(cls, file_path: Union[Path, str]):
# 从文件中读取序列化的数据
with open(file_path, 'rb') as f:
loaded_data = pickle.load(f)
return cls(file=file_path, data=loaded_data)
@attrs.define
class EpochLog:
epoch: int
train_acc: float
eval_acc: float
@classmethod
def from_log_lines(cls, log_lines):
epochs = []
epoch_data = {}
for line in log_lines:
if line.startswith('Epoch'):
epoch = int(line.split()[-1])
epoch_data = {'epoch': epoch}
epochs.append(epoch_data)
elif line.startswith('train_acc:'):
epoch_data['train_acc'] = float(line.split(':')[-1])
elif line.startswith('eval_acc:'):
epoch_data['eval_acc'] = float(line.split(':')[-1])
return [cls(**epoch) for epoch in epochs if 'train_acc' in epoch and 'eval_acc' in epoch]
@attrs.define
class AccuracyLogger:
epochs: List[EpochLog]
@classmethod
def from_log_file(cls, log_file: Path):
log_lines = log_file.read_text().splitlines()
epoch_logs = EpochLog.from_log_lines(log_lines)
return cls(epoch_logs)
def plot_acc_curve(self, save_path: Path = None):
epochs = [e.epoch for e in self.epochs]
train_accs = [e.train_acc for e in self.epochs]
eval_accs = [e.eval_acc for e in self.epochs]
plt.figure(figsize=(10, 6))
sns.lineplot(x=epochs, y=train_accs, label='Train Accuracy')
sns.lineplot(x=epochs, y=eval_accs, label='Eval Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Evaluation Accuracy over Epochs')
plt.legend()
plt.grid(True)
if save_path:
plt.savefig(save_path)
else:
plt.show()
def plot_acc_curve(log_file: Path, save_path: Path = Path('accuracy_curve.png')):
accuracy_logger = AccuracyLogger.from_log_file(log_file)
accuracy_logger.plot_acc_curve(save_path=save_path)
def split_train_eval_data(path: List[Path],
ext: str = 'txt', train_dir: str = 'train_data', eval_dir: str = 'eval_dir'):
# 创建目录
if not Path(train_dir).exists():
Path(train_dir).mkdir()
if not Path(eval_dir).exists():
Path(eval_dir).mkdir()
# 获取所有文件
all_files = (file for f in path for file in f.glob(f'*.{ext}'))
# 转换为列表以便于打乱和分割
all_files = list(all_files)
random.shuffle(all_files)
# 计算分割点
split_idx = int(len(all_files) * 0.8)
# 分割数据集
train_files = all_files[:split_idx]
eval_files = all_files[split_idx:]
# 复制文件到相应目录
for file in train_files:
shutil.copy(file, Path(train_dir) / file.name)
for file in eval_files:
shutil.copy(file, Path(eval_dir) / file.name)
if __name__ == '__main__':
all_file = Path('origin_data').glob('*.jsonl')
data = []
for file in all_file:
t_lines = file.read_text()[1:-1].splitlines()[1:]
for i, line in enumerate(t_lines):
if line[-1] == ',':
line = line[:-1]
particle_dict = json.loads(line)
data.append(Particle(**particle_dict))
with open('particle_data.pkl', 'wb') as f:
pickle.dump(data, f)
print(f"Total {len(data)} particles, train data account for {len(data)*0.8}, eval data account for {len(data)*0.2}")
with open('particle_data_train.pkl', 'wb') as f:
pickle.dump(data[:int(len(data)*0.8)], f)
with open('particle_data_eval.pkl', 'wb') as f:
pickle.dump(data[int(len(data)*0.2):], f)
all_files = ParticleData.from_file('/data/bgptf/particle_data.pkl')
train_files = ParticleData.from_file('/data/bgptf/particle_data_train.pkl')
eval_files = ParticleData.from_file('/data/bgptf/particle_data_eval.pkl')
test_files = ParticleData.from_file('/data/bgptf/particle_test.pkl')
p = Path('train_file_split')
if not p.exists():
p.mkdir()
for n, particle in tqdm(enumerate(train_files.data), total=len(train_files.data), desc="Writing files"):
with open(p.joinpath(f'class{particle.output}_{n}.txt'), 'w') as f:
f.write(particle.input)
# 定义源目录和目标目录
source_dir = Path('train_file_split')
train_dir = Path('bbbar_train_split')
eval_dir = Path('bbbar_eval_split')
# 创建目标目录
train_dir.mkdir(parents=True, exist_ok=True)
eval_dir.mkdir(parents=True, exist_ok=True)
# 获取所有文件路径
all_files = list(source_dir.glob('*.txt'))
# 设置划分比例,例如 80% 作为训练集20% 作为验证集
train_ratio = 0.8
# 计算训练集和验证集的文件数量
num_train_files = int(len(all_files) * train_ratio)
num_eval_files = len(all_files) - num_train_files
# 随机打乱文件列表
random.shuffle(all_files)
# 划分训练集和验证集
train_files = all_files[:num_train_files]
eval_files = all_files[num_train_files:]
# 复制文件到目标目录并显示进度条
print("正在复制训练集文件...")
for file in tqdm(train_files, desc="训练集进度"):
shutil.copy(file, train_dir / file.name)
print("正在复制验证集文件...")
for file in tqdm(eval_files, desc="验证集进度"):
shutil.copy(file, eval_dir / file.name)
print(f"训练集文件数: {len(train_files)}")
print(f"验证集文件数: {len(eval_files)}")
...