Trainers¶
Note
You are viewing the API reference documentation.
This page provides detailed API documentation for trainer classes. For a comprehensive tutorial-style guide with complete examples, see the ../TRAINER_GUIDE.
Quick Start: See below for minimal examples
Detailed Tutorial: ../TRAINER_GUIDE (complete guide with 10 sections)
Design Philosophy: Architecture & Design Philosophy (understanding BaseTrainer abstraction)
Overview¶
OmniGenBench provides three powerful trainer implementations for different training scenarios. All trainers inherit from a unified BaseTrainer abstract class, ensuring consistent API and functionality.
Trainer Comparison¶
Trainer |
Best For |
Key Features |
Requirements |
|---|---|---|---|
Trainer |
Single-GPU training |
Lightweight, easy debugging, native PyTorch |
PyTorch 2.5+ |
AccelerateTrainer |
Multi-GPU distributed training |
Zero-config distributed, DeepSpeed/FSDP support |
HuggingFace Accelerate |
HFTrainer |
HuggingFace ecosystem integration |
Full HF features, callbacks, logging |
HuggingFace Transformers |
Quick Start¶
Native Trainer (Single-GPU):
from omnigenbench import Trainer, ModelHub
model = ModelHub.load("yangheng/OmniGenome-186M")
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
epochs=10,
batch_size=32,
autocast="fp16", # Mixed precision
device="cuda:0",
)
metrics = trainer.train()
Accelerate Trainer (Multi-GPU):
from omnigenbench import AccelerateTrainer
trainer = AccelerateTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
epochs=10,
batch_size=32, # Per-device batch size
autocast="fp16",
)
metrics = trainer.train()
# Run with: accelerate launch train.py
HuggingFace Trainer:
from omnigenbench import HFTrainer
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=10,
per_device_train_batch_size=32,
fp16=True,
)
trainer = HFTrainer(
model=model,
training_args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
metrics = trainer.train()
Common Features¶
All trainers support:
✅ Mixed Precision Training: FP16, BF16 for faster training
✅ Early Stopping: Automatic training termination when metrics plateau
✅ Gradient Accumulation: Simulate larger batch sizes
✅ Gradient Clipping: Prevent exploding gradients
✅ Checkpoint Management: Automatic saving and loading
✅ Flexible Optimizers: Support any PyTorch optimizer
✅ Learning Rate Scheduling: Built-in and custom schedulers
✅ Evaluation Metrics: Custom metric computation
✅ Reproducibility: Seed control for consistent results
Base Trainer¶
Abstract base class defining the common interface for all trainers.
Base trainer class.
This module provides the abstract base class for all trainers in the OmniGenome framework. It defines the common interface and shared functionality that all trainer implementations should provide.
- class omnigenbench.src.trainer.base_trainer.BaseTrainer(model: Module, train_dataset: Dataset | None = None, eval_dataset: Dataset | None = None, test_dataset: Dataset | None = None, epochs: int = 3, learning_rate: float = 2e-05, batch_size: int = 8, patience: int = -1, max_grad_norm: float = 1.0, gradient_accumulation_steps: int = 1, optimizer: Optimizer | None = None, loss_fn: Module | None = None, compute_metrics: List | str | None = None, seed: int = 42, autocast: str = 'float16', **kwargs)[source]
Bases:
ABCAbstract base class for all trainers in the OmniGenome framework.
This class defines the common interface and shared functionality that all trainer implementations should provide. It includes methods for training, evaluation, testing, and model management.
- Variables:
model – The model to be trained
train_loader – DataLoader for training data
eval_loader – DataLoader for validation data
test_loader – DataLoader for test data
epochs – Number of training epochs
batch_size – Batch size for training
patience – Early stopping patience
gradient_accumulation_steps – Number of steps for gradient accumulation
optimizer – Optimizer for training
loss_fn – Loss function
compute_metrics – List of metric computation functions
seed – Random seed for reproducibility
metrics – Dictionary to store training metrics
predictions – Dictionary to store model predictions
metadata – Dictionary containing environment and training metadata
trial_name – Name of the current training trial
_optimization_direction – Optimization direction (‘larger_is_better’ or ‘smaller_is_better’)
Example
>>> class MyTrainer(BaseTrainer): ... def _setup_training_components(self): ... # Implementation specific setup ... pass ... def _train_epoch(self, epoch): ... # Implementation specific training loop ... pass >>> trainer = MyTrainer(model=model, train_dataset=train_dataset) >>> metrics = trainer.train()
- evaluate() Dict[str, Any][source]
Evaluate the model on the validation dataset.
- Returns:
Dict[str, Any] – Dictionary containing evaluation metrics
- get_model(**kwargs) Module[source]
Get the trained model.
- Parameters:
**kwargs – Additional keyword arguments
- Returns:
torch.nn.Module – The trained model
- predict(data_loader: DataLoader) Dict[str, Any][source]
Generate predictions using the trained model.
- Parameters:
data_loader (DataLoader) – DataLoader for prediction data
- Returns:
Dict[str, Any] – Dictionary containing predictions
- save_model(path: str, overwrite: bool = False, **kwargs) None[source]
Save the trained model.
- Parameters:
path (str) – Path to save the model
overwrite (bool) – Whether to overwrite existing files (default: False)
**kwargs – Additional keyword arguments
- test() Dict[str, Any][source]
Test the model on the test dataset.
- Returns:
Dict[str, Any] – Dictionary containing test metrics
- train(path_to_save: str | None = None, **kwargs) Dict[str, Any][source]
Train the model.
This method implements the main training loop including early stopping, model checkpointing, and metric tracking.
- Parameters:
path_to_save (Optional[str]) – Path to save the trained model
**kwargs – Additional keyword arguments
- Returns:
Dict[str, Any] – Training metrics and results
- class omnigenbench.src.trainer.base_trainer.MetricsDict[source]
Bases:
dictA custom dictionary class for storing training metrics with enhanced readability.
This class extends the built-in dict to provide a more readable string representation of metrics, making it easier to understand training progress and results.
Key Components¶
MetricsDict
Enhanced dictionary for training metrics with formatted display:
metrics = trainer.train()
print(metrics) # Automatically formatted output
# Access metrics
best_accuracy = metrics['best_valid']['eval_accuracy']
final_loss = metrics['train_metrics_history'][-1]['train_loss']
Optimization Direction
Automatically infers whether to minimize or maximize metrics:
trainer = Trainer(
optimization_metric="loss", # Minimizes
# or
optimization_metric="accuracy", # Maximizes
)
Trainer (Native PyTorch)¶
Native PyTorch trainer for single-GPU training.
Native training utilities.
This module provides a native PyTorch training framework for genomic models, including automatic mixed precision training, early stopping, metric tracking, and model checkpointing.
- class omnigenbench.src.trainer.trainer.Trainer(model: Module, device: str | device | None = None, **kwargs)[source]
Bases:
BaseTrainerNative PyTorch trainer for genomic models.
This trainer provides a complete training framework with automatic mixed precision, early stopping, metric tracking, and model checkpointing using native PyTorch without distributed training dependencies.
- Variables:
device – Device to run training on (CPU or GPU)
fast_dtype – Data type for mixed precision training
scaler – Gradient scaler for mixed precision training
Example
>>> trainer = Trainer( ... model=model, ... train_dataset=train_dataset, ... eval_dataset=eval_dataset, ... epochs=10, ... batch_size=32, ... optimizer=optimizer ... ) >>> metrics = trainer.train()
- evaluate() Dict[str, Any][source]
Evaluate the model on the validation dataset.
- Returns:
Dict[str, Any] – Dictionary containing evaluation metrics
- save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]
Save the trained model.
- Parameters:
path_to_save (str) – Path to save the model
overwrite (bool) – Whether to overwrite existing files (default: False)
**kwargs – Additional keyword arguments
- test() Dict[str, Any][source]
Test the model on the test dataset.
- Returns:
Dict[str, Any] – Dictionary containing test metrics
- unwrap_model(model: Module | None = None) Module[source]
Unwrap the model from any distributed training wrappers.
- Parameters:
model (Optional[torch.nn.Module]) – Model to unwrap (default: None, uses self.model)
- Returns:
torch.nn.Module – The unwrapped model
Features¶
Automatic Mixed Precision: Using
torch.cuda.amp.GradScalerDevice Management: Auto-detects CUDA/CPU
Simple and Fast: Minimal overhead, easy debugging
Customizable: Easy to extend and modify
Example¶
from omnigenbench import Trainer
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# Configure optimizer and scheduler
optimizer = AdamW(
model.parameters(),
lr=5e-5,
weight_decay=0.01
)
scheduler = CosineAnnealingLR(
optimizer,
T_max=10000
)
# Create trainer
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
optimizer=optimizer,
lr_scheduler=scheduler,
# Training configuration
epochs=10,
batch_size=32,
gradient_accumulation_steps=2,
max_grad_norm=1.0,
# Mixed precision
autocast="fp16", # or "bf16" for Ampere+ GPUs
# Early stopping
patience=3,
delta=0.001,
# Checkpointing
save_dir="./checkpoints",
save_steps=1000,
save_total_limit=3,
# Device
device="cuda:0",
seed=42,
)
# Train
metrics = trainer.train()
# Evaluate
test_metrics = trainer.evaluate(test_dataset)
AccelerateTrainer (Distributed Training)¶
Distributed trainer using HuggingFace Accelerate for multi-GPU training.
Accelerate-based distributed training utilities.
This module provides HuggingFace Accelerate-based distributed training framework for genomic models, including automatic mixed precision training, distributed training support, early stopping, and model checkpointing.
- class omnigenbench.src.trainer.accelerate_trainer.AccelerateTrainer(model: Module, **kwargs)[source]
Bases:
BaseTrainerHuggingFace Accelerate-based distributed trainer for genomic models.
This trainer provides distributed training capabilities with automatic mixed precision, gradient accumulation, and early stopping. It supports both single and multi-GPU training with seamless integration with HuggingFace Accelerate.
- Variables:
accelerator – HuggingFace Accelerate instance for distributed training
early_stop_flag – Tensor for coordinating early stopping across processes
Example
>>> trainer = AccelerateTrainer( ... model=model, ... train_dataset=train_dataset, ... eval_dataset=eval_dataset, ... epochs=10, ... batch_size=32, ... optimizer=optimizer ... ) >>> metrics = trainer.train()
- evaluate() Dict[str, Any][source]
Evaluate the model on the validation dataset.
This method runs the model in evaluation mode and computes metrics on the validation dataset. It handles distributed evaluation and gathers results from all processes.
- Returns:
Dict[str, Any] – Dictionary containing evaluation metrics
- save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]
Save the trained model.
- Parameters:
path_to_save (str) – Path to save the model
overwrite (bool) – Whether to overwrite existing files (default: False)
**kwargs – Additional keyword arguments for model saving
- test() Dict[str, Any][source]
Test the model on the test dataset.
This method runs the model in evaluation mode and computes metrics on the test dataset. It handles distributed testing and gathers results from all processes.
- Returns:
Dict[str, Any] – Dictionary containing test metrics
- train(path_to_save: str | None = None, **kwargs) Dict[str, Any][source]
Train the model using distributed training.
This method performs the complete training loop with validation, early stopping, and model checkpointing. It handles distributed training across multiple GPUs and processes.
- Parameters:
path_to_save (Optional[str]) – Path to save the trained model
**kwargs – Additional keyword arguments for model saving
- Returns:
Dict[str, Any] – Dictionary containing training metrics
Features¶
Zero-Config Distributed: Automatic multi-GPU detection and setup
DeepSpeed Integration: Support for Zero Stage 1/2/3
FSDP Support: Fully Sharded Data Parallel
Flexible Backends: DDP, DeepSpeed, FSDP, and more
Gradient Synchronization: Automatic gradient accumulation across devices
Distributed Training Modes¶
Multi-GPU (DDP):
# Auto-detects all GPUs
trainer = AccelerateTrainer(
model=model,
train_dataset=train_dataset,
batch_size=32, # Per-device batch size
)
# Launch
accelerate launch train.py
DeepSpeed:
# Create ds_config.json
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2
},
"fp16": {
"enabled": true
}
}
# Launch
accelerate launch --config_file ds_config.json train.py
FSDP (Large Models):
# Configure once
accelerate config # Select FSDP, FULL_SHARD
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
trainer = AccelerateTrainer(
model=model,
train_dataset=train_dataset,
batch_size=16,
)
# Launch
accelerate launch train.py
Example¶
from omnigenbench import AccelerateTrainer
trainer = AccelerateTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
epochs=10,
batch_size=16, # Per-device batch size
gradient_accumulation_steps=4,
autocast="fp16",
save_dir="./checkpoints",
eval_steps=500,
save_steps=500,
)
# Train
metrics = trainer.train()
# Only print on main process
if trainer.accelerator.is_main_process:
print(metrics)
HFTrainer (HuggingFace Integration)¶
Wrapper for HuggingFace Trainer with OmniGenome metadata support.
HuggingFace trainer integration for genomic models.
This module provides HuggingFace trainer wrappers for genomic models, enabling seamless integration with the HuggingFace training ecosystem while maintaining OmniGenome-specific functionality.
- class omnigenbench.src.trainer.hf_trainer.HFTrainer(model: Module, training_args: TrainingArguments | None = None, **kwargs)[source]
Bases:
BaseTrainerHuggingFace trainer wrapper for genomic models.
This class extends the OmniGenome BaseTrainer to integrate with HuggingFace Trainer while maintaining OmniGenome-specific metadata and functionality. It provides seamless integration with the HuggingFace training ecosystem.
- Variables:
hf_trainer – The underlying HuggingFace Trainer instance
training_args – HuggingFace TrainingArguments instance
metadata – Dictionary containing OmniGenome library information
Example
>>> from transformers import TrainingArguments >>> training_args = TrainingArguments( ... output_dir="./output", ... num_train_epochs=3, ... per_device_train_batch_size=16, ... ) >>> trainer = HFTrainer( ... model=model, ... train_dataset=train_dataset, ... eval_dataset=eval_dataset, ... training_args=training_args ... ) >>> metrics = trainer.train()
- evaluate() Dict[str, Any][source]
Evaluate the model on the validation dataset.
- Returns:
Dict[str, Any] – Dictionary containing evaluation metrics
- get_model(**kwargs) Module[source]
Get the trained model.
- Parameters:
**kwargs – Additional keyword arguments
- Returns:
torch.nn.Module – The trained model
- save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]
Save the trained model.
- Parameters:
path_to_save (str) – Path to save the model
overwrite (bool) – Whether to overwrite existing files (default: False)
**kwargs – Additional keyword arguments
- test() Dict[str, Any][source]
Test the model on the test dataset.
- Returns:
Dict[str, Any] – Dictionary containing test metrics
- train(path_to_save: str | None = None, **kwargs) Dict[str, Any][source]
Train the model using HuggingFace Trainer.
- Parameters:
path_to_save (Optional[str]) – Path to save the trained model
**kwargs – Additional keyword arguments
- Returns:
Dict[str, Any] – Training metrics and results
- class omnigenbench.src.trainer.hf_trainer.HFTrainingArguments(*args, **kwargs)[source]
Bases:
TrainingArgumentsHuggingFace training arguments wrapper for genomic models.
This class extends the HuggingFace TrainingArguments to include OmniGenome-specific metadata while maintaining full compatibility with the HuggingFace training ecosystem.
- Variables:
metadata – Dictionary containing OmniGenome library information
Example
>>> training_args = HFTrainingArguments( ... output_dir="./output", ... num_train_epochs=3, ... per_device_train_batch_size=16, ... ) >>> trainer = HFTrainer(model=model, training_args=training_args)
Features¶
Full HF Ecosystem: Seamless integration with Transformers
Rich Callbacks: Built-in callbacks for WandB, TensorBoard, etc.
Advanced Logging: Comprehensive training logs
Checkpoint Management: Automatic best model tracking
Custom Metrics: Easy metric computation
Example¶
from omnigenbench import HFTrainer
from transformers import TrainingArguments
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
# Define metrics
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=-1)
return {
"accuracy": accuracy_score(labels, predictions),
"f1": f1_score(labels, predictions, average='weighted'),
}
# Configure training
training_args = TrainingArguments(
output_dir="./results",
# Training
num_train_epochs=10,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
learning_rate=5e-5,
weight_decay=0.01,
warmup_steps=500,
# Evaluation
evaluation_strategy="steps",
eval_steps=500,
# Saving
save_strategy="steps",
save_steps=500,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy",
# Logging
logging_dir="./logs",
logging_steps=100,
report_to=["tensorboard", "wandb"],
# Mixed precision
fp16=True,
# Other
seed=42,
)
# Create trainer
trainer = HFTrainer(
model=model,
training_args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)
# Train
train_result = trainer.train()
# Evaluate
eval_results = trainer.evaluate()
test_results = trainer.evaluate(test_dataset)
# Save
trainer.save_model("./final_model")
trainer.save_metrics("test", test_results)
Advanced Usage¶
Custom Trainer¶
Extend BaseTrainer for custom training logic:
from omnigenbench.src.trainer import BaseTrainer
import torch
class CustomTrainer(BaseTrainer):
def _setup_training_components(self):
"""Setup custom components"""
self.device = torch.device("cuda")
self.model.to(self.device)
self.scaler = torch.cuda.amp.GradScaler()
def _prepare_batch(self, batch):
"""Custom batch preparation"""
return batch.to(self.device)
def _train_epoch(self, epoch):
"""Custom training loop"""
self.model.train()
total_loss = 0
for batch in self.train_loader:
batch = self._prepare_batch(batch)
with torch.cuda.amp.autocast():
outputs = self.model(**batch)
loss = outputs.loss
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
total_loss += loss.item()
return total_loss / len(self.train_loader)
Multi-Task Learning¶
Train on multiple tasks simultaneously:
class MultiTaskTrainer(Trainer):
def __init__(self, model, task_datasets, task_weights=None, **kwargs):
self.task_datasets = task_datasets
self.task_weights = task_weights or {t: 1.0 for t in task_datasets}
super().__init__(model, **kwargs)
def _train_epoch(self, epoch):
self.model.train()
total_loss = 0
for task_name, dataset in self.task_datasets.items():
task_weight = self.task_weights[task_name]
# ... train on each task
return total_loss
Gradient Checkpointing¶
Save memory for large models:
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
trainer = AccelerateTrainer(
model=model,
batch_size=64, # Can use larger batches
)
Best Practices¶
✅ Recommended
Set random seed for reproducibility
Use mixed precision (FP16/BF16) for faster training
Enable early stopping to prevent overfitting
Save checkpoints regularly
Monitor both training and validation metrics
Use gradient accumulation for large effective batch sizes
❌ Avoid
Training without validation set
Extremely large learning rates
Not saving checkpoints (risk losing progress)
Ignoring GPU memory constraints
Training without monitoring metrics
Performance Optimization¶
# 1. Efficient data loading
trainer = Trainer(
num_workers=8,
pin_memory=True,
prefetch_factor=2,
)
# 2. PyTorch 2.0 compilation
model = torch.compile(model)
# 3. Gradient checkpointing
model.gradient_checkpointing_enable()
# 4. Mixed precision
trainer = Trainer(autocast="bf16") # More stable than fp16
# 5. DeepSpeed for large-scale training
# See AccelerateTrainer documentation
See Also¶
../TRAINER_GUIDE - Comprehensive guide with examples
User Guide - Basic usage examples
datasets - Dataset documentation
models - Model documentation