Source code for omnigenbench.src.model.module_utils

# -*- coding: utf-8 -*-
# file: module_utils.py
# time: 22:53 18/07/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
This module provides utility classes and functions for handling model inputs,
pooling operations, and attention mechanisms used across different OmniGenome model types.
"""
import torch
import torch.nn as nn

from transformers.models.bert.modeling_bert import BertPooler
from transformers.tokenization_utils_base import BatchEncoding


[docs] class OmniPooling(torch.nn.Module): """ A flexible pooling layer for OmniGenome models that handles different input formats. This class provides a unified interface for pooling operations across different model architectures, supporting both causal language models and encoder-based models. It can handle various input formats including tuples, dictionaries, BatchEncoding objects, and tensors. Attributes: config: Model configuration object containing architecture and tokenizer settings pooler: BertPooler instance for non-causal models, None for causal models """ def __init__(self, config, *args, **kwargs): """ Initialize the OmniPooling layer. Args: config: Model configuration object containing architecture information *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__(*args, **kwargs) self.config = config self.pooler = BertPooler(self.config) if not self._is_causal_lm() else None
[docs] def forward(self, inputs, last_hidden_state): """ Perform pooling operation on the last hidden state. This method handles different input formats and applies appropriate pooling: - For causal language models: Uses the last non-padded token - For encoder models: Uses the BertPooler Args: inputs: Input data in various formats (tuple, dict, BatchEncoding, or tensor) last_hidden_state (torch.Tensor): Hidden states from the model [batch_size, seq_len, hidden_size] Returns: torch.Tensor: Pooled representation [batch_size, hidden_size] Raises: ValueError: If input format is not supported or cannot be parsed """ if isinstance(inputs, tuple): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else None elif isinstance(inputs, BatchEncoding) or isinstance(inputs, dict): input_ids = inputs["input_ids"] attention_mask = ( inputs["attention_mask"] if "attention_mask" in inputs else None ) elif isinstance(inputs, torch.Tensor): shape = inputs.shape try: if len(shape) == 3: # compatible with hf_trainer in AutoBenchmark if shape[1] == 2: input_ids = inputs[:, 0] attention_mask = inputs[:, 1] else: input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else None elif len(shape) == 2: input_ids = inputs attention_mask = None except: raise ValueError( f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}." ) else: raise ValueError( f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}." ) if not self.pooler: pad_token_id = getattr(self.config, "pad_token_id", -100) sequence_lengths = input_ids.ne(pad_token_id).sum(dim=1) - 1 last_hidden_state = last_hidden_state[ torch.arange(input_ids.size(0), device=last_hidden_state.device), sequence_lengths, ] else: last_hidden_state = self.pooler(last_hidden_state) return last_hidden_state
def _is_causal_lm(self): """ Check if the model is a causal language model. Determines if the model architecture is causal based on the configuration. Returns: bool: True if the model is a causal language model, False otherwise """ if ( hasattr(self.config, "architectures") and "CausalLM" in str(self.config.architectures) ) or ( hasattr(self.config, "auto_map") and "CausalLM" in str(self.config.auto_map) ): return True else: return False
# class InteractingAttention(nn.Module): # def __init__(self, embed_size, num_heads=12): # super(InteractingAttention, self).__init__() # self.num_heads = num_heads # self.embed_size = embed_size # # assert embed_size % num_heads == 0, "Embedding size should be divisible by number of heads" # # self.head_dim = embed_size // num_heads # # self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) # self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) # self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) # self.layer_norm = nn.LayerNorm(num_heads * self.head_dim, eps=1e-6) # # self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size) # # # def forward(self, query, keys, values): # def forward(self, query, keys, values): # # N = query.shape[0] # value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # # # Split embedding into self.num_heads pieces # values = values.reshape(N, value_len, self.num_heads, self.head_dim) # keys = keys.reshape(N, key_len, self.num_heads, self.head_dim) # queries = query.reshape(N, query_len, self.num_heads, self.head_dim) # # values = self.values(values) # (N, value_len, heads, head_dim) # keys = self.keys(keys) # (N, key_len, heads, head_dim) # queries = self.queries(queries) # (N, query_len, heads, head_dim) # # energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # # attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3) # (N, heads, query_len, key_len) # # out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( # N, query_len, self.num_heads * self.head_dim # ) # out = self.layer_norm(out + query) # out = self.fc_out(out) # out = self.layer_norm(out + query) # return out
[docs] class InteractingAttention(nn.Module): """ An interacting attention mechanism for sequence modeling. This class implements a multi-head attention mechanism with residual connections and layer normalization. It's designed for processing sequences where different parts of the sequence need to interact with each other. Attributes: attention: Multi-head attention layer layer_norm: Layer normalization for residual connections fc_out: Output projection layer """ def __init__(self, embed_size, num_heads=24): """ Initialize the InteractingAttention module. Args: embed_size (int): Size of the embedding dimension num_heads (int): Number of attention heads (default: 24) Raises: AssertionError: If embed_size is not divisible by num_heads """ super(InteractingAttention, self).__init__() assert ( embed_size % num_heads == 0 ), "Embedding size should be divisible by number of heads" self.attention = nn.MultiheadAttention( embed_dim=embed_size, num_heads=num_heads, batch_first=True ) self.layer_norm = nn.LayerNorm(embed_size, eps=1e-6) self.fc_out = nn.Linear(embed_size, embed_size)
[docs] def forward(self, query, keys, values): """ Forward pass through the interacting attention mechanism. Args: query (torch.Tensor): Query tensor [batch_size, query_len, embed_size] keys (torch.Tensor): Key tensor [batch_size, key_len, embed_size] values (torch.Tensor): Value tensor [batch_size, value_len, embed_size] Returns: torch.Tensor: Output tensor with same shape as query """ att_output, _ = self.attention(query, keys, values) query = self.layer_norm(att_output + query) output = self.fc_out(query) output = self.layer_norm(output + query) return output