Downstream Models

Classification Models

class omnigenbench.src.model.classification.model.OmniModelForMultiLabelSequenceClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceClassification

This model is designed for multi-label classification tasks where a single sequence can be assigned multiple labels simultaneously. It extends the sequence classification model with multi-label capabilities. It uses sigmoid activation instead of softmax to allow multiple labels per sequence and uses binary cross-entropy loss for training.

Variables:
  • softmax (torch.nn.Sigmoid) – Sigmoid layer for multi-label probability computation.

  • loss_fn (torch.nn.BCELoss) – Binary cross-entropy loss function for training.

inference(sequence_or_inputs, **kwargs)[source]

Performs multi-label inference with human-readable output. It converts logits to binary labels and provides confidence scores.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable binary labels for each sequence

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # tensor([1, 0, 1, 0])
loss_function(logits, labels)[source]

Calculates the binary cross-entropy loss for multi-label classification.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth multi-label targets.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns multi-label predictions. It applies a threshold to determine which labels are active for each sequence.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Multi-label predictions for each sequence

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'])  # tensor([1, 0, 1, 0])
class omnigenbench.src.model.classification.model.OmniModelForMultiLabelSequenceClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceClassificationWith2DStructure

inference(sequence_or_inputs, **kwargs)[source]

This method provides processed, human-readable sequence-level predictions. It converts logits to class labels and provides confidence scores.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable class labels for each sequence

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # "positive"
>>> print(results['confidence'])   # 0.95
loss_function(logits, labels)[source]

This method computes the cross-entropy loss between the predicted logits and the ground truth labels.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth labels.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns sequence-level predictions. It processes the inputs through the model and returns the predicted class for each sequence.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Predicted class indices for each sequence

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'])  # tensor([0])
>>> # Predict on multiple sequences
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
class omnigenbench.src.model.classification.model.OmniModelForSequenceClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

Model for sequence-level classification tasks in genomic analysis.

This class implements sequence classification where the entire input sequence is classified into discrete categories. Common genomic applications include: - Promoter vs. non-promoter classification - Functional region annotation (enhancer, silencer, insulator) - Sequence origin classification (species, cell type) - Regulatory element prediction

The model applies pooling over the sequence dimension to create a fixed-length representation, which is then classified via a linear head with softmax activation.

Key Features:

  • Flexible Pooling: Supports mean, max, cls-token, and attention-based pooling strategies via OmniPooling. Strategy is configurable in model config.

  • Multi-Class Support: Handles binary and multi-class classification through configurable num_labels parameter.

  • Probability Output: Provides both logits and probability distributions via softmax activation for confidence-based predictions.

  • Loss Function: Uses CrossEntropyLoss by default, suitable for single-label classification with mutually exclusive classes.

Variables:
  • pooler (OmniPooling) – Pooling layer for aggregating sequence representations into fixed-length vectors. Pooling strategy determined by config.pooling_mode.

  • softmax (torch.nn.Softmax) – Softmax activation for converting logits to probability distributions over classes.

  • classifier (torch.nn.Linear) – Linear classification head mapping pooled representations to class logits. Output dimension equals num_labels.

  • loss_fn (torch.nn.CrossEntropyLoss) – Loss function for training. Automatically handles class weights if specified in config.

Example

>>> # Basic usage
>>> from omnigenbench import OmniModelForSequenceClassification, OmniTokenizer
>>> tokenizer = OmniTokenizer.from_pretrained("yangheng/OmniGenome-186M")
>>> model = OmniModelForSequenceClassification(
...     "yangheng/OmniGenome-186M",
...     tokenizer=tokenizer,
...     num_labels=2
... )
>>>
>>> # Inference on single sequence
>>> result = model.inference("ATCGATCGATCG")
>>> print(result['predictions'])  # Class index
>>> print(result['confidence'])   # Prediction confidence
>>>
>>> # Batch inference
>>> sequences = ["ATCGATCG", "GCTAGCTA", "TTAACCGG"]
>>> results = model.inference(sequences)
>>> print(results['predictions'])  # Array of class indices
forward(**inputs)[source]

This method performs the forward pass through the model, computing sequence-level logits and applying softmax to produce probability distributions over the label classes.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Sequence-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([0])
... )
inference(sequence_or_inputs, **kwargs)[source]

This method provides processed, human-readable sequence-level predictions. It converts logits to class labels and provides confidence scores.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable class labels for each sequence

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # "positive"
>>> print(results['confidence'])   # 0.95
loss_function(logits, labels)[source]

