Explainability Module¶
Epistasis¶
- class omnigenbench.src.explainability.epistasis.explainer.EpistasisExplainer(model, method: str = 'squid')[source]
Bases:
AbstractExplainerExplains and visualizes pairwise epistatic interactions in a sequence.
This explainer uses an underlying method (like SQUID) to fit a pairwise surrogate model to the target model’s predictions. It then extracts the second-order interaction terms (epistasis) and visualizes them as an interactive heatmap, showing the effect of combining mutations at two different positions.
- Variables:
ExplainerClass (AbstractExplainer) – The underlying explainer class (e.g., SQUIDExplainer).
explainer (AbstractExplainer) – An instance of the explainer, configured for pairwise analysis.
matrix (np.ndarray) – The most recently computed epistatic interaction matrix.
- explain(sequence, **kwargs)[source]
Computes the pairwise interaction matrix for a given sequence.
This method calls the underlying SQUID explainer to generate the epistasis matrix (theta_lclc), which quantifies the interaction effect between every pair of possible mutations.
- Parameters:
sequence (str) – The input sequence to explain.
**kwargs – Additional keyword arguments passed to the underlying explainer’s explain method.
- Returns:
np.ndarray –
- A 4D numpy array of shape (L, A, L, A), where L is the
sequence length and A is the alphabet size. matrix[l1, c1, l2, c2] represents the interaction effect between character c1 at position l1 and character c2 at position l2.
- visualize_heatmap(matrix, sequence: str, save_path=None, **kwargs)[source]
Visualizes the epistatic interaction matrix as an interactive heatmap.
This method creates a detailed heatmap where each cell represents the interaction strength between two specific mutations. The heatmap is lower-triangular to avoid redundancy.
- Parameters:
matrix (np.ndarray) – The 4D epistasis matrix from the explain method.
sequence (str) – The original sequence, used for context.
save_path (str, optional) – Path to save the interactive HTML plot. Defaults to None.
**kwargs – Not currently used, but included for future extensibility.
- omnigenbench.src.explainability.epistasis.explainer.get_explainer(name: str) AbstractExplainer[source]
Retrieves an explainer class from the registry by its name.
- Parameters:
name (str) – The name of the explainer method to retrieve.
- Returns:
AbstractExplainer – The explainer class corresponding to the given name.
Sequence Logo¶
- class omnigenbench.src.explainability.sequence_logo.explainer.SequenceLogoExplainer(model, method: str = 'squid')[source]
Bases:
AbstractExplainerA high-level wrapper for generating and visualizing model explanations.
This class provides a simple interface to use various underlying attribution methods (like ‘squid’) to explain a model’s predictions on a given sequence. It can generate attribution scores and visualize them as either a sequence logo or an interactive heatmap.
- Variables:
ExplainerClass – The underlying explainer class retrieved from the registry.
explainer – An instance of the ExplainerClass used to compute attributions.
matrix – Stores the most recently computed attribution matrix.
Example
>>> from omnigenbench import OmniModelForPrediction >>> from omnigenbench.explainers import SequenceLogoExplainer >>> # Load a model trained for a specific task >>> model = OmniModelForPrediction.from_pretrained("anonymous8/OmniGenome-186M-Promoter") >>> # Initialize the explainer >>> explainer = SequenceLogoExplainer(model) >>> sequence = "AGCGTTAGAC" >>> # Generate and visualize the explanation as a sequence logo >>> explainer(sequence, visualize_type="logo")
- explain(sequence, **kwargs)[source]
Generates an attribution matrix for a given sequence.
This method uses the underlying explainer (e.g., ‘squid’) to compute the attribution scores for each character at each position in the input sequence.
- Parameters:
sequence (str) – The input DNA or protein sequence to explain.
**kwargs – Additional keyword arguments to be passed to the underlying explainer’s explain method.
- Returns:
np.ndarray –
- A matrix of attribution scores, typically with a shape of
(sequence_length, alphabet_size).
- visualize_heatmap(matrix, sequence: str, save_path=None, **kwargs)[source]
Visualizes an attribution matrix as an interactive heatmap.
This method uses the plotly library to create a heatmap where the color of each cell represents the attribution score for a specific character at a specific position. The plot is interactive, allowing for hovering to see exact values.
- Parameters:
matrix (np.ndarray) – The attribution matrix to visualize, with shape (sequence_length, alphabet_size).
sequence (str) – The input sequence, used for labeling the x-axis.
save_path (str, optional) – The file path to save the generated plot. Note: saving interactive plots may require additional libraries like ‘kaleido’.
**kwargs – Additional keyword arguments for customizing the plot, including: - title (str): The title of the plot. - width (int): The width of the plot in pixels. - height (int): The height of the plot in pixels. - xaxis_title (str): The title for the x-axis. - yaxis_title (str): The title for the y-axis.
- visualize_logo(matrix, save_path=None, **kwargs)[source]
Visualizes an attribution matrix as a sequence logo.
This method uses the logomaker library to create a sequence logo. The height of each character visually represents its attribution score at that position. The plot can be customized using various keyword arguments.
- Parameters:
logo – The logo to visualize.
**kwargs – Additional keyword arguments. Mainly used for the following parameters: - figsize: The size of the figure. - ylabel: The label of the y-axis. - xlabel: The label of the x-axis. - title: The title of the figure. - color_scheme: The color scheme of the logo. - show_spines: bool, whether to show the spines of the logo. - spines: list, the spines to show. - show_ticks: bool, whether to show x/y ticks. - colors: dict, the colors of the logo.
- omnigenbench.src.explainability.sequence_logo.explainer.get_explainer(name: str) AbstractExplainer[source]
Retrieves an explainer class from the registry by name.
This function acts as a factory to access different explanation methods that have been registered in the EXPLAINER_REGISTRY.
- Parameters:
name (str) – The name of the explainer method to retrieve.
- Returns:
AbstractExplainer – The explainer class corresponding to the given name.
Visualization¶
- class omnigenbench.src.explainability.visualization_2d.explainer.Visualization2DExplainer(model, method: str = 'tsne')[source]
Bases:
AbstractExplainerA high-level explainer for creating 2D visualizations of sequence embeddings.
This class provides a convenient wrapper around various dimensionality reduction algorithms (like t-SNE) to generate and visualize 2D representations of high-dimensional sequence embeddings. It simplifies the process of creating interactive scatter plots to explore the structure of the embedding space.
- Variables:
model (OmniModelForEmbedding) – The model used for generating embeddings.
ExplainerClass (AbstractExplainer) – The specific dimensionality reduction class being used (e.g., TSNEExplainer).
explainer (AbstractExplainer) – An instance of the ExplainerClass.
- explain(sequences, labels=None, **kwargs)[source]
Generates the 2D embeddings for the input sequences.
This method acts as a wrapper, calling the explain method of the underlying dimensionality reduction explainer (e.g., TSNEExplainer).
- Parameters:
sequences (List[str]) – The list of input sequences to explain.
labels (Optional[List[Any]], optional) – A list of corresponding labels. Not used in computation but passed down. Defaults to None.
**kwargs – Additional keyword arguments to be passed to the underlying explainer’s explain method (e.g., perplexity for t-SNE).
- Returns:
np.ndarray –
- An array of shape (n_sequences, 2) containing the
generated 2D coordinates.
- visualize(embeddings, sequences, labels=None, width=800, height=600, title='2D Visualization of Sequence Embeddings', point_size=8, point_opacity=0.8, wrap_width=50, color_palette=None, save_path=None, **kwargs)[source]
Creates an interactive 2D scatter plot of the embeddings.
This method uses Plotly Express to generate a rich, interactive visualization where each point represents a sequence. Hovering over a point reveals its sequence and label.
- Parameters:
embeddings (np.ndarray) – The 2D coordinates to visualize, shape (n, 2).
sequences (List[str]) – The original sequences, used for hover-over tooltips.
labels (Optional[List[Any]], optional) – Labels for coloring points. If None, all points are assigned a single ‘Unlabeled’ category. Defaults to None.
width (int, optional) – The width of the figure in pixels. Defaults to 800.
height (int, optional) – The height of the figure in pixels. Defaults to 600.
title (str, optional) – The title of the plot. Defaults to “2D Visualization of Sequence Embeddings”.
point_size (int, optional) – The size of the scatter plot points. Defaults to 8.
point_opacity (float, optional) – The opacity of the points. Defaults to 0.8.
wrap_width (int, optional) – The maximum width for sequence text in the hover tooltip before it’s truncated. Defaults to 50.
color_palette (Optional[List[str]], optional) – A list of CSS colors to use. If None, a default Plotly palette is used. Defaults to None.
save_path (Optional[str], optional) – The file path to save the interactive plot as an HTML file. If None, the plot is not saved. Defaults to None.
**kwargs – Not currently used, but included for future extensibility.
- Returns:
plotly.graph_objs._figure.Figure –
- The Plotly scatter plot figure object, which can be
further customized or displayed.
- omnigenbench.src.explainability.visualization_2d.explainer.get_explainer(name: str) AbstractExplainer[source]
Retrieves an explainer class from the registry by its name.
This function acts as a factory, allowing for dynamic selection of the dimensionality reduction algorithm to be used.
- Parameters:
name (str) – The name of the explainer method to retrieve (e.g., “tsne”).
- Returns:
AbstractExplainer – The explainer class corresponding to the given name.