Source code for omnigenbench.src.trainer.base_trainer

# -*- coding: utf-8 -*-
# file: base_trainer.py
# time: 15:00 15/07/2025
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
Base trainer class.

This module provides the abstract base class for all trainers in the OmniGenome
framework. It defines the common interface and shared functionality that all
trainer implementations should provide.
"""

import os
import warnings
import tempfile
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, Any
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torch

from ..misc.utils import env_meta_info, fprint, seed_everything


[docs] class MetricsDict(dict): """ A custom dictionary class for storing training metrics with enhanced readability. This class extends the built-in dict to provide a more readable string representation of metrics, making it easier to understand training progress and results. """ def __repr__(self) -> str: """ Return a formatted, human-readable string representation of the metrics. Returns: str: A formatted string showing metrics organized by stage (train/valid/test) and best validation metrics. """ if not self: return "MetricsDict(empty)" lines = ["=" * 80] lines.append("Training Metrics Summary".center(80)) lines.append("=" * 80) # Display best validation metrics first if available if "best_valid" in self: lines.append("\n[Best Validation Metrics]") lines.append("-" * 80) best_metrics = self["best_valid"] if isinstance(best_metrics, dict): for key, value in best_metrics.items(): if isinstance(value, (int, float)): lines.append(f" {key:.<40} {value:.6f}") elif isinstance(value, (list, tuple, np.ndarray)): if len(value) > 0 and isinstance(value[0], (int, float)): mean_val = np.mean(value) std_val = np.std(value) lines.append(f" {key:.<40} {mean_val:.6f} ± {std_val:.6f}") else: lines.append(f" {key:.<40} {value}") else: lines.append(f" {key:.<40} {value}") # Display metrics for each stage (train, valid, test) for stage in ["train", "valid", "test"]: if stage in self and self[stage]: lines.append(f"\n[{stage.capitalize()} Metrics History]") lines.append("-" * 80) metrics_list = self[stage] if isinstance(metrics_list, list) and len(metrics_list) > 0: # Show the number of epochs/evaluations lines.append(f" Total evaluations: {len(metrics_list)}") # Show the latest metrics latest_metrics = metrics_list[-1] if isinstance(latest_metrics, dict): lines.append(f" Latest (Epoch {len(metrics_list)}):") for key, value in latest_metrics.items(): if isinstance(value, (int, float)): lines.append(f" {key:.<38} {value:.6f}") elif isinstance(value, (list, tuple, np.ndarray)): if len(value) > 0 and isinstance( value[0], (int, float) ): mean_val = np.mean(value) std_val = np.std(value) lines.append( f" {key:.<38} {mean_val:.6f} ± {std_val:.6f}" ) else: lines.append(f" {key:.<38} {value}") else: lines.append(f" {key:.<38} {value}") # Show trend if we have multiple epochs if len(metrics_list) > 1: lines.append(f" First (Epoch 1):") first_metrics = metrics_list[0] if isinstance(first_metrics, dict): for key, value in first_metrics.items(): if isinstance(value, (int, float)): lines.append(f" {key:.<38} {value:.6f}") elif isinstance(value, (list, tuple, np.ndarray)): if len(value) > 0 and isinstance( value[0], (int, float) ): mean_val = np.mean(value) lines.append(f" {key:.<38} {mean_val:.6f}") lines.append("=" * 80) return "\n".join(lines) def __str__(self) -> str: """Return the same formatted representation as __repr__.""" return self.__repr__()
def _infer_optimization_direction( metrics: Dict[str, Any], prev_metrics: List[Dict[str, Any]] ) -> str: """ Infer the optimization direction based on metric values. This function analyzes the trend of metric values to determine whether larger values are better (e.g., accuracy) or smaller values are better (e.g., loss). Args: metrics (Dict[str, Any]): Current metric values prev_metrics (List[Dict[str, Any]]): Previous metric values from multiple epochs Returns: str: Either 'larger_is_better' or 'smaller_is_better' """ # Check if metrics is empty if not metrics or len(metrics) == 0: fprint( "Warning: Cannot infer optimization direction from empty metrics. Defaulting to 'smaller_is_better'." ) return "smaller_is_better" larger_is_better_metrics = [ "accuracy", "f1", "recall", "precision", "roc_auc", "pr_auc", "score", "auc", "balanced_accuracy", "matthews_corrcoef", "jaccard", "dice", ] smaller_is_better_metrics = [ "loss", "error", "mse", "mae", "rmse", "r2", "distance", "perplexity", "cross_entropy", "binary_cross_entropy", "focal_loss", "huber_loss", ] # Check if any metric name matches known patterns for metric_name in metrics.keys(): metric_name_lower = metric_name.lower() for pattern in larger_is_better_metrics: if pattern in metric_name_lower: return "larger_is_better" for pattern in smaller_is_better_metrics: if pattern in metric_name_lower: return "smaller_is_better" # If no pattern matches, try to infer from metric trends if prev_metrics and len(prev_metrics) >= 2: fprint( "Cannot determine optimization direction from metric names. Attempting inference from trends." ) try: # Check if prev_metrics have non-empty values if ( not list(metrics.values()) or not list(prev_metrics[-1].values()) or not list(prev_metrics[0].values()) ): fprint("Warning: Empty metric values found. Cannot infer from trends.") return "smaller_is_better" # Get the first metric value for trend analysis first_metric_key = list(metrics.keys())[0] current_value = np.mean( list(metrics.values())[0] if isinstance(list(metrics.values())[0], (list, tuple, np.ndarray)) else [list(metrics.values())[0]] ) prev_value = np.mean( list(prev_metrics[-1].values())[0] if isinstance( list(prev_metrics[-1].values())[0], (list, tuple, np.ndarray) ) else [list(prev_metrics[-1].values())[0]] ) earlier_value = np.mean( list(prev_metrics[0].values())[0] if isinstance( list(prev_metrics[0].values())[0], (list, tuple, np.ndarray) ) else [list(prev_metrics[0].values())[0]] ) # Check if metrics are consistently increasing or decreasing is_increasing = earlier_value < prev_value < current_value is_decreasing = earlier_value > prev_value > current_value if is_increasing: return "larger_is_better" elif is_decreasing: return "smaller_is_better" except (IndexError, KeyError, TypeError) as e: fprint(f"Error inferring optimization direction: {e}") # Default to smaller_is_better (common for loss-based metrics) fprint( "Cannot determine optimization direction. Defaulting to 'smaller_is_better'." ) return "smaller_is_better"
[docs] class BaseTrainer(ABC): """ Abstract base class for all trainers in the OmniGenome framework. This class defines the common interface and shared functionality that all trainer implementations should provide. It includes methods for training, evaluation, testing, and model management. Attributes: model: The model to be trained train_loader: DataLoader for training data eval_loader: DataLoader for validation data test_loader: DataLoader for test data epochs: Number of training epochs batch_size: Batch size for training patience: Early stopping patience gradient_accumulation_steps: Number of steps for gradient accumulation optimizer: Optimizer for training loss_fn: Loss function compute_metrics: List of metric computation functions seed: Random seed for reproducibility metrics: Dictionary to store training metrics predictions: Dictionary to store model predictions metadata: Dictionary containing environment and training metadata trial_name: Name of the current training trial _optimization_direction: Optimization direction ('larger_is_better' or 'smaller_is_better') Example: >>> class MyTrainer(BaseTrainer): ... def _setup_training_components(self): ... # Implementation specific setup ... pass ... def _train_epoch(self, epoch): ... # Implementation specific training loop ... pass >>> trainer = MyTrainer(model=model, train_dataset=train_dataset) >>> metrics = trainer.train() """ def __init__( self, model: torch.nn.Module, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, epochs: int = 3, learning_rate: float = 2e-5, batch_size: int = 8, patience: int = -1, max_grad_norm: float = 1.0, gradient_accumulation_steps: int = 1, optimizer: Optional[torch.optim.Optimizer] = None, loss_fn: Optional[torch.nn.Module] = None, compute_metrics: Optional[Union[List, str]] = None, seed: int = 42, autocast: str = "float16", **kwargs, ): """ Initialize the base trainer. Args: model (torch.nn.Module): The model to be trained train_dataset (Optional[Dataset]): Training dataset eval_dataset (Optional[Dataset]): Validation dataset test_dataset (Optional[Dataset]): Test dataset epochs (int): Number of training epochs (default: 3) batch_size (int): Batch size for training (default: 8) patience (int): Early stopping patience (default: -1, no early stopping) gradient_accumulation_steps (int): Gradient accumulation steps (default: 1) optimizer (Optional[torch.optim.Optimizer]): Optimizer for training loss_fn (Optional[torch.nn.Module]): Loss function compute_metrics (Optional[Union[List, str]]): Metric computation functions seed (int): Random seed for reproducibility (default: 42) autocast (str): Mixed precision type (default: "float16") **kwargs: Additional keyword arguments """ self.model = model self.epochs = epochs self.batch_size = batch_size self.patience = patience if patience > 0 else epochs self.max_grad_norm = max_grad_norm self.gradient_accumulation_steps = gradient_accumulation_steps self.optimizer = optimizer if not optimizer: warnings.warn( f"No optimizer provided. Defaulting to Adam optimizer with {learning_rate}." ) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) # Set loss function if provided if loss_fn is not None and hasattr(self.model, "set_loss_fn"): self.model.set_loss_fn(loss_fn) self.compute_metrics = ( ( compute_metrics if isinstance(compute_metrics, list) else [compute_metrics] ) if compute_metrics else [] ) if not self.compute_metrics: warnings.warn( "No compute metrics provided. Metrics will not be calculated during training." ) self.seed = seed seed_everything(seed) self.autocast = autocast # Initialize data loaders self._setup_data_loaders( train_dataset, eval_dataset, test_dataset, batch_size, **kwargs ) # Initialize training components self._setup_training_components() # Initialize metadata and tracking self.metadata = env_meta_info() self.metrics = MetricsDict() self.predictions = {} self._optimization_direction = None self.trial_name = kwargs.get("trial_name", self.model.__class__.__name__) if not hasattr(self, "_model_state_dict_path"): # Create temporary directory if it doesn't exist temp_dir = tempfile.gettempdir() os.makedirs(temp_dir, exist_ok=True) # Use mkstemp for better cross-platform compatibility fd, temp_path = tempfile.mkstemp(suffix=".pt", dir=temp_dir) os.close(fd) # Close file descriptor immediately self._model_state_dict_path = temp_path def _setup_data_loaders( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], test_dataset: Optional[Dataset], batch_size: int, **kwargs, ) -> None: """ Set up data loaders for training, evaluation, and testing. Args: train_dataset (Optional[Dataset]): Training dataset eval_dataset (Optional[Dataset]): Validation dataset test_dataset (Optional[Dataset]): Test dataset batch_size (int): Batch size for data loaders **kwargs: Additional keyword arguments """ # Check if pre-built loaders are provided if kwargs.get("train_loader"): self.train_loader = kwargs.get("train_loader") else: if train_dataset is None: raise ValueError( "train_dataset must be provided if train_loader is not." ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True ) if kwargs.get("eval_loader") or kwargs.get("valid_loader"): self.eval_loader = kwargs.get("eval_loader", None) or kwargs.get( "valid_loader", None ) else: # Create data loaders from datasets self.eval_loader = ( DataLoader(eval_dataset, batch_size=batch_size, shuffle=False) if eval_dataset is not None else None ) if kwargs.get("test_loader"): self.test_loader = kwargs.get("test_loader", None) else: self.test_loader = ( DataLoader(test_dataset, batch_size=batch_size, shuffle=False) if test_dataset is not None else None ) @abstractmethod def _setup_training_components(self) -> None: """ Set up training-specific components (device, scaler, etc.). This method should be implemented by subclasses to initialize trainer-specific components like device selection, gradient scalers, distributed training setup, etc. """ pass @abstractmethod def _train_epoch(self, epoch: int) -> float: """ Train the model for one epoch. Args: epoch (int): Current epoch number Returns: float: Average training loss for the epoch """ pass def _is_metric_better(self, metrics: Dict[str, Any], stage: str = "valid") -> bool: """ Check if the current metrics are better than the best metrics so far. Args: metrics (Dict[str, Any]): Current metric values stage (str): Stage name ("valid" or "test") Returns: bool: True if current metrics are better than best metrics Raises: AssertionError: If stage is not "valid" or "test" """ assert stage in [ "valid", "test", ], "The metrics stage should be either 'valid' or 'test'." # Check if metrics is empty if not metrics or len(metrics) == 0: fprint("Warning: Empty metrics dictionary received. Skipping comparison.") # Still store the empty metrics for tracking if stage not in self.metrics: self.metrics[stage] = [metrics] else: self.metrics[stage].append(metrics) return False # Store current metrics prev_metrics = self.metrics.get(stage, None) if stage not in self.metrics: self.metrics[stage] = [metrics] else: self.metrics[stage].append(metrics) # Initialize best metrics if not present (only with non-empty metrics) if "best_valid" not in self.metrics: self.metrics["best_valid"] = metrics return True # If we have best_valid but prev_metrics is None, this is the second call # We should still compare against best_valid if prev_metrics is None or len(prev_metrics) == 0: prev_metrics = [self.metrics.get("best_valid", {})] # Determine optimization direction if self._optimization_direction is None: self._optimization_direction = _infer_optimization_direction( metrics, prev_metrics ) # Compare metrics based on optimization direction try: # Check if metrics has values before accessing if not list(metrics.values()): fprint( "Warning: Metrics dictionary has no values. Skipping comparison." ) return False current_value = np.mean( list(metrics.values())[0] if isinstance(list(metrics.values())[0], (list, tuple, np.ndarray)) else [list(metrics.values())[0]] ) # Check if best_valid has values if not list(self.metrics["best_valid"].values()): fprint( "Warning: Best metrics dictionary has no values. Skipping comparison." ) return False best_value = np.mean( list(self.metrics["best_valid"].values())[0] if isinstance( list(self.metrics["best_valid"].values())[0], (list, tuple, np.ndarray), ) else [list(self.metrics["best_valid"].values())[0]] ) if self._optimization_direction == "larger_is_better": is_better = current_value > best_value else: # smaller_is_better is_better = current_value < best_value if is_better: self.metrics["best_valid"] = metrics return True except (IndexError, KeyError, TypeError) as e: fprint(f"Error comparing metrics: {e}") return False return False
[docs] def train(self, path_to_save: Optional[str] = None, **kwargs) -> Dict[str, Any]: """ Train the model. This method implements the main training loop including early stopping, model checkpointing, and metric tracking. Args: path_to_save (Optional[str]): Path to save the trained model **kwargs: Additional keyword arguments Returns: Dict[str, Any]: Training metrics and results """ seed_everything(self.seed) patience_counter = 0 # Initial evaluation if self.eval_loader is not None and len(self.eval_loader) > 0: initial_metrics = self.evaluate() else: initial_metrics = self.test() if self._is_metric_better(initial_metrics, stage="valid"): self._save_state_dict() patience_counter = 0 # Main training loop for epoch in range(self.epochs): # Train for one epoch avg_loss = self._train_epoch(epoch) # Evaluate after each epoch if self.eval_loader is not None and len(self.eval_loader) > 0: valid_metrics = self.evaluate() else: valid_metrics = self.test() # Check for improvement and early stopping if self._is_metric_better(valid_metrics, stage="valid"): self._save_state_dict() patience_counter = 0 else: patience_counter += 1 if patience_counter >= self.patience: fprint(f"Early stopping at epoch {epoch + 1}.") break # Save epoch checkpoint if requested if path_to_save: self._save_epoch_checkpoint( path_to_save, epoch, valid_metrics, **kwargs ) # Final testing with best model if self.test_loader is not None and len(self.test_loader) > 0: self._load_state_dict() test_metrics = self.test() self._is_metric_better(test_metrics, stage="test") # Save final model if requested if path_to_save: self._save_final_model(path_to_save, **kwargs) # Clean up temporary files self._remove_state_dict() return self.metrics
[docs] def evaluate(self) -> Dict[str, Any]: """ Evaluate the model on the validation dataset. Returns: Dict[str, Any]: Dictionary containing evaluation metrics """ self.model.eval() all_truth = [] all_preds = [] with torch.no_grad(): for batch in tqdm(self.eval_loader, desc="Evaluating"): batch = self._prepare_batch(batch) output = self._predict_batch(batch) predictions = output["predictions"] labels = batch["labels"] all_truth.append(self._process_labels(labels)) all_preds.append(self._process_predictions(predictions)) # Concatenate all predictions and labels all_truth = self._concatenate_outputs(all_truth) all_preds = self._concatenate_outputs(all_preds) # Compute metrics if not np.all(all_truth == -100): valid_metrics = {} for metric_func in self.compute_metrics: valid_metrics.update(metric_func(all_truth, all_preds)) fprint(valid_metrics) else: valid_metrics = {"Validation labels may be NaN. No metrics calculated.": 0} # Store predictions self.predictions.update({"valid": {"pred": all_preds, "true": all_truth}}) return valid_metrics
[docs] def test(self) -> Dict[str, Any]: """ Test the model on the test dataset. Returns: Dict[str, Any]: Dictionary containing test metrics """ self.model.eval() all_truth = [] all_preds = [] with torch.no_grad(): for batch in tqdm(self.test_loader, desc="Testing"): batch = self._prepare_batch(batch) output = self._predict_batch(batch) predictions = output["predictions"] labels = batch["labels"] all_truth.append(self._process_labels(labels)) all_preds.append(self._process_predictions(predictions)) # Concatenate all predictions and labels all_truth = self._concatenate_outputs(all_truth) all_preds = self._concatenate_outputs(all_preds) # Compute metrics if not np.all(all_truth == -100): test_metrics = {} for metric_func in self.compute_metrics: test_metrics.update(metric_func(all_truth, all_preds)) fprint(test_metrics) else: test_metrics = {"Test labels may be NaN. No metrics calculated.": 0} # Store predictions self.predictions.update({"test": {"pred": all_preds, "true": all_truth}}) return test_metrics
[docs] def predict(self, data_loader: DataLoader) -> Dict[str, Any]: """ Generate predictions using the trained model. Args: data_loader (DataLoader): DataLoader for prediction data Returns: Dict[str, Any]: Dictionary containing predictions """ return self._predict_batch(data_loader)
[docs] def get_model(self, **kwargs) -> torch.nn.Module: """ Get the trained model. Args: **kwargs: Additional keyword arguments Returns: torch.nn.Module: The trained model """ return self.model
[docs] def save_model(self, path: str, overwrite: bool = False, **kwargs) -> None: """ Save the trained model. Args: path (str): Path to save the model overwrite (bool): Whether to overwrite existing files (default: False) **kwargs: Additional keyword arguments """ if hasattr(self.model, "save"): self.model.save(path, overwrite, **kwargs) else: torch.save(self.model.state_dict(), f"{path}.pt")
# Abstract methods that subclasses may need to implement @abstractmethod def _prepare_batch(self, batch: Any) -> Any: """ Prepare a batch for model input. Args: batch: Input batch Returns: Prepared batch """ pass @abstractmethod def _predict_batch(self, batch: Any) -> Dict[str, torch.Tensor]: """ Generate predictions for a batch. Args: batch: Input batch Returns: Dict[str, torch.Tensor]: Dictionary containing predictions """ pass def _process_labels(self, labels: torch.Tensor) -> np.ndarray: """ Process labels for metric computation. Args: labels (torch.Tensor): Raw labels Returns: np.ndarray: Processed labels """ return labels.float().cpu().numpy() def _process_predictions(self, predictions: torch.Tensor) -> np.ndarray: """ Process predictions for metric computation. Args: predictions (torch.Tensor): Raw predictions Returns: np.ndarray: Processed predictions """ return predictions.float().cpu().numpy() def _concatenate_outputs(self, outputs: List[np.ndarray]) -> np.ndarray: """ Concatenate list of outputs. Args: outputs (List[np.ndarray]): List of output arrays Returns: np.ndarray: Concatenated outputs """ if not outputs: return np.array([]) sample_output = outputs[0] if sample_output.ndim > 1: return np.vstack(outputs) else: return np.hstack(outputs) def _save_epoch_checkpoint( self, path_to_save: str, epoch: int, metrics: Dict[str, Any], **kwargs ) -> None: """ Save model checkpoint after each epoch. Args: path_to_save (str): Base path for saving epoch (int): Current epoch number metrics (Dict[str, Any]): Current metrics **kwargs: Additional keyword arguments """ checkpoint_path = f"{path_to_save}_epoch_{epoch + 1}" if metrics: for key, value in metrics.items(): if isinstance(value, (int, float)): checkpoint_path += f"_seed_{self.seed}_{key}_{value:.4f}" self.save_model(checkpoint_path, **kwargs) def _save_final_model(self, path_to_save: str, **kwargs) -> None: """ Save the final trained model. Args: path_to_save (str): Base path for saving **kwargs: Additional keyword arguments """ final_path = f"{path_to_save}_final" if self.metrics.get("test") and len(self.metrics["test"]) > 0: for key, value in self.metrics["test"][-1].items(): if isinstance(value, (int, float)): final_path += f"_seed_{self.seed}_{key}_{value:.4f}" self.save_model(final_path, **kwargs) def _save_state_dict(self) -> None: """ Save model state dictionary to temporary file. """ try: if os.path.exists(self._model_state_dict_path): os.remove(self._model_state_dict_path) except Exception as e: fprint(f"Failed to remove temporary checkpoint file: {e}") torch.save(self.model.state_dict(), self._model_state_dict_path) def _load_state_dict(self) -> None: """ Load model state dictionary from temporary file. """ if hasattr(self, "_model_state_dict_path") and os.path.exists( self._model_state_dict_path ): self.model.load_state_dict( torch.load(self._model_state_dict_path, map_location="cpu") ) def _remove_state_dict(self) -> None: """ Remove temporary state dictionary file. """ if hasattr(self, "_model_state_dict_path"): try: if os.path.exists(self._model_state_dict_path): os.remove(self._model_state_dict_path) except Exception as e: fprint(f"Failed to remove temporary checkpoint file: {e}")