This method computes the cross-entropy loss between the predicted logits and the ground truth labels.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth labels.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns sequence-level predictions. It processes the inputs through the model and returns the predicted class for each sequence.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Predicted class indices for each sequence

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'])  # tensor([0])
>>> # Predict on multiple sequences
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
class omnigenbench.src.model.classification.model.OmniModelForSequenceClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceClassification

forward(**inputs)[source]

This method performs the forward pass through the model, computing sequence-level logits and applying softmax to produce probability distributions over the label classes.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Sequence-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([0])
... )
class omnigenbench.src.model.classification.model.OmniModelForTokenClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

Model for token-level (per-nucleotide) classification tasks in genomic analysis.

This class implements per-token classification where each nucleotide in the input sequence receives an independent class prediction. Common genomic applications include: - Splice site detection (donor/acceptor/none) - Secondary structure prediction (helix/sheet/loop/coil) - Protein binding site identification (per-nucleotide) - Chromatin state annotation (per-position) - Base modification detection (m6A, m5C, etc.)

Unlike sequence classification, this model produces outputs of the same length as the input sequence, with each position classified independently.

Key Features:

  • Per-Token Predictions: Each nucleotide receives an independent classification, enabling fine-grained sequence annotation.

  • Variable-Length Output: Output length matches input sequence length (excluding special tokens), handling sequences of arbitrary length.

  • Special Token Handling: Automatically excludes [CLS], [SEP], [PAD] tokens from predictions to return only biologically relevant positions.

  • Loss Computation: Uses CrossEntropyLoss with automatic padding token masking via PyTorch’s ignore_index=-100 convention.

Variables:
  • softmax (torch.nn.Softmax) – Softmax activation for converting per-token logits to probability distributions over classes.

  • classifier (torch.nn.Linear) – Linear classification head applied to each token independently. Maps hidden_size to num_labels for each position.

  • loss_fn (torch.nn.CrossEntropyLoss) – Loss function for training. Automatically ignores padding tokens (label=-100) during loss computation.

Example

>>> # Basic usage
>>> from omnigenbench import OmniModelForTokenClassification, OmniTokenizer
>>> tokenizer = OmniTokenizer.from_pretrained("yangheng/OmniGenome-186M")
>>> model = OmniModelForTokenClassification(
...     "yangheng/OmniGenome-186M",
...     tokenizer=tokenizer,
...     num_labels=3  # e.g., 3 classes: background, donor, acceptor
... )
>>>
>>> # Inference on single sequence
>>> result = model.inference("ATCGATCGATCG")
>>> print(len(result['predictions']))  # Length matches input sequence
>>> print(result['predictions'])       # Per-nucleotide class labels
>>>
>>> # Training example
>>> outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
>>> loss = model.loss_function(outputs['logits'], labels)
forward(**inputs)[source]

Forward pass for token classification.

This method performs the forward pass through the model, computing logits for each token in the input sequence and applying softmax to produce probability distributions.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Token-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([[0, 1, 0, 1]])
... )
inference(sequence_or_inputs, **kwargs)[source]

Performs token-level inference with human-readable output.

This method provides processed, human-readable token-level predictions. It converts logits to class labels and handles special tokens appropriately.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable class labels for each token

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # ['A', 'T', 'C', 'G', ...]
loss_function(logits, labels)[source]

Calculates the cross-entropy loss for token classification.

This method computes the cross-entropy loss between the predicted logits and the ground truth labels, ignoring padding tokens.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth labels.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

Performs token-level prediction on raw inputs.

This method takes raw sequences or tokenized inputs and returns token-level predictions. It processes the inputs through the model and returns the predicted class for each token.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Predicted class indices for each token

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'].shape)  # (seq_len,)
>>> # Predict on multiple sequences
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
class omnigenbench.src.model.classification.model.OmniModelForTokenClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForTokenClassification

forward(**inputs)[source]

Forward pass for token classification.

This method performs the forward pass through the model, computing logits for each token in the input sequence and applying softmax to produce probability distributions.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Token-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([[0, 1, 0, 1]])
... )

Regression Models

Regression models for OmniGenome framework.

This module provides various regression model implementations for genomic sequence analysis, including token-level regression, sequence-level regression, structural imputation, and matrix regression/classification tasks.

class omnigenbench.src.model.regression.model.OmniModelForMatrixClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model performs classification on matrix representations of genomic sequences, useful for tasks like structure classification, contact map classification, or other matrix-based genomic analysis tasks.

Variables:
  • resnet – ResNet backbone for processing matrix inputs

  • classifier – Linear layer for classification output

  • loss_fn – Cross-entropy loss function

forward(**inputs)[source]

Forward pass for matrix classification.

