# -*- coding: utf-8 -*-
# file: __init__.py
# time: 2025-07-10 13:57
# author: Shasha Zhou <sz484@exeter.ac.uk>
# Copyright (C) 2020-2025. All Rights Reserved.
from ...abc.abstract_explainer import AbstractExplainer
from ..shared_methods.squid_explainer import SQUIDExplainer
import plotly.graph_objects as go
import numpy as np
from ...misc.utils import fprint
EXPLAINER_REGISTRY = {
"squid": SQUIDExplainer,
}
[docs]
def get_explainer(name: str) -> AbstractExplainer:
"""Retrieves an explainer class from the registry by its name.
Args:
name (str): The name of the explainer method to retrieve.
Returns:
AbstractExplainer: The explainer class corresponding to the given name.
"""
fprint(f"Getting explainer with method: {name}")
return EXPLAINER_REGISTRY[name]
[docs]
class EpistasisExplainer(AbstractExplainer):
"""Explains and visualizes pairwise epistatic interactions in a sequence.
This explainer uses an underlying method (like SQUID) to fit a pairwise
surrogate model to the target model's predictions. It then extracts the
second-order interaction terms (epistasis) and visualizes them as an
interactive heatmap, showing the effect of combining mutations at two
different positions.
Attributes:
ExplainerClass (AbstractExplainer): The underlying explainer class (e.g., SQUIDExplainer).
explainer (AbstractExplainer): An instance of the explainer, configured for pairwise analysis.
matrix (np.ndarray): The most recently computed epistatic interaction matrix.
"""
def __init__(self, model, method: str = "squid"):
"""Initializes the EpistasisExplainer.
Args:
model (Any): The model to explain, which must be compatible with the
chosen underlying explainer method.
method (str, optional): The method to use for explaining epistasis.
Defaults to "squid".
"""
fprint(f"Initializing EpistasisExplainer with method: {method}")
super().__init__(model)
self.ExplainerClass = get_explainer(method)
self.explainer = self.ExplainerClass(model, gpmap="pairwise")
fprint("EpistasisExplainer initialized successfully")
[docs]
def explain(self, sequence, **kwargs):
"""Computes the pairwise interaction matrix for a given sequence.
This method calls the underlying SQUID explainer to generate the epistasis
matrix (`theta_lclc`), which quantifies the interaction effect between
every pair of possible mutations.
Args:
sequence (str): The input sequence to explain.
**kwargs: Additional keyword arguments passed to the underlying
explainer's `explain` method.
Returns:
np.ndarray: A 4D numpy array of shape `(L, A, L, A)`, where `L` is the
sequence length and `A` is the alphabet size. `matrix[l1, c1, l2, c2]`
represents the interaction effect between character `c1` at
position `l1` and character `c2` at position `l2`.
"""
fprint(f"Generating explanations for sequence: {sequence}")
matrix = self.explainer.explain(sequence, **kwargs)
self.matrix = matrix
return matrix
[docs]
def visualize_heatmap(self, matrix, sequence: str, save_path=None, **kwargs):
"""Visualizes the epistatic interaction matrix as an interactive heatmap.
This method creates a detailed heatmap where each cell represents the
interaction strength between two specific mutations. The heatmap is
lower-triangular to avoid redundancy.
Args:
matrix (np.ndarray): The 4D epistasis matrix from the `explain` method.
sequence (str): The original sequence, used for context.
save_path (str, optional): Path to save the interactive HTML plot. Defaults to None.
**kwargs: Not currently used, but included for future extensibility.
"""
fprint("Visualizing the heatmap...")
L, A, _, _ = matrix.shape
matrix = matrix.reshape(L * A, L * A)
mask = np.tri(L * A, L * A, k=0, dtype=bool)
masked_matrix = np.where(mask, matrix, np.nan)
labels = [f"{i}:{base}" for i in range(L) for base in self.explainer.alphabet]
fig = go.Figure(
data=go.Heatmap(
z=masked_matrix,
x=labels,
y=labels,
colorscale="RdBu_r",
zmid=0,
zmin=-np.max(np.abs(matrix)),
zmax=np.max(np.abs(matrix)),
colorbar=dict(thickness=10, len=0.5, yanchor="middle", y=0.5),
hovertemplate="From %{y}<br>To %{x}<br>Value: %{z:.4f}<extra></extra>",
)
)
shapes = []
stair_path = ""
for i in range(L * A):
x = i - 0.5
y = i - 0.5
if i == 0:
stair_path += f"M {x},{y} "
stair_path += f"L {x + 1},{y} "
if i < L * A - 1:
stair_path += f"L {x + 1},{y + 1} "
stair_path += f"L {(L*A) - 0.5},{(L*A) - 0.5}"
shapes.append(
dict(
type="path",
path=stair_path,
line=dict(color="black", width=0.4),
fillcolor="rgba(0,0,0,0)",
layer="above",
)
)
for i in range(0, L + 1):
grid_pos = i * A - 0.5
shapes.append(
dict(
type="line",
x0=grid_pos,
y0=grid_pos,
x1=grid_pos,
y1=L * A - 0.5,
line=dict(color="black", width=0.4),
)
)
shapes.append(
dict(
type="line",
x0=-0.5,
y0=grid_pos,
x1=grid_pos,
y1=grid_pos,
line=dict(color="black", width=0.4),
)
)
fig.update_layout(
title=None,
xaxis=dict(
scaleanchor="y",
scaleratio=1,
tickangle=45,
showgrid=False,
showticklabels=False,
),
yaxis=dict(
scaleanchor="x",
scaleratio=1,
autorange="reversed",
showgrid=False,
showticklabels=False,
),
shapes=shapes,
width=500,
height=500,
margin=dict(t=0, b=0, l=20, r=50),
plot_bgcolor="white",
)
fig.show()
def __call__(self, sequence, save_path=None, **kwargs):
"""A convenience method to explain and visualize in one step.
Args:
sequence (str): The sequence to explain.
save_path (str, optional): The path to save the figure. Defaults to None.
**kwargs: Additional keyword arguments passed to the `explain` method.
"""
fprint(f"Generating explanations for sequence: {sequence}")
matrix = self.explainer.explain(sequence, gpmap="additive", **kwargs)
self.visualize_heatmap(matrix, sequence, save_path=save_path, **kwargs)