Source code for omnigenbench.src.explainability.shared_methods.tsne_explainer
#
# Author: Shasha Zhou <sz484@exeter.ac.uk>
# Description:
#
# Copyright (C) 2020-2025. All Rights Reserved.
#
# -*- coding: utf-8 -*-
# file: tsne_explainer.py
# time: 2025-06-16 21:33
# author: Shasha Zhou <sz484@exeter.ac.uk>
# Copyright (C) 2020-2025. All Rights Reserved.
from sklearn.manifold import TSNE
from typing import List, Union, Optional
from ...abc.abstract_explainer import AbstractExplainer
from ...misc.utils import fprint
[docs]
class TSNEExplainer(AbstractExplainer):
"""Visualizes high-dimensional sequence embeddings in 2D using t-SNE.
This explainer generates high-dimensional embeddings from a set of input sequences
using a given model. It then applies the t-SNE (t-Distributed Stochastic
Neighbor Embedding) algorithm to project these embeddings into a two-dimensional
space. This is useful for visualizing the structure of the learned embedding
space and observing how sequences with different labels cluster.
Attributes:
model: The model used to generate sequence embeddings.
tsne (sklearn.manifold.TSNE): The t-SNE transformer instance.
"""
def __init__(self, model, **kwargs):
"""Initializes the TSNEExplainer.
Args:
model: A model object capable of generating embeddings, which should
have a `batch_encode` method (e.g., `OmniModelForEmbedding`).
**kwargs: Additional keyword arguments to be passed directly to the
`sklearn.manifold.TSNE` constructor. This allows for customization
of parameters like `perplexity`, `learning_rate`, `n_iter`, etc.
"""
super().__init__(model)
self.tsne = TSNE(n_components=2, **kwargs)
[docs]
def explain(
self,
sequences: List[str],
labels: List[Union[int, str]],
embedding_file: Optional[str] = None,
**kwargs,
):
"""Generates 2D embeddings for a set of sequences using t-SNE.
This method first obtains high-dimensional embeddings for the input sequences,
either by generating them with the model or by loading them from a file.
It then applies the fitted t-SNE algorithm to project these embeddings
into a two-dimensional representation suitable for plotting.
"""
fprint("Starting t-SNE explanation")
if embedding_file is not None:
fprint(f"Loading embeddings from {embedding_file}")
model_embeddings = self.model.load_embeddings(embedding_file)
else:
fprint("Encoding sequences")
model_embeddings = self.model.batch_encode(sequences, **kwargs)
fprint("Fitting t-SNE")
tsne_embeddings = self.tsne.fit_transform(model_embeddings)
fprint("t-SNE explanation completed")
return tsne_embeddings