# -*- coding: utf-8 -*-
# file: __init__.py
# time: 2025-06-16 21:06
# author: Shasha Zhou <sz484@exeter.ac.uk>
# Copyright (C) 2020-2025. All Rights Reserved.
from ...abc.abstract_explainer import AbstractExplainer
from ...model.embedding.model import OmniModelForEmbedding
from ..shared_methods.tsne_explainer import TSNEExplainer
import pandas as pd
import plotly.express as px
from ...misc.utils import fprint
EXPLAINER_REGISTRY = {
"tsne": TSNEExplainer,
}
[docs]
def get_explainer(name: str) -> AbstractExplainer:
"""Retrieves an explainer class from the registry by its name.
This function acts as a factory, allowing for dynamic selection of the
dimensionality reduction algorithm to be used.
Args:
name (str): The name of the explainer method to retrieve (e.g., "tsne").
Returns:
AbstractExplainer: The explainer class corresponding to the given name.
"""
fprint(f"Getting explainer with method: {name}")
return EXPLAINER_REGISTRY[name]
[docs]
class Visualization2DExplainer(AbstractExplainer):
"""A high-level explainer for creating 2D visualizations of sequence embeddings.
This class provides a convenient wrapper around various dimensionality reduction
algorithms (like t-SNE) to generate and visualize 2D representations of
high-dimensional sequence embeddings. It simplifies the process of creating
interactive scatter plots to explore the structure of the embedding space.
Attributes:
model (OmniModelForEmbedding): The model used for generating embeddings.
ExplainerClass (AbstractExplainer): The specific dimensionality reduction class being used (e.g., TSNEExplainer).
explainer (AbstractExplainer): An instance of the `ExplainerClass`.
"""
def __init__(self, model, method: str = "tsne"):
"""Initializes the Visualization2DExplainer.
Args:
model (OmniModelForEmbedding): The model to explain. It must be an instance of
`OmniModelForEmbedding` as it needs the
`batch_encode` method.
method (str, optional): The dimensionality reduction method to use.
Currently, only "tsne" is supported. Defaults to "tsne".
"""
fprint(f"Initializing Visualization2DExplainer with method: {method}")
super().__init__(model)
assert isinstance(
model, OmniModelForEmbedding
), "Model must be an instance of OmniModelForEmbedding"
self.ExplainerClass = get_explainer(method)
self.explainer = self.ExplainerClass(model)
fprint("Visualization2DExplainer initialized successfully")
[docs]
def explain(self, sequences, labels=None, **kwargs):
"""Generates the 2D embeddings for the input sequences.
This method acts as a wrapper, calling the `explain` method of the underlying
dimensionality reduction explainer (e.g., TSNEExplainer).
Args:
sequences (List[str]): The list of input sequences to explain.
labels (Optional[List[Any]], optional): A list of corresponding labels.
Not used in computation but passed down.
Defaults to None.
**kwargs: Additional keyword arguments to be passed to the underlying
explainer's `explain` method (e.g., `perplexity` for t-SNE).
Returns:
np.ndarray: An array of shape `(n_sequences, 2)` containing the
generated 2D coordinates.
"""
fprint(f"Generating explanations for {len(sequences)} sequences")
self.sequences = sequences
self.labels = labels
embeddings = self.explainer.explain(sequences, labels, **kwargs)
fprint(f"Generated embeddings with shape: {embeddings.shape}")
return embeddings
[docs]
def visualize(
self,
embeddings,
sequences,
labels=None,
width=800,
height=600,
title="2D Visualization of Sequence Embeddings",
point_size=8,
point_opacity=0.8,
wrap_width=50,
color_palette=None,
save_path=None,
**kwargs,
):
"""Creates an interactive 2D scatter plot of the embeddings.
This method uses Plotly Express to generate a rich, interactive visualization
where each point represents a sequence. Hovering over a point reveals its
sequence and label.
Args:
embeddings (np.ndarray): The 2D coordinates to visualize, shape `(n, 2)`.
sequences (List[str]): The original sequences, used for hover-over tooltips.
labels (Optional[List[Any]], optional): Labels for coloring points. If None, all points
are assigned a single 'Unlabeled' category. Defaults to None.
width (int, optional): The width of the figure in pixels. Defaults to 800.
height (int, optional): The height of the figure in pixels. Defaults to 600.
title (str, optional): The title of the plot. Defaults to "2D Visualization of Sequence Embeddings".
point_size (int, optional): The size of the scatter plot points. Defaults to 8.
point_opacity (float, optional): The opacity of the points. Defaults to 0.8.
wrap_width (int, optional): The maximum width for sequence text in the hover tooltip
before it's truncated. Defaults to 50.
color_palette (Optional[List[str]], optional): A list of CSS colors to use. If None, a default
Plotly palette is used. Defaults to None.
save_path (Optional[str], optional): The file path to save the interactive plot as an HTML file.
If None, the plot is not saved. Defaults to None.
**kwargs: Not currently used, but included for future extensibility.
Returns:
plotly.graph_objs._figure.Figure: The Plotly scatter plot figure object, which can be
further customized or displayed.
"""
fprint("Starting visualization process")
fprint(f"Processing {len(sequences)} sequences for visualization")
# Truncate long sequences
wrapped_sequences = [
seq[:wrap_width] + "..." if len(seq) > wrap_width else seq
for seq in sequences
]
# Handle labels
if labels is None:
labels_str = ["Unlabeled"] * len(sequences)
fprint("No labels provided, using 'Unlabeled' for all points")
else:
labels_str = [str(label) for label in labels]
fprint(f"Processing {len(set(labels_str))} unique labels")
# Unique labels and colors
unique_labels = sorted(set(labels_str))
if color_palette is None:
color_palette = (
px.colors.qualitative.Set3
+ px.colors.qualitative.Pastel
+ px.colors.qualitative.Bold
)
color_discrete_map = {
label: color_palette[i % len(color_palette)]
for i, label in enumerate(unique_labels)
}
# DataFrame for plotting
df = pd.DataFrame(
{
"x": embeddings[:, 0],
"y": embeddings[:, 1],
"label": labels_str,
"sequence": wrapped_sequences,
}
)
# Create scatter plot
fig = px.scatter(
df,
x="x",
y="y",
color="label",
hover_data={"sequence": True, "label": True},
color_discrete_map=color_discrete_map,
labels={"x": "Component 1", "y": "Component 2"},
)
# Style
fig.update_traces(
marker=dict(size=point_size, opacity=point_opacity, line=dict(width=0.5))
)
fig.update_layout(
width=width,
height=height,
title={
"text": title,
"x": 0.5,
"xanchor": "center",
"yanchor": "top",
},
legend_title_text="Label",
legend=dict(bordercolor="Black", borderwidth=0.5, itemsizing="constant"),
plot_bgcolor="rgba(245, 245, 245, 1)",
paper_bgcolor="rgba(255,255,255,1)",
)
if save_path:
fprint(f"Saving visualization to: {save_path}")
fig.write_html(save_path)
fprint("Visualization completed successfully")
return fig
def __call__(self, sequences, labels=None, **kwargs):
"""A convenience method to generate and visualize the explanation in one step.
This method chains the `explain` and `visualize` calls, providing a simple
one-line interface for the most common use case.
Args:
sequences (List[str]): The list of input sequences.
labels (Optional[List[Any]], optional): The corresponding labels for the sequences.
Defaults to None.
**kwargs: Additional keyword arguments passed to both the `explain` and
`visualize` methods.
Returns:
plotly.graph_objs._figure.Figure: The final, interactive Plotly figure object.
"""
embeddings = self.explainer.explain(sequences, labels, **kwargs)
fig = self.visualize(embeddings, sequences, labels, **kwargs)
return fig