Parameters:

**inputs – Input tensors including matrix representations and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for matrix classification.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for matrix classification.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for matrix classification.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForMatrixRegression(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model performs regression on matrix representations of genomic sequences, useful for tasks like contact map prediction, structure prediction, or other matrix-based genomic analysis tasks.

Variables:
  • resnet – ResNet backbone for processing matrix inputs

  • classifier – Linear layer for regression output

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for matrix regression.

Parameters:

**inputs – Input tensors including matrix representations and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for matrix regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for matrix regression.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for matrix regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForSequenceRegression(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model performs regression at the sequence level, predicting a single continuous value for the entire input sequence. It’s useful for tasks like predicting overall expression levels, binding affinities, or other sequence-level properties.

Variables:
  • pooler – OmniPooling layer for sequence-level representation

  • classifier – Linear layer for regression output

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for sequence-level regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for sequence-level regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for sequence-level regression.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for sequence-level regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForSequenceRegressionWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceRegression

This model extends the basic sequence regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks.

forward(**inputs)[source]

Forward pass for 2D structure-aware sequence regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, labels, and structural info

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

class omnigenbench.src.model.regression.model.OmniModelForStructuralImputation(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceRegression

This model is specialized for imputing missing structural information in genomic sequences. It extends the sequence regression model with additional embedding capabilities for structural features.

Variables:
  • embedding – Embedding layer for structural features

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for structural imputation.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

class omnigenbench.src.model.regression.model.OmniModelForTokenRegression(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

Token-level regression model for genomic sequences.

This model performs regression at the token level, predicting continuous values for each token in the input sequence. It’s useful for tasks like predicting binding affinities, expression levels, or other continuous properties at each position in a genomic sequence.

Variables:
  • classifier – Linear layer for regression output

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for token-level regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for token-level regression, excluding special tokens.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for token-level regression.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for token-level regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForTokenRegressionWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForTokenRegression

This model extends the basic token regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks.

forward(**inputs)[source]

Forward pass for 2D structure-aware token regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, labels, and structural info

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

ResNet implementation for genomic sequence analysis.

This module provides a ResNet architecture adapted for processing genomic sequences and their structural representations. It includes basic blocks, bottleneck blocks, and a complete ResNet implementation optimized for genomic data.

class omnigenbench.src.model.regression.resnet.BasicBlock(inplanes: int, planes: int, stride: int = 1, downsample=None, groups: int = 1, dilation: int = 1, norm_layer: Callable[[...], Module] | None = None)[source]

Bases: Module

This block implements a basic residual connection with two convolutions and is optimized for processing genomic sequence data with layer normalization.

Variables:
  • expansion (int) – Expansion factor for the block (default: 1)

  • conv1 – First 3x3 convolution layer

  • bn1 – First layer normalization

  • conv2 – Second 5x5 convolution layer

  • bn2 – Second layer normalization

  • relu – ReLU activation function

  • drop – Dropout layer

  • downsample – Downsampling layer for residual connection

  • stride – Stride for the convolutions

expansion: int = 1
forward(x: Tensor) Tensor[source]

Forward pass through the BasicBlock.

Parameters:

x (Tensor) – Input tensor [batch_size, channels, height, width]

Returns:

Tensor – Output tensor with same shape as input

class omnigenbench.src.model.regression.resnet.Bottleneck(inplanes: int, planes: int, stride: int = 1, downsample: Module | None = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Callable[[...], Module] | None = None)[source]

Bases: Module

This block implements a bottleneck residual connection with three convolutions (1x1, 3x3, 1x1) and is designed for deeper networks. It’s adapted from the original ResNet V1.5 implementation.

Variables:
  • expansion (int) – Expansion factor for the block (default: 4)

  • conv1 – First 1x1 convolution layer

  • bn1 – First batch normalization

  • conv2 – Second 3x3 convolution layer

  • bn2 – Second batch normalization

  • conv3 – Third 1x1 convolution layer

  • bn3 – Third batch normalization

  • relu – ReLU activation function

  • downsample – Downsampling layer for residual connection

  • stride – Stride for the convolutions

expansion: int = 4
forward(x: Tensor) Tensor[source]

Forward pass through the Bottleneck block.

Parameters:

x (Tensor) – Input tensor [batch_size, channels, height, width]

Returns:

Tensor – Output tensor with same shape as input

class omnigenbench.src.model.regression.resnet.ResNet(channels, block: Type[BasicBlock | Bottleneck], layers: List[int], zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 1, replace_stride_with_dilation=None, norm_layer=None)[source]

Bases: Module

This ResNet implementation is specifically designed for processing genomic sequences and their structural representations. It uses layer normalization instead of batch normalization and is optimized for genomic data characteristics.

Variables:
  • _norm_layer – Normalization layer type

  • inplanes – Number of input channels for the first layer

  • dilation – Dilation factor for convolutions

  • groups – Number of groups for grouped convolutions

  • base_width – Base width for bottleneck blocks

  • conv1 – Initial convolution layer

  • bn1 – Initial normalization layer

  • relu – ReLU activation function

  • layer1 – First layer of ResNet blocks

  • fc1 – Final fully connected layer

forward(x: Tensor) Tensor[source]

Forward pass through the ResNet.

Parameters:

x (Tensor) – Input tensor [batch_size, channels, height, width]

Returns:

Tensor – Output tensor after processing through ResNet

omnigenbench.src.model.regression.resnet.conv1x1(in_planes, out_planes, stride=1)[source]

1x1 convolution.

Parameters:
  • in_planes (int) – Number of input channels

  • out_planes (int) – Number of output channels

  • stride (int) – Stride for the convolution (default: 1)

Returns:

nn.Conv2d – 1x1 convolution layer

omnigenbench.src.model.regression.resnet.conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)[source]

3x3 convolution with padding.

Parameters:
  • in_planes (int) – Number of input channels

  • out_planes (int) – Number of output channels

  • stride (int) – Stride for the convolution (default: 1)

  • groups (int) – Number of groups for grouped convolution (default: 1)

  • dilation (int) – Dilation factor for the convolution (default: 1)

Returns:

nn.Conv2d – 3x3 convolution layer

omnigenbench.src.model.regression.resnet.conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1)[source]

5x5 convolution with padding.

Parameters:
  • in_planes (int) – Number of input channels

  • out_planes (int) – Number of output channels

  • stride (int) – Stride for the convolution (default: 1)

  • groups (int) – Number of groups for grouped convolution (default: 1)

  • dilation (int) – Dilation factor for the convolution (default: 1)

Returns:

nn.Conv2d – 5x5 convolution layer

omnigenbench.src.model.regression.resnet.resnet_b16(channels=128, bbn=16)[source]

This function creates a ResNet model with 16 basic blocks, optimized for processing genomic sequences and their structural representations.

Parameters:
  • channels (int) – Number of input channels (default: 128)

  • bbn (int) – Number of basic blocks (default: 16)

Returns:

ResNet – Configured ResNet model

Embedding Models

class omnigenbench.src.model.embedding.model.OmniModelForEmbedding(config_or_model, tokenizer=None, *args, **kwargs)[source]

Bases: Module

This class provides a unified interface for loading pre-trained models and generating embeddings from genomic sequences. It supports various aggregation methods and batch processing for efficient embedding generation.

Variables:
  • tokenizer – The tokenizer for processing input sequences

  • model – The pre-trained model for generating embeddings

  • _device – The device (CPU/GPU) where the model is loaded

Example

>>> from omnigenbench import OmniModelForEmbedding
>>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
>>> sequences = ["ATCGGCTA", "GGCTAGCTA"]
>>> embeddings = model.batch_encode(sequences)
>>> print(f"Embeddings shape: {embeddings.shape}")
torch.Size([2, 768])
batch_encode(sequences, batch_size=8, max_length=512, agg='head', require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None)[source]

Batch encode sequences into aggregated (pooled) embeddings.

Parameters:
  • sequences (List[str]) – Input DNA or RNA sequences for encoding.

  • batch_size (int, default=8) – Number of sequences to process per batch.

  • max_length (int, default=512) – Maximum sequence length for tokenization.

  • agg (str, default="head") – Aggregation method for pooling. Options: “head”, “mean”, “tail”.

  • require_grad (bool, default=False) – Whether to preserve gradients for fine-tuning.

  • return_on_cpu (bool, default=True) – Whether to move results to CPU memory.

  • use_autocast (bool, default=False) – Whether to enable mixed precision (CUDA only).

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision.

Returns:

torch.Tensor – Pooled embeddings with shape (num_sequences, hidden_size).

Note

This method maintains backward compatibility with existing code. When require_grad=True, gradients flow through the model for end-to-end training.

Example

>>> sequences = ["ATCGGCTA", "GGCTAGCTA"]
>>> embeddings = model.batch_encode(sequences, batch_size=4, agg="mean")
>>> print(embeddings.shape)
torch.Size([2, 768])
batch_encode_tokens(sequences, batch_size=8, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True)[source]

Encode sequences to token-level embeddings (last_hidden_state).

Parameters:
  • sequences (List[str]) – Input DNA/RNA sequences for token-level encoding

  • batch_size (int, default=8) – Number of sequences to process per batch

  • max_length (int, default=512) – Maximum sequence length for tokenization

  • use_autocast (bool, default=False) – Enable mixed precision training (CUDA only)

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision

  • require_grad (bool, default=False) – Preserve gradient computation graph for fine-tuning

  • return_on_cpu (bool, default=True) – Transfer outputs to CPU memory

Returns:

torch.Tensor – Token embeddings with shape (num_sequences, max_length, hidden_size)

Note

When require_grad=True, gradients flow through the transformer model for end-to-end training. Set return_on_cpu=False to keep tensors on GPU device for downstream processing.

batch_extract_attention_scores(sequences, batch_size=4, max_length=512, layer_indices=None, head_indices=None, return_on_cpu=True, use_autocast=False, amp_dtype=None)[source]

Extract attention scores from multiple genomic sequences in batches.

This method provides efficient batch processing for attention extraction from multiple sequences, useful for comparative analysis of attention patterns.

Parameters:
  • sequences (List[str]) – List of input DNA or RNA sequences for attention extraction.

  • batch_size (int, default=4) – Number of sequences to process per batch. Smaller batch sizes reduce memory usage.

  • max_length (int, default=512) – Maximum sequence length for tokenization.

  • layer_indices (List[int], optional) – Specific transformer layer indices to extract. If None, extracts attention from all layers.

  • head_indices (List[int], optional) – Specific attention head indices to extract. If None, extracts attention from all heads.

  • return_on_cpu (bool, default=True) – Whether to transfer outputs to CPU memory.

  • use_autocast (bool, default=False) – Whether to enable mixed precision.

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision.

Returns:

List[Dict[str, torch.Tensor]]

List of dictionaries, each containing:
  • ’attentions’: Attention weights tensor for each sequence

  • ’tokens’: List of tokenized input tokens for each sequence

  • ’attention_mask’: Attention mask tensor for each sequence

Example

>>> sequences = ["ATCGATCGATCG", "GGCCTTAACCGG", "TTTTAAAACCCC"]
>>> results = model.batch_extract_attention_scores(sequences, batch_size=2)
>>> print(f"Number of results: {len(results)}")
>>> print(f"First result attention shape: {results[0]['attentions'].shape}")

Note

Batch processing improves efficiency when analyzing multiple sequences. Consider reducing batch_size if encountering memory issues.

compute_similarity(embedding1, embedding2, dim=0)[source]

Compute cosine similarity between two embeddings.

Parameters:
  • embedding1 (torch.Tensor or np.ndarray) – The first embedding

  • embedding2 (torch.Tensor or np.ndarray) – The second embedding

  • dim (int, optional) – Dimension along which to compute cosine similarity. Defaults to 0

Returns:

float – Cosine similarity score between -1 and 1

Example

>>> emb1 = model.encode("ATCGGCTA")
>>> emb2 = model.encode("GGCTAGCTA")
>>> similarity = model.compute_similarity(emb1, emb2)
>>> print(f"Cosine similarity: {similarity:.4f}")
0.8234
property device

Get the device where the model is located.

encode(sequence, max_length=512, agg='head', keep_dim=False, require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None)[source]

Encode a single sequence into pooled embeddings.

Parameters:
  • sequence (str) – Input DNA or RNA sequence for encoding.

  • max_length (int, default=512) – Maximum sequence length for tokenization.

  • agg (str, default="head") – Aggregation strategy for pooling. Options: “head”, “mean”, “tail”.

  • keep_dim (bool, default=False) – Whether to preserve batch dimension in output.

  • require_grad (bool, default=False) – Whether to preserve gradients for fine-tuning.

  • return_on_cpu (bool, default=True) – Whether to move results to CPU memory.

  • use_autocast (bool, default=False) – Whether to enable mixed precision.

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision.

Returns:

torch.Tensor – Pooled embedding with shape (hidden_size,) or (1, hidden_size) if keep_dim=True.

Example

>>> sequence = "ATCGATCGATCG"
>>> embedding = model.encode(sequence, agg="mean", max_length=200)
>>> print(embedding.shape)
torch.Size([768])
encode_tokens(sequence, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True)[source]

Encode a single sequence to token-level embeddings.

Parameters:
  • sequence (str) – Input DNA/RNA sequence for token-level encoding

  • max_length (int, default=512) – Maximum sequence length for tokenization

  • use_autocast (bool, default=False) – Enable mixed precision training (CUDA only)

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision

  • require_grad (bool, default=False) – Preserve gradient computation graph for fine-tuning

  • return_on_cpu (bool, default=True) – Transfer output to CPU memory

Returns:

torch.Tensor – Token embeddings with shape (max_length, hidden_size)

Example

>>> model = OmniModelForEmbedding("yangheng/OmniGenome-52M")
>>> sequence = "ATCGATCGATCG"
>>> token_embeddings = model.encode_tokens(sequence, max_length=200)
>>> print(f"Token embeddings shape: {token_embeddings.shape}")
torch.Size([200, 768])
extract_attention_scores(sequence, max_length=512, layer_indices=None, head_indices=None, return_on_cpu=True, use_autocast=False, amp_dtype=None)[source]

Extract attention scores from a single genomic sequence.

This method extracts attention weights from transformer layers, providing insights into which positions the model focuses on during sequence processing.

Parameters:
  • sequence (str) – Input DNA or RNA sequence for attention extraction.

  • max_length (int, default=512) – Maximum sequence length for tokenization.

  • layer_indices (List[int], optional) – Specific transformer layer indices to extract. If None, extracts attention from all layers.

  • head_indices (List[int], optional) – Specific attention head indices to extract. If None, extracts attention from all heads.

  • return_on_cpu (bool, default=True) – Whether to transfer output to CPU memory.

  • use_autocast (bool, default=False) – Whether to enable mixed precision.

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision.

Returns:

Dict[str, torch.Tensor]

Dictionary containing:
  • ’attentions’: Attention weights tensor with shape (num_layers, num_heads, seq_len, seq_len)

  • ’tokens’: List of tokenized input tokens

  • ’attention_mask’: Attention mask tensor indicating valid positions

Example

>>> sequence = "ATCGATCGATCG"
>>> result = model.extract_attention_scores(sequence, max_length=200)
>>> print(f"Attention shape: {result['attentions'].shape}")
>>> print(f"First 10 tokens: {result['tokens'][:10]}")

Note

The attention tensor follows the standard transformer format where higher values indicate stronger attention between token pairs.

get_attention_statistics(attention_scores, attention_mask=None, layer_aggregation='mean', head_aggregation='mean')[source]

Compute comprehensive statistics from attention scores.

This method analyzes attention patterns by computing various statistical measures that help understand the model’s focus and attention distribution.

Parameters:
  • attention_scores (torch.Tensor) – Attention tensor with shape (num_layers, num_heads, seq_len, seq_len).

  • attention_mask (torch.Tensor, optional) – Attention mask to exclude padding tokens from statistics.

  • layer_aggregation (str, default="mean") – Method to aggregate across transformer layers. Options: “mean”, “max”, “sum”, “first”, “last”.

  • head_aggregation (str, default="mean") – Method to aggregate across attention heads. Options: “mean”, “max”, “sum”.

Returns:

Dict[str, torch.Tensor]

Dictionary containing attention statistics:
  • ’attention_matrix’: Aggregated attention matrix

  • ’attention_entropy’: Entropy measure of attention distribution

  • ’max_attention_per_position’: Maximum attention value for each position

  • ’attention_concentration’: Measure of attention concentration (L2 norm)

  • ’self_attention_scores’: Self-attention scores (diagonal values)

Example

>>> result = model.extract_attention_scores(sequence)
>>> stats = model.get_attention_statistics(result['attentions'], result['attention_mask'])
>>> print(f"Average attention entropy: {stats['attention_entropy'].mean():.4f}")

Note

Higher entropy indicates more distributed attention, while lower entropy suggests more focused attention patterns.

load_embeddings(embedding_path)[source]

Load embeddings from a file.

Parameters:

embedding_path (str) – Path to the saved embeddings

Returns:

torch.Tensor – The loaded embeddings

Example

>>> embeddings = model.load_embeddings("embeddings.pt")
>>> print(f"Loaded embeddings shape: {embeddings.shape}")
torch.Size([100, 768])
save_embeddings(embeddings, output_path)[source]

Save the generated embeddings to a file.

Parameters:
  • embeddings (torch.Tensor) – The embeddings to save

  • output_path (str) – Path to save the embeddings

Example

>>> embeddings = model.batch_encode(sequences)
>>> model.save_embeddings(embeddings, "embeddings.pt")
>>> print("Embeddings saved successfully")
to(*args, **kwargs)[source]

Move model to specified device and/or change dtype.

visualize_attention_pattern(attention_result, layer_idx=0, head_idx=0, save_path=None, figsize=(12, 10))[source]

Visualize attention patterns as an interactive heatmap.

This method creates a visual representation of attention weights, helping to understand which sequence positions the model focuses on during processing.

Parameters:
  • attention_result (Dict) – Result dictionary from extract_attention_scores() or batch_extract_attention_scores() containing attention data.

  • layer_idx (int, default=0) – Index of the transformer layer to visualize.

  • head_idx (int, default=0) – Index of the attention head to visualize.

  • save_path (str, optional) – File path to save the visualization image. If None, the plot is not saved to disk.

  • figsize (tuple, default=(12, 10)) – Figure size as (width, height) in inches.

Returns:

matplotlib.figure.Figure

The generated matplotlib figure object, or None if

matplotlib is not available.

Example

>>> sequence = "ATCGATCGATCG"
>>> result = model.extract_attention_scores(sequence)
>>> fig = model.visualize_attention_pattern(
...     result, layer_idx=0, head_idx=0, save_path="attention_plot.png"
... )
>>> # fig.show()  # Display the plot

Note

Requires matplotlib for visualization. Install with: pip install matplotlib The heatmap uses a blue color scheme where darker colors indicate stronger attention.

MLM Models

Masked Language Model (MLM) for genomic sequences.

This module provides a masked language model implementation specifically designed for genomic sequences. It supports masked language modeling tasks where tokens are randomly masked and the model learns to predict the original tokens.

class omnigenbench.src.model.mlm.model.OmniModelForMLM(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

Masked Language Model for genomic sequences.

This model implements masked language modeling for genomic sequences, where tokens are randomly masked and the model learns to predict the original tokens. It’s useful for pre-training genomic language models and understanding sequence patterns and dependencies.

Variables:

loss_fn – Cross-entropy loss function for masked language modeling

forward(**inputs)[source]

Forward pass for masked language modeling.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing loss, logits, and last_hidden_state

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for masked language modeling, decoding predictions to sequences.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing decoded predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for masked language modeling.

Parameters:
  • logits (torch.Tensor) – Model predictions [batch_size, seq_len, vocab_size]

  • labels (torch.Tensor) – Ground truth labels [batch_size, seq_len]

Returns:

torch.Tensor – Computed cross-entropy loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for masked language modeling.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

RNA Design Models

RNA design model using masked language modeling and evolutionary algorithms.

This module provides an RNA design model that combines masked language modeling with evolutionary algorithms to design RNA sequences that fold into specific target structures. It uses a multi-objective optimization approach to balance structure similarity and thermodynamic stability.

class omnigenbench.src.model.rna_design.model.OmniModelForRNADesign(model='yangheng/OmniGenome-186M', device=None, parallel=False, output_format='RNA', *args, **kwargs)[source]

Bases: Module

RNA design model using masked language modeling and evolutionary algorithms.

This model combines a pre-trained masked language model with evolutionary algorithms to design RNA sequences that fold into specific target structures. It uses a multi-objective optimization approach to balance structure similarity and thermodynamic stability.

Variables:
  • device – Device to run the model on (CPU or GPU)

  • parallel – Whether to use parallel processing for structure prediction

  • tokenizer – Tokenizer for processing RNA sequences

  • model – Pre-trained masked language model

design(structure, mutation_ratio=0.5, num_population=100, num_generation=100)[source]

Design RNA sequences for a target structure using evolutionary algorithms.

Parameters:
  • structure (str) – Target secondary structure in dot-bracket notation

  • mutation_ratio (float) – Mutation rate for genetic algorithm (0.0-1.0)

  • num_population (int) – Population size for each generation

  • num_generation (int) – Maximum number of evolutionary generations

Returns:

list

List of designed RNA sequences that fold into the target structure.

Returns all sequences with perfect match (score=0) if found, otherwise returns the best sequences from final population.

Example

>>> model = OmniModelForRNADesign(model="yangheng/OmniGenome-186M")
>>> sequences = model.design(structure="(((...)))", num_population=100, num_generation=50)
>>> print(f"Designed {len(sequences)} sequences")

Sequence-to-Sequence Models

Sequence-to-sequence model for genomic sequences.

This module provides a sequence-to-sequence model implementation for genomic sequences. It’s designed for tasks where the input and output are both sequences, such as sequence translation, structure prediction, or sequence transformation tasks.

class omnigenbench.src.model.seq2seq.model.OmniModelForSeq2Seq(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model implements a sequence-to-sequence architecture for genomic sequences, where the input is one sequence and the output is another sequence. It’s useful for tasks like sequence translation, structure prediction, or sequence transformation. The model can be extended to implement specific seq2seq tasks by overriding the forward, predict, and inference methods.

Augmentation Models

Data augmentation model for genomic sequences.

This module provides a data augmentation model that uses masked language modeling to generate augmented versions of genomic sequences. It’s useful for expanding training datasets and improving model robustness.

class omnigenbench.src.model.augmentation.model.OmniModelForAugmentation(config_or_model=None, noise_ratio=0.15, max_length=1026, instance_num=1, batch_size=32, use_amp=None, *args, **kwargs)[source]

Bases: Module

Data augmentation model for genomic sequences using masked language modeling. This model uses a pre-trained masked language model to generate augmented versions of genomic sequences by randomly masking tokens and predicting replacements. It’s useful for expanding training datasets and improving model generalization.

Variables:
  • tokenizer – Tokenizer for processing genomic sequences

  • model – Pre-trained masked language model

  • device – Device to run the model on (CPU or GPU)

  • noise_ratio – Proportion of tokens to mask for augmentation

  • max_length – Maximum sequence length for tokenization

  • k – Number of augmented instances to generate per sequence

apply_noise_to_sequence(seq)[source]

Apply noise to a single sequence by randomly masking tokens.

Parameters:

seq (str) – Input genomic sequence

Returns:

str – Sequence with randomly masked tokens

augment(seq, k=None)[source]

Generate multiple augmented instances for a single sequence.

Parameters:
  • seq (str) – Input genomic sequence

  • k (int, optional) – Number of augmented instances to generate (default: None, uses self.k)

Returns:

list – List of augmented sequences

augment_from_file(input_file, output_file)[source]

Main function to handle the augmentation process from a file input to a file output.

This method loads sequences from an input file, augments them using the MLM model, and saves the augmented sequences to an output file.

Parameters:
  • input_file (str) – Path to the input file containing sequences

  • output_file (str) – Path to the output file where augmented sequences will be saved

augment_sequence(seq)[source]

Perform augmentation on a single sequence by predicting masked tokens.

Parameters:

seq (str) – Input genomic sequence with masked tokens

Returns:

str – Augmented sequence with predicted tokens replacing masked tokens

augment_sequences(sequences)[source]

Augment a list of sequences by applying noise and performing MLM-based predictions.

Parameters:

sequences (list) – List of genomic sequences to augment

Returns:

list – List of all augmented sequences

load_sequences_from_file(input_file)[source]

Load sequences from a JSON file.

Parameters:

input_file (str) – Path to the input JSON file containing sequences

Returns:

list – List of sequences loaded from the file

save_augmented_sequences(augmented_sequences, output_file)[source]

Save augmented sequences to a JSON file.

Parameters:
  • augmented_sequences (list) – List of augmented sequences to save

  • output_file (str) – Path to the output JSON file

Model Utilities

This module provides utility classes and functions for handling model inputs, pooling operations, and attention mechanisms used across different OmniGenome model types.

class omnigenbench.src.model.module_utils.InteractingAttention(embed_size, num_heads=24)[source]

Bases: Module

An interacting attention mechanism for sequence modeling.

This class implements a multi-head attention mechanism with residual connections and layer normalization. It’s designed for processing sequences where different parts of the sequence need to interact with each other.

Variables:
  • attention – Multi-head attention layer

  • layer_norm – Layer normalization for residual connections

  • fc_out – Output projection layer

forward(query, keys, values)[source]

Forward pass through the interacting attention mechanism.

Parameters:
  • query (torch.Tensor) – Query tensor [batch_size, query_len, embed_size]

  • keys (torch.Tensor) – Key tensor [batch_size, key_len, embed_size]

  • values (torch.Tensor) – Value tensor [batch_size, value_len, embed_size]

Returns:

torch.Tensor – Output tensor with same shape as query

class omnigenbench.src.model.module_utils.OmniPooling(config, *args, **kwargs)[source]

Bases: Module

A flexible pooling layer for OmniGenome models that handles different input formats.

This class provides a unified interface for pooling operations across different model architectures, supporting both causal language models and encoder-based models. It can handle various input formats including tuples, dictionaries, BatchEncoding objects, and tensors.

Variables:
  • config – Model configuration object containing architecture and tokenizer settings

  • pooler – BertPooler instance for non-causal models, None for causal models

forward(inputs, last_hidden_state)[source]

Perform pooling operation on the last hidden state.

This method handles different input formats and applies appropriate pooling: - For causal language models: Uses the last non-padded token - For encoder models: Uses the BertPooler

Parameters:
  • inputs – Input data in various formats (tuple, dict, BatchEncoding, or tensor)

  • last_hidden_state (torch.Tensor) – Hidden states from the model [batch_size, seq_len, hidden_size]

Returns:

torch.Tensor – Pooled representation [batch_size, hidden_size]

Raises:

ValueError – If input format is not supported or cannot be parsed