Source code for omnigenbench.src.model.regression.model

# -*- coding: utf-8 -*-
# file: model.py
# time: 18:36 06/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
Regression models for OmniGenome framework.

This module provides various regression model implementations for genomic sequence analysis,
including token-level regression, sequence-level regression, structural imputation,
and matrix regression/classification tasks.
"""
import torch

from .resnet import resnet_b16
from ...abc.abstract_model import OmniModel
from ..module_utils import OmniPooling


[docs] class OmniModelForTokenRegression(OmniModel): """ Token-level regression model for genomic sequences. This model performs regression at the token level, predicting continuous values for each token in the input sequence. It's useful for tasks like predicting binding affinities, expression levels, or other continuous properties at each position in a genomic sequence. Attributes: classifier: Linear layer for regression output loss_fn: Mean squared error loss function """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the token regression model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__ self.classifier = torch.nn.Linear( self.config.hidden_size, self.config.num_labels ) self.loss_fn = torch.nn.MSELoss() # self.model_info()
[docs] def forward(self, **inputs): """ Forward pass for token-level regression. Args: **inputs: Input tensors including input_ids, attention_mask, and labels Returns: dict: Dictionary containing logits, last_hidden_state, and labels """ labels = inputs.pop("labels", None) last_hidden_state = self.last_hidden_state_forward(**inputs) last_hidden_state = self.dropout(last_hidden_state) last_hidden_state = self.activation(last_hidden_state) logits = self.classifier(last_hidden_state) outputs = { "logits": logits, "last_hidden_state": last_hidden_state, "labels": labels, } return outputs
[docs] def predict(self, sequence_or_inputs, **kwargs): """ Generate predictions for token-level regression. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): predictions.append(logits[i].cpu()) outputs = { "predictions": ( torch.vstack(predictions).to(self.model.device) if predictions[0].shape else torch.tensor(predictions).to(self.model.device) ), "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def inference(self, sequence_or_inputs, **kwargs): """ Perform inference for token-level regression, excluding special tokens. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) inputs = raw_outputs["inputs"] logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][ 1:-1 ] predictions.append(i_logit.detach().cpu()) if not isinstance(sequence_or_inputs, list): outputs = { "predictions": predictions[0], "logits": logits[0], "last_hidden_state": last_hidden_state[0], } else: outputs = { "predictions": predictions, "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def loss_function(self, logits, labels): """ Compute the loss for token-level regression. Args: logits (torch.Tensor): Model predictions labels (torch.Tensor): Ground truth labels Returns: torch.Tensor: Computed loss value """ padding_value = ( self.config.ignore_y if hasattr(self.config, "ignore_y") else -100 ) logits = logits.view(-1) labels = labels.view(-1) mask = torch.where(labels != padding_value) filtered_logits = logits[mask] filtered_targets = labels[mask] loss = self.loss_fn(filtered_logits, filtered_targets) return loss
[docs] class OmniModelForSequenceRegression(OmniModel): """ This model performs regression at the sequence level, predicting a single continuous value for the entire input sequence. It's useful for tasks like predicting overall expression levels, binding affinities, or other sequence-level properties. Attributes: pooler: OmniPooling layer for sequence-level representation classifier: Linear layer for regression output loss_fn: Mean squared error loss function """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the sequence regression model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__ self.pooler = OmniPooling(self.config) self.classifier = torch.nn.Linear( self.config.hidden_size, self.config.num_labels ) self.loss_fn = torch.nn.MSELoss() # self.model_info()
[docs] def forward(self, **inputs): """ Forward pass for sequence-level regression. Args: **inputs: Input tensors including input_ids, attention_mask, and labels Returns: dict: Dictionary containing logits, last_hidden_state, and labels """ labels = inputs.pop("labels", None) last_hidden_state = self.last_hidden_state_forward(**inputs) last_hidden_state = self.dropout(last_hidden_state) last_hidden_state = self.activation(last_hidden_state) last_hidden_state = self.pooler(inputs, last_hidden_state) logits = self.classifier(last_hidden_state) outputs = { "logits": logits, "last_hidden_state": last_hidden_state, "labels": labels, } return outputs
[docs] def predict(self, sequence_or_inputs, **kwargs): """ Generate predictions for sequence-level regression. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): predictions.append(logits[i].cpu()) outputs = { "predictions": ( torch.vstack(predictions).to(self.model.device) if predictions[0].shape else torch.tensor(predictions).to(self.model.device) ), "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def inference(self, sequence_or_inputs, **kwargs): """ Perform inference for sequence-level regression. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): predictions.append(logits[i].cpu()) if not isinstance(sequence_or_inputs, list): outputs = { "predictions": predictions[0], "logits": logits[0], "last_hidden_state": last_hidden_state[0], } else: outputs = { "predictions": predictions, "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def loss_function(self, logits, labels): """ Compute the loss for sequence-level regression. Args: logits (torch.Tensor): Model predictions labels (torch.Tensor): Ground truth labels Returns: torch.Tensor: Computed loss value """ padding_value = ( self.config.ignore_y if hasattr(self.config, "ignore_y") else -100 ) logits = logits.view(-1) labels = labels.view(-1) mask = torch.where(labels != padding_value) filtered_logits = logits[mask] filtered_targets = labels[mask] loss = self.loss_fn(filtered_logits, filtered_targets) return loss
[docs] class OmniModelForStructuralImputation(OmniModelForSequenceRegression): """ This model is specialized for imputing missing structural information in genomic sequences. It extends the sequence regression model with additional embedding capabilities for structural features. Attributes: embedding: Embedding layer for structural features loss_fn: Mean squared error loss function """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the structural imputation model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__ self.loss_fn = torch.nn.MSELoss() self.embedding = torch.nn.Embedding(1, self.config.hidden_size) # self.model_info()
[docs] def forward(self, **inputs): """ Forward pass for structural imputation. Args: **inputs: Input tensors including input_ids, attention_mask, and labels Returns: dict: Dictionary containing logits, last_hidden_state, and labels """ labels = inputs.pop("labels", None) last_hidden_state = self.last_hidden_state_forward(**inputs) last_hidden_state = self.dropout(last_hidden_state) last_hidden_state = self.activation(last_hidden_state) last_hidden_state = self.pooler(inputs, last_hidden_state) logits = self.classifier(last_hidden_state) outputs = { "logits": logits, "last_hidden_state": last_hidden_state, "labels": labels, } return outputs
[docs] class OmniModelForTokenRegressionWith2DStructure(OmniModelForTokenRegression): """ This model extends the basic token regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks. """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the 2D structure-aware token regression model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__
[docs] def forward(self, **inputs): """ Forward pass for 2D structure-aware token regression. Args: **inputs: Input tensors including input_ids, attention_mask, labels, and structural info Returns: dict: Dictionary containing logits, last_hidden_state, and labels """ labels = inputs.pop("labels", None) last_hidden_state = self.last_hidden_state_forward(**inputs) last_hidden_state = self.dropout(last_hidden_state) last_hidden_state = self.activation(last_hidden_state) logits = self.classifier(last_hidden_state) outputs = { "logits": logits, "last_hidden_state": last_hidden_state, "labels": labels, } return outputs
[docs] class OmniModelForSequenceRegressionWith2DStructure(OmniModelForSequenceRegression): """ This model extends the basic sequence regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks. """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the 2D structure-aware sequence regression model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__
[docs] def forward(self, **inputs): """ Forward pass for 2D structure-aware sequence regression. Args: **inputs: Input tensors including input_ids, attention_mask, labels, and structural info Returns: dict: Dictionary containing logits, last_hidden_state, and labels """ labels = inputs.pop("labels", None) last_hidden_state = self.last_hidden_state_forward(**inputs) last_hidden_state = self.dropout(last_hidden_state) last_hidden_state = self.activation(last_hidden_state) last_hidden_state = self.pooler(inputs, last_hidden_state) logits = self.classifier(last_hidden_state) outputs = { "logits": logits, "last_hidden_state": last_hidden_state, "labels": labels, } return outputs
[docs] class OmniModelForMatrixRegression(OmniModel): """ This model performs regression on matrix representations of genomic sequences, useful for tasks like contact map prediction, structure prediction, or other matrix-based genomic analysis tasks. Attributes: resnet: ResNet backbone for processing matrix inputs classifier: Linear layer for regression output loss_fn: Mean squared error loss function """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the matrix regression model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__ self.resnet = resnet_b16(channels=128, bbn=16) self.classifier = torch.nn.Linear(1, self.config.num_labels) self.loss_fn = torch.nn.MSELoss() # self.model_info()
[docs] def forward(self, **inputs): """ Forward pass for matrix regression. Args: **inputs: Input tensors including matrix representations and labels Returns: dict: Dictionary containing logits, last_hidden_state, and labels """ labels = inputs.pop("labels", None) matrix_inputs = inputs.pop("matrix_inputs", None) if matrix_inputs is None: raise ValueError("matrix_inputs is required for matrix regression") outputs = self.resnet(matrix_inputs) logits = self.classifier(outputs) outputs = { "logits": logits, "last_hidden_state": outputs, "labels": labels, } return outputs
[docs] def predict(self, sequence_or_inputs, **kwargs): """ Generate predictions for matrix regression. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): predictions.append(logits[i].cpu()) outputs = { "predictions": ( torch.vstack(predictions).to(self.model.device) if predictions[0].shape else torch.tensor(predictions).to(self.model.device) ), "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def inference(self, sequence_or_inputs, **kwargs): """ Perform inference for matrix regression. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): predictions.append(logits[i].cpu()) if not isinstance(sequence_or_inputs, list): outputs = { "predictions": predictions[0], "logits": logits[0], "last_hidden_state": last_hidden_state[0], } else: outputs = { "predictions": predictions, "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def loss_function(self, logits, labels): """ Compute the loss for matrix regression. Args: logits (torch.Tensor): Model predictions labels (torch.Tensor): Ground truth labels Returns: torch.Tensor: Computed loss value """ padding_value = ( self.config.ignore_y if hasattr(self.config, "ignore_y") else -100 ) logits = logits.view(-1) labels = labels.view(-1) mask = torch.where(labels != padding_value) filtered_logits = logits[mask] filtered_targets = labels[mask] loss = self.loss_fn(filtered_logits, filtered_targets) return loss
[docs] class OmniModelForMatrixClassification(OmniModel): """ This model performs classification on matrix representations of genomic sequences, useful for tasks like structure classification, contact map classification, or other matrix-based genomic analysis tasks. Attributes: resnet: ResNet backbone for processing matrix inputs classifier: Linear layer for classification output loss_fn: Cross-entropy loss function """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the matrix classification model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__ # For binary classification, output size is 1 self.classifier = torch.nn.Linear(self.config.hidden_size, 1) self.sigmoid = torch.nn.Sigmoid() # Change to BCEWithLogitsLoss for binary classification self.loss_fn = torch.nn.BCEWithLogitsLoss() self.cnn = resnet_b16(channels=self.config.hidden_size, bbn=16) # self.model_info()
[docs] def forward(self, **inputs): """ Forward pass for matrix classification. Args: **inputs: Input tensors including matrix representations and labels Returns: dict: Dictionary containing logits, last_hidden_state, and labels """ labels = inputs.pop("labels", None) matrix_inputs = inputs.pop("matrix_inputs", None) if matrix_inputs is None: raise ValueError("matrix_inputs is required for matrix classification") outputs = self.resnet(matrix_inputs) logits = self.classifier(outputs) outputs = { "logits": logits, "last_hidden_state": outputs, "labels": labels, } return outputs
[docs] def predict(self, sequence_or_inputs, **kwargs): """ Generate predictions for matrix classification. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): # Apply sigmoid for binary classification pred_class = (logits[i] > 0.5).float() predictions.append(pred_class.cpu()) outputs = { "predictions": ( torch.vstack(predictions).to(self.model.device) if predictions[0].shape else torch.tensor(predictions).to(self.model.device) ), "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def inference(self, sequence_or_inputs, **kwargs): """ Perform inference for matrix classification. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) inputs = raw_outputs["inputs"] logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] probabilities = [] for i in range(logits.shape[0]): i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][ 1:-1 ] probs = i_logit # For binary classification, threshold at 0.5 pred_class = (probs > 0.5).float() predictions.append(pred_class.detach().cpu()) probabilities.append(probs.detach().cpu()) if not isinstance(sequence_or_inputs, list): outputs = { "predictions": predictions[0], "logits": logits[0], "last_hidden_state": last_hidden_state[0], } else: outputs = { "predictions": predictions, "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def loss_function(self, logits, labels): """ Compute the loss for matrix classification. Args: logits (torch.Tensor): Model predictions labels (torch.Tensor): Ground truth labels Returns: torch.Tensor: Computed loss value """ padding_value = ( self.config.ignore_y if hasattr(self.config, "ignore_y") else -100 ) logits = logits.view(-1, self.config.num_labels) labels = labels.view(-1) mask = torch.where(labels != padding_value) # Filter out padding filtered_logits = logits[mask] filtered_targets = labels[mask] # Reshape for binary classification filtered_logits = filtered_logits.view(-1) filtered_targets = filtered_targets.view( -1 ).float() # Convert to float for BCEWithLogitsLoss # Apply BCEWithLogitsLoss loss = self.loss_fn(filtered_logits, filtered_targets) return loss