# -*- coding: utf-8 -*-
# file: model.py
# time: 18:37 22/09/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import torch
from transformers import AutoTokenizer, AutoModel
from ....src.misc.utils import fprint
[docs]
class OmniModelForEmbedding(torch.nn.Module):
"""
This class provides a unified interface for loading pre-trained models and
generating embeddings from genomic sequences. It supports various aggregation
methods and batch processing for efficient embedding generation.
Attributes:
tokenizer: The tokenizer for processing input sequences
model: The pre-trained model for generating embeddings
_device: The device (CPU/GPU) where the model is loaded
Example:
>>> from omnigenbench import OmniModelForEmbedding
>>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
>>> sequences = ["ATCGGCTA", "GGCTAGCTA"]
>>> embeddings = model.batch_encode(sequences)
>>> print(f"Embeddings shape: {embeddings.shape}")
torch.Size([2, 768])
"""
def __init__(self, config_or_model, tokenizer=None, *args, **kwargs):
"""
Initialize the embedding model.
Args:
config_or_model (str): Name or path of the pre-trained model to load
tokenizer (optional): Pre-loaded tokenizer. If None, loads from config_or_model
*args: Additional positional arguments passed to AutoModel.from_pretrained
**kwargs: Additional keyword arguments passed to AutoModel.from_pretrained
"""
super().__init__()
# Use provided tokenizer or load from model path
self.tokenizer = (
tokenizer
if tokenizer is not None
else AutoTokenizer.from_pretrained(config_or_model)
)
self.model = AutoModel.from_pretrained(config_or_model, *args, **kwargs)
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self._device)
self.model.eval() # Set model to evaluation mode
[docs]
def batch_encode(
self,
sequences,
batch_size=8,
max_length=512,
agg="head",
require_grad: bool = False,
return_on_cpu: bool = True,
use_autocast: bool = False,
amp_dtype=None,
):
"""Batch encode sequences into aggregated (pooled) embeddings.
Args:
sequences (List[str]): Input DNA or RNA sequences for encoding.
batch_size (int, default=8): Number of sequences to process per batch.
max_length (int, default=512): Maximum sequence length for tokenization.
agg (str, default="head"): Aggregation method for pooling. Options: "head", "mean", "tail".
require_grad (bool, default=False): Whether to preserve gradients for fine-tuning.
return_on_cpu (bool, default=True): Whether to move results to CPU memory.
use_autocast (bool, default=False): Whether to enable mixed precision (CUDA only).
amp_dtype (torch.dtype, optional): Data type for automatic mixed precision.
Returns:
torch.Tensor: Pooled embeddings with shape (num_sequences, hidden_size).
Note:
This method maintains backward compatibility with existing code.
When require_grad=True, gradients flow through the model for end-to-end training.
Example:
>>> sequences = ["ATCGGCTA", "GGCTAGCTA"]
>>> embeddings = model.batch_encode(sequences, batch_size=4, agg="mean")
>>> print(embeddings.shape)
torch.Size([2, 768])
"""
# Handle empty sequence list
if not sequences:
# Return empty tensor with correct shape (0, hidden_size)
hidden_size = self.model.config.hidden_size
return torch.empty(0, hidden_size, dtype=torch.float32)
embeds = []
device = self.device
is_cuda = isinstance(device, torch.device) and device.type == "cuda"
for i in range(0, len(sequences), batch_size):
batch_sequences = sequences[i : i + batch_size]
inputs = self.tokenizer(
batch_sequences,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
ctx = (
(
torch.autocast(device_type="cuda", dtype=amp_dtype)
if (use_autocast and is_cuda)
else torch.enable_grad()
)
if require_grad
else torch.no_grad()
)
with ctx:
outputs = self.model(**inputs).last_hidden_state # (B,L,H)
hidden = outputs if not return_on_cpu else outputs.cpu()
if agg == "head":
pooled = hidden[:, 0, :]
elif agg == "mean":
mask = (
inputs["attention_mask"]
if not return_on_cpu
else inputs["attention_mask"].cpu()
)
pooled = (hidden * mask.unsqueeze(-1)).sum(1) / mask.sum(
1, keepdim=True
)
elif agg == "tail":
mask = inputs["attention_mask"]
lengths = mask.sum(1) - 1
pooled_list = []
for bi, l in enumerate(lengths):
pooled_list.append(hidden[bi, int(l.item()), :])
pooled = torch.stack(pooled_list, 0)
else:
raise ValueError(f"Unsupported agg: {agg}")
embeds.append(pooled)
out = torch.cat(embeds, 0)
return out
[docs]
def batch_encode_tokens(
self,
sequences,
batch_size=8,
max_length=512,
use_autocast=False,
amp_dtype=None,
require_grad: bool = False,
return_on_cpu: bool = True,
):
"""
Encode sequences to token-level embeddings (last_hidden_state).
Args:
sequences (List[str]): Input DNA/RNA sequences for token-level encoding
batch_size (int, default=8): Number of sequences to process per batch
max_length (int, default=512): Maximum sequence length for tokenization
use_autocast (bool, default=False): Enable mixed precision training (CUDA only)
amp_dtype (torch.dtype, optional): Data type for automatic mixed precision
require_grad (bool, default=False): Preserve gradient computation graph for fine-tuning
return_on_cpu (bool, default=True): Transfer outputs to CPU memory
Returns:
torch.Tensor: Token embeddings with shape (num_sequences, max_length, hidden_size)
Note:
When require_grad=True, gradients flow through the transformer model for end-to-end training.
Set return_on_cpu=False to keep tensors on GPU device for downstream processing.
"""
outputs = []
for i in range(0, len(sequences), batch_size):
batch_sequences = sequences[i : i + batch_size]
inputs = self.tokenizer(
batch_sequences,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
)
inputs = {key: value.to(self.device) for key, value in inputs.items()}
ctx = (
(
torch.autocast(device_type="cuda", dtype=amp_dtype)
if (
use_autocast
and isinstance(self.device, torch.device)
and self.device.type == "cuda"
)
else torch.enable_grad()
)
if require_grad
else torch.no_grad()
)
with ctx:
last_hidden = self.model(**inputs).last_hidden_state # (B, L, H)
if return_on_cpu:
last_hidden = last_hidden.cpu()
outputs.append(last_hidden)
out = torch.cat(outputs, dim=0)
return out
[docs]
def encode_tokens(
self,
sequence,
max_length=512,
use_autocast=False,
amp_dtype=None,
require_grad: bool = False,
return_on_cpu: bool = True,
):
"""
Encode a single sequence to token-level embeddings.
Args:
sequence (str): Input DNA/RNA sequence for token-level encoding
max_length (int, default=512): Maximum sequence length for tokenization
use_autocast (bool, default=False): Enable mixed precision training (CUDA only)
amp_dtype (torch.dtype, optional): Data type for automatic mixed precision
require_grad (bool, default=False): Preserve gradient computation graph for fine-tuning
return_on_cpu (bool, default=True): Transfer output to CPU memory
Returns:
torch.Tensor: Token embeddings with shape (max_length, hidden_size)
Example:
>>> model = OmniModelForEmbedding("yangheng/OmniGenome-52M")
>>> sequence = "ATCGATCGATCG"
>>> token_embeddings = model.encode_tokens(sequence, max_length=200)
>>> print(f"Token embeddings shape: {token_embeddings.shape}")
torch.Size([200, 768])
"""
device = self.device
inputs = self.tokenizer(
sequence,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
is_cuda = isinstance(device, torch.device) and device.type == "cuda"
ctx = (
(
torch.autocast(device_type="cuda", dtype=amp_dtype)
if (use_autocast and is_cuda)
else torch.enable_grad()
)
if require_grad
else torch.no_grad()
)
with ctx:
hidden = self.model(**inputs).last_hidden_state # (1, L, H)
if return_on_cpu:
hidden = hidden.cpu()
# Remove batch dimension for single sequence
return hidden.squeeze(0) # (L, H)
[docs]
def encode(
self,
sequence,
max_length=512,
agg="head",
keep_dim=False,
require_grad: bool = False,
return_on_cpu: bool = True,
use_autocast: bool = False,
amp_dtype=None,
):
"""Encode a single sequence into pooled embeddings.
Args:
sequence (str): Input DNA or RNA sequence for encoding.
max_length (int, default=512): Maximum sequence length for tokenization.
agg (str, default="head"): Aggregation strategy for pooling. Options: "head", "mean", "tail".
keep_dim (bool, default=False): Whether to preserve batch dimension in output.
require_grad (bool, default=False): Whether to preserve gradients for fine-tuning.
return_on_cpu (bool, default=True): Whether to move results to CPU memory.
use_autocast (bool, default=False): Whether to enable mixed precision.
amp_dtype (torch.dtype, optional): Data type for automatic mixed precision.
Returns:
torch.Tensor: Pooled embedding with shape (hidden_size,) or (1, hidden_size) if keep_dim=True.
Example:
>>> sequence = "ATCGATCGATCG"
>>> embedding = model.encode(sequence, agg="mean", max_length=200)
>>> print(embedding.shape)
torch.Size([768])
"""
device = self.device
inputs = self.tokenizer(
sequence,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
is_cuda = isinstance(device, torch.device) and device.type == "cuda"
ctx = (
(
torch.autocast(device_type="cuda", dtype=amp_dtype)
if (use_autocast and is_cuda)
else torch.enable_grad()
)
if require_grad
else torch.no_grad()
)
with ctx:
hidden = self.model(**inputs).last_hidden_state # (1,L,H)
hidden = hidden if not return_on_cpu else hidden.cpu()
if agg == "head":
emb = hidden[:, 0, :]
elif agg == "mean":
mask = (
inputs["attention_mask"]
if not return_on_cpu
else inputs["attention_mask"].cpu()
)
emb = (hidden * mask.unsqueeze(-1)).sum(1) / mask.sum(1, keepdim=True)
elif agg == "tail":
mask = inputs["attention_mask"]
l = int(mask.sum(1).item()) - 1
emb = hidden[:, l, :]
else:
raise ValueError(f"Unsupported agg: {agg}")
if not keep_dim:
emb = emb.squeeze(0)
return emb
[docs]
def save_embeddings(self, embeddings, output_path):
"""
Save the generated embeddings to a file.
Args:
embeddings (torch.Tensor): The embeddings to save
output_path (str): Path to save the embeddings
Example:
>>> embeddings = model.batch_encode(sequences)
>>> model.save_embeddings(embeddings, "embeddings.pt")
>>> print("Embeddings saved successfully")
"""
torch.save(embeddings, output_path)
fprint(f"Embeddings saved to {output_path}")
[docs]
def load_embeddings(self, embedding_path):
"""
Load embeddings from a file.
Args:
embedding_path (str): Path to the saved embeddings
Returns:
torch.Tensor: The loaded embeddings
Example:
>>> embeddings = model.load_embeddings("embeddings.pt")
>>> print(f"Loaded embeddings shape: {embeddings.shape}")
torch.Size([100, 768])
"""
embeddings = torch.load(embedding_path, map_location=self.device)
fprint(f"Embeddings loaded from {embedding_path}")
return embeddings
[docs]
def compute_similarity(self, embedding1, embedding2, dim=0):
"""
Compute cosine similarity between two embeddings.
Args:
embedding1 (torch.Tensor or np.ndarray): The first embedding
embedding2 (torch.Tensor or np.ndarray): The second embedding
dim (int, optional): Dimension along which to compute cosine similarity. Defaults to 0
Returns:
float: Cosine similarity score between -1 and 1
Example:
>>> emb1 = model.encode("ATCGGCTA")
>>> emb2 = model.encode("GGCTAGCTA")
>>> similarity = model.compute_similarity(emb1, emb2)
>>> print(f"Cosine similarity: {similarity:.4f}")
0.8234
"""
# Convert numpy arrays to tensors if needed
import numpy as np
if isinstance(embedding1, np.ndarray):
embedding1 = torch.from_numpy(embedding1)
if isinstance(embedding2, np.ndarray):
embedding2 = torch.from_numpy(embedding2)
similarity = torch.nn.functional.cosine_similarity(
embedding1, embedding2, dim=dim
)
return similarity
[docs]
def get_attention_statistics(
self,
attention_scores,
attention_mask=None,
layer_aggregation="mean",
head_aggregation="mean",
):
"""Compute comprehensive statistics from attention scores.
This method analyzes attention patterns by computing various statistical measures
that help understand the model's focus and attention distribution.
Args:
attention_scores (torch.Tensor): Attention tensor with shape (num_layers, num_heads, seq_len, seq_len).
attention_mask (torch.Tensor, optional): Attention mask to exclude padding tokens from statistics.
layer_aggregation (str, default="mean"): Method to aggregate across transformer layers.
Options: "mean", "max", "sum", "first", "last".
head_aggregation (str, default="mean"): Method to aggregate across attention heads.
Options: "mean", "max", "sum".
Returns:
Dict[str, torch.Tensor]: Dictionary containing attention statistics:
- 'attention_matrix': Aggregated attention matrix
- 'attention_entropy': Entropy measure of attention distribution
- 'max_attention_per_position': Maximum attention value for each position
- 'attention_concentration': Measure of attention concentration (L2 norm)
- 'self_attention_scores': Self-attention scores (diagonal values)
Example:
>>> result = model.extract_attention_scores(sequence)
>>> stats = model.get_attention_statistics(result['attentions'], result['attention_mask'])
>>> print(f"Average attention entropy: {stats['attention_entropy'].mean():.4f}")
Note:
Higher entropy indicates more distributed attention, while lower entropy
suggests more focused attention patterns.
"""
if attention_mask is not None:
# Create a mask for valid positions (excluding padding)
mask = attention_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len)
mask = mask * attention_mask.unsqueeze(0).unsqueeze(
-1
) # (1, 1, seq_len, seq_len)
# Apply mask to attention scores
attention_scores = attention_scores * mask
# Aggregate across heads
if head_aggregation == "mean":
head_aggregated = attention_scores.mean(dim=1)
elif head_aggregation == "max":
head_aggregated = attention_scores.max(dim=1)[0]
elif head_aggregation == "sum":
head_aggregated = attention_scores.sum(dim=1)
else:
raise ValueError(f"Unsupported head_aggregation: {head_aggregation}")
# Aggregate across layers
if layer_aggregation == "mean":
layer_aggregated = head_aggregated.mean(dim=0)
elif layer_aggregation == "max":
layer_aggregated = head_aggregated.max(dim=0)[0]
elif layer_aggregation == "sum":
layer_aggregated = head_aggregated.sum(dim=0)
elif layer_aggregation == "first":
layer_aggregated = head_aggregated[0]
elif layer_aggregation == "last":
layer_aggregated = head_aggregated[-1]
else:
raise ValueError(f"Unsupported layer_aggregation: {layer_aggregation}")
# Compute various statistics
statistics = {
"attention_matrix": layer_aggregated,
"attention_entropy": -torch.sum(
layer_aggregated * torch.log(layer_aggregated + 1e-9), dim=-1
),
"max_attention_per_position": layer_aggregated.max(dim=-1)[0],
"attention_concentration": (layer_aggregated**2).sum(
dim=-1
), # How concentrated attention is
"self_attention_scores": torch.diag(
layer_aggregated
), # Diagonal values (self-attention)
}
return statistics
[docs]
def visualize_attention_pattern(
self,
attention_result,
layer_idx=0,
head_idx=0,
save_path=None,
figsize=(12, 10),
):
"""Visualize attention patterns as an interactive heatmap.
This method creates a visual representation of attention weights, helping to understand
which sequence positions the model focuses on during processing.
Args:
attention_result (Dict): Result dictionary from extract_attention_scores() or
batch_extract_attention_scores() containing attention data.
layer_idx (int, default=0): Index of the transformer layer to visualize.
head_idx (int, default=0): Index of the attention head to visualize.
save_path (str, optional): File path to save the visualization image.
If None, the plot is not saved to disk.
figsize (tuple, default=(12, 10)): Figure size as (width, height) in inches.
Returns:
matplotlib.figure.Figure: The generated matplotlib figure object, or None if
matplotlib is not available.
Example:
>>> sequence = "ATCGATCGATCG"
>>> result = model.extract_attention_scores(sequence)
>>> fig = model.visualize_attention_pattern(
... result, layer_idx=0, head_idx=0, save_path="attention_plot.png"
... )
>>> # fig.show() # Display the plot
Note:
Requires matplotlib for visualization. Install with: pip install matplotlib
The heatmap uses a blue color scheme where darker colors indicate stronger attention.
"""
try:
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
fprint(
"matplotlib is required for visualization. Install with: pip install matplotlib"
)
return None
attention_matrix = attention_result["attentions"][layer_idx, head_idx].numpy()
tokens = attention_result["tokens"]
attention_mask = attention_result["attention_mask"].numpy()
# Find the actual sequence length (excluding padding)
seq_len = int(attention_mask.sum())
# Truncate to actual sequence length
attention_matrix = attention_matrix[:seq_len, :seq_len]
tokens = tokens[:seq_len]
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(attention_matrix, cmap="Blues", aspect="auto")
# Set ticks and labels
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticklabels(tokens)
# Add colorbar
plt.colorbar(im, ax=ax, label="Attention Weight")
# Set title and labels
ax.set_title(f"Attention Pattern - Layer {layer_idx}, Head {head_idx}")
ax.set_xlabel("Key Positions")
ax.set_ylabel("Query Positions")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
fprint(f"Attention visualization saved to {save_path}")
return fig
@property
def device(self):
"""Get the device where the model is located."""
return self._device
[docs]
def to(self, *args, **kwargs):
"""Move model to specified device and/or change dtype."""
# Call the parent to() method first
super().to(*args, **kwargs)
self.model.to(*args, **kwargs)
# Update internal device tracking if a device is specified
for arg in args:
if isinstance(arg, torch.device):
self._device = arg
break
elif isinstance(arg, str) and ("cuda" in arg or "cpu" in arg):
self._device = torch.device(arg)
break
# Check if device is specified in kwargs
if "device" in kwargs:
device = kwargs["device"]
if isinstance(device, torch.device):
self._device = device
else:
self._device = torch.device(device)
# If no explicit device was provided, sync with the actual model device
if (
not any(
isinstance(arg, (torch.device, str))
and ("cuda" in str(arg) or "cpu" in str(arg))
for arg in args
)
and "device" not in kwargs
):
# Get the actual device from model parameters
try:
model_device = next(self.model.parameters()).device
self._device = model_device
except StopIteration:
pass # No parameters in model
return self
# Example usage
if __name__ == "__main__":
model_name = "anonymous8/OmniGenome-186M"
embedding_model = OmniModelForEmbedding(model_name)
# Encode multiple sequences
sequences = ["ATCGGCTA", "GGCTAGCTA"]
embedding = embedding_model.encode(sequences[0])
fprint(f"Single embedding shape: {embedding.shape}")
embeddings = embedding_model.batch_encode(sequences)
fprint(f"Embeddings for sequences: {embeddings}")
# Save and load embeddings
embedding_model.save_embeddings(embeddings, "embeddings.pt")
loaded_embeddings = embedding_model.load_embeddings("embeddings.pt")
# Compute similarity between two embeddings
similarity = embedding_model.compute_similarity(
loaded_embeddings[0], loaded_embeddings[1]
)
fprint(f"Cosine similarity: {similarity}")