203 lines
6.6 KiB
Python
203 lines
6.6 KiB
Python
#!/usr/bin/env python
|
||
# -*- encoding: utf-8 -*-
|
||
'''
|
||
@file :test.py
|
||
@Description: : test unit for ParticleData class
|
||
@Date :2024/05/14 16:19:01
|
||
@Author :lyzeng
|
||
@Email :pylyzeng@gmail.com
|
||
@version :1.0
|
||
'''
|
||
from pathlib import Path
|
||
from typing import List, Union
|
||
import json
|
||
from json import JSONDecodeError
|
||
from typing import List, Union
|
||
import attrs
|
||
import pickle
|
||
from tqdm import tqdm
|
||
import random
|
||
import shutil
|
||
from pathlib import Path
|
||
import seaborn as sns
|
||
import matplotlib.pyplot as plt
|
||
|
||
@attrs.define
|
||
class Particle:
|
||
instruction: str
|
||
input: str
|
||
output: int
|
||
|
||
@attrs.define
|
||
class ParticleData:
|
||
file: Union[Path, str] = attrs.field(
|
||
converter=attrs.converters.optional(lambda v: Path(v))
|
||
)
|
||
data: List[Particle] = attrs.field(
|
||
converter=lambda data: data
|
||
)
|
||
max_length: int = attrs.field(init=False)
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __attrs_post_init__(self):
|
||
self.max_length = max([len(bytearray(d.input, 'utf-8')) for d in self.data])
|
||
|
||
@classmethod
|
||
def from_file(cls, file_path: Union[Path, str]):
|
||
# 从文件中读取序列化的数据
|
||
with open(file_path, 'rb') as f:
|
||
loaded_data = pickle.load(f)
|
||
return cls(file=file_path, data=loaded_data)
|
||
|
||
|
||
@attrs.define
|
||
class EpochLog:
|
||
epoch: int
|
||
train_acc: float
|
||
eval_acc: float
|
||
|
||
@classmethod
|
||
def from_log_lines(cls, log_lines):
|
||
epochs = []
|
||
epoch_data = {}
|
||
for line in log_lines:
|
||
if line.startswith('Epoch'):
|
||
epoch = int(line.split()[-1])
|
||
epoch_data = {'epoch': epoch}
|
||
epochs.append(epoch_data)
|
||
elif line.startswith('train_acc:'):
|
||
epoch_data['train_acc'] = float(line.split(':')[-1])
|
||
elif line.startswith('eval_acc:'):
|
||
epoch_data['eval_acc'] = float(line.split(':')[-1])
|
||
return [cls(**epoch) for epoch in epochs if 'train_acc' in epoch and 'eval_acc' in epoch]
|
||
|
||
@attrs.define
|
||
class AccuracyLogger:
|
||
epochs: List[EpochLog]
|
||
|
||
@classmethod
|
||
def from_log_file(cls, log_file: Path):
|
||
log_lines = log_file.read_text().splitlines()
|
||
epoch_logs = EpochLog.from_log_lines(log_lines)
|
||
return cls(epoch_logs)
|
||
|
||
def plot_acc_curve(self, save_path: Path = None):
|
||
epochs = [e.epoch for e in self.epochs]
|
||
train_accs = [e.train_acc for e in self.epochs]
|
||
eval_accs = [e.eval_acc for e in self.epochs]
|
||
|
||
plt.figure(figsize=(10, 6))
|
||
sns.lineplot(x=epochs, y=train_accs, label='Train Accuracy')
|
||
sns.lineplot(x=epochs, y=eval_accs, label='Eval Accuracy')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel('Accuracy')
|
||
plt.title('Training and Evaluation Accuracy over Epochs')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
else:
|
||
plt.show()
|
||
|
||
def plot_acc_curve(log_file: Path, save_path: Path = Path('accuracy_curve.png')):
|
||
accuracy_logger = AccuracyLogger.from_log_file(log_file)
|
||
accuracy_logger.plot_acc_curve(save_path=save_path)
|
||
|
||
def split_train_eval_data(path: List[Path],
|
||
ext: str = 'txt', train_dir: str = 'train_data', eval_dir: str = 'eval_dir'):
|
||
# 创建目录
|
||
if not Path(train_dir).exists():
|
||
Path(train_dir).mkdir()
|
||
|
||
if not Path(eval_dir).exists():
|
||
Path(eval_dir).mkdir()
|
||
# 获取所有文件
|
||
all_files = (file for f in path for file in f.glob(f'*.{ext}'))
|
||
# 转换为列表以便于打乱和分割
|
||
all_files = list(all_files)
|
||
random.shuffle(all_files)
|
||
# 计算分割点
|
||
split_idx = int(len(all_files) * 0.8)
|
||
# 分割数据集
|
||
train_files = all_files[:split_idx]
|
||
eval_files = all_files[split_idx:]
|
||
# 复制文件到相应目录
|
||
for file in train_files:
|
||
shutil.copy(file, Path(train_dir) / file.name)
|
||
for file in eval_files:
|
||
shutil.copy(file, Path(eval_dir) / file.name)
|
||
|
||
if __name__ == '__main__':
|
||
all_file = Path('origin_data').glob('*.jsonl')
|
||
data = []
|
||
for file in all_file:
|
||
t_lines = file.read_text()[1:-1].splitlines()[1:]
|
||
for i, line in enumerate(t_lines):
|
||
if line[-1] == ',':
|
||
line = line[:-1]
|
||
particle_dict = json.loads(line)
|
||
data.append(Particle(**particle_dict))
|
||
|
||
with open('particle_data.pkl', 'wb') as f:
|
||
pickle.dump(data, f)
|
||
|
||
print(f"Total {len(data)} particles, train data account for {len(data)*0.8}, eval data account for {len(data)*0.2}")
|
||
|
||
with open('particle_data_train.pkl', 'wb') as f:
|
||
pickle.dump(data[:int(len(data)*0.8)], f)
|
||
|
||
with open('particle_data_eval.pkl', 'wb') as f:
|
||
pickle.dump(data[int(len(data)*0.2):], f)
|
||
|
||
all_files = ParticleData.from_file('/data/bgptf/particle_data.pkl')
|
||
train_files = ParticleData.from_file('/data/bgptf/particle_data_train.pkl')
|
||
eval_files = ParticleData.from_file('/data/bgptf/particle_data_eval.pkl')
|
||
test_files = ParticleData.from_file('/data/bgptf/particle_test.pkl')
|
||
p = Path('train_file_split')
|
||
if not p.exists():
|
||
p.mkdir()
|
||
for n, particle in tqdm(enumerate(train_files.data), total=len(train_files.data), desc="Writing files"):
|
||
with open(p.joinpath(f'class{particle.output}_{n}.txt'), 'w') as f:
|
||
f.write(particle.input)
|
||
|
||
# 定义源目录和目标目录
|
||
source_dir = Path('train_file_split')
|
||
train_dir = Path('bbbar_train_split')
|
||
eval_dir = Path('bbbar_eval_split')
|
||
|
||
# 创建目标目录
|
||
train_dir.mkdir(parents=True, exist_ok=True)
|
||
eval_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 获取所有文件路径
|
||
all_files = list(source_dir.glob('*.txt'))
|
||
|
||
# 设置划分比例,例如 80% 作为训练集,20% 作为验证集
|
||
train_ratio = 0.8
|
||
|
||
# 计算训练集和验证集的文件数量
|
||
num_train_files = int(len(all_files) * train_ratio)
|
||
num_eval_files = len(all_files) - num_train_files
|
||
|
||
# 随机打乱文件列表
|
||
random.shuffle(all_files)
|
||
|
||
# 划分训练集和验证集
|
||
train_files = all_files[:num_train_files]
|
||
eval_files = all_files[num_train_files:]
|
||
|
||
# 复制文件到目标目录并显示进度条
|
||
print("正在复制训练集文件...")
|
||
for file in tqdm(train_files, desc="训练集进度"):
|
||
shutil.copy(file, train_dir / file.name)
|
||
|
||
print("正在复制验证集文件...")
|
||
for file in tqdm(eval_files, desc="验证集进度"):
|
||
shutil.copy(file, eval_dir / file.name)
|
||
|
||
print(f"训练集文件数: {len(train_files)}")
|
||
print(f"验证集文件数: {len(eval_files)}")
|
||
... |