#!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @file :test.py @Description: : test unit for ParticleData class @Date :2024/05/14 16:19:01 @Author :lyzeng @Email :pylyzeng@gmail.com @version :1.0 ''' from pathlib import Path from typing import List, Union import json from json import JSONDecodeError from typing import List, Union import attrs import pickle from tqdm import tqdm import random import shutil from pathlib import Path import seaborn as sns import matplotlib.pyplot as plt @attrs.define class Particle: instruction: str input: str output: int @attrs.define class ParticleData: file: Union[Path, str] = attrs.field( converter=attrs.converters.optional(lambda v: Path(v)) ) data: List[Particle] = attrs.field( converter=lambda data: data ) max_length: int = attrs.field(init=False) def __len__(self): return len(self.data) def __attrs_post_init__(self): self.max_length = max([len(bytearray(d.input, 'utf-8')) for d in self.data]) @classmethod def from_file(cls, file_path: Union[Path, str]): # 从文件中读取序列化的数据 with open(file_path, 'rb') as f: loaded_data = pickle.load(f) return cls(file=file_path, data=loaded_data) @attrs.define class EpochLog: epoch: int train_acc: float eval_acc: float @classmethod def from_log_lines(cls, log_lines): epochs = [] epoch_data = {} for line in log_lines: if line.startswith('Epoch'): epoch = int(line.split()[-1]) epoch_data = {'epoch': epoch} epochs.append(epoch_data) elif line.startswith('train_acc:'): epoch_data['train_acc'] = float(line.split(':')[-1]) elif line.startswith('eval_acc:'): epoch_data['eval_acc'] = float(line.split(':')[-1]) return [cls(**epoch) for epoch in epochs if 'train_acc' in epoch and 'eval_acc' in epoch] @attrs.define class AccuracyLogger: epochs: List[EpochLog] @classmethod def from_log_file(cls, log_file: Path): log_lines = log_file.read_text().splitlines() epoch_logs = EpochLog.from_log_lines(log_lines) return cls(epoch_logs) def plot_acc_curve(self, save_path: Path = None): epochs = [e.epoch for e in self.epochs] train_accs = [e.train_acc for e in self.epochs] eval_accs = [e.eval_acc for e in self.epochs] plt.figure(figsize=(10, 6)) sns.lineplot(x=epochs, y=train_accs, label='Train Accuracy') sns.lineplot(x=epochs, y=eval_accs, label='Eval Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.title('Training and Evaluation Accuracy over Epochs') plt.legend() plt.grid(True) if save_path: plt.savefig(save_path) else: plt.show() def plot_acc_curve(log_file: Path, save_path: Path = Path('accuracy_curve.png')): accuracy_logger = AccuracyLogger.from_log_file(log_file) accuracy_logger.plot_acc_curve(save_path=save_path) def split_train_eval_data(path: List[Path], ext: str = 'txt', train_dir: str = 'train_data', eval_dir: str = 'eval_dir'): # 创建目录 if not Path(train_dir).exists(): Path(train_dir).mkdir() if not Path(eval_dir).exists(): Path(eval_dir).mkdir() # 获取所有文件 all_files = (file for f in path for file in f.glob(f'*.{ext}')) # 转换为列表以便于打乱和分割 all_files = list(all_files) random.shuffle(all_files) # 计算分割点 split_idx = int(len(all_files) * 0.8) # 分割数据集 train_files = all_files[:split_idx] eval_files = all_files[split_idx:] # 复制文件到相应目录 for file in train_files: shutil.copy(file, Path(train_dir) / file.name) for file in eval_files: shutil.copy(file, Path(eval_dir) / file.name) if __name__ == '__main__': all_file = Path('origin_data').glob('*.jsonl') data = [] for file in all_file: t_lines = file.read_text()[1:-1].splitlines()[1:] for i, line in enumerate(t_lines): if line[-1] == ',': line = line[:-1] particle_dict = json.loads(line) data.append(Particle(**particle_dict)) with open('particle_data.pkl', 'wb') as f: pickle.dump(data, f) print(f"Total {len(data)} particles, train data account for {len(data)*0.8}, eval data account for {len(data)*0.2}") with open('particle_data_train.pkl', 'wb') as f: pickle.dump(data[:int(len(data)*0.8)], f) with open('particle_data_eval.pkl', 'wb') as f: pickle.dump(data[int(len(data)*0.2):], f) all_files = ParticleData.from_file('/data/bgptf/particle_data.pkl') train_files = ParticleData.from_file('/data/bgptf/particle_data_train.pkl') eval_files = ParticleData.from_file('/data/bgptf/particle_data_eval.pkl') test_files = ParticleData.from_file('/data/bgptf/particle_test.pkl') p = Path('train_file_split') if not p.exists(): p.mkdir() for n, particle in tqdm(enumerate(train_files.data), total=len(train_files.data), desc="Writing files"): with open(p.joinpath(f'class{particle.output}_{n}.txt'), 'w') as f: f.write(particle.input) # 定义源目录和目标目录 source_dir = Path('train_file_split') train_dir = Path('bbbar_train_split') eval_dir = Path('bbbar_eval_split') # 创建目标目录 train_dir.mkdir(parents=True, exist_ok=True) eval_dir.mkdir(parents=True, exist_ok=True) # 获取所有文件路径 all_files = list(source_dir.glob('*.txt')) # 设置划分比例,例如 80% 作为训练集,20% 作为验证集 train_ratio = 0.8 # 计算训练集和验证集的文件数量 num_train_files = int(len(all_files) * train_ratio) num_eval_files = len(all_files) - num_train_files # 随机打乱文件列表 random.shuffle(all_files) # 划分训练集和验证集 train_files = all_files[:num_train_files] eval_files = all_files[num_train_files:] # 复制文件到目标目录并显示进度条 print("正在复制训练集文件...") for file in tqdm(train_files, desc="训练集进度"): shutil.copy(file, train_dir / file.name) print("正在复制验证集文件...") for file in tqdm(eval_files, desc="验证集进度"): shutil.copy(file, eval_dir / file.name) print(f"训练集文件数: {len(train_files)}") print(f"验证集文件数: {len(eval_files)}") ...