first add
This commit is contained in:
203
dataloader.py
Normal file
203
dataloader.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@file :test.py
|
||||
@Description: : test unit for ParticleData class
|
||||
@Date :2024/05/14 16:19:01
|
||||
@Author :lyzeng
|
||||
@Email :pylyzeng@gmail.com
|
||||
@version :1.0
|
||||
'''
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
import attrs
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@attrs.define
|
||||
class Particle:
|
||||
instruction: str
|
||||
input: str
|
||||
output: int
|
||||
|
||||
@attrs.define
|
||||
class ParticleData:
|
||||
file: Union[Path, str] = attrs.field(
|
||||
converter=attrs.converters.optional(lambda v: Path(v))
|
||||
)
|
||||
data: List[Particle] = attrs.field(
|
||||
converter=lambda data: data
|
||||
)
|
||||
max_length: int = attrs.field(init=False)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.max_length = max([len(bytearray(d.input, 'utf-8')) for d in self.data])
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path: Union[Path, str]):
|
||||
# 从文件中读取序列化的数据
|
||||
with open(file_path, 'rb') as f:
|
||||
loaded_data = pickle.load(f)
|
||||
return cls(file=file_path, data=loaded_data)
|
||||
|
||||
|
||||
@attrs.define
|
||||
class EpochLog:
|
||||
epoch: int
|
||||
train_acc: float
|
||||
eval_acc: float
|
||||
|
||||
@classmethod
|
||||
def from_log_lines(cls, log_lines):
|
||||
epochs = []
|
||||
epoch_data = {}
|
||||
for line in log_lines:
|
||||
if line.startswith('Epoch'):
|
||||
epoch = int(line.split()[-1])
|
||||
epoch_data = {'epoch': epoch}
|
||||
epochs.append(epoch_data)
|
||||
elif line.startswith('train_acc:'):
|
||||
epoch_data['train_acc'] = float(line.split(':')[-1])
|
||||
elif line.startswith('eval_acc:'):
|
||||
epoch_data['eval_acc'] = float(line.split(':')[-1])
|
||||
return [cls(**epoch) for epoch in epochs if 'train_acc' in epoch and 'eval_acc' in epoch]
|
||||
|
||||
@attrs.define
|
||||
class AccuracyLogger:
|
||||
epochs: List[EpochLog]
|
||||
|
||||
@classmethod
|
||||
def from_log_file(cls, log_file: Path):
|
||||
log_lines = log_file.read_text().splitlines()
|
||||
epoch_logs = EpochLog.from_log_lines(log_lines)
|
||||
return cls(epoch_logs)
|
||||
|
||||
def plot_acc_curve(self, save_path: Path = None):
|
||||
epochs = [e.epoch for e in self.epochs]
|
||||
train_accs = [e.train_acc for e in self.epochs]
|
||||
eval_accs = [e.eval_acc for e in self.epochs]
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
sns.lineplot(x=epochs, y=train_accs, label='Train Accuracy')
|
||||
sns.lineplot(x=epochs, y=eval_accs, label='Eval Accuracy')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Accuracy')
|
||||
plt.title('Training and Evaluation Accuracy over Epochs')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_acc_curve(log_file: Path, save_path: Path = Path('accuracy_curve.png')):
|
||||
accuracy_logger = AccuracyLogger.from_log_file(log_file)
|
||||
accuracy_logger.plot_acc_curve(save_path=save_path)
|
||||
|
||||
def split_train_eval_data(path: List[Path],
|
||||
ext: str = 'txt', train_dir: str = 'train_data', eval_dir: str = 'eval_dir'):
|
||||
# 创建目录
|
||||
if not Path(train_dir).exists():
|
||||
Path(train_dir).mkdir()
|
||||
|
||||
if not Path(eval_dir).exists():
|
||||
Path(eval_dir).mkdir()
|
||||
# 获取所有文件
|
||||
all_files = (file for f in path for file in f.glob(f'*.{ext}'))
|
||||
# 转换为列表以便于打乱和分割
|
||||
all_files = list(all_files)
|
||||
random.shuffle(all_files)
|
||||
# 计算分割点
|
||||
split_idx = int(len(all_files) * 0.8)
|
||||
# 分割数据集
|
||||
train_files = all_files[:split_idx]
|
||||
eval_files = all_files[split_idx:]
|
||||
# 复制文件到相应目录
|
||||
for file in train_files:
|
||||
shutil.copy(file, Path(train_dir) / file.name)
|
||||
for file in eval_files:
|
||||
shutil.copy(file, Path(eval_dir) / file.name)
|
||||
|
||||
if __name__ == '__main__':
|
||||
all_file = Path('origin_data').glob('*.jsonl')
|
||||
data = []
|
||||
for file in all_file:
|
||||
t_lines = file.read_text()[1:-1].splitlines()[1:]
|
||||
for i, line in enumerate(t_lines):
|
||||
if line[-1] == ',':
|
||||
line = line[:-1]
|
||||
particle_dict = json.loads(line)
|
||||
data.append(Particle(**particle_dict))
|
||||
|
||||
with open('particle_data.pkl', 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
print(f"Total {len(data)} particles, train data account for {len(data)*0.8}, eval data account for {len(data)*0.2}")
|
||||
|
||||
with open('particle_data_train.pkl', 'wb') as f:
|
||||
pickle.dump(data[:int(len(data)*0.8)], f)
|
||||
|
||||
with open('particle_data_eval.pkl', 'wb') as f:
|
||||
pickle.dump(data[int(len(data)*0.2):], f)
|
||||
|
||||
all_files = ParticleData.from_file('/data/bgptf/particle_data.pkl')
|
||||
train_files = ParticleData.from_file('/data/bgptf/particle_data_train.pkl')
|
||||
eval_files = ParticleData.from_file('/data/bgptf/particle_data_eval.pkl')
|
||||
test_files = ParticleData.from_file('/data/bgptf/particle_test.pkl')
|
||||
p = Path('train_file_split')
|
||||
if not p.exists():
|
||||
p.mkdir()
|
||||
for n, particle in tqdm(enumerate(train_files.data), total=len(train_files.data), desc="Writing files"):
|
||||
with open(p.joinpath(f'class{particle.output}_{n}.txt'), 'w') as f:
|
||||
f.write(particle.input)
|
||||
|
||||
# 定义源目录和目标目录
|
||||
source_dir = Path('train_file_split')
|
||||
train_dir = Path('bbbar_train_split')
|
||||
eval_dir = Path('bbbar_eval_split')
|
||||
|
||||
# 创建目标目录
|
||||
train_dir.mkdir(parents=True, exist_ok=True)
|
||||
eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取所有文件路径
|
||||
all_files = list(source_dir.glob('*.txt'))
|
||||
|
||||
# 设置划分比例,例如 80% 作为训练集,20% 作为验证集
|
||||
train_ratio = 0.8
|
||||
|
||||
# 计算训练集和验证集的文件数量
|
||||
num_train_files = int(len(all_files) * train_ratio)
|
||||
num_eval_files = len(all_files) - num_train_files
|
||||
|
||||
# 随机打乱文件列表
|
||||
random.shuffle(all_files)
|
||||
|
||||
# 划分训练集和验证集
|
||||
train_files = all_files[:num_train_files]
|
||||
eval_files = all_files[num_train_files:]
|
||||
|
||||
# 复制文件到目标目录并显示进度条
|
||||
print("正在复制训练集文件...")
|
||||
for file in tqdm(train_files, desc="训练集进度"):
|
||||
shutil.copy(file, train_dir / file.name)
|
||||
|
||||
print("正在复制验证集文件...")
|
||||
for file in tqdm(eval_files, desc="验证集进度"):
|
||||
shutil.copy(file, eval_dir / file.name)
|
||||
|
||||
print(f"训练集文件数: {len(train_files)}")
|
||||
print(f"验证集文件数: {len(eval_files)}")
|
||||
...
|
||||
Reference in New Issue
Block a user