Explainability Module

Epistasis

class omnigenbench.src.explainability.epistasis.explainer.EpistasisExplainer(model, method: str = 'squid')[source]

Bases: AbstractExplainer

Explains and visualizes pairwise epistatic interactions in a sequence.

This explainer uses an underlying method (like SQUID) to fit a pairwise surrogate model to the target model’s predictions. It then extracts the second-order interaction terms (epistasis) and visualizes them as an interactive heatmap, showing the effect of combining mutations at two different positions.

Variables:
  • ExplainerClass (AbstractExplainer) – The underlying explainer class (e.g., SQUIDExplainer).

  • explainer (AbstractExplainer) – An instance of the explainer, configured for pairwise analysis.

  • matrix (np.ndarray) – The most recently computed epistatic interaction matrix.

explain(sequence, **kwargs)[source]

Computes the pairwise interaction matrix for a given sequence.

This method calls the underlying SQUID explainer to generate the epistasis matrix (theta_lclc), which quantifies the interaction effect between every pair of possible mutations.

Parameters:
  • sequence (str) – The input sequence to explain.

  • **kwargs – Additional keyword arguments passed to the underlying explainer’s explain method.

Returns:

np.ndarray

A 4D numpy array of shape (L, A, L, A), where L is the

sequence length and A is the alphabet size. matrix[l1, c1, l2, c2] represents the interaction effect between character c1 at position l1 and character c2 at position l2.

visualize_heatmap(matrix, sequence: str, save_path=None, **kwargs)[source]

Visualizes the epistatic interaction matrix as an interactive heatmap.

This method creates a detailed heatmap where each cell represents the interaction strength between two specific mutations. The heatmap is lower-triangular to avoid redundancy.

Parameters:
  • matrix (np.ndarray) – The 4D epistasis matrix from the explain method.

  • sequence (str) – The original sequence, used for context.

  • save_path (str, optional) – Path to save the interactive HTML plot. Defaults to None.

  • **kwargs – Not currently used, but included for future extensibility.

omnigenbench.src.explainability.epistasis.explainer.get_explainer(name: str) AbstractExplainer[source]

Retrieves an explainer class from the registry by its name.

Parameters:

name (str) – The name of the explainer method to retrieve.

Returns:

AbstractExplainer – The explainer class corresponding to the given name.

Shared Methods

class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDAdditiveGPMap(L: int, A: int, reg_strength: float = 0.0)[source]

Bases: Module

Additive genotype‑phenotype map: φ = θ_0 + Σ_{l,c} θ_{l,c} x_{l,c}.

forward(x: Tensor) Tensor[source]

x is one‑hot with shape (N, L, A). Outputs latent φ with shape (N, 1).

l2_regularizer() Tensor[source]
class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDBaseMutagenesis[source]

Bases: object

SQUIDBaseMutagenesis is a class that generates in silico MAVE data for a given sequence.

class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDCombinatorialMutagenesis(alphabet: List[str], max_order: int = 1, mut_window: Tuple[int, int] | None = None, seed: int | None = None)[source]

Bases: SQUIDBaseMutagenesis

SQUIDCombinatorialMutagenesis is a class that generates in silico MAVE data for a given sequence using combinatorial mutagenesis.

class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDExplainer(model, gpmap: str = 'additive', **kwargs)[source]

Bases: AbstractExplainer

Explains model predictions using the SQUID method.

SQUID (Surrogate-based QUantitative-epistatic-Interaction-Discovery) is a method that uses in-silico mutagenesis to generate a dataset, which is then used to train a simpler, interpretable surrogate model. From this surrogate model, additive (first-order) or pairwise (second-order) feature attributions can be extracted.

Variables:
  • model – The target deep learning model to explain.

  • gpmap (str) – The type of genotype-phenotype map for the surrogate model, either ‘additive’ or ‘pairwise’.

  • token_to_id (Dict[str, int]) – A mapping from sequence characters to integer IDs.

  • alphabet (List[str]) – The list of unique characters in the input sequence.

  • num_tokens (int) – The size of the alphabet.

Reference:

Seitz, E.E., McCandlish, D.M., Kinney, J.B., and Koo P.K. Interpreting cis-regulatory mechanisms from genomic deep neural networks using surrogate models. Nat Mach Intell (2024). https://doi.org/10.1038/s42256-024-00851-5

explain(sequence: str, mut_type: str = 'random', mut_rate: float = 0.1, uniform: bool = False, max_order: int = -1, mut_window: Tuple[int, int] | None = None, inter_window: Tuple[int, int] | None = None, context_agnositc: bool = False, num_sim: int = 10000, seed: int | None = None, save_window: Tuple[int, int] | None = None, batch_size: int = 32, **kwargs)[source]

