Source code for omnigenbench.src.trainer.hf_trainer

# -*- coding: utf-8 -*-
# file: hf_trainer.py
# time: 14:40 06/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
HuggingFace trainer integration for genomic models.

This module provides HuggingFace trainer wrappers for genomic models,
enabling seamless integration with the HuggingFace training ecosystem
while maintaining OmniGenome-specific functionality.
"""

from typing import Dict, Any, Optional, Union
import torch
from transformers import Trainer
from transformers import TrainingArguments

from .base_trainer import BaseTrainer
from ... import __name__ as omnigenbench_name
from ... import __version__ as omnigenbench_version


[docs] class HFTrainer(BaseTrainer): """ HuggingFace trainer wrapper for genomic models. This class extends the OmniGenome BaseTrainer to integrate with HuggingFace Trainer while maintaining OmniGenome-specific metadata and functionality. It provides seamless integration with the HuggingFace training ecosystem. Attributes: hf_trainer: The underlying HuggingFace Trainer instance training_args: HuggingFace TrainingArguments instance metadata: Dictionary containing OmniGenome library information Example: >>> from transformers import TrainingArguments >>> training_args = TrainingArguments( ... output_dir="./output", ... num_train_epochs=3, ... per_device_train_batch_size=16, ... ) >>> trainer = HFTrainer( ... model=model, ... train_dataset=train_dataset, ... eval_dataset=eval_dataset, ... training_args=training_args ... ) >>> metrics = trainer.train() """ def __init__( self, model: torch.nn.Module, training_args: Optional[TrainingArguments] = None, **kwargs, ): """ Initialize the HuggingFace trainer wrapper. Args: model (torch.nn.Module): The model to be trained training_args (Optional[TrainingArguments]): HuggingFace training arguments **kwargs: Additional keyword arguments passed to BaseTrainer """ # Extract training arguments or create default ones if training_args is None: training_args = TrainingArguments( output_dir="./output", num_train_epochs=kwargs.get("epochs", 3), per_device_train_batch_size=kwargs.get("batch_size", 8), per_device_eval_batch_size=kwargs.get("batch_size", 8), logging_dir="./logs", evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, ) self.training_args = training_args # Initialize base trainer super().__init__(model, **kwargs) # Store metadata self.metadata = { "library_name": omnigenbench_name, "omnigenbench_version": omnigenbench_version, } def _setup_training_components(self) -> None: """ Set up HuggingFace training-specific components. This method initializes the HuggingFace Trainer with the model, datasets, and training arguments. """ # Initialize HuggingFace Trainer self.hf_trainer = Trainer( model=self.model, args=self.training_args, train_dataset=self.train_loader.dataset if self.train_loader else None, eval_dataset=self.eval_loader.dataset if self.eval_loader else None, compute_metrics=self._compute_hf_metrics if self.compute_metrics else None, ) def _compute_hf_metrics(self, eval_pred): """ Compute metrics for HuggingFace trainer. This method adapts OmniGenome metrics to work with HuggingFace trainer. Args: eval_pred: Evaluation predictions from HuggingFace trainer Returns: Dict[str, float]: Dictionary of computed metrics """ predictions, labels = eval_pred # Convert to numpy arrays if hasattr(predictions, "numpy"): predictions = predictions.numpy() if hasattr(labels, "numpy"): labels = labels.numpy() # Compute metrics using OmniGenome metric functions metrics = {} for metric_func in self.compute_metrics: metrics.update(metric_func(labels, predictions)) return metrics def _prepare_batch(self, batch: Any) -> Any: """ Prepare a batch for model input. For HuggingFace trainer, batch preparation is handled internally. Args: batch: Input batch Returns: The input batch """ return batch def _predict_batch(self, batch: Any) -> Dict[str, torch.Tensor]: """ Generate predictions for a batch using the model. Args: batch: Input batch Returns: Dict[str, torch.Tensor]: Dictionary containing predictions """ return self.model.predict(batch) def _train_epoch(self, epoch: int) -> float: """ Train the model for one epoch. This method is not used in HuggingFace trainer as training is handled by the HuggingFace Trainer.train() method. Args: epoch (int): Current epoch number Returns: float: Average training loss for the epoch """ # This method is not used in HuggingFace trainer return 0.0
[docs] def train(self, path_to_save: Optional[str] = None, **kwargs) -> Dict[str, Any]: """ Train the model using HuggingFace Trainer. Args: path_to_save (Optional[str]): Path to save the trained model **kwargs: Additional keyword arguments Returns: Dict[str, Any]: Training metrics and results """ # Set output directory if path_to_save is provided if path_to_save: self.training_args.output_dir = path_to_save # Train using HuggingFace Trainer train_result = self.hf_trainer.train() # Extract metrics train_metrics = train_result.metrics # Evaluate if eval dataset is available if self.eval_loader: eval_metrics = self.hf_trainer.evaluate() else: eval_metrics = {} # Test if test dataset is available if self.test_loader: test_metrics = self.test() else: test_metrics = {} # Combine all metrics all_metrics = { "train": train_metrics, "valid": eval_metrics, "test": test_metrics, } # Save model if requested if path_to_save: self.save_model(path_to_save, **kwargs) return all_metrics
[docs] def evaluate(self) -> Dict[str, Any]: """ Evaluate the model on the validation dataset. Returns: Dict[str, Any]: Dictionary containing evaluation metrics """ if self.eval_loader: return self.hf_trainer.evaluate() else: return {}
[docs] def test(self) -> Dict[str, Any]: """ Test the model on the test dataset. Returns: Dict[str, Any]: Dictionary containing test metrics """ if self.test_loader: # Use HuggingFace trainer's predict method predictions = self.hf_trainer.predict(self.test_loader.dataset) # Extract predictions and labels pred_labels = predictions.predictions true_labels = predictions.label_ids # Compute metrics if not torch.all(torch.tensor(true_labels) == -100): test_metrics = {} for metric_func in self.compute_metrics: test_metrics.update(metric_func(true_labels, pred_labels)) else: test_metrics = {"Test labels may be NaN. No metrics calculated.": 0} return test_metrics else: return {}
[docs] def save_model(self, path_to_save: str, overwrite: bool = False, **kwargs) -> None: """ Save the trained model. Args: path_to_save (str): Path to save the model overwrite (bool): Whether to overwrite existing files (default: False) **kwargs: Additional keyword arguments """ # Save using HuggingFace trainer self.hf_trainer.save_model(path_to_save) # Also save using OmniGenome model's save method if available if hasattr(self.model, "save"): self.model.save(path_to_save, overwrite, **kwargs)
[docs] def get_model(self, **kwargs) -> torch.nn.Module: """ Get the trained model. Args: **kwargs: Additional keyword arguments Returns: torch.nn.Module: The trained model """ return self.hf_trainer.model
[docs] class HFTrainingArguments(TrainingArguments): """ HuggingFace training arguments wrapper for genomic models. This class extends the HuggingFace TrainingArguments to include OmniGenome-specific metadata while maintaining full compatibility with the HuggingFace training ecosystem. Attributes: metadata: Dictionary containing OmniGenome library information Example: >>> training_args = HFTrainingArguments( ... output_dir="./output", ... num_train_epochs=3, ... per_device_train_batch_size=16, ... ) >>> trainer = HFTrainer(model=model, training_args=training_args) """ def __init__(self, *args, **kwargs): """ Initialize the HuggingFace training arguments wrapper. Args: *args: Positional arguments passed to the parent TrainingArguments **kwargs: Keyword arguments passed to the parent TrainingArguments """ super(HFTrainingArguments, self).__init__(*args, **kwargs) self.metadata = { "library_name": omnigenbench_name, "omnigenbench_version": omnigenbench_version, }