Source code for omnigenbench.auto.auto_bench.auto_bench

# -*- coding: utf-8 -*-
# file: auto_bench.py
# time: 11:54 14/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.

import os
import time
import warnings

import findfile
import torch
from metric_visualizer import MetricVisualizer

from transformers import TrainingArguments, Trainer as HFTrainer
from ...src.abc.abstract_tokenizer import OmniTokenizer
from ...src.lora.lora_model import OmniLoraModel
from ...src.misc.utils import (
    seed_everything,
    fprint,
    load_module_from_path,
    check_bench_version,
    clean_temp_checkpoint,
)
from ...src.trainer.trainer import Trainer
from ...src.trainer.accelerate_trainer import AccelerateTrainer
from ...src.utility.dataset_hub.dataset_hub import download_benchmark
from ...auto.config.auto_config import AutoConfig
from ... import __version__ as omnigenbench_version


[docs] class AutoBench: """ Automated benchmarking framework for evaluating genomic foundation models across standardized benchmark suites with reproducible protocols and statistical rigor. This class orchestrates the complete evaluation pipeline: benchmark dataset acquisition, model loading, distributed inference, metric calculation, multi-seed averaging, and results visualization. It implements best practices for genomic machine learning evaluation, including proper cross-validation, ignored label handling, and task-specific metric selection. **Design Philosophy**: AutoBench follows the "Convention over Configuration" principle, providing sensible defaults while allowing full customization. By default, it uses the ``native`` trainer for single-GPU evaluation (optimizing for control and debuggability), while the CLI defaults to ``accelerate`` for distributed evaluation (optimizing for throughput). **Benchmark Suites Supported**: - **RGB**: RNA Genome Benchmarks (12 tasks) - RNA structure and function prediction - **BEACON**: Broad Evaluation Across Computational geNOmics (13 tasks) - Multi-domain RNA - **PGB**: Plant Genomics Benchmarks (7 categories) - Plant-specific sequence analysis - **GUE**: Genomics Understanding Evaluation (36 datasets) - DNA general understanding - **GB**: Genomics Benchmarks (9 datasets) - Classic DNA classification tasks **Evaluation Protocol**: 1. **Dataset Loading**: Automatically downloads benchmark datasets from HuggingFace Hub or local cache, validates data format, and applies task-specific preprocessing 2. **Model Initialization**: Loads pre-trained models with proper task-specific heads, handling multi-label classification, regression, and token-level prediction 3. **Multi-Seed Evaluation**: Runs independent training/evaluation with different random seeds (typically 3-5) to quantify variance and ensure statistical significance 4. **Metric Calculation**: Computes task-appropriate metrics (MCC, F1, AUPRC for classification; MSE, Spearman for regression) with proper handling of ignored labels 5. **Result Aggregation**: Calculates mean ± standard deviation across seeds, generates visualizations, and serializes results with MetricVisualizer **Trainer Backend Selection**: - ``native`` (Python API default): Pure PyTorch training loop for single-GPU evaluation, providing explicit control over training dynamics and simplified debugging - ``accelerate`` (CLI default): HuggingFace Accelerate for distributed evaluation across multiple GPUs, enabling efficient parallel inference on large benchmarks - ``hf_trainer``: HuggingFace Trainer API integration for users familiar with that ecosystem Attributes: benchmark (str): Name or local path of the benchmark suite to evaluate on. config_or_model (str): HuggingFace Hub identifier or local path to the model. tokenizer: Tokenizer instance for sequence preprocessing. Auto-loaded if None. autocast (str): Mixed precision mode ('fp16', 'bf16', 'fp32') for memory efficiency. overwrite (bool): Whether to overwrite existing evaluation results or resume from cache. trainer (str): Training backend ('native', 'accelerate', 'hf_trainer'). mv_path (str): Path to MetricVisualizer file for result serialization and visualization. mv (MetricVisualizer): Active visualizer instance for tracking metrics across seeds. bench_metadata: Benchmark configuration metadata loaded from benchmark's metadata.py. """ def __init__( self, benchmark, config_or_model, tokenizer=None, **kwargs, ): """ Initializes the AutoBench instance. Args: benchmark (str): The name or path of the benchmark to use. Can be a local path or a HuggingFace Hub benchmark name. For hub benchmarks, it will be automatically downloaded. config_or_model (str): The name or path of the model to evaluate. tokenizer: The tokenizer to use. If None, it will be loaded from the model path. **kwargs: Additional keyword arguments. - autocast (str): The autocast precision to use ('fp16', 'bf16', etc.). Defaults to 'fp16'. - overwrite (bool): Whether to overwrite existing evaluation results. Defaults to False. - trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer'). Defaults to 'native'. - cache_dir (str): Directory to cache downloaded benchmarks from hub. Defaults to './__OMNIGENBENCH_DATA__/benchmarks/'. Example: >>> # Initialize with a local benchmark path >>> bench = AutoBench("/path/to/benchmark", "yangheng/OmniGenome-186M") >>> # Initialize with a HuggingFace Hub benchmark name (auto-downloads) >>> bench = AutoBench("RGB", "yangheng/OmniGenome-186M") >>> # Initialize with custom settings >>> bench = AutoBench("RGB", "model_name", ... autocast="bf16", trainer="accelerate") """ self.benchmark_name_or_path = ( benchmark.rstrip("/") if isinstance(benchmark, str) else benchmark ) self.autocast = kwargs.pop("autocast", "fp16") self.overwrite = kwargs.pop("overwrite", False) self.trainer = kwargs.pop("trainer", "native") self.cache_dir = kwargs.pop("cache_dir", None) # Check if benchmark is a hub name or local path self.is_hub_benchmark = not os.path.exists(self.benchmark_name_or_path) if self.is_hub_benchmark: fprint(f"Detected HuggingFace Hub benchmark: {self.benchmark_name_or_path}") fprint("Downloading benchmark from hub...") # Download benchmark from hub using the unified download logic self.benchmark = download_benchmark( self.benchmark_name_or_path, cache_dir=self.cache_dir, use_hf_api=True, # Use robust HF Hub API force_download=self.overwrite, ) self.benchmark = os.path.dirname( findfile.find_file(self.benchmark, "metadata.py") ) fprint(f"Benchmark downloaded to: {self.benchmark}") else: self.benchmark = self.benchmark_name_or_path fprint(f"Using local benchmark: {self.benchmark}") self.config_or_model = config_or_model self.tokenizer = tokenizer if isinstance(config_or_model, str): self.config_or_model = config_or_model.rstrip("/") self.model_name = config_or_model.split("/")[-1] else: self.model_name = config_or_model.__class__.__name__ if isinstance(tokenizer, str): self.tokenizer = tokenizer.rstrip("/") os.makedirs("./autobench_evaluations", exist_ok=True) time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime()) # Use benchmark name for mv_name (not full path) benchmark_name = os.path.basename(self.benchmark_name_or_path) mv_name = f"{benchmark_name}-{self.model_name}" self.mv_path = f"./autobench_evaluations/{mv_name}-{time_str}.mv" mv_paths = findfile.find_files( "./autobench_evaluations", and_key=[benchmark_name, self.model_name, ".mv"], ) if mv_paths and not self.overwrite: self.mv = MetricVisualizer.load(mv_paths[-1]) self.mv.summary(round=4) else: self.mv = MetricVisualizer(self.mv_path) # Import benchmark list self.bench_metadata = load_module_from_path( f"bench_metadata", f"{self.benchmark}/metadata.py" ) if hasattr(self.bench_metadata, "__omnigenbench_version__"): fprint( "Benchmark metadata version:", self.bench_metadata.__omnigenbench_version__, ) check_bench_version( self.bench_metadata.__omnigenbench_version__, omnigenbench_version ) elif hasattr(self.bench_metadata, "__omnigenome_version__"): fprint( "Benchmark metadata version:", self.bench_metadata.__omnigenome_version__, ) check_bench_version( self.bench_metadata.__omnigenome_version__, omnigenbench_version ) fprint("Loaded benchmarks: ", self.bench_metadata.bench_list) self.bench_info()
[docs] def bench_info(self): """ Prints and returns information about the current benchmark setup. Returns: str: A string containing benchmark information. Example: >>> info = bench.bench_info() >>> print(info) """ info = f"Benchmark Root: {self.benchmark}\n" info += f"Benchmark List: {self.bench_metadata.bench_list}\n" info += f"Model Name or Path: {self.model_name}\n" info += f"Tokenizer: {self.tokenizer}\n" info += f"Metric Visualizer Path: {self.mv_path}\n" info += f"BenchConfig Details: {self.bench_metadata}\n" fprint(info) return info
[docs] def run(self, **kwargs): """ Runs the benchmarking process. This method iterates through the tasks in the benchmark, loads the corresponding configurations, initializes the model, tokenizer, and datasets, and then trains and evaluates the model. Args: **kwargs: Additional keyword arguments that will override the default parameters in the benchmark configuration. Example: >>> # Run benchmarking with default settings >>> bench.run() >>> # Run with custom parameters >>> bench.run(learning_rate=1e-4, batch_size=16) """ bs_scale = kwargs.pop("bs_scale", 1) # Import benchmark config for _, bench in enumerate(self.bench_metadata.bench_list): _kwargs = kwargs.copy() clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day fprint( ">" * 80, f"\nRunning evaluation for task: {bench}", "Progress: ", _ + 1, "/", len(self.bench_metadata.bench_list), f"{(_ + 1) * 100 / len(self.bench_metadata.bench_list)}%", ) bench_config_path = findfile.find_file( self.benchmark, and_key=f"{self.benchmark}.{bench}.config".split("."), ) config = load_module_from_path("config", bench_config_path) bench_config = None for attr_name in dir(config): attr = getattr(config, attr_name) if isinstance( attr, AutoConfig ): # Check if it is an instance of AutoConfig bench_config = attr if bench_config is None: raise ValueError( f"Could not find AutoConfig instance in {bench_config_path}" ) fprint(f"Loaded config for {bench} from {bench_config_path}") fprint(bench_config) # Init Tokenizer and Model if not self.tokenizer: tokenizer = OmniTokenizer.from_pretrained( self.config_or_model, trust_remote_code=bench_config.get("trust_remote_code", True), **bench_config, ) else: tokenizer = self.tokenizer for key, value in _kwargs.items(): if key in bench_config: fprint( "Override", key, "with", value, "according to the input kwargs", ) bench_config.update({key: value}) else: warnings.warn( f"kwarg: {key} not found in bench_config while setting {key} = {value}" ) bench_config.update({key: value}) for key, value in bench_config.items(): if key in bench_config and key in _kwargs: _kwargs.pop(key) if not isinstance(bench_config["seeds"], list): bench_config["seeds"] = [bench_config["seeds"]] random_seeds = bench_config["seeds"] for seed in random_seeds: batch_size = ( bench_config["batch_size"] if "batch_size" in bench_config else 8 ) * bs_scale record_name = f"{self.benchmark}-{bench}-{self.model_name}".split("/")[ -1 ] # check if the record exists if record_name in self.mv.transpose() and len( list(self.mv.transpose()[record_name].values())[0] ) >= len(random_seeds): continue seed_everything(seed) if self.config_or_model: model_cls = bench_config["model_cls"] model = model_cls( self.config_or_model, tokenizer=tokenizer, label2id=bench_config.label2id, num_labels=bench_config["num_labels"], trust_remote_code=True, ignore_mismatched_sizes=True, ) else: raise ValueError( "config_or_model is not specified. Please provide a valid model name or path." ) fprint(f"\n{model}") if kwargs.get("lora_config", {}) or kwargs.get("lora", True): fprint( "Applying LoRA to the model with config:", kwargs.get("lora_config", {}) or "Default Config", ) model = OmniLoraModel(model, **kwargs.get("lora_config", {})) # Init Trainer dataset_cls = bench_config["dataset_cls"] if hasattr(model.config, "max_position_embeddings"): max_length = min( bench_config["max_length"], model.config.max_position_embeddings, ) else: max_length = bench_config["max_length"] train_set = dataset_cls( dataset_name_or_path=bench_config["train_file"], tokenizer=tokenizer, label2id=bench_config["label2id"], max_length=max_length, structure_in=bench_config.get("structure_in", False), max_examples=bench_config.get("max_examples", None), shuffle=bench_config.get("shuffle", True), drop_long_seq=bench_config.get("drop_long_seq", False), **_kwargs, ) test_set = dataset_cls( dataset_name_or_path=bench_config["test_file"], tokenizer=tokenizer, label2id=bench_config["label2id"], max_length=max_length, structure_in=bench_config.get("structure_in", False), max_examples=bench_config.get("max_examples", None), shuffle=False, drop_long_seq=bench_config.get("drop_long_seq", False), **_kwargs, ) if "valid_file" in bench_config and bench_config["valid_file"]: valid_set = dataset_cls( dataset_name_or_path=bench_config["valid_file"], tokenizer=tokenizer, label2id=bench_config["label2id"], max_length=max_length, structure_in=bench_config.get("structure_in", False), max_examples=bench_config.get("max_examples", None), shuffle=False, drop_long_seq=bench_config.get("drop_long_seq", False), **_kwargs, ) else: valid_set = None if self.trainer == "hf_trainer": # Set up HuggingFace Trainer hf_kwargs = { k: v for k, v in kwargs.items() if hasattr(TrainingArguments, k) and k != "output_dir" } training_args = TrainingArguments( output_dir=f"autobench_evaluations/{self.model_name}-{bench}", num_train_epochs=hf_kwargs.pop( "num_train_epochs", bench_config["epochs"] ), per_device_train_batch_size=hf_kwargs.pop( "batch_size", batch_size ), per_device_eval_batch_size=hf_kwargs.pop( "batch_size", batch_size ), gradient_accumulation_steps=hf_kwargs.pop( "gradient_accumulation_steps", 1 ), learning_rate=hf_kwargs.pop("learning_rate", 2e-5), weight_decay=hf_kwargs.pop("weight_decay", 0), eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"), save_strategy=hf_kwargs.pop("save_strategy", "epoch"), fp16=hf_kwargs.pop("fp16", True), remove_unused_columns=False, label_names=["labels"], **hf_kwargs, ) valid_set = valid_set if len(valid_set) else test_set if len(bench_config["compute_metrics"]) > 1: fprint( "Multiple metrics not supported by HFTrainer, using the first one metric only." ) trainer = HFTrainer( model=model, args=training_args, train_dataset=train_set, eval_dataset=valid_set, compute_metrics=( bench_config["compute_metrics"][0] if isinstance(bench_config["compute_metrics"], list) else bench_config["compute_metrics"] ), ) # Train and evaluate eval_result = trainer.evaluate( valid_set if len(valid_set) else test_set ) print(eval_result) train_result = trainer.train() eval_result = trainer.evaluate() test_result = trainer.evaluate( test_set if len(test_set) else valid_set ) metrics = { "train": train_result.metrics, "eval": eval_result, "test": test_result, } fprint(metrics) else: optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=( bench_config["learning_rate"] if "learning_rate" in bench_config else 2e-5 ), weight_decay=( bench_config["weight_decay"] if "weight_decay" in bench_config else 0 ), ) if self.trainer == "accelerate": trainer_cls = AccelerateTrainer else: trainer_cls = Trainer fprint(f"Using Trainer: {trainer_cls}") trainer = trainer_cls( model=model, train_dataset=train_set, eval_dataset=valid_set, test_dataset=test_set, batch_size=batch_size, patience=( bench_config["patience"] if "patience" in bench_config else 3 ), epochs=bench_config["epochs"], gradient_accumulation_steps=bench_config.get( "gradient_accumulation_steps", 1 ), optimizer=optimizer, loss_fn=( bench_config["loss_fn"] if "loss_fn" in bench_config else None ), compute_metrics=bench_config["compute_metrics"], seed=seed, autocast=self.autocast, **_kwargs, ) metrics = trainer.train() predictions = trainer.predictions if bench_config.get("save_predictions", False): os.makedirs(f"predictions/{bench}", exist_ok=True) import numpy as np for split in predictions.keys(): with open( f"predictions/{bench}/{split}.npy", "wb", ) as f: np.save(f, predictions[split]) if metrics: for key, value in metrics["test"][-1].items(): try: value = float(value) except: pass # ignore non-float values self.mv.log(f"{record_name}", f"{key}", value) # for key, value in metrics['test'][-1].items(): # self.mv.log(f'{record_name}', f'test_{key}', value) # for i, valid_metrics in enumerate(metrics["valid"]): # for key, value in valid_metrics.items(): # self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value) self.mv.summary(round=4) self.mv.dump(self.mv_path) self.mv.to_csv(self.mv_path.replace(".mv", ".csv")) del model, trainer, optimizer torch.cuda.empty_cache()