Generates feature attributions for an input sequence using the SQUID method.

This method performs three main steps:

Generates a dataset of mutated sequences and their corresponding model predictions (in-silico MAVE). Trains an interpretable surrogate model on this dataset. Extracts the learned parameters from the surrogate model, which represent the feature attributions.

Parameters:
  • sequence (str) – The input sequence to explain.

  • mut_type (str, optional) – The mutagenesis strategy. Can be “random” or “combinatorial”. Defaults to “random”.

  • mut_rate (float, optional) – The average mutation rate for ‘random’ mutagenesis. Defaults to 0.1.

  • uniform (bool, optional) – If True, use a fixed number of mutations per sequence for ‘random’ mutagenesis. Defaults to False.

  • max_order (int, optional) – The maximum order of mutations for ‘combinatorial’ mutagenesis. -1 means all orders. Defaults to -1.

  • mut_window (Tuple[int, int], optional) – The (start, end) window within the sequence to apply mutations. Defaults to None.

  • inter_window (Tuple[int, int], optional) – A window for inter-mutational analysis. Defaults to None.

  • context_agnositc (bool, optional) – If True, randomize the context outside the mutation window. Defaults to False.

  • num_sim (int, optional) – The number of mutated sequences to generate. Defaults to 10000.

  • seed (Optional[int], optional) – A random seed for reproducibility. Defaults to None.

  • save_window (Tuple[int, int], optional) – The window of the sequence to use for training the surrogate model. Defaults to None.

  • batch_size (int, optional) – Batch size for getting predictions from the target model. Defaults to 32.

  • **kwargs – Additional arguments passed to the surrogate model fitting process.

Returns:

np.ndarray

The learned parameters from the surrogate model.
  • If gpmap is ‘additive’, returns theta_lc with shape (L, A), representing first-order effects.

  • If gpmap is ‘pairwise’, returns theta_lclc with shape (L, A, L, A), representing second-order effects.

class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDGlobalEpistasis(hidden_nodes: int = 50)[source]

Bases: Module

Simple 1‑hidden‑layer sigmoid‑basis network to model GE non‑linearity.

forward(z)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDPairwiseGPMap(L: int, A: int, rank: int = 8, reg_strength: float = 0.0)[source]

Bases: Module

Full pairwise model: φ = θ_0 + Σ θ_{l,c} x_{l,c} + Σ θ_{l1,c1,l2,c2} x_{l1,c1} x_{l2,c2}. The interaction tensor is stored in a factorised low‑rank form so we can scale to reasonable sequence lengths without a massive O(L²A²) memory footprint. We use a CP‑decomposition with K latent factors.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

l2_regularizer() Tensor[source]
class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDRandomMutagenesis(alphabet: List[str], mut_rate: float = 0.1, uniform: bool = False, seed: int | None = None)[source]

Bases: SQUIDBaseMutagenesis

SQUIDRandomMutagenesis is a class that generates in silico MAVE data for a given sequence using random mutagenesis.

class omnigenbench.src.explainability.shared_methods.squid_explainer.SQUIDSurrogateModel(input_shape: Tuple[int, int, int], num_tasks: int, gpmap: str = 'additive', regression_type: str = 'GE', linearity: str = 'nonlinear', noise: str = 'Gaussian', noise_order: int = 0, reg_strength: float = 0.1, hidden_nodes: int = 50, token_to_id: Dict[str, int] = {}, deduplicate: bool = True, gpu: bool = True, pairwise_rank: int = 8, seed: int | None = None)[source]

Bases: Module

dataframe(x: ndarray, y: ndarray) Tuple[List[str], Tensor][source]

MAVE‑NN wanted a pandas DataFrame. Here we simply return seq_list (for inspection) and a PyTorch tensor (N,) or (N,num_tasks) for y.

fit(x: ndarray, y: ndarray, learning_rate: float = 0.0005, epochs: int = 500, batch_size: int = 128, early_stopping: bool = True, patience: int = 25, save_dir: str | None = None, verbose: int = 1) Tuple[Module, List[str]][source]

End‑to‑end training loop with a 60/20/20 train/val/test split. Returns the trained nn.Module and the sequence list (for external evaluation if needed).

get_info(verbose: int = 1) float[source]

Compute a heuristic predictive information metric. For GE we return the test‑set R². For MPA we return accuracy. This does not attempt to reproduce the variational information bound used by MAVE‑NN, but is often a useful quick‑and‑dirty proxy.

get_params() Tuple[ndarray, ndarray, ndarray | None][source]

Return (theta_0, theta_lc, theta_lclc). Pairwise parameters are only returned if gpmap_type == ‘pairwise’.

