Files
labweb/models/main.py
2025-12-16 11:39:15 +08:00

109 lines
4.6 KiB
Python

import argparse
import torch
import os
from model import Model, LabelSmoothing, SimpleLossCompute
from trainer import train_epoch, valid_epoch
from datasets import build_dataloaders, build_lazy_dataloaders, DataProcessor
from torch.cuda.amp import GradScaler
class EarlyStopping:
def __init__(self, patience=5, delta=0.0, save_path="best_model.pt"):
self.patience = patience
self.delta = delta
self.save_path = save_path
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss, model):
if self.best_loss is None or val_loss < self.best_loss - self.delta:
self.best_loss = val_loss
self.counter = 0
torch.save(model.state_dict(), self.save_path)
print(f"? Validation loss improved to {val_loss:.4f}. Model saved to {self.save_path}")
else:
self.counter += 1
print(f"?? No improvement. EarlyStopping counter: {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.early_stop = True
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--d_model", type=int, default=64)
parser.add_argument("--heads", type=int, default=8)
parser.add_argument("--enc_layers", type=int, default=3)
parser.add_argument("--dec_layers", type=int, default=3)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--es_patience", type=int, default=10)
parser.add_argument("--lr_patience", type=int, default=3)
parser.add_argument("--factor", type=float, default=0.8)
parser.add_argument("--save_path", type=str, default="best_model.pt")
args = parser.parse_args()
device = torch.device(args.device)
print(f"Using device: {device}")
# Create DataProcessor instance
dp = DataProcessor(
ko_count_path="/home/zzhang/gzy/Uncultured/generative_ML/data/gene_ko_protein/KO_count.csv",
data_path="/home/zzhang/gzy/Uncultured/generative_ML/encoder_decoder/protT5_embedding/new_ko_all.csv",
embedding_h5_path="/home/zzhang/gzy/Uncultured/ko_pre_train/ProtT5_Encoder/test_genome_ProtT5_embeddings.h5"
)
# Only load metadata, not all embedding data
dp.load_data()
# Use new lazy loading method to build data loaders
train_loader, train_loader_no_shuffle, valid_loader, test_loader = build_lazy_dataloaders(
dp, batch_size=args.batch_size, split_ratio=8, shuffle=True
)
# Save data loaders
torch.save(train_loader, 'train_loader.pth')
torch.save(train_loader_no_shuffle, 'train_loader_no_shuffle.pth')
torch.save(test_loader, 'test_loader.pth')
torch.save(valid_loader, 'valid_loader.pth')
model = Model(args.heads, args.d_model, len(dp.vo_cab), args.enc_layers, args.dec_layers, args.dropout, dp.vo_cab).to(device)
criterion = LabelSmoothing(tgt_size=len(dp.vo_cab), padding_idx=dp.vo_cab["blank"], smoothing=0.0).to(device)
loss_compute = SimpleLossCompute(criterion, device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.lr_patience, factor=args.factor)
scaler = GradScaler()
# Early stopping
early_stopping = EarlyStopping(patience=args.es_patience, save_path=args.save_path)
# Mixed precision training option (default off to avoid CUDA compatibility issues)
use_mixed_precision = False
print(f"Using mixed precision training: {use_mixed_precision}")
for epoch in range(1, args.epochs+1):
train_loss, t_time = train_epoch(model, train_loader, device, loss_compute, scaler, optimizer, use_amp=use_mixed_precision)
valid_loss, v_time = valid_epoch(model, valid_loader, device, loss_compute, use_amp=use_mixed_precision)
scheduler.step(valid_loss)
current_lr = optimizer.param_groups[0]['lr']
print(f"[Epoch {epoch}] train_loss={train_loss:.4f} ({t_time:.1f}s), "
f"valid_loss={valid_loss:.4f} ({v_time:.1f}s), lr={current_lr:.6f}")
# early stopping check
early_stopping(valid_loss, model)
if early_stopping.early_stop:
print("?? Early stopping triggered.")
break
print(f"Training finished. Best model is saved at {args.save_path}")
if __name__ == "__main__":
main()