Source code for omnigenbench.src.tokenizer.kmers_tokenizer

# -*- coding: utf-8 -*-
# file: kmers_tokenizer.py
# time: 18:31 08/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import warnings

from ..abc.abstract_tokenizer import OmniTokenizer

warnings.filterwarnings("once")


[docs] class OmniKmersTokenizer(OmniTokenizer): """ A k-mer based tokenizer for genomic sequences. This tokenizer breaks genomic sequences into overlapping k-mers and uses a base tokenizer to convert them into token IDs. It supports various k-mer sizes and overlap configurations for different genomic applications. Attributes: base_tokenizer: The underlying tokenizer for converting k-mers to IDs k: Size of k-mers overlap: Number of overlapping positions between consecutive k-mers max_length: Maximum sequence length for tokenization metadata: Dictionary containing tokenizer metadata Example: >>> from omnigenbench import OmniKmersTokenizer >>> from transformers import AutoTokenizer >>> base_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") >>> tokenizer = OmniKmersTokenizer(base_tokenizer, k=4, overlap=2) >>> sequence = "ACGUAGGUAUCGUAGA" >>> tokens = tokenizer.tokenize(sequence) >>> print(tokens) [['ACGU', 'GUAG', 'UAGG', 'AGGU', 'GGUA', 'GUAU', 'UAUC', 'AUCG', 'UCGU', 'CGUA', 'GUAG', 'UAGA']] """ def __init__(self, base_tokenizer=None, k=3, overlap=0, max_length=512, **kwargs): """ Initialize the OmniKmersTokenizer. Args: base_tokenizer: The base tokenizer for converting k-mers to token IDs k (int, optional): Size of k-mers. Defaults to 3 overlap (int, optional): Number of overlapping positions between consecutive k-mers. Defaults to 0 max_length (int, optional): Maximum sequence length for tokenization. Defaults to 512 **kwargs: Additional keyword arguments passed to parent class """ super(OmniKmersTokenizer, self).__init__(base_tokenizer, **kwargs) self.k = k self.overlap = overlap self.max_length = max_length self.metadata["tokenizer_name"] = self.__class__.__name__ def __call__(self, sequence, **kwargs): """ Tokenize a sequence or list of sequences into tokenized inputs. This method processes the input sequence(s) by first converting them to k-mers, then using the base tokenizer to convert k-mers to token IDs. It handles sequence preprocessing (U/T conversion) and adds special tokens. Args: sequence (str or list): Input sequence(s) to tokenize **kwargs: Additional keyword arguments including max_length Returns: dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask' Example: >>> sequence = "ACGUAGGUAUCGUAGA" >>> tokenized = tokenizer(sequence) >>> print(tokenized['input_ids'].shape) torch.Size([1, 14]) """ if self.u2t: sequence = "".join([seq.replace("U", "T").upper() for seq in sequence]) if self.t2u: sequence = "".join([seq.replace("T", "U").upper() for seq in sequence]) sequence_tokens = self.tokenize(sequence)[ : kwargs.get("max_length", self.max_length) - 2 ] tokenized_inputs = { "input_ids": [], "attention_mask": [], } bos_id = ( self.base_tokenizer.bos_token_id if self.base_tokenizer.bos_token_id is not None else self.base_tokenizer.cls_token_id ) eos_id = ( self.base_tokenizer.eos_token_id if self.base_tokenizer.eos_token_id is not None else self.base_tokenizer.sep_token_id ) for tokens in sequence_tokens: tokenized_inputs["input_ids"].append( [bos_id] + self.base_tokenizer.convert_tokens_to_ids(tokens) + [eos_id] ) tokenized_inputs["attention_mask"].append( [1] * len(tokenized_inputs["input_ids"][-1]) ) for i, ids in enumerate(tokenized_inputs["input_ids"]): if ids.count(self.base_tokenizer.unk_token_id) / len(ids) > 0.1: warnings.warn( f"Unknown tokens are more than 10% in the {i}th sequence, please check the tokenization process." ) tokenized_inputs = self.base_tokenizer.pad( tokenized_inputs, padding="max_length", max_length=len(sequence_tokens), return_attention_mask=True, return_tensors="pt", ) return tokenized_inputs
[docs] @staticmethod def from_pretrained(config_or_model, **kwargs): """ Create a k-mers tokenizer from a pre-trained model. Args: config_or_model (str): Name or path of the pre-trained model **kwargs: Additional keyword arguments Returns: OmniKmersTokenizer: Initialized k-mers tokenizer Example: >>> tokenizer = OmniKmersTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") >>> print(type(tokenizer)) <class 'omnigenome.src.tokenizer.kmers_tokenizer.OmniKmersTokenizer'> """ from transformers import AutoTokenizer self = OmniKmersTokenizer( AutoTokenizer.from_pretrained(config_or_model, **kwargs) ) return self
[docs] def tokenize(self, sequence, **kwargs): """ Convert sequence(s) into k-mers. This method breaks the input sequence(s) into overlapping k-mers based on the configured k-mer size and overlap parameters. Args: sequence (str or list): Input sequence(s) to convert to k-mers **kwargs: Additional keyword arguments Returns: list: List of k-mer lists for each input sequence Example: >>> sequence = "ACGUAGGUAUCGUAGA" >>> k_mers = tokenizer.tokenize(sequence) >>> print(k_mers[0][:3]) ['ACGU', 'GUAG', 'UAGG'] """ if isinstance(sequence, str): sequences = [sequence] else: sequences = sequence sequence_tokens = [] for i in range(len(sequences)): tokens = [] for j in range(0, len(sequences[i]), self.k - self.overlap): tokens.append(sequences[i][j : j + self.k]) sequence_tokens.append(tokens) return sequence_tokens
[docs] def encode(self, input_ids, **kwargs): """ Encode input IDs using the base tokenizer. Args: input_ids: Input IDs to encode **kwargs: Additional keyword arguments Returns: Encoded input IDs """ return self.base_tokenizer.encode(input_ids, **kwargs)
[docs] def decode(self, input_ids, **kwargs): """ Decode input IDs using the base tokenizer. Args: input_ids: Input IDs to decode **kwargs: Additional keyword arguments Returns: Decoded sequence """ return self.base_tokenizer.decode(input_ids, **kwargs)
[docs] def encode_plus(self, sequence, **kwargs): """ Encode a sequence with additional information. This method is not yet implemented for k-mers tokenizer. Args: sequence: Input sequence **kwargs: Additional keyword arguments Raises: NotImplementedError: This method is not implemented yet """ raise NotImplementedError("The encode_plus() function is not implemented yet.")
if __name__ == "__main__": from transformers import AutoTokenizer # RNA = "ACGUAGGUAUCGUAGA" # # base_tokenizer_name = 'bert-base-cased' # base_tokenizer_name = "facebook/esm2_t12_35M_UR50D" # base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name) # tokenizer = KmersTokenizer(base_tokenizer) # tokens = tokenizer.tokenize(RNA) # fprint(tokens) # tokenized_inputs = tokenizer(RNA) # fprint(tokenized_inputs) RNA = "ACGUAGGUAUCGUAGA" # base_tokenizer_name = 'bert-base-cased' base_tokenizer_name = "facebook/esm2_t12_35M_UR50D" base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name) tokenizer = OmniKmersTokenizer(base_tokenizer, k=4, overlap=2, max_length=512) tokens = tokenizer.tokenize(RNA) print(tokens) tokenized_inputs = tokenizer(RNA) print(tokenized_inputs)