class omnigenbench.src.explainability.shared_methods.tsne_explainer.TSNEExplainer(model, **kwargs)[source]

Bases: AbstractExplainer

Visualizes high-dimensional sequence embeddings in 2D using t-SNE.

This explainer generates high-dimensional embeddings from a set of input sequences using a given model. It then applies the t-SNE (t-Distributed Stochastic Neighbor Embedding) algorithm to project these embeddings into a two-dimensional space. This is useful for visualizing the structure of the learned embedding space and observing how sequences with different labels cluster.

Variables:
  • model – The model used to generate sequence embeddings.

  • tsne (sklearn.manifold.TSNE) – The t-SNE transformer instance.

explain(sequences: List[str], labels: List[str | int], embedding_file: str | None = None, **kwargs)[source]

Generates 2D embeddings for a set of sequences using t-SNE.

This method first obtains high-dimensional embeddings for the input sequences, either by generating them with the model or by loading them from a file. It then applies the fitted t-SNE algorithm to project these embeddings into a two-dimensional representation suitable for plotting.

Visualization

class omnigenbench.src.explainability.visualization_2d.explainer.Visualization2DExplainer(model, method: str = 'tsne')[source]

Bases: AbstractExplainer

A high-level explainer for creating 2D visualizations of sequence embeddings.

This class provides a convenient wrapper around various dimensionality reduction algorithms (like t-SNE) to generate and visualize 2D representations of high-dimensional sequence embeddings. It simplifies the process of creating interactive scatter plots to explore the structure of the embedding space.

Variables:
  • model (OmniModelForEmbedding) – The model used for generating embeddings.

  • ExplainerClass (AbstractExplainer) – The specific dimensionality reduction class being used (e.g., TSNEExplainer).

  • explainer (AbstractExplainer) – An instance of the ExplainerClass.

explain(sequences, labels=None, **kwargs)[source]

Generates the 2D embeddings for the input sequences.

This method acts as a wrapper, calling the explain method of the underlying dimensionality reduction explainer (e.g., TSNEExplainer).

Parameters:
  • sequences (List[str]) – The list of input sequences to explain.

  • labels (Optional[List[Any]], optional) – A list of corresponding labels. Not used in computation but passed down. Defaults to None.

  • **kwargs – Additional keyword arguments to be passed to the underlying explainer’s explain method (e.g., perplexity for t-SNE).

Returns:

np.ndarray

An array of shape (n_sequences, 2) containing the

generated 2D coordinates.

visualize(embeddings, sequences, labels=None, width=800, height=600, title='2D Visualization of Sequence Embeddings', point_size=8, point_opacity=0.8, wrap_width=50, color_palette=None, save_path=None, **kwargs)[source]

Creates an interactive 2D scatter plot of the embeddings.

This method uses Plotly Express to generate a rich, interactive visualization where each point represents a sequence. Hovering over a point reveals its sequence and label.

Parameters:
  • embeddings (np.ndarray) – The 2D coordinates to visualize, shape (n, 2).

  • sequences (List[str]) – The original sequences, used for hover-over tooltips.

  • labels (Optional[List[Any]], optional) – Labels for coloring points. If None, all points are assigned a single ‘Unlabeled’ category. Defaults to None.

  • width (int, optional) – The width of the figure in pixels. Defaults to 800.

  • height (int, optional) – The height of the figure in pixels. Defaults to 600.

  • title (str, optional) – The title of the plot. Defaults to “2D Visualization of Sequence Embeddings”.

  • point_size (int, optional) – The size of the scatter plot points. Defaults to 8.

  • point_opacity (float, optional) – The opacity of the points. Defaults to 0.8.

  • wrap_width (int, optional) – The maximum width for sequence text in the hover tooltip before it’s truncated. Defaults to 50.

  • color_palette (Optional[List[str]], optional) – A list of CSS colors to use. If None, a default Plotly palette is used. Defaults to None.

  • save_path (Optional[str], optional) – The file path to save the interactive plot as an HTML file. If None, the plot is not saved. Defaults to None.

  • **kwargs – Not currently used, but included for future extensibility.

Returns:

plotly.graph_objs._figure.Figure

The Plotly scatter plot figure object, which can be

further customized or displayed.

omnigenbench.src.explainability.visualization_2d.explainer.get_explainer(name: str) AbstractExplainer[source]

Retrieves an explainer class from the registry by its name.

This function acts as a factory, allowing for dynamic selection of the dimensionality reduction algorithm to be used.

Parameters:

name (str) – The name of the explainer method to retrieve (e.g., “tsne”).

Returns:

AbstractExplainer – The explainer class corresponding to the given name.