Source code for omnigenbench.src.explainability.shared_methods.squid_explainer

# -*- coding: utf-8 -*-
# file: sqyud_explainer.py
# time: 2025-06-23 15:19
# author: Shasha Zhou <sz484@exeter.ac.uk>
# Copyright (C) 2020-2025. All Rights Reserved.

from ...abc.abstract_explainer import AbstractExplainer
from ...misc.utils import fprint
from typing import List, Optional, Tuple, Dict
import numpy as np
import random
from itertools import combinations, product
from ...model.classification.model import (
    OmniModelForSequenceClassification,
    OmniModelForTokenClassification,
)
from ...model.regression.model import (
    OmniModelForSequenceRegression,
    OmniModelForTokenRegression,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import math
import os


[docs] class SQUIDExplainer(AbstractExplainer): """Explains model predictions using the SQUID method. SQUID (Surrogate-based QUantitative-epistatic-Interaction-Discovery) is a method that uses in-silico mutagenesis to generate a dataset, which is then used to train a simpler, interpretable surrogate model. From this surrogate model, additive (first-order) or pairwise (second-order) feature attributions can be extracted. Attributes: model: The target deep learning model to explain. gpmap (str): The type of genotype-phenotype map for the surrogate model, either 'additive' or 'pairwise'. token_to_id (Dict[str, int]): A mapping from sequence characters to integer IDs. alphabet (List[str]): The list of unique characters in the input sequence. num_tokens (int): The size of the alphabet. Reference: Seitz, E.E., McCandlish, D.M., Kinney, J.B., and Koo P.K. Interpreting cis-regulatory mechanisms from genomic deep neural networks using surrogate models. Nat Mach Intell (2024). https://doi.org/10.1038/s42256-024-00851-5 """ def __init__(self, model, gpmap: str = "additive", **kwargs): """Initializes the SQUIDExplainer. Args: model: The pre-trained model to be explained. gpmap (str, optional): The type of surrogate model to fit. Can be 'additive' for first-order effects or 'pairwise' for second-order effects. Defaults to "additive". **kwargs: Additional keyword arguments. """ super().__init__(model) self.model = model self.token_to_id = {} self.alphabet = [] self.num_tokens = 0 self.gpmap = gpmap
[docs] def explain( self, sequence: str, mut_type: str = "random", mut_rate: float = 0.1, uniform: bool = False, max_order: int = -1, mut_window: Optional[Tuple[int, int]] = None, inter_window: Optional[Tuple[int, int]] = None, context_agnositc: bool = False, num_sim: int = 10000, seed: Optional[int] = None, save_window: Optional[Tuple[int, int]] = None, batch_size: int = 32, **kwargs, ): """ Generates feature attributions for an input sequence using the SQUID method. This method performs three main steps: Generates a dataset of mutated sequences and their corresponding model predictions (in-silico MAVE). Trains an interpretable surrogate model on this dataset. Extracts the learned parameters from the surrogate model, which represent the feature attributions. Args: sequence (str): The input sequence to explain. mut_type (str, optional): The mutagenesis strategy. Can be "random" or "combinatorial". Defaults to "random". mut_rate (float, optional): The average mutation rate for 'random' mutagenesis. Defaults to 0.1. uniform (bool, optional): If True, use a fixed number of mutations per sequence for 'random' mutagenesis. Defaults to False. max_order (int, optional): The maximum order of mutations for 'combinatorial' mutagenesis. -1 means all orders. Defaults to -1. mut_window (Tuple[int, int], optional): The (start, end) window within the sequence to apply mutations. Defaults to None. inter_window (Tuple[int, int], optional): A window for inter-mutational analysis. Defaults to None. context_agnositc (bool, optional): If True, randomize the context outside the mutation window. Defaults to False. num_sim (int, optional): The number of mutated sequences to generate. Defaults to 10000. seed (Optional[int], optional): A random seed for reproducibility. Defaults to None. save_window (Tuple[int, int], optional): The window of the sequence to use for training the surrogate model. Defaults to None. batch_size (int, optional): Batch size for getting predictions from the target model. Defaults to 32. **kwargs: Additional arguments passed to the surrogate model fitting process. Returns: np.ndarray: The learned parameters from the surrogate model. - If `gpmap` is 'additive', returns `theta_lc` with shape (L, A), representing first-order effects. - If `gpmap` is 'pairwise', returns `theta_lclc` with shape (L, A, L, A), representing second-order effects. """ fprint("Starting SQUID explanation") print(f"Using gpmap type: {self.gpmap}") # initialize the alphabet and id_to_token unique_chars = [] for ch in sequence: if ch not in unique_chars: unique_chars.append(ch) self.token_to_id = {ch: idx for idx, ch in enumerate(unique_chars)} self.alphabet = unique_chars self.num_tokens = len(unique_chars) # in silico MAVE data generation print("Generating in silico MAVE data...") full_seqs, x_mut, y_mut = self._generate_in_silico_mave( sequence, mut_type, mut_rate, uniform, max_order, mut_window, inter_window, context_agnositc, num_sim, seed, save_window, batch_size, **kwargs, ) # build and train surrogate model print("Building and training surrogate model...") surrogate_model = SQUIDSurrogateModel( input_shape=x_mut.shape, gpmap=self.gpmap, num_tasks=len(y_mut[0]), token_to_id=self.token_to_id, seed=seed, ) surrogate_model.fit(x_mut, y_mut) theta_0, theta_lc, theta_lclc = surrogate_model.get_params() if self.gpmap == "additive": return theta_lc elif self.gpmap == "pairwise": return theta_lclc
def _generate_in_silico_mave( self, sequence: str, mut_type: str = "random", mut_rate: float = 0.1, uniform: bool = False, max_order: int = -1, mut_window: Optional[Tuple[int, int]] = None, inter_window: Optional[Tuple[int, int]] = None, context_agnositc: bool = False, num_sim: int = 10000, seed: Optional[int] = None, save_window: Optional[Tuple[int, int]] = None, batch_size: int = 32, **kwargs, ): """(Private) Generates an in-silico MAVE dataset. This helper function creates a dataset of mutated sequences and uses the provided deep learning model to predict their corresponding phenotypes. Args: (See `explain` method for argument descriptions) Returns: Tuple[List[str], np.ndarray, np.ndarray]: A tuple containing: - A list of the generated full sequences (strings). - The one-hot encoded mutated sequences (`x_mut`). - The model predictions for each sequence (`y_mut`). """ if seed is not None: np.random.seed(seed) random.seed(seed) if mut_type == "random": mutagenesis = SQUIDRandomMutagenesis( alphabet=self.alphabet, mut_rate=mut_rate, uniform=uniform, seed=seed ) elif mut_type == "combinatorial": mutagenesis = SQUIDCombinatorialMutagenesis( alphabet=self.alphabet, max_order=max_order, mut_window=mut_window, seed=seed, ) else: raise ValueError( f"Invalid mut_type: {mut_type}. Must be 'random' or 'combinatorial'." ) L = len(sequence) full_seqs = [] # generate in silico MAVE based on mutagenesis strategy if mut_window is not None: start, end = mut_window assert ( start >= 0 and end <= L ), f"Invalid mut_window: {mut_window}. Must be within the sequence length [0, {L}]." target_seqs = sequence[start:end] mutated_parts = mutagenesis(target_seqs, num_sim, **kwargs) else: mutated_parts = mutagenesis(sequence, num_sim, **kwargs) x_mut = [self._one_hot_encode(m_seq) for m_seq in mutated_parts] x_mut = np.stack(x_mut) y_mut = self._mut_predictor(mutated_parts, batch_size, **kwargs) y_mut = np.array(y_mut) return ( mutated_parts, x_mut, y_mut, ) # No mutated window, return all mutated parts for m_seq in mutated_parts: # Add context if context_agnositc: left = "".join(random.choice(self.alphabet, k=start)) right = "".join(random.choice(self.alphabet, k=L - end)) else: left = sequence[:start] right = sequence[end:] full_seq = left + m_seq + right # Add inter-mutated region if inter_window is not None: inters = ( inter_window if isinstance(inter_window[0], list) else [inter_window] ) for w_start, w_end in inters: assert ( w_start >= 0 and w_end <= L ), f"Invalid inter_window: {inter_window}. Must be within the sequence length [0, {L}]." inter_seq = "".join(random.choice(self.alphabet, k=w_end - w_start)) full_seq = full_seq[:w_start] + inter_seq + full_seq[w_end:] full_seqs.append(full_seq) if save_window is not None: start, end = save_window assert ( start >= 0 and end <= L ), f"Invalid save_window: {save_window}. Must be within the sequence length [0, {L}]." if mut_window is not None: if save_window[0] > mut_window[0] or save_window[1] < mut_window[1]: start = 0 end = L fprint( "Conflicting save_window and mut_window. save_window is ignored." ) else: start, end = 0, L x_mut = [self._one_hot_encode(full_seq[start:end]) for full_seq in full_seqs] x_mut = np.stack(x_mut) y_mut = self._mut_predictor(full_seqs, batch_size, **kwargs) y_mut = np.array(y_mut) return full_seqs, x_mut, y_mut def _mut_predictor(self, mutated_parts: List[str], batch_size: int = 32, **kwargs): """ Predict the output of the model for the mutated parts. Args: sequence: the sequence to explain, string mutated_parts: the mutated parts, list of strings **kwargs: additional arguments Returns: y_mut: the output of the model for the mutated parts, numpy array """ if isinstance(self.model, OmniModelForSequenceClassification) or isinstance( self.model, OmniModelForTokenClassification ): y_mut = [] for i in range(0, len(mutated_parts), batch_size): batch = mutated_parts[i : i + batch_size] y_mut.append(self.model.predict(batch, **kwargs)["logits"]) y_mut = torch.cat(y_mut, dim=0).detach().cpu().numpy() elif isinstance(self.model, OmniModelForSequenceRegression) or isinstance( self.model, OmniModelForTokenRegression ): y_mut = [] for i in range(0, len(mutated_parts), batch_size): batch = mutated_parts[i : i + batch_size] y_mut.append(self.model.predict(batch, **kwargs)["predictions"]) y_mut = torch.cat(y_mut, dim=0).detach().cpu().numpy() else: raise ValueError( f"Model type {type(self.model)} not supported for SQUID explainer." ) return y_mut def _one_hot_encode(self, sequence: str): """ One-hot encode the sequence. Args: sequence: the sequence to encode, string Returns: one_hot: the one-hot encoded sequence, numpy array """ one_hot = np.zeros((len(sequence), self.num_tokens), dtype=int) for i, token in enumerate(sequence): token_id = self.token_to_id.get(token, -1) if token_id == -1: raise ValueError(f"Token {token} not found in the alphabet.") one_hot[i, token_id] = 1 return one_hot
# def _build_surrogate_model(self, X_mut, y_mut): # """ # Build a surrogate model for the mutated parts. # """
[docs] class SQUIDBaseMutagenesis: """ SQUIDBaseMutagenesis is a class that generates in silico MAVE data for a given sequence. """ def __call__(self, sequence: str, num_sim: int = 100, **kwargs): """ Return an in silico MAVE based on the given sequence. Args: sequence: the sequence to mutate, string num_sim: the number of simulations to generate, default is 100, int **kwargs: additional arguments Returns: list: a list of mutated sequences """ raise NotImplementedError("This method should be implemented by the subclass.")
[docs] class SQUIDRandomMutagenesis(SQUIDBaseMutagenesis): """ SQUIDRandomMutagenesis is a class that generates in silico MAVE data for a given sequence using random mutagenesis. """ def __init__( self, alphabet: List[str], mut_rate: float = 0.1, uniform: bool = False, seed: Optional[int] = None, ): self.mut_rate = mut_rate self.uniform = uniform self.seed = seed self.alphabet = alphabet def __call__(self, sequence: str, num_sim: int = 10000, **kwargs): """ Return an in silico MAVE based on the given sequence. """ if self.seed is not None: np.random.seed(self.seed) random.seed(self.seed) L = len(sequence) avg_num_mut = int(np.ceil(L * self.mut_rate)) mutated_seqs = [] if self.uniform: num_muts = [avg_num_mut] * num_sim else: num_muts = np.random.poisson(avg_num_mut, size=num_sim) num_muts = np.clip(num_muts, 0, L) for n_mut in num_muts: seq_list = list(sequence) mut_positions = np.random.choice(L, n_mut, replace=False) for pos in mut_positions: original = seq_list[pos] alternatives = [l for l in self.alphabet if l != original] seq_list[pos] = random.choice(alternatives) mutated_seqs.append("".join(seq_list)) return mutated_seqs
[docs] class SQUIDCombinatorialMutagenesis(SQUIDBaseMutagenesis): """ SQUIDCombinatorialMutagenesis is a class that generates in silico MAVE data for a given sequence using combinatorial mutagenesis. """ def __init__( self, alphabet: List[str], max_order: int = 1, mut_window: Optional[Tuple[int, int]] = None, seed: Optional[int] = None, ): """ Initialize the SQUIDCombinatorialMutagenesis class. Args: max_order: the maximum order of mutations to generate. If -1, generates all possible mutations. If 1, generates only single mutations (all SNVs). If 2, generates single and double mutations, etc. int, default is 1 mut_window: [int, int], optional. Index of start and end of the mutated region. If provided, only generates mutations within this window (inclusive on both ends). For example, mut_window=(4, 6) will generate mutations at positions 4, 5, and 6. (Defaults to None, which means the entire sequence is mutated) seed: the seed for the random number generator, default is None, int """ self.max_order = max_order self.mut_window = mut_window self.seed = seed self.alphabet = alphabet def __call__(self, sequence: str, num_sim: int = 10000, **kwargs): """ Return an in silico MAVE based on the given sequence. Args: sequence: the sequence to mutate, string num_sim: this parameter is ignored, int **kwargs: additional arguments Returns: list: a list of mutated sequences """ if self.max_order < -1: raise ValueError( f"Invalid max_order: {self.max_order}. Must be -1 or greater." ) if self.seed is not None: np.random.seed(self.seed) random.seed(self.seed) L = len(sequence) if self.mut_window is not None: start_pos, end_pos = self.mut_window end_pos += 1 if start_pos < 0 or end_pos > L: raise ValueError( f"Invalid mut_window: {self.mut_window}. Must be within the sequence length [0, {L}]." ) else: start_pos, end_pos = 0, L window_len = end_pos - start_pos if self.max_order > window_len: raise ValueError( f"max_order: {self.max_order} is greater than the window length: {window_len}." ) max_order = window_len if self.max_order == -1 else self.max_order ref_seq = list(sequence) mutated_seqs = ["".join(ref_seq)] # generate all possible alternative bases for each position alt_base_dict = { i: [l for l in self.alphabet if l != ref_seq[i]] for i in range(start_pos, end_pos) } # generate all possible mutations for order in range(1, max_order + 1): mut_pos_combinations = list(combinations(range(start_pos, end_pos), order)) total_variants = len(mut_pos_combinations) * (3**order) for positions in mut_pos_combinations: alt_base_lists = [alt_base_dict[pos] for pos in positions] for alt_bases in product(*alt_base_lists): new_seq = ref_seq.copy() for pos, alt_base in zip(positions, alt_bases): new_seq[pos] = alt_base mutated_seqs.append("".join(new_seq)) return mutated_seqs
[docs] class SQUIDAdditiveGPMap(nn.Module): """Additive genotype‑phenotype map: φ = θ_0 + Σ_{l,c} θ_{l,c} x_{l,c}.""" def __init__(self, L: int, A: int, reg_strength: float = 0.0): super().__init__() # θ_{l,c}. Shape (L, A) self.theta_lc = nn.Parameter(torch.zeros(L, A)) # θ_0 – scalar bias. self.theta_0 = nn.Parameter(torch.zeros(1)) self.reg_strength = reg_strength
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """x is one‑hot with shape (N, L, A). Outputs latent φ with shape (N, 1).""" # (N, L, A) · (L, A) → (N, L, A) → sum_{l,c} phi = (x * self.theta_lc).sum(dim=(1, 2)) + self.theta_0 # (N,) return phi.unsqueeze(-1) # (N,1)
[docs] def l2_regularizer(self) -> torch.Tensor: if self.reg_strength == 0: return torch.tensor(0.0, device=self.theta_lc.device) return self.reg_strength * (self.theta_lc.pow(2).mean())
[docs] class SQUIDPairwiseGPMap(nn.Module): """Full pairwise model: φ = θ_0 + Σ θ_{l,c} x_{l,c} + Σ θ_{l1,c1,l2,c2} x_{l1,c1} x_{l2,c2}. The interaction tensor is stored in a factorised low‑rank form so we can scale to reasonable sequence lengths without a massive O(L²A²) memory footprint. We use a CP‑decomposition with *K* latent factors.""" def __init__(self, L: int, A: int, rank: int = 8, reg_strength: float = 0.0): super().__init__() self.L, self.A = L, A self.rank = rank self.theta_lc = nn.Parameter(torch.zeros(L, A)) self.theta_0 = nn.Parameter(torch.zeros(1)) # low‑rank factors: for each position‑char we have a rank‑dim embedding self.U = nn.Parameter(torch.randn(L, A, rank) * 0.01) self.reg_strength = reg_strength
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # additive part additive = (x * self.theta_lc).sum(dim=(1, 2)) # (N,) # pairwise part via CP‑factorised tensor # φ_pair = Σ_{k} (Σ_{l,c} U_{l,c,k} x_{l,c})² − Σ_{l,c} (U_{l,c,k} x_{l,c})² # The subtraction removes the diagonal terms. # Compute projections along each rank component proj = torch.einsum("nlc,lck->nk", x, self.U) # (N, rank) pairwise = 0.5 * ( proj.pow(2).sum(dim=1) - (x.pow(2) * self.U.pow(2).sum(dim=-1)).sum(dim=(1, 2)) ) phi = additive + pairwise + self.theta_0 # (N,) return phi.unsqueeze(-1)
[docs] def l2_regularizer(self) -> torch.Tensor: if self.reg_strength == 0: return torch.tensor(0.0, device=self.theta_lc.device) reg = self.theta_lc.pow(2).mean() + self.U.pow(2).mean() return self.reg_strength * reg
[docs] class SQUIDGlobalEpistasis(nn.Module): """Simple 1‑hidden‑layer sigmoid‑basis network to model GE non‑linearity.""" def __init__(self, hidden_nodes: int = 50): super().__init__() self.net = nn.Sequential( nn.Linear(1, hidden_nodes), nn.Sigmoid(), nn.Linear(hidden_nodes, 1) )
[docs] def forward(self, z): return self.net(z)
[docs] class SQUIDSurrogateModel(nn.Module): def __init__( self, input_shape: Tuple[int, int, int], num_tasks: int, gpmap: str = "additive", regression_type: str = "GE", linearity: str = "nonlinear", noise: str = "Gaussian", noise_order: int = 0, reg_strength: float = 0.1, hidden_nodes: int = 50, token_to_id: Dict[str, int] = {}, deduplicate: bool = True, gpu: bool = True, pairwise_rank: int = 8, seed: Optional[int] = None, ): super().__init__() self.N, self.L, self.A = input_shape self.gpmap_type = gpmap self.regression_type = regression_type self.linearity = linearity self.noise = noise self.noise_order = noise_order if linearity == "nonlinear" else 0 self.reg_strength = reg_strength self.hidden_nodes = hidden_nodes self.token_to_id = ( token_to_id if token_to_id is not None else {"A": 0, "U": 1, "G": 2, "C": 3} ) self.deduplicate = deduplicate self.device = torch.device( "cuda" if (gpu and torch.cuda.is_available()) else "cpu" ) self.seed = seed # Build GP‑map if self.gpmap_type == "additive": self.gpmap = SQUIDAdditiveGPMap(self.L, self.A, reg_strength=reg_strength) elif self.gpmap_type == "pairwise": print(f"Using pairwise gpmap with rank: {pairwise_rank}") self.gpmap = SQUIDPairwiseGPMap( self.L, self.A, rank=pairwise_rank, reg_strength=reg_strength ) else: raise ValueError(f"Unsupported gpmap type: {self.gpmap_type}") # Non‑linearity (for GE regressions) if self.linearity == "nonlinear": self.nonlinearity = SQUIDGlobalEpistasis(hidden_nodes=hidden_nodes) else: self.nonlinearity = nn.Identity() # Final task‑specific heads if self.regression_type == "GE": # Continuous output; one regression head per task self.head = nn.Linear(1, num_tasks) elif self.regression_type == "MPA": # Classification; assume y ∈ {0,…,C‑1}. We model logits per class. self.head = nn.Linear(1, num_tasks) else: raise ValueError("regression_type must be 'GE' or 'MPA'") self.to(self.device)
[docs] def dataframe( self, x: np.ndarray, y: np.ndarray ) -> "Tuple[List[str], torch.Tensor]": """MAVE‑NN wanted a pandas DataFrame. Here we simply return *seq_list* (for inspection) and a PyTorch tensor (N,) or (N,num_tasks) for y. """ def one_hot_to_indices(x: torch.Tensor) -> torch.Tensor: """Convert one‑hot encoded sequences to indices.""" return torch.argmax(x, dim=-1) def indices_to_seqs( indices: torch.Tensor, token_to_id: Dict[str, int] ) -> List[str]: """Convert indices to sequences.""" return [ "".join([list(token_to_id.keys())[i] for i in indice]) for indice in indices ] seq_ids = one_hot_to_indices(torch.tensor(x, dtype=torch.long)) seq_list = indices_to_seqs(seq_ids, self.token_to_id) y_tensor = torch.tensor(y, dtype=torch.float32) return seq_list, y_tensor
[docs] def fit( self, x: np.ndarray, y: np.ndarray, learning_rate: float = 5e-4, epochs: int = 500, batch_size: int = 128, early_stopping: bool = True, patience: int = 25, save_dir: Optional[str] = None, verbose: int = 1, ) -> Tuple[nn.Module, List[str]]: """End‑to‑end training loop with a 60/20/20 train/val/test split. Returns the trained *nn.Module* and the sequence list (for external evaluation if needed).""" # Convert to tensors x_tensor = torch.tensor(x, dtype=torch.float32) # (N,L,A) y_tensor = torch.tensor(y, dtype=torch.float32) # Deduplicate if requested if self.deduplicate: unique, indices = np.unique( np.concatenate([x.reshape(self.N, -1), y.reshape(self.N, -1)], axis=1), axis=0, return_index=True, ) x_tensor = x_tensor[indices] y_tensor = y_tensor[indices] # Build dataset & split dataset = TensorDataset(x_tensor, y_tensor) n_total = len(dataset) n_train = int(0.6 * n_total) n_val = int(0.2 * n_total) n_test = n_total - n_train - n_val self.train_set, self.val_set, self.test_set = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed( self.seed if self.seed is not None else 42 ), ) train_loader = DataLoader(self.train_set, batch_size=batch_size, shuffle=True) val_loader = DataLoader(self.val_set, batch_size=batch_size) self.test_loader = DataLoader(self.test_set, batch_size=batch_size) # Optimiser optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) best_val_loss = math.inf epochs_without_improve = 0 self.train_history = {"loss": [], "val_loss": []} for epoch in range(1, epochs + 1): self.train() total_loss = 0.0 for xb, yb in train_loader: xb, yb = xb.to(self.device), yb.to(self.device) optimizer.zero_grad() loss = self._nll(xb, yb) (loss + self.gpmap.l2_regularizer()).backward() optimizer.step() total_loss += loss.item() * xb.size(0) avg_loss = total_loss / n_train # Validation self.eval() with torch.no_grad(): val_loss = ( sum( self._nll(xb.to(self.device), yb.to(self.device)).item() * xb.size(0) for xb, yb in val_loader ) / n_val ) self.train_history["loss"].append(avg_loss) self.train_history["val_loss"].append(val_loss) if verbose and epoch % 10 == 0: print( f"Epoch {epoch:4d} | loss={avg_loss:.4f} | val_loss={val_loss:.4f}" ) # Early stopping if early_stopping: if val_loss < best_val_loss - 1e-4: # significant improvement best_val_loss = val_loss epochs_without_improve = 0 best_state = self.state_dict() else: epochs_without_improve += 1 if epochs_without_improve >= patience: if verbose: print("Early stopping triggered.") break # Restore best weights if early stopping if early_stopping: self.load_state_dict(best_state) # Optionally save if save_dir is not None: os.makedirs(save_dir, exist_ok=True) torch.save( self.state_dict(), os.path.join(save_dir, "surrogate_mavenn_pt.pt") ) seq_list, _ = self.dataframe(x, y) return self, seq_list
def _nll(self, xb: torch.Tensor, yb: torch.Tensor) -> torch.Tensor: """Negative log‑likelihood / loss function depending on regression type and noise model.""" preds = self._predict_continuous(xb) # (N,num_tasks) continuous if self.regression_type == "GE": if self.noise == "Gaussian": return F.mse_loss(preds, yb) elif self.noise == "Cauchy": # −log p(y|φ,γ) where p is Cauchy with scale γ. We learn a # global logγ parameter. if not hasattr(self, "log_gamma"): self.log_gamma = nn.Parameter(torch.zeros(1, device=self.device)) gamma = torch.exp(self.log_gamma) return ( torch.log(math.pi * gamma) + torch.log(1 + ((yb - preds) / gamma) ** 2).mean() ) else: raise NotImplementedError("Only Gaussian or Cauchy noise implemented.") else: # MPA – classification over *num_tasks* categories return F.cross_entropy(preds, yb.long()) def _predict_continuous(self, xb: torch.Tensor) -> torch.Tensor: """Forward pass returning *continuous* latent phenotypes, with optional GE non‑linearity, followed by task‑specific linear head. Used by both GE and MPA (where logits are derived from continuous φ).""" xb = xb.to(self.device) z = self.gpmap(xb) # (N,1) z_nl = self.nonlinearity(z) # (N,1) out = self.head(z_nl) # (N,num_tasks) return out
[docs] @torch.no_grad() def get_info(self, verbose: int = 1) -> float: """Compute a *heuristic* predictive information metric. For GE we return the test‑set R². For MPA we return accuracy. This does *not* attempt to reproduce the variational information bound used by MAVE‑NN, but is often a useful quick‑and‑dirty proxy.""" self.eval() all_preds, all_y = [], [] for xb, yb in self.test_loader: xb, yb = xb.to(self.device), yb.to(self.device) preds = self._predict_continuous(xb) all_preds.append(preds.cpu()) all_y.append(yb.cpu()) y_true = torch.cat(all_y, dim=0) y_pred = torch.cat(all_preds, dim=0) if self.regression_type == "GE": ss_res = ((y_true - y_pred) ** 2).sum() ss_tot = ((y_true - y_true.mean()) ** 2).sum() r2 = 1.0 - ss_res / ss_tot if verbose: print(f"Test R²: {r2:.4f}") return r2.item() else: # classification accuracy acc = (y_pred.argmax(dim=1) == y_true.long()).float().mean() if verbose: print(f"Test accuracy: {acc:.4f}") return acc.item()
[docs] def get_params(self) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: """Return (theta_0, theta_lc, theta_lclc). Pairwise parameters are only returned if gpmap_type == 'pairwise'.""" theta_0 = self.gpmap.theta_0.detach().cpu().numpy() theta_lc = self.gpmap.theta_lc.detach().cpu().numpy() theta_lclc = None if self.gpmap_type == "pairwise": # Reconstruct full L×A×L×A tensor from low‑rank factors U = self.gpmap.U.detach().cpu().numpy() # (L,A,K) K = U.shape[-1] # θ_{l1,c1,l2,c2} ≈ Σ_k U_{l1,c1,k} U_{l2,c2,k} theta_lclc = np.einsum("lak, mbk -> lamb", U, U) # shape (L,A,L,A) return theta_0, theta_lc, theta_lclc