Source code for omnigenbench.src.model.augmentation.model

# -*- coding: utf-8 -*-
# file: model.py
# time: 18:37 22/09/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
Data augmentation model for genomic sequences.

This module provides a data augmentation model that uses masked language modeling
to generate augmented versions of genomic sequences. It's useful for expanding
training datasets and improving model robustness.
"""
import torch
import random
import json
import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer
import autocuda


[docs] class OmniModelForAugmentation(torch.nn.Module): """ Data augmentation model for genomic sequences using masked language modeling. This model uses a pre-trained masked language model to generate augmented versions of genomic sequences by randomly masking tokens and predicting replacements. It's useful for expanding training datasets and improving model generalization. Attributes: tokenizer: Tokenizer for processing genomic sequences model: Pre-trained masked language model device: Device to run the model on (CPU or GPU) noise_ratio: Proportion of tokens to mask for augmentation max_length: Maximum sequence length for tokenization k: Number of augmented instances to generate per sequence """ def __init__( self, config_or_model=None, noise_ratio=0.15, max_length=1026, instance_num=1, batch_size=32, use_amp=None, *args, **kwargs ): """ Initialize the augmentation model. Args: config_or_model (str): Path or model name for loading the pre-trained model noise_ratio (float): The proportion of tokens to mask in each sequence for augmentation (default: 0.15) max_length (int): The maximum sequence length for tokenization (default: 1026) instance_num (int): Number of augmented instances to generate per sequence (default: 1) batch_size (int): Batch size used when running the MLM forward pass (default: 32) use_amp (bool|None): Whether to use automatic mixed precision for speed on GPU (default: auto-detect) *args: Additional positional arguments **kwargs: Additional keyword arguments """ super().__init__() try: self.tokenizer = AutoTokenizer.from_pretrained(config_or_model) except Exception as e: if "RnaTokenizer" in str(e): from multimolecule import RnaTokenizer self.tokenizer = RnaTokenizer.from_pretrained(config_or_model) self.model = AutoModelForMaskedLM.from_pretrained( config_or_model, trust_remote_code=True ) self.device = autocuda.auto_cuda() self.model.to(self.device) self.model.eval() # Hyperparameters for augmentation self.noise_ratio = noise_ratio self.max_length = max_length self.k = instance_num self.batch_size = int(batch_size) if batch_size is not None else 32 if use_amp is None: self.use_amp = torch.cuda.is_available() else: self.use_amp = bool(use_amp)
[docs] def load_sequences_from_file(self, input_file): """ Load sequences from a JSON file. Args: input_file (str): Path to the input JSON file containing sequences Returns: list: List of sequences loaded from the file """ sequences = [] with open(input_file, "r") as f: for line in f.readlines(): sequences.append(json.loads(line)["seq"]) return sequences
[docs] def apply_noise_to_sequence(self, seq): """ Apply noise to a single sequence by randomly masking tokens. Args: seq (str): Input genomic sequence Returns: str: Sequence with randomly masked tokens """ seq_list = self.tokenizer.tokenize(seq) if not seq_list: return seq mask_cnt = ( max(1, int(len(seq_list) * self.noise_ratio)) if self.noise_ratio > 0 else 0 ) for _ in range(mask_cnt): random_idx = random.randint(0, len(seq_list) - 1) seq_list[random_idx] = self.tokenizer.mask_token return "".join(seq_list)
def _augment_batch(self, seq_list): """ Augment a batch of masked sequences using a single forward pass. Args: seq_list (List[str]): List of masked sequences Returns: List[str]: Augmented sequences """ if not seq_list: return [] tokenized_inputs = self.tokenizer( seq_list, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt", ) with torch.no_grad(): if self.use_amp: # autocast speeds up FP16/BF16 on GPU while preserving accuracy for decoding autocast_ctx = torch.cuda.amp.autocast( dtype=torch.float16 if torch.cuda.is_available() else None ) else: # dummy context manager class _Noop: def __enter__(self): return None def __exit__(self, exc_type, exc, tb): return False autocast_ctx = _Noop() with autocast_ctx: outputs = self.model(**tokenized_inputs.to(self.device)) logits = ( outputs["logits"] if isinstance(outputs, dict) else outputs.logits ) predicted_tokens = logits.argmax(dim=-1).detach().cpu() input_ids = tokenized_inputs["input_ids"].cpu() mask_id = self.tokenizer.mask_token_id # Replace [MASK] positions with top-1 predictions mask_positions = input_ids == mask_id input_ids[mask_positions] = predicted_tokens[mask_positions] # Decode back to strings augmented_sequences = self.tokenizer.batch_decode( input_ids, skip_special_tokens=True ) return augmented_sequences
[docs] def augment_sequence(self, seq): """ Perform augmentation on a single sequence by predicting masked tokens. Args: seq (str): Input genomic sequence with masked tokens Returns: str: Augmented sequence with predicted tokens replacing masked tokens """ # Keep single-sample path for API compatibility, implemented via batch for efficiency return self._augment_batch([seq])[0]
[docs] def augment(self, seq, k=None): """ Generate multiple augmented instances for a single sequence. Args: seq (str): Input genomic sequence k (int, optional): Number of augmented instances to generate (default: None, uses self.k) Returns: list: List of augmented sequences """ k = self.k if k is None else int(k) # Prepare K noised variants, then process in mini-batches noised_variants = [self.apply_noise_to_sequence(seq) for _ in range(k)] augmented_sequences = [] for i in range(0, len(noised_variants), self.batch_size): batch = noised_variants[i : i + self.batch_size] augmented_sequences.extend(self._augment_batch(batch)) return augmented_sequences
[docs] def augment_sequences(self, sequences): """ Augment a list of sequences by applying noise and performing MLM-based predictions. Args: sequences (list): List of genomic sequences to augment Returns: list: List of all augmented sequences """ all_augmented_sequences = [] total = len(sequences) * max(1, self.k) pbar = tqdm.tqdm(total=total, desc="Augmenting Sequences", unit="inst") buffer = [] # We accumulate masked sequences across inputs until batch_size, then run one forward pass for seq in sequences: # generate k masked copies for this sequence for _ in range(self.k): buffer.append(self.apply_noise_to_sequence(seq)) # flush when buffer reaches batch size if len(buffer) >= self.batch_size: all_augmented_sequences.extend(self._augment_batch(buffer)) pbar.update(len(buffer)) buffer = [] # flush the tail if buffer: all_augmented_sequences.extend(self._augment_batch(buffer)) pbar.update(len(buffer)) pbar.close() return all_augmented_sequences
[docs] def save_augmented_sequences(self, augmented_sequences, output_file): """ Save augmented sequences to a JSON file. Args: augmented_sequences (list): List of augmented sequences to save output_file (str): Path to the output JSON file """ with open(output_file, "w") as f: for seq in augmented_sequences: f.write(json.dumps({"aug_seq": seq}) + "\n")
[docs] def augment_from_file(self, input_file, output_file): """ Main function to handle the augmentation process from a file input to a file output. This method loads sequences from an input file, augments them using the MLM model, and saves the augmented sequences to an output file. Args: input_file (str): Path to the input file containing sequences output_file (str): Path to the output file where augmented sequences will be saved """ sequences = self.load_sequences_from_file(input_file) augmented_sequences = self.augment_sequences(sequences) self.save_augmented_sequences(augmented_sequences, output_file)
# Example usage if __name__ == "__main__": model = OmniModelForAugmentation( config_or_model="anonymous8/OmniGenome-186M", noise_ratio=0.2, # Example noise ratio max_length=1026, # Maximum token length instance_num=3, # Number of augmented instances per sequence batch_size=64, ) aug = model.augment_sequence("ATCTTGCATTGAAG") input_file = "toy_datasets/test.json" output_file = "toy_datasets/augmented_sequences.json" model.augment_from_file(input_file, output_file)