# -*- coding: utf-8 -*-
# file: trainer.py
# time: 14:40 06/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
Native training utilities.
This module provides a native PyTorch training framework for genomic models,
including automatic mixed precision training, early stopping, metric tracking,
and model checkpointing.
"""
import os
import tempfile
import autocuda
import numpy as np
from tqdm import tqdm
from typing import Dict, Any, Optional, Union
import torch
from torch.cuda.amp import GradScaler
from .base_trainer import BaseTrainer
from ..misc.utils import fprint
[docs]
class Trainer(BaseTrainer):
"""
Native PyTorch trainer for genomic models.
This trainer provides a complete training framework with automatic mixed precision,
early stopping, metric tracking, and model checkpointing using native PyTorch
without distributed training dependencies.
Attributes:
device: Device to run training on (CPU or GPU)
fast_dtype: Data type for mixed precision training
scaler: Gradient scaler for mixed precision training
Example:
>>> trainer = Trainer(
... model=model,
... train_dataset=train_dataset,
... eval_dataset=eval_dataset,
... epochs=10,
... batch_size=32,
... optimizer=optimizer
... )
>>> metrics = trainer.train()
"""
def __init__(
self,
model: torch.nn.Module,
device: Optional[Union[torch.device, str]] = None,
**kwargs,
):
"""
Initialize the native trainer.
Args:
model (torch.nn.Module): The model to be trained
device (Optional[Union[torch.device, str]]): Device to run training on
**kwargs: Additional keyword arguments passed to BaseTrainer
"""
# Set device before calling parent constructor
self.device = device if device else autocuda.auto_cuda()
self.device = (
torch.device(self.device) if isinstance(self.device, str) else self.device
)
super().__init__(model, **kwargs)
def _setup_training_components(self) -> None:
"""
Set up native training-specific components.
This method initializes the device, mixed precision settings,
and gradient scaler for native PyTorch training.
"""
# Set up mixed precision data type
self.fast_dtype = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}.get(self.autocast, torch.float16)
# Initialize gradient scaler for mixed precision
self.scaler = GradScaler()
# Move model to device
self.model.to(self.device)
def _prepare_batch(self, batch: Any) -> Any:
"""
Prepare a batch for model input by moving to device.
Args:
batch: Input batch
Returns:
Batch moved to the appropriate device
"""
return batch.to(self.device)
def _predict_batch(self, batch: Any) -> Dict[str, torch.Tensor]:
"""
Generate predictions for a batch using the model.
Args:
batch: Input batch
Returns:
Dict[str, torch.Tensor]: Dictionary containing predictions
"""
if self.fast_dtype and self.fast_dtype != torch.float32:
with torch.autocast(device_type=self.device.type, dtype=self.fast_dtype):
return self.model.predict(batch)
else:
return self.model.predict(batch)
def _train_epoch(self, epoch: int) -> float:
"""
Train the model for one epoch using native PyTorch.
Args:
epoch (int): Current epoch number
Returns:
float: Average training loss for the epoch
"""
self.model.train()
train_loss = []
# 使用累积器来正确跟踪未缩放的损失值
loss_accumulator = 0.0
steps_since_update = 0
train_it = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.epochs} Loss")
# 在训练循环开始前清零梯度
self.optimizer.zero_grad()
for step, batch in enumerate(train_it):
batch = self._prepare_batch(batch)
# Forward pass with optional mixed precision
if self.fast_dtype and self.fast_dtype != torch.float32:
with torch.autocast(
device_type=self.device.type, dtype=self.fast_dtype
):
outputs = self.model(**batch)
else:
outputs = self.model(**batch)
# Compute loss
loss = self._compute_loss(outputs)
# 累积原始损失值用于显示(在缩放前)
loss_accumulator += loss.item()
steps_since_update += 1
# Scale loss for gradient accumulation
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
# Backward pass with optional mixed precision
if self.fast_dtype and self.fast_dtype != torch.float32:
self.scaler.scale(loss).backward()
else:
loss.backward()
# Optimizer step and gradient clearing after accumulation
if (step + 1) % self.gradient_accumulation_steps == 0 or (step + 1) == len(
self.train_loader
):
if self.fast_dtype and self.fast_dtype != torch.float32:
# Clip gradients before optimizer step
if (
hasattr(self, "max_grad_norm")
and self.max_grad_norm is not None
):
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Clip gradients before optimizer step
if (
hasattr(self, "max_grad_norm")
and self.max_grad_norm is not None
):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.optimizer.step()
# 完成参数更新后,记录平均损失值并重置累积器
avg_loss = loss_accumulator / steps_since_update
train_loss.append(avg_loss)
train_it.set_description(
f"Epoch {epoch + 1}/{self.epochs} Loss: {np.nanmean(train_loss):.4f}"
)
# 重置累积器
loss_accumulator = 0.0
steps_since_update = 0
# 清空梯度,准备下一次累积
self.optimizer.zero_grad()
elif (step + 1) % max(1, self.gradient_accumulation_steps // 5) == 0:
# 在累积过程中也定期更新进度条,但不更新参数
train_it.set_description(
f"Epoch {epoch + 1}/{self.epochs} Loss: {np.nanmean(train_loss) if train_loss else 0:.4f} (Accumulating {steps_since_update}/{self.gradient_accumulation_steps})"
)
return np.nanmean(train_loss)
def _compute_loss(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Compute loss from model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs
Returns:
torch.Tensor: Computed loss
Raises:
ValueError: If no loss function is available
"""
if "loss" in outputs:
return outputs["loss"]
# Try to use model's loss function
if hasattr(self.model, "loss_function") and callable(self.model.loss_function):
return self.model.loss_function(outputs["logits"], outputs["labels"])
if (
hasattr(self.model, "model")
and hasattr(self.model.model, "loss_function")
and callable(self.model.model.loss_function)
):
return self.model.model.loss_function(outputs["logits"], outputs["labels"])
raise ValueError(
"The model does not have a loss function defined. "
"Please provide a loss function or ensure the model has one."
)
[docs]
def evaluate(self) -> Dict[str, Any]:
"""
Evaluate the model on the validation dataset.
Returns:
Dict[str, Any]: Dictionary containing evaluation metrics
"""
with torch.no_grad():
self.model.eval()
val_truth = []
val_preds = []
it = tqdm(self.eval_loader, desc="Evaluating")
for batch in it:
batch = self._prepare_batch(batch)
labels = batch["labels"]
batch.pop("labels")
with torch.autocast(
device_type=self.device.type, dtype=self.fast_dtype
):
predictions = self._predict_batch(batch)["predictions"]
val_truth.append(labels.float().cpu().numpy())
val_preds.append(predictions.float().cpu().numpy())
val_truth = self._concatenate_outputs(val_truth)
val_preds = self._concatenate_outputs(val_preds)
if not np.all(val_truth == -100):
valid_metrics = {}
for metric_func in self.compute_metrics:
valid_metrics.update(metric_func(val_truth, val_preds))
fprint(valid_metrics)
else:
valid_metrics = {
"Validation set labels may be NaN. No metrics calculated.": 0
}
self.predictions.update({"valid": {"pred": val_preds, "true": val_truth}})
return valid_metrics
[docs]
def test(self) -> Dict[str, Any]:
"""
Test the model on the test dataset.
Returns:
Dict[str, Any]: Dictionary containing test metrics
"""
with torch.no_grad():
self.model.eval()
preds = []
truth = []
it = tqdm(self.test_loader, desc="Testing")
for batch in it:
batch = self._prepare_batch(batch)
labels = batch["labels"]
batch.pop("labels")
with torch.autocast(
device_type=self.device.type, dtype=self.fast_dtype
):
predictions = self._predict_batch(batch)["predictions"]
truth.append(labels.float().cpu().numpy())
preds.append(predictions.float().cpu().numpy())
truth = self._concatenate_outputs(truth)
preds = self._concatenate_outputs(preds)
if not np.all(truth == -100):
test_metrics = {}
for metric_func in self.compute_metrics:
test_metrics.update(metric_func(truth, preds))
fprint(test_metrics)
else:
test_metrics = {"Test set labels may be NaN. No metrics calculated.": 0}
self.predictions.update({"test": {"pred": preds, "true": truth}})
return test_metrics
[docs]
def unwrap_model(self, model: Optional[torch.nn.Module] = None) -> torch.nn.Module:
"""
Unwrap the model from any distributed training wrappers.
Args:
model (Optional[torch.nn.Module]): Model to unwrap (default: None, uses self.model)
Returns:
torch.nn.Module: The unwrapped model
"""
if model is None:
model = self.model
# For native trainer, no unwrapping needed typically
try:
return model.module # In case of DataParallel
except AttributeError:
return model
def _load_state_dict(self) -> None:
"""
Load model state dictionary from temporary file.
"""
if hasattr(self, "_model_state_dict_path") and os.path.exists(
self._model_state_dict_path
):
self.unwrap_model().load_state_dict(
torch.load(self._model_state_dict_path, map_location="cpu")
)
self.unwrap_model().to(self.device)
def _save_state_dict(self) -> None:
"""
Save model state dictionary to temporary file.
"""
try:
if os.path.exists(self._model_state_dict_path):
os.remove(self._model_state_dict_path)
except Exception as e:
fprint(
f"Failed to remove the temporary checkpoint file {self._model_state_dict_path}: {e}"
)
torch.save(self.unwrap_model().state_dict(), self._model_state_dict_path)
[docs]
def save_model(self, path_to_save: str, overwrite: bool = False, **kwargs) -> None:
"""
Save the trained model.
Args:
path_to_save (str): Path to save the model
overwrite (bool): Whether to overwrite existing files (default: False)
**kwargs: Additional keyword arguments
"""
self.unwrap_model().save(path_to_save, overwrite, **kwargs)