Abstract Classes

OmniGenBench provides a set of abstract base classes that define the core interfaces for datasets, models, metrics, and tokenizers. These abstract classes are designed to be subclassed, allowing users to implement custom logic for new data formats, model architectures, evaluation metrics, or sequence representations.

How to Use Abstract Classes:

  • Start by exploring the abstract base classes for datasets, models, metrics, and tokenizers.

  • To add new functionality, subclass the relevant abstract class and implement the required methods.

  • The package uses these abstract classes as the foundation for all built-in and user-extended components, ensuring consistency and interoperability.

Main Abstract Classes:

  • OmniDataset: Base class for datasets. Subclass to support new data formats or preprocessing logic.

  • OmniModel: Base class for models. Subclass to implement custom architectures or tasks.

  • OmniMetric: Base class for evaluation metrics. Subclass to define new metrics for benchmarking.

  • OmniTokenizer: Base class for tokenizers. Subclass to support new sequence representations.

Refer to the API documentation below for details on each abstract class, including their methods and usage examples.

OmniModel

class omnigenbench.src.abc.abstract_model.OmniModel(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: EmbeddingMixin, Module

Abstract base class providing a unified interface for genomic foundation models in the OmniGenBench framework. This class handles model initialization, forward passes, loss computation, prediction interfaces, and model persistence while maintaining compatibility with HuggingFace’s ecosystem.

Architectural Pattern: This class follows the Template Method pattern, providing common infrastructure while delegating task-specific behavior to subclasses (OmniModelForSequenceClassification, OmniModelForTokenClassification, etc.).

Inherited Capabilities (via EmbeddingMixin):

  • Embedding Generation: batch_encode(), encode(), encode_tokens() for extracting fixed-length sequence representations from genomic sequences

  • Attention Extraction: extract_attention_scores(), batch_extract_attention_scores() for model interpretability and attention weight visualization

  • Similarity Computation: compute_similarity() for sequence comparison and relatedness analysis

  • Visualization Tools: visualize_attention_pattern() for generating attention heatmaps and understanding model focus

All task-specific OmniModel subclasses automatically inherit these capabilities, enabling representation learning and interpretability without additional implementation.

Design Philosophy: By inheriting from both EmbeddingMixin and torch.nn.Module, this class seamlessly integrates sequence embedding capabilities with PyTorch’s standard training infrastructure, making it compatible with native PyTorch training loops, HuggingFace Trainer, and Accelerate-based distributed training.

Task-Specific Subclasses: Users should instantiate concrete implementations rather than this abstract class directly:

  • OmniModelForSequenceClassification: Sequence-level classification tasks (e.g., promoter identification, functional annotation)

  • OmniModelForMultiLabelSequenceClassification: Multi-label classification (e.g., transcription factor binding site prediction with 919 TFs)

  • OmniModelForTokenClassification: Per-nucleotide predictions (e.g., splice site detection, secondary structure annotation)

  • OmniModelForSequenceRegression: Sequence-level continuous predictions (e.g., gene expression levels, binding affinity scores)

  • OmniModelForTokenRegression: Per-nucleotide continuous predictions (e.g., chromatin accessibility profiles, conservation scores)

  • OmniModelForRNADesign: Structure-guided RNA sequence generation (genetic algorithm + masked language model)

  • OmniModelForEmbedding: Representation learning and feature extraction

property device

Return the actual device of model parameters, not cached value.

static from_pretrained(config_or_model, tokenizer, *args, **kwargs)[source]

Loads a pre-trained model and tokenizer.

Parameters:
  • config_or_model – The name or path of the pre-trained model.

  • tokenizer – The tokenizer to use.

  • args – Additional positional arguments.

  • kwargs – Additional keyword arguments.

Returns:

An instance of OmniModel.

inference(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns processed predictions that are ready for human consumption. It typically includes post-processing steps like converting logits to class labels or probabilities.

If the model has a dataset_class, this method will use the dataset’s prepare_input method for data preprocessing, allowing for more complex data preparation including custom field handling.

Parameters:
  • sequence_or_inputs

    Can be one of: - str: A single sequence (e.g., “ATCGATCG”) - list: A list of sequences (e.g., [“ATCGATCG”, “GCTAGCTA”]) - dict: A dictionary with ‘sequence’/’seq’ and optionally ‘label’/’labels’

    and other custom fields (e.g., {“sequence”: “ATCG”, “label”: 1})

    • BatchEncoding/dict with ‘input_ids’: Already tokenized inputs

  • **kwargs – Additional arguments for tokenization, dataset preparation, or inference. Common options include ‘max_length’, ‘padding’, ‘truncation’, etc.

Returns:

dict

A dictionary containing the processed predictions, typically including

’predictions’, ‘confidence’, and other human-readable outputs.

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # Class labels
>>>
>>> # Inference on multiple sequences
>>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
>>>
>>> # Inference with dict input (if dataset_class is set)
>>> results = model.inference({"sequence": "ATCGATCG", "label": 1})
>>> print(results['predictions'])
last_hidden_state_forward(**inputs)[source]

Performs a forward pass to get the last hidden state from the base model. It also handles compatibility with different model architectures by mapping input parameters appropriately.

Parameters:

**inputs – The inputs to the model, compatible with the base model’s forward method. Typically includes ‘input_ids’, ‘attention_mask’, and other model-specific parameters.

Returns:

torch.Tensor – The last hidden state tensor.

Example

>>> inputs = {
...     'input_ids': torch.tensor([[1, 2, 3, 4]]),
...     'attention_mask': torch.tensor([[1, 1, 1, 1]])
... }
>>> hidden_states = model.last_hidden_state_forward(**inputs)
load(path, **kwargs)[source]

Load a complete model from a saved directory.

This method performs a comprehensive load operation that restores: - Model metadata and configuration - Loss function (if saved) - Model weights with validation - Tokenizer

The load process follows a 6-step workflow: 1. Load and validate metadata 2. Load model configuration 3. Restore dataset class (if available) 4. Restore loss function (if available) 5. Load model weights with validation 6. Load tokenizer

Parameters:
  • path (str) – Directory path containing the saved model.

  • **kwargs – Additional keyword arguments passed to loading functions. Common options include: - device (str): Device to load model to (e.g., ‘cpu’, ‘cuda:0’) - trust_remote_code (bool): Whether to trust custom code

Returns:

OmniModel – The loaded model instance (self).

Raises:
  • ValueError – If saved model class doesn’t match current class.

  • FileNotFoundError – If required files (metadata.json, pytorch_model.bin) are missing.

Example

>>> # Basic load
>>> loaded_model = model.load("checkpoint")
>>> # Load to specific device
>>> loaded_model = model.load("checkpoint", device="cuda:0")
>>> # Load with custom code trust
>>> loaded_model = model.load("checkpoint", trust_remote_code=True)
Side Effects:
  • Updates all model attributes (weights, config, tokenizer, loss_fn)

  • Prints warnings for configuration differences

  • Prints warnings for weight loading issues (missing/unexpected keys)

  • Prints confirmation messages for restored components

Note

  • Model class must match the saved model class

  • Partial loading is supported (missing keys will be randomly initialized)

  • Loss function restoration is optional (warnings only if fails)

  • Custom models require trust_remote_code=True (enabled by default)

loss_function(logits, labels)[source]

Calculates the loss. This method should be implemented by concrete model classes to define how the loss is calculated for their specific task (classification, regression, etc.).

Parameters:
  • logits (torch.Tensor) – The model’s output logits.

  • labels (torch.Tensor) – The ground truth labels.

Returns:

torch.Tensor – The calculated loss.

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Example

>>> # In a classification model
>>> loss = model.loss_function(logits, labels)
model_info()[source]

Prints and returns detailed information about the model.

Returns:

A string containing the model information.

predict(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns the raw model outputs (logits, hidden states, etc.) without post-processing. It’s useful for getting the model’s direct predictions for further processing.

If the model has a dataset_class, this method will use the dataset’s prepare_input method for data preprocessing, allowing for more complex data preparation including custom field handling.

Parameters:
  • sequence_or_inputs

    Can be one of: - str: A single sequence (e.g., “ATCGATCG”) - list: A list of sequences (e.g., [“ATCGATCG”, “GCTAGCTA”]) - dict: A dictionary with ‘sequence’/’seq’ and optionally ‘label’/’labels’

    and other custom fields (e.g., {“sequence”: “ATCG”, “label”: 1})

    • BatchEncoding/dict with ‘input_ids’: Already tokenized inputs

  • **kwargs – Additional arguments for tokenization or dataset preparation. Common options include ‘max_length’, ‘padding’, ‘truncation’, etc.

Returns:

dict

A dictionary containing the raw model outputs, typically including

logits, last_hidden_state, and other model-specific outputs.

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>>
>>> # Predict on multiple sequences
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
>>>
>>> # Predict with dict input (if dataset_class is set)
>>> outputs = model.predict({"sequence": "ATCGATCG", "label": 1})
save(path, overwrite=False, dtype=torch.float16, **kwargs)[source]

Save the complete model, tokenizer, and metadata to a directory.

This method performs a comprehensive save operation that includes: - Base model configuration files - Custom model and dataset class source code (if applicable) - Complete metadata including loss function and custom attributes - Model weights and tokenizer

The save process follows a 6-step workflow: 1. Save base configuration files 2. Collect comprehensive metadata 3. Save custom model class source (if user-defined) 4. Save custom dataset class source (if available) 5. Save metadata to JSON file 6. Save model weights and tokenizer

Parameters:
  • path (str) – Target directory path for saving the model.

  • overwrite (bool, optional) – Whether to overwrite existing directory. If False, a timestamp will be appended to path. Defaults to False.

  • dtype (torch.dtype, optional) – Data type for saving model weights. Model is temporarily converted to this dtype during saving. Defaults to torch.float16.

  • **kwargs – Additional keyword arguments (reserved for future use).

Side Effects:
  • Creates target directory if it doesn’t exist

  • Temporarily moves model to CPU and converts dtype

  • Restores original device and dtype after saving

  • Prints confirmation message on completion

Example

>>> # Basic save
>>> model.save("my_model")
>>> # Save with overwrite
>>> model.save("my_model", overwrite=True)
>>> # Save with specific dtype
>>> model.save("my_model", dtype=torch.float32)

Note

  • Model is set to eval mode before saving

  • Original device and dtype are preserved

  • Custom classes are only saved if defined outside framework packages

  • Failures in non-critical steps (e.g., custom class saving) are logged but don’t stop the process

set_loss_fn(loss_function)[source]

Sets a custom loss function for the model. The loss function should be compatible with the model’s output format.

Parameters:

loss_function (callable) – A callable loss function that takes logits and labels as arguments.

Example

>>> import torch.nn as nn
>>> model.set_loss_fn(nn.CrossEntropyLoss())
to(*args, **kwargs)[source]

Move model to specified device/dtype and keep device tracking in sync.

omnigenbench.src.abc.abstract_model.count_parameters(model)[source]

This function iterates through all parameters of a PyTorch model and counts only those that require gradients (i.e., trainable parameters).

Parameters:

model (torch.nn.Module) – A PyTorch model.

Returns:

int – The total number of trainable parameters.

Example

>>> model = OmniModelForSequenceClassification(config, tokenizer)
>>> num_params = count_parameters(model)
>>> print(f"Model has {num_params} trainable parameters")

OmniDataset

class omnigenbench.src.abc.abstract_dataset.OmniDataset(dataset_name_or_path=None, tokenizer=None, max_length=None, **kwargs)[source]

Bases: Dataset

Abstract base class providing a unified interface for genomic datasets in the OmniGenBench framework. This class handles polymorphic data loading from multiple formats, integrated tokenization, label management, and PyTorch DataLoader compatibility.

Design Pattern: This class implements the Strategy pattern for format-specific parsing while maintaining a consistent API. Different file formats (JSON, CSV, FASTA, Parquet, etc.) are handled transparently through pluggable loaders, with tokenization and preprocessing applied uniformly regardless of input format.

Key Features:

  • Format Agnosticism: Supports JSON, CSV, Parquet, FASTA, FASTQ, BED, VCF, and NumPy formats through auto-detection based on file extension. Custom formats can be added by subclassing and implementing format-specific loaders.

  • Integrated Tokenization: Sequences are tokenized within the dataset pipeline for consistency and efficient caching. Tokenization parameters (max_length, padding, truncation) are configured at dataset initialization.

  • Lazy Loading: Large datasets are loaded incrementally to minimize memory footprint. Data is read into memory on-demand during training/inference rather than all at once.

  • Label Management: Automatic bidirectional mapping between string labels and integer indices (label2id/id2label), with support for multi-label scenarios and PyTorch’s -100 ignore convention for masked tokens.

  • RNA Structure Integration: Optional secondary structure prediction via ViennaRNA for structure-aware models. Structures are cached to avoid redundant computation.

  • Sequence Filtering: Optional filtering of sequences exceeding max_length via drop_long_seq parameter, useful for maintaining fixed-length batches without truncation.

Data Format Convention: All input files must contain at minimum a sequence field (or one of its aliases: seq, text, dna, rna). For supervised tasks, a label field (or aliases: labels, target, y) is also required. Additional custom fields are preserved and passed through the pipeline.

Supported File Formats:

  • JSON: Line-delimited JSON (.json, .jsonl) with one record per line

  • CSV/TSV: Comma or tab-separated values (.csv, .tsv) with header row

  • Parquet: Apache Parquet format (.parquet) for efficient columnar storage

  • FASTA: Biological sequence format (.fasta, .fa) with optional metadata in headers

  • FASTQ: Sequencing format (.fastq, .fq) with quality scores (quality scores ignored)

  • BED: Genomic interval format (.bed) for position-based features

  • VCF: Variant Call Format (.vcf) for genetic variants (experimental)

  • NumPy: NumPy arrays (.npy, .npz) for pre-computed features

Variables:
  • tokenizer – Tokenizer instance for sequence encoding. Must be compatible with the model architecture being used. Can be OmniTokenizer or HuggingFace tokenizer.

  • max_length (int) – Maximum sequence length for tokenization. Sequences exceeding this length are truncated (default) or dropped (if drop_long_seq=True).

  • label2id (dict) – Mapping from string labels to integer indices. Automatically populated during data loading if not provided. Example: {“negative”: 0, “positive”: 1}.

  • id2label (dict) – Inverse mapping from integer indices to string labels. Automatically generated from label2id.

  • shuffle (bool) – Whether to shuffle dataset order on initialization. Default True. Set to False for validation/test sets to maintain reproducible evaluation.

  • structure_in (bool) – Whether to include RNA secondary structure predictions as input features. Requires ViennaRNA installation. Default False. Adds dot-bracket notation as additional input for structure-aware models.

  • drop_long_seq (bool) – Whether to drop sequences longer than max_length instead of truncating them. Default False. When True, sequences exceeding max_length are filtered out during loading.

  • metadata (dict) – Framework metadata including version information and environment details. Automatically populated with Python version, OmniGenBench version, timestamp, etc.

  • rna2structure (RNA2StructureCache) – Persistent cache for RNA structure predictions to avoid redundant ViennaRNA calls. Only created when structure_in=True.

  • data (list) – Internal storage for loaded dataset samples. Each element is a dictionary containing ‘sequence’, ‘label’, and any additional custom fields.

Note

This is an abstract base class. Use task-specific subclasses for actual datasets:

  • OmniDatasetForSequenceClassification: Sequence-level classification

  • OmniDatasetForMultiLabelClassification: Multi-label classification

  • OmniDatasetForTokenClassification: Per-nucleotide classification

  • OmniDatasetForSequenceRegression: Sequence-level regression

  • OmniDatasetForTokenRegression: Per-nucleotide regression

classmethod from_hub(dataset_name_or_path, tokenizer, splits=None, max_length=None, cache_dir=None, **kwargs)[source]

Create OmniDataset instances from HuggingFace Hub or local directory.

This method supports loading datasets from: 1. OmniGenBench Hub on HuggingFace (downloads if needed) 2. Local directory containing dataset files

Parameters:
  • dataset_name_or_path (str) – Name of the dataset on HuggingFace Hub, or path to local directory.

  • tokenizer – The tokenizer to use for processing sequences.

  • splits (list, optional) – List of splits to create. Defaults to [‘train’, ‘valid’, ‘test’].

  • max_length (int, optional) – Maximum sequence length.

  • cache_dir (str, optional) – Directory to cache the dataset or look for local files.

  • **kwargs – Additional arguments passed to the dataset constructor.

Returns:

dict – Dictionary containing datasets for each split.

Example

>>> from omnigenbench import OmniTokenizer, OmniDatasetForSequenceClassification
>>> tokenizer = OmniTokenizer.from_pretrained("yangheng/OmniGenome-52M")
>>> # Load from HuggingFace Hub
>>> datasets = OmniDatasetForSequenceClassification.from_hub(
...     "translation_efficiency_prediction",
...     tokenizer=tokenizer
... )
>>> # Load from local directory
>>> datasets = OmniDatasetForSequenceClassification.from_hub(
...     "/path/to/local/dataset",
...     tokenizer=tokenizer,
...     cache_dir="/path/to/local/dataset"
... )
>>> train_loader = datasets['train'].get_dataloader(batch_size=16)
classmethod from_huggingface(dataset_name_or_path, tokenizer, splits=None, max_length=None, cache_dir=None, **kwargs)[source]

Create OmniDataset instances from a HuggingFace dataset.

Deprecated since version 0.3.0: from_huggingface is deprecated and will be removed in version 0.4.0. Use from_hub instead, which supports both HuggingFace Hub and local data sources.

Parameters:
  • dataset_name_or_path (str) – Name of the HuggingFace dataset or base URL.

  • tokenizer – The tokenizer to use for processing sequences.

  • splits (list, optional) – List of splits to create. Defaults to [‘train’, ‘valid’, ‘test’].

  • max_length (int, optional) – Maximum sequence length.

  • cache_dir (str, optional) – Directory to cache the dataset.

  • **kwargs – Additional arguments passed to the dataset constructor.

Returns:

dict – Dictionary containing datasets for each split.

Example

>>> from omnigenbench import OmniTokenizer,OmniDatasetForSequenceClassification
>>> tokenizer = OmniTokenizer.from_pretrained("yangheng/OmniGenome-52M")
>>> datasets = OmniDatasetForSequenceClassification.from_huggingface(
...     "translation_efficiency_prediction",
...     tokenizer=tokenizer
... )
>>> train_loader = datasets['train'].get_dataloader(batch_size=16)
get_column(column_name)[source]

Returns all values for a specific column in the dataset.

Parameters:

column_name (str) – The name of the column.

Returns:

list – A list of values from the specified column.

get_dataloader(batch_size=16, shuffle=None, num_workers=0, pin_memory=None, **kwargs)[source]

Creates a PyTorch DataLoader for this dataset.

Parameters:
  • batch_size (int) – Batch size for the DataLoader.

  • shuffle (bool) – Whether to shuffle the data. If None, uses self.shuffle.

  • num_workers (int) – Number of worker processes for data loading.

  • pin_memory (bool) – Whether to pin memory. If None, auto-detects based on CUDA availability.

  • **kwargs – Additional arguments passed to DataLoader.

Returns:

torch.utils.data.DataLoader – A DataLoader for this dataset.

get_inputs_length()[source]

Calculates and returns statistics about sequence and label lengths.

Returns:

dict – A dictionary with length statistics (min, max, avg).

get_labels()[source]

Returns the set of unique labels in the dataset.

Returns:

set – The set of unique labels.

info(sections=None, detailed=False, return_dict=False)[source]

Print formatted dataset information in table format using tabulate.

This method displays dataset_info in a human-readable table format. It’s separate from model/tokenizer metadata and focuses on dataset characteristics.

Parameters:
  • sections (list, optional) – List of section names to print. If None, prints all. Available sections: ‘basic’, ‘statistics’, ‘features’, ‘splits’, ‘preprocessing’, ‘metrics’, ‘citation’, ‘all’

  • detailed (bool, optional) – If True, prints detailed JSON data. Defaults to False.

  • return_dict (bool, optional) – If True, returns the dataset_info dict instead of printing. Defaults to False.

Returns:

dict or None – Returns dataset_info dict if return_dict=True, otherwise None.

Example

>>> dataset.info()  # Print all sections in table format
>>> dataset.info(sections=['basic', 'statistics'])  # Print specific sections
>>> dataset.info(detailed=True)  # Print detailed JSON data
>>> info_dict = dataset.info(return_dict=True)  # Get info as dictionary
load_dataset_name_or_path(dataset_name_or_path, **kwargs)[source]

Loads data from a file or list of files.

Parameters:
  • dataset_name_or_path (str or list) – Path to the data file or a list of paths.

  • **kwargs – Additional keyword arguments, e.g., max_examples.

Returns:

list – A list of examples.

prepare_input(instance, **kwargs)[source]

Prepares a single data instance for the model. Must be implemented by subclasses.

Parameters:
  • instance (dict) – A single data instance (e.g., a dictionary).

  • **kwargs – Additional keyword arguments for tokenization.

Returns:

dict – A dictionary of tokenized inputs.

print_label_distribution()[source]

Print the distribution of labels for 0-dimensional (scalar) labels. This is useful for classification tasks where each sample has a single label.

sample(n=1) OmniDataset[source]

Returns a random sample of n items from the dataset.

Parameters:

n (int) – The number of samples to return.

Returns:

OmniDataset – A OmniDataset of data samples.

to(device)[source]

Moves all tensor data in the dataset to the specified device.

Parameters:

device (str or torch.device) – The target device.

Returns:

OmniDataset – The dataset itself.

class omnigenbench.src.abc.abstract_dataset.OmniDict(*args, **kwargs)[source]

Bases: dict

This class extends the standard Python dictionary to provide a convenient method for moving all tensor values to a specific device (CPU/GPU).

to(device)[source]

Moves all tensor values in the dictionary to the specified device.

Parameters:

device (str or torch.device) – The target device (e.g., ‘cuda:0’ or ‘cpu’).

Returns:

OmniDict – The dictionary itself, with tensors moved to the new device.

Example

>>> data = OmniDict({'input_ids': torch.tensor([1, 2, 3])})
>>> data.to('cuda:0')  # Moves tensors to GPU
omnigenbench.src.abc.abstract_dataset.covert_input_to_tensor(data)[source]

This function traverses through nested dictionaries and lists, converting numerical values to PyTorch tensors while preserving the structure.

Parameters:

data (list or dict) – A list or dictionary containing data samples.

Returns:

list or dict – The data structure with numerical values converted to tensors.

Example

>>> data = [{'input_ids': [1, 2, 3], 'labels': [0]}]
>>> tensor_data = covert_input_to_tensor(data)
>>> print(type(tensor_data[0]['input_ids']))  # <class 'torch.Tensor'>

OmniTokenizer

class omnigenbench.src.abc.abstract_tokenizer.OmniTokenizer(base_tokenizer=None, max_length=512, **kwargs)[source]

Bases: object

Abstract base class providing a unified interface for tokenizers in the OmniGenBench framework. This class wraps underlying tokenizers (typically from HuggingFace) and provides genomic-specific preprocessing functionality for biological sequence analysis.

Design Pattern: This class implements the Wrapper pattern (also known as Adapter), providing a consistent API while delegating core tokenization to specialized implementations. It adds genomic-specific preprocessing (RNA-to-DNA conversion, whitespace insertion, sequence normalization) while maintaining compatibility with HuggingFace’s tokenizer ecosystem.

Architecture: The tokenizer stack consists of three layers:

  1. Base Tokenizer: HuggingFace AutoTokenizer or custom implementation providing vocabulary, encoding/decoding primitives, and special token handling.

  2. OmniTokenizer Wrapper (this class): Adds genomic preprocessing, metadata tracking, and unified API across different tokenizer types.

  3. Custom Wrappers: Optional task-specific wrappers (loaded from omnigenome_wrapper.py if present in model directory) for specialized preprocessing logic.

Genomic-Specific Features:

  • Sequence Normalization: Automatic uppercase conversion and nucleotide standardization (converts lowercase to uppercase, handles ambiguous nucleotide codes).

  • RNA/DNA Conversion: Bidirectional U↔T conversion via u2t and t2u flags. Essential when applying models trained on one sequence type to another (e.g., using DNA models for RNA data).

  • Whitespace Injection: Optional character separation for character-level models. Converts “ATCG” → “A T C G” for models trained with spaced sequences.

  • Special Token Handling: Automatic insertion of [CLS], [SEP], [PAD], [MASK] tokens according to model requirements. Handles both BERT-style (e.g., [CLS] seq [SEP]) and GPT-style (seq [EOS]) conventions.

  • K-mer Tokenization: Support for overlapping k-mer segmentation (k=3,4,5,6) for capturing local sequence patterns. Example: “ATCGATCG” with k=3 → “ATC TCG CGA GAT ATC TCG”.

  • Codon-Aware Tokenization: Specialized handling for protein-coding sequences with triplet nucleotide units, preserving reading frame information.

  • Structure-Informed Tokenization: Optional integration with RNA secondary structure for structure-aware models (dot-bracket notation encoding).

Integration with Model Loading: When loading models via ModelHub or from_pretrained(), the framework first checks for custom tokenizer wrappers (omnigenome_wrapper.py) in the model directory, falling back to standard AutoTokenizer if not found. This enables model-specific preprocessing without modifying core code.

Common Tokenizer Types:

  • OmniSingleNucleotideTokenizer: Character-level tokenization (vocab size ~10)

  • OmniKmersTokenizer: K-mer based tokenization (vocab size 4^k, typically 64-4096)

  • OmniBPETokenizer: Byte-Pair Encoding for learned subword units (vocab size 1000-50000)

Variables:
  • base_tokenizer – Underlying tokenizer instance (e.g., from HuggingFace Transformers). Provides vocabulary, encoding primitives, and special token definitions. Can be any object implementing encode(), decode(), and __call__() methods.

  • max_length (int) – Default maximum sequence length for tokenization. Can be overridden in individual tokenization calls. Sequences longer than this are truncated. Typical values: 512 (short sequences), 2048 (medium), 10000+ (long genomic regions).

  • metadata (dict) – Framework metadata including version information and custom attributes. Automatically populated with tokenizer type, version, timestamp, etc.

  • u2t (bool) – Whether to convert ‘U’ (uracil) to ‘T’ (thymine) for RNA→DNA conversion. Useful when training DNA models on RNA data or applying DNA-trained models to RNA sequences. Default False.

  • t2u (bool) – Whether to convert ‘T’ to ‘U’ for DNA→RNA conversion. Useful for RNA structure prediction models trained on DNA sequences, or when applying RNA models to DNA data. Default False.

  • add_whitespace (bool) – Whether to insert spaces between characters for character-level tokenization. Required for some BERT-style models trained on spaced sequences. Example: “ATCG” becomes “A T C G”. Default False.

  • trust_remote_code (bool) – Whether to trust remote code when loading tokenizers from HuggingFace Hub. Default True. Set to False in security-critical environments.

Note

  • Set u2t=True when using DNA models on RNA sequences

  • Set t2u=True when using RNA models on DNA sequences

  • Never set both u2t=True and t2u=True simultaneously (results undefined)

  • add_whitespace should match the training configuration of the model

decode(sequence, **kwargs)[source]

Converts a list of token IDs back into a sequence. Must be implemented by subclasses.

Parameters:
  • sequence (list) – A list of token IDs.

  • **kwargs – Additional arguments.

Returns:

str – The decoded sequence.

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Example

>>> # In a nucleotide tokenizer
>>> sequence = tokenizer.decode([1, 2, 3, 4])
>>> print(sequence)  # "ATCG"
encode(sequence, **kwargs)[source]

Converts a sequence into a list of token IDs. Must be implemented by subclasses.

Parameters:
  • sequence (str) – The input sequence.

  • **kwargs – Additional arguments.

Returns:

list – A list of token IDs.

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Example

>>> # In a nucleotide tokenizer
>>> token_ids = tokenizer.encode("ATCGATCG")
>>> print(token_ids)  # [1, 2, 3, 4, 1, 2, 3, 4]
static from_pretrained(config_or_model, **kwargs)[source]

Loads a tokenizer from a pre-trained model path.

Parameters:
  • config_or_model (str) – The name or path of the pre-trained model.

  • **kwargs – Additional arguments for the tokenizer.

Returns:

OmniTokenizer – An instance of a tokenizer.

Example

>>> # Load from a pre-trained model
>>> tokenizer = OmniTokenizer.from_pretrained("model_name")
>>> # Load with custom parameters
>>> tokenizer = OmniTokenizer.from_pretrained("model_name", trust_remote_code=True)
save_pretrained(save_directory)[source]

Saves the base tokenizer to a directory.

Parameters:

save_directory (str) – The directory to save the tokenizer to.

Example

>>> tokenizer.save_pretrained("./saved_tokenizer")
tokenize(sequence, **kwargs)[source]

Converts a sequence into a list of tokens. Must be implemented by subclasses.

Parameters:
  • sequence (str) – The input sequence.

  • **kwargs – Additional arguments.

Returns:

list – A list of tokens.

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Example

>>> # In a nucleotide tokenizer
>>> tokens = tokenizer.tokenize("ATCGATCG")
>>> print(tokens)  # ['A', 'T', 'C', 'G', 'A', 'T', 'C', 'G']

OmniMetrics

class omnigenbench.src.abc.abstract_metric.OmniMetric(metric_func=None, ignore_y=None, *args, **kwargs)[source]

Bases: object

Abstract base class providing a unified interface for evaluation metrics in the OmniGenBench framework. This class integrates seamlessly with scikit-learn’s metric ecosystem while adding genomics-specific functionality for handling masked labels, multi-task evaluation, and specialized biological metrics.

Design Philosophy: This class follows the Strategy pattern, allowing interchangeable metric implementations while maintaining a consistent compute() interface. All scikit-learn metrics are automatically exposed as attributes for convenient access without explicit imports.

Key Features:

  • Scikit-learn Integration: Automatic exposure of all sklearn.metrics functions as attributes (accuracy_score, f1_score, matthews_corrcoef, etc.), eliminating the need for separate metric imports.

  • Masked Label Handling: Support for PyTorch’s -100 ignore convention via the ignore_y parameter. Labels matching ignore_y are filtered out before metric computation, essential for tasks with variable-length outputs or padded sequences.

  • Flexible Computation: The compute() method accepts various input formats (lists, numpy arrays, torch tensors) and returns standardized dictionary outputs for consistent logging and tracking.

  • Multi-Metric Reporting: Subclasses (ClassificationMetric, RegressionMetric, RankingMetric) compute multiple relevant metrics in a single call, providing comprehensive evaluation without manual orchestration.

  • Custom Metric Support: Easy extensibility through subclassing and implementing custom compute() methods for domain-specific metrics (e.g., Matthews Correlation Coefficient for imbalanced genomic datasets).

Common Genomic Use Cases:

  • Imbalanced Classification: MCC and AUPRC for rare variant detection, where accuracy alone is misleading

  • Multi-Label Prediction: Hamming loss and F1-macro for transcription factor binding site prediction across hundreds of TFs

  • Regression Tasks: Spearman correlation for gene expression prediction, where rank order matters more than absolute values

  • Token-Level Prediction: Per-nucleotide metrics for secondary structure prediction and splice site detection

Subclass Implementations:

  • ClassificationMetric: Comprehensive classification metrics (accuracy, precision, recall, F1, MCC, AUROC, AUPRC) with automatic threshold selection

  • RegressionMetric: Regression-specific metrics (MSE, MAE, R², Spearman/Pearson correlation) for continuous predictions

  • RankingMetric: Ranking and retrieval metrics (NDCG, MAP, Precision@K) for information retrieval tasks

Variables:
  • metric_func (callable, optional) – A callable metric function from sklearn.metrics. If provided, used as the primary metric computation function. If None, subclasses should implement their own compute() method.

  • ignore_y (any, optional) – A value in the ground truth labels to be ignored during metric computation. Commonly set to -100 (PyTorch’s default ignore index) or None. Labels matching this value are filtered out before metric calculation, useful for masked language modeling, padding, or variable-length sequences.

  • metadata (dict) – Framework metadata including version information, timestamp, and environment details. Automatically populated on initialization.

Note

This is an abstract base class. Use task-specific subclasses for actual evaluation:

  • Use ClassificationMetric for binary/multi-class/multi-label classification

  • Use RegressionMetric for continuous value prediction

  • Use RankingMetric for ranking and retrieval tasks

  • Subclass OmniMetric for custom metrics with specialized compute() implementations

Example

>>> # Access scikit-learn metrics directly
>>> metric = OmniMetric()
>>> acc = metric.accuracy_score(y_true, y_pred)
>>>
>>> # Use with ignore_y for masked tokens
>>> metric = OmniMetric(ignore_y=-100)
>>> # Labels of -100 will be filtered before computation
compute(y_true, y_pred) dict[source]

Computes the metric. This method must be implemented by subclasses.

Parameters:
  • y_true – Ground truth labels.

  • y_pred – Predicted labels.

Returns:

dict – A dictionary with the metric name as key and its value.

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Example

>>> # In a classification metric
>>> result = metric.compute(y_true, y_pred)
>>> print(result)  # {'accuracy': 0.85}
static flatten(y_true, y_pred)[source]

Flattens the ground truth and prediction arrays. It handles various input formats and converts them to 1D numpy arrays.

Parameters:
  • y_true – Ground truth labels in any format that can be converted to numpy array.

  • y_pred – Predicted labels in any format that can be converted to numpy array.

Returns:

tuple – A tuple of flattened y_true and y_pred as numpy arrays.

Example

>>> y_true = [[1, 2], [3, 4]]
>>> y_pred = [[1, 2], [3, 4]]
>>> flat_true, flat_pred = OmniMetric.flatten(y_true, y_pred)
>>> print(flat_true.shape)  # (4,)

OmniLoRA

This module provides Low-Rank Adaptation (LoRA) implementation for efficient fine-tuning of large genomic language models. LoRA reduces the number of trainable parameters by adding low-rank adaptation layers to existing model weights.

class omnigenbench.src.lora.lora_model.OmniLoraModel(model, **kwargs)[source]

Bases: Module

Wrapper around a LoRA-adapted model.

forward(*args, **kwargs)[source]

Perform a forward pass through the LoRA-adapted model.

This method delegates the forward computation to the underlying LoRA model, which automatically combines the frozen base model outputs with the LoRA adaptation outputs. The forward pass is mathematically equivalent to: output = BaseModel(x) + LoRA_adaptation(x)

Parameters:
  • *args – Positional arguments passed to the underlying model’s forward method. Typically includes input tensors (input_ids, attention_mask, etc.).

  • **kwargs – Keyword arguments passed to the underlying model’s forward method. Model-specific parameters like labels, output_hidden_states, etc.

Returns:

Model outputs in the same format as the base model, but incorporating LoRA adaptations. The exact return type depends on the base model architecture (e.g., BaseModelOutput, SequenceClassifierOutput).

Examples

Basic forward pass: >>> outputs = model(input_ids, attention_mask=attention_mask)

With additional parameters: >>> outputs = model( … input_ids=input_ids, … attention_mask=attention_mask, … labels=labels, … output_hidden_states=True … )

Note

  • LoRA adaptations are automatically applied during the forward pass

  • No manual intervention needed to combine base and adaptation outputs

  • Maintains full compatibility with the original model’s forward signature

last_hidden_state_forward(**kwargs)[source]

Perform forward pass and return the last hidden state from the base model.

This method provides access to intermediate representations from the base model while incorporating LoRA adaptations, useful for feature extraction and analysis tasks.

Parameters:

**kwargs – Keyword arguments passed to the base model’s last_hidden_state_forward method.

Returns:

Last hidden state tensor with LoRA adaptations applied. Shape typically [batch_size, sequence_length, hidden_size].

Examples

Get hidden states: >>> hidden_states = model.last_hidden_state_forward( … input_ids=input_ids, … attention_mask=attention_mask … )

Note

  • Hidden states include the effects of LoRA adaptations

  • Useful for feature extraction and representation analysis

  • Maintains compatibility with base model’s hidden state interface

model_info()[source]

Get detailed information about the underlying base model.

Returns comprehensive information about the model architecture, configuration, and other metadata through the base model’s info interface.

Returns:

Model information from the base model, typically including architecture details, parameter counts, configuration settings, etc.

Examples

Get model information: >>> info = model.model_info() >>> print(info) # Display model architecture and stats

Note

  • Information reflects the base model architecture

  • LoRA-specific details may not be included (use print(model) for LoRA info)

  • Useful for understanding the underlying model structure

predict(*args, **kwargs)[source]

Generate predictions using the LoRA-adapted model through the base model interface.

This method provides access to the base model’s prediction functionality while incorporating LoRA adaptations. It’s particularly useful for inference tasks where the base model has specialized prediction methods.

Parameters:
  • *args – Positional arguments passed to the base model’s predict method.

  • **kwargs – Keyword arguments passed to the base model’s predict method.

Returns:

Predictions from the base model, enhanced with LoRA adaptations. The format depends on the specific base model’s predict implementation.

Examples

Generate predictions: >>> predictions = model.predict(input_sequences)

With custom parameters: >>> predictions = model.predict( … sequences, … max_length=512, … temperature=0.7 … )

Note

  • Delegates to the base model’s predict method while maintaining LoRA adaptations

  • Useful for models with specialized inference interfaces

  • LoRA adaptations are automatically included in the prediction process

save(*args, **kwargs)[source]

Save the LoRA-adapted model using the base model’s save functionality.

This method delegates saving operations to the base model while preserving LoRA adapter information. The exact saving behavior depends on the base model’s implementation.

Parameters:
  • *args – Positional arguments passed to the base model’s save method.

  • **kwargs – Keyword arguments passed to the base model’s save method.

Returns:

Result of the base model’s save operation.

Examples

Save model: >>> model.save(‘path/to/save/directory’)

Save with custom parameters: >>> model.save( … save_directory=’./checkpoints’, … save_config=True, … save_tokenizer=True … )

Note

  • For saving LoRA adapters specifically, use PEFT’s save_pretrained() method

  • This method saves the complete model state including LoRA adaptations

  • Check base model documentation for specific save parameters

set_loss_fn(fn)[source]

Set a custom loss function for the base model.

This method allows configuration of specialized loss functions through the base model’s interface, useful for custom training objectives.

Parameters:

fn (callable) – Loss function to be used by the base model. Should follow PyTorch loss function conventions.

Returns:

Result of setting the loss function on the base model.

Examples

Set custom loss function: >>> custom_loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1) >>> model.set_loss_fn(custom_loss)

Note

  • Loss function applies to the combined base + LoRA model outputs

  • Delegates to the base model’s loss function setting mechanism

  • LoRA adaptations are included in loss computation automatically

to(*args, **kwargs)[source]

Move the LoRA model to a specified device and/or data type with intelligent state management.

This method extends PyTorch’s standard .to() functionality by ensuring that both the underlying LoRA model and wrapper state (device/dtype tracking) remain consistent after device or precision changes. It safely handles device transfers while maintaining model integrity.

Parameters:
  • *args – Positional arguments passed to the underlying model’s .to() method. Common usage: - device (str/torch.device): Target device (‘cuda’, ‘cpu’, ‘cuda:0’, etc.) - dtype (torch.dtype): Target data type (torch.float16, torch.float32, etc.)

  • **kwargs – Keyword arguments passed to the underlying model’s .to() method. Supports all standard PyTorch .to() parameters.

Returns:

OmniLoraModel – Returns self for method chaining, following PyTorch convention.

Side Effects:
  • Updates internal device and dtype tracking attributes

  • Moves all model parameters and buffers to the specified device/dtype

  • Synchronizes device information across all model modules

Examples

Move to GPU: >>> model = model.to(‘cuda’) >>> model = model.to(torch.device(‘cuda:0’))

Change precision: >>> model = model.to(torch.float16) # Convert to half precision >>> model = model.to(dtype=torch.bfloat16) # Convert to bfloat16

Combined device and dtype: >>> model = model.to(‘cuda’, dtype=torch.float16)

Method chaining: >>> model = OmniLoraModel(base_model).to(‘cuda’).train()

Note

  • Device/dtype information is automatically tracked for internal consistency

  • Exception handling ensures robustness if parameter introspection fails

  • All modules receive updated device information for framework compatibility

  • LoRA adapters maintain the same precision as the base model after transfer

omnigenbench.src.lora.lora_model.auto_lora_model(model, **kwargs)[source]

This function automatically identifies suitable target modules and creates a LoRA-adapted version of the input model. It handles configuration setup and parameter freezing for efficient fine-tuning.

Parameters:
  • model – The base model to adapt with LoRA

  • **kwargs – Additional LoRA configuration parameters

Returns:

The LoRA-adapted model

Raises:

AssertionError – If no target modules are found for LoRA injection

omnigenbench.src.lora.lora_model.find_linear_target_modules(model, keyword_filter=None, use_full_path=True)[source]

This function searches through a model’s modules to identify linear layers that can be adapted using LoRA. It supports filtering by keyword patterns to target specific types of layers.

Parameters:
  • model – The model to search for linear modules

  • keyword_filter (str, list, tuple, optional) – Keywords to filter modules by name

  • use_full_path (bool) – Whether to return full module paths or just names (default: True)

Returns:

list – Sorted list of linear module names that can be targeted for LoRA

Raises:

TypeError – If keyword_filter is not None, str, or a list/tuple of str