Source code for omnigenbench.src.model.mlm.model

# -*- coding: utf-8 -*-
# file: model.py
# time: 13:30 10/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
Masked Language Model (MLM) for genomic sequences.

This module provides a masked language model implementation specifically designed
for genomic sequences. It supports masked language modeling tasks where tokens
are randomly masked and the model learns to predict the original tokens.
"""
import numpy as np
import torch
from transformers import BatchEncoding

from ...abc.abstract_model import OmniModel


[docs] class OmniModelForMLM(OmniModel): """ Masked Language Model for genomic sequences. This model implements masked language modeling for genomic sequences, where tokens are randomly masked and the model learns to predict the original tokens. It's useful for pre-training genomic language models and understanding sequence patterns and dependencies. Attributes: loss_fn: Cross-entropy loss function for masked language modeling """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initialize the MLM model. Args: config_or_model: Model configuration or pre-trained model tokenizer: Tokenizer for processing input sequences *args: Additional positional arguments **kwargs: Additional keyword arguments Raises: ValueError: If the model doesn't support masked language modeling """ super().__init__(config_or_model, tokenizer, *args, **kwargs) self.metadata["model_name"] = self.__class__.__name__ if "MaskedLM" not in self.model.__class__.__name__: raise ValueError( "The model does not have a language model head, which is required for MLM." "Please use a model that supports masked language modeling." ) self.loss_fn = torch.nn.CrossEntropyLoss()
[docs] def forward(self, **inputs): """ Forward pass for masked language modeling. Args: **inputs: Input tensors including input_ids, attention_mask, and labels Returns: dict: Dictionary containing loss, logits, and last_hidden_state """ inputs = inputs.pop("inputs") outputs = self.model(**inputs, output_hidden_states=True) last_hidden_state = ( outputs["last_hidden_state"] if "last_hidden_state" in outputs else outputs["hidden_states"][-1] ) logits = outputs["logits"] if "logits" in outputs else None loss = outputs["loss"] if "loss" in outputs else None outputs = { "loss": loss, "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def predict(self, sequence_or_inputs, **kwargs): """ Generate predictions for masked language modeling. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): predictions.append(logits[i].argmax(dim=-1).cpu()) if not isinstance(sequence_or_inputs, list): outputs = { "predictions": predictions[0], "logits": logits[0], "last_hidden_state": last_hidden_state[0], } else: outputs = { "predictions": ( torch.stack(predictions) if predictions[0].shape else torch.tensor(predictions).to(self.model.device) ), "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def inference(self, sequence_or_inputs, **kwargs): """ Perform inference for masked language modeling, decoding predictions to sequences. Args: sequence_or_inputs: Input sequences or pre-processed inputs **kwargs: Additional keyword arguments Returns: dict: Dictionary containing decoded predictions, logits, and last_hidden_state """ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) inputs = raw_outputs["inputs"] logits = raw_outputs["logits"] last_hidden_state = raw_outputs["last_hidden_state"] predictions = [] for i in range(logits.shape[0]): i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][ 1:-1 ] prediction = self.tokenizer.decode(i_logit.argmax(dim=-1)).replace(" ", "") predictions.append(list(prediction)) if not isinstance(sequence_or_inputs, list): outputs = { "predictions": predictions[0], "logits": logits[0], "last_hidden_state": last_hidden_state[0], } else: outputs = { "predictions": predictions, "logits": logits, "last_hidden_state": last_hidden_state, } return outputs
[docs] def loss_function(self, logits, labels): """ Compute the loss for masked language modeling. Args: logits (torch.Tensor): Model predictions [batch_size, seq_len, vocab_size] labels (torch.Tensor): Ground truth labels [batch_size, seq_len] Returns: torch.Tensor: Computed cross-entropy loss value """ loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(logits.view(-1, self.tokenizer.vocab_size), labels.view(-1)) return loss