from __future__ import annotations
import os
import json
from typing import Optional
import torch
import torch.nn as nn
from dataclasses import dataclass, field
import warnings
from ..abc.abstract_model import OmniModel
__all__ = [
"OmniCNNBaseline",
"OmniRNNBaseline",
"OmniBPNetBaseline",
"OmniBasenjiBaseline",
"OmniDeepSTARRBaseline",
"OmniGenericBaseline",
"create_baseline",
]
# ---------------- Utility -----------------
class _MaskedGlobalMaxPool1d(nn.Module):
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Masked global max-pooling over the sequence dimension.
Parameters
----------
x : torch.Tensor
Input features with shape ``(batch_size, seq_len, hidden_size)`` if channels-last,
or after internal transpose ``(batch_size, channels, seq_len)`` when used inside CNN blocks.
attention_mask : torch.Tensor, optional
Binary mask with shape ``(batch_size, seq_len)`` where 1 marks valid tokens and 0 marks padding.
Returns
-------
torch.Tensor
Pooled tensor of shape ``(batch_size, hidden_size)``.
Notes
-----
- If ``attention_mask`` is provided, positions with 0 are masked using ``-inf`` before max.
- If ``attention_mask`` is ``None``, a simple ``max(dim=1)`` is applied.
Pseudocode
----------
.. code-block:: python
if attention_mask is None:
return x.max(dim=1).values
masked_x = x.masked_fill(attention_mask.unsqueeze(-1) == 0, -inf)
return masked_x.max(dim=1).values
"""
if attention_mask is None:
return x.max(dim=1).values
masked_x = x.masked_fill(attention_mask.unsqueeze(-1).eq(0), float("-inf"))
return masked_x.max(dim=1).values
# ---------------- Simple Baselines with Heads (legacy style) -----------------
[docs]
class OmniCNNBaseline(OmniModel):
"""A simple 1D-CNN baseline with global max pooling for multi-label tasks.
This legacy-style model builds an embedding layer followed by multiple
convolutional filters with kernel sizes specified in ``kernel_sizes``,
concatenates the features, pools them, and applies a linear classifier
with a sigmoid for multi-label probabilities.
Parameters
----------
tokenizer : Any
Tokenizer providing ``vocab_size`` or ``get_vocab()`` and ``pad_token_id``.
num_labels : int
Number of output labels.
embed_dim : int, optional
Token embedding dimension, by default 128.
num_filters : int, optional
Number of output channels per convolutional filter, by default 128.
kernel_sizes : tuple[int, ...], optional
Convolution kernel sizes, by default ``(3, 5, 7)``.
dropout : float, optional
Dropout probability, by default 0.1.
Inputs
------
input_ids : torch.LongTensor of shape ``(batch_size, seq_len)``
attention_mask : torch.LongTensor of shape ``(batch_size, seq_len)``, optional
labels : torch.FloatTensor of shape ``(batch_size, num_labels)``, optional
Outputs
-------
dict
- ``logits``: ``(batch_size, num_labels)`` in ``[0,1]`` after sigmoid.
- ``last_hidden_state``: pooled hidden vector ``(batch_size, hidden_size)``.
- ``labels``: passthrough of input labels if provided.
Notes
-----
- Loss used is BCE (``nn.BCELoss``) expecting labels in ``{0,1}`` floats.
- See :meth:`predict` and :meth:`inference` for convenience wrappers.
Pseudocode
----------
.. code-block:: python
x = Embedding(input_ids)
x = Dropout(x)
feats = [Conv1D_k(ReLU)(x_T) for k in kernel_sizes]
feats = concat(feats, dim=channels).T
pooled = masked_global_max_pool(feats, attention_mask)
logits = sigmoid(Linear(pooled))
"""
def __init__(self, tokenizer, *args, **kwargs):
embed_dim = kwargs.pop("embed_dim", 128)
num_filters = kwargs.pop("num_filters", 128)
kernel_sizes = kwargs.pop("kernel_sizes", (3, 5, 7))
dropout = kwargs.pop("dropout", 0.1)
num_labels = kwargs.pop("num_labels")
class Cfg: ...
cfg = Cfg()
cfg.hidden_size = num_filters * len(kernel_sizes)
cfg.num_labels = num_labels
cfg.label2id = {str(i): i for i in range(num_labels)}
cfg.id2label = {i: str(i) for i in range(num_labels)}
cfg.name_or_path = "CNNBaseline"
cfg.model_type = "cnn"
cfg.architectures = ["CNNBaseline"]
cfg.pad_token_id = getattr(tokenizer, "pad_token_id", -100)
cfg.embed_dim = embed_dim
cfg.num_filters = num_filters
cfg.kernel_sizes = list(kernel_sizes)
cfg.dropout = dropout
class _Stub(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.register_buffer("_dev_tracker", torch.empty(0))
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
super().__init__(_Stub(cfg), tokenizer, num_labels=num_labels, *args, **kwargs)
vocab_size = getattr(self.tokenizer, "vocab_size", None) or len(
self.tokenizer.get_vocab()
)
_pad = self.pad_token_id
if isinstance(_pad, torch.Tensor):
_pad = int(_pad.item())
elif isinstance(_pad, int) or _pad is None:
pass
else:
_pad = -100
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=_pad)
self.convs = nn.ModuleList(
[
nn.Sequential(
nn.Conv1d(embed_dim, num_filters, k, padding=k // 2), nn.ReLU()
)
for k in kernel_sizes
]
)
self.pool = _MaskedGlobalMaxPool1d()
self.classifier = nn.Linear(num_filters * len(kernel_sizes), num_labels)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(dropout)
self.loss_fn = nn.BCELoss()
def _build_config_dict(self):
return {
"model_type": self.config.model_type,
"architectures": self.config.architectures,
"num_labels": self.config.num_labels,
"label2id": self.config.label2id,
"id2label": self.config.id2label,
"pad_token_id": self.config.pad_token_id,
"hidden_size": self.config.hidden_size,
"embed_dim": getattr(self.config, "embed_dim", None),
"num_filters": getattr(self.config, "num_filters", None),
"kernel_sizes": getattr(self.config, "kernel_sizes", None),
"dropout": getattr(self.config, "dropout", None),
"vocab_size": self.embedding.num_embeddings,
"model_cls": self.__class__.__name__,
"library_name": "OMNIGENBENCH",
}
[docs]
def save_pretrained(self, save_directory: str, overwrite: bool = True):
os.makedirs(save_directory, exist_ok=True)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(self._build_config_dict(), f, ensure_ascii=False, indent=2)
if hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(save_directory)
else:
with open(os.path.join(save_directory, "tokenizer.bin"), "wb") as f:
torch.save(self.tokenizer, f)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
metadata = getattr(self, "metadata", {})
metadata["model_cls"] = self.__class__.__name__
metadata["library_name"] = metadata.get("library_name", "OMNIGENBENCH")
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
[docs]
@classmethod
def from_pretrained(
cls, save_directory: str, tokenizer=None, map_location=None, **kwargs
):
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg_dict = json.load(f)
if tokenizer is None:
if os.path.exists(
os.path.join(save_directory, "tokenizer_config.json")
) or os.path.exists(os.path.join(save_directory, "vocab.json")):
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(save_directory)
except Exception:
tokenizer = None
if tokenizer is None and os.path.exists(
os.path.join(save_directory, "tokenizer.bin")
):
with open(os.path.join(save_directory, "tokenizer.bin"), "rb") as f:
tokenizer = torch.load(f, map_location=map_location)
if tokenizer is None:
raise ValueError("Tokenizer could not be loaded; please provide one.")
model = cls(
tokenizer,
embed_dim=cfg_dict.get("embed_dim", cfg_dict.get("hidden_size")),
num_filters=cfg_dict.get("num_filters", 128),
kernel_sizes=tuple(cfg_dict.get("kernel_sizes", (3, 5, 7))),
dropout=cfg_dict.get("dropout", 0.1),
num_labels=cfg_dict["num_labels"],
label2id=cfg_dict.get("label2id"),
**kwargs,
)
state_dict = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state_dict, strict=False)
meta_path = os.path.join(save_directory, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf8") as f:
model.metadata = json.load(f)
return model
[docs]
def forward(self, **inputs):
"""Forward pass.
Parameters
----------
input_ids : torch.LongTensor
Token ids of shape ``(batch_size, seq_len)``.
attention_mask : torch.LongTensor, optional
Mask of shape ``(batch_size, seq_len)`` with 1 for valid tokens.
labels : torch.FloatTensor, optional
Multi-hot label matrix ``(batch_size, num_labels)``.
Returns
-------
dict
Dictionary with ``logits``, ``last_hidden_state``, and optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self.embedding(inputs["input_ids"])
x = self.dropout(x)
feats = [m(x.transpose(1, 2)) for m in self.convs]
feats = torch.cat(feats, dim=1).transpose(1, 2)
pooled = self.pool(feats, inputs.get("attention_mask"))
pooled = self.dropout(pooled)
logits = self.sigmoid(self.classifier(pooled))
return {"logits": logits, "last_hidden_state": pooled, "labels": labels}
[docs]
def predict(self, sequence_or_inputs, **kwargs):
"""Return probabilities for each label.
This calls the internal convenience routine to accept either raw sequences
or already-tokenized inputs, then returns the probabilities.
Returns
-------
dict
Keys: ``predictions`` (alias of ``logits``), ``logits``, ``last_hidden_state``.
"""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
return {
"predictions": out["logits"],
"logits": out["logits"],
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def inference(self, sequence_or_inputs, threshold: float = 0.5, **kwargs):
"""Return binary predictions with a threshold.
Parameters
----------
threshold : float, optional
Decision threshold in ``[0,1]`` applied to probabilities, by default 0.5.
"""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
logits = out["logits"]
preds = (logits >= threshold).to(torch.int)
if not isinstance(sequence_or_inputs, list):
return {
"predictions": preds[0].cpu(),
"logits": logits[0].cpu(),
"confidence": torch.max(logits[0]).cpu(),
"last_hidden_state": out["last_hidden_state"][0].cpu(),
}
return {
"predictions": preds.cpu(),
"logits": logits.cpu(),
"confidence": torch.max(logits, dim=-1)[0].cpu(),
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def loss_function(self, logits, labels):
"""Binary cross-entropy loss with logits already in probability space.
Notes
-----
This legacy baseline uses ``BCELoss`` assuming inputs were passed through
a sigmoid already. Prefer ``BCEWithLogitsLoss`` for numerical stability
in new code.
"""
return self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
[docs]
class OmniRNNBaseline(OmniModel):
"""A simple BiLSTM baseline for sequence modeling.
Embeds tokens, applies a multi-layer LSTM (optionally bidirectional), then
mean-pools over valid tokens and classifies with a sigmoid layer for
multi-label probabilities.
Parameters
----------
tokenizer : Any
Tokenizer with ``vocab_size`` and ``pad_token_id``.
num_labels : int
Number of output labels.
embed_dim : int, optional
Embedding dimension, by default 128.
hidden_dim : int, optional
LSTM hidden size per direction, by default 256.
num_layers : int, optional
Number of LSTM layers, by default 1.
bidirectional : bool, optional
Whether to use a bidirectional LSTM, by default True.
dropout : float, optional
Dropout probability, by default 0.1.
Outputs
-------
dict
- ``logits``: probabilities after sigmoid ``(batch_size, num_labels)``.
- ``last_hidden_state``: pooled vector ``(batch_size, hidden_size)``.
- ``labels``: passthrough if provided.
Pseudocode
----------
.. code-block:: python
x = Embedding(input_ids)
x = Dropout(x)
seq_out, _ = LSTM(x)
pooled = masked_mean(seq_out, attention_mask)
logits = sigmoid(Linear(pooled))
"""
def __init__(self, tokenizer, *args, **kwargs):
embed_dim = kwargs.pop("embed_dim", 128)
hidden_dim = kwargs.pop("hidden_dim", 256)
num_layers = kwargs.pop("num_layers", 1)
bidirectional = kwargs.pop("bidirectional", True)
dropout = kwargs.pop("dropout", 0.1)
num_labels = kwargs.pop("num_labels")
class Cfg: ...
cfg = Cfg()
cfg.hidden_size = hidden_dim * (2 if bidirectional else 1)
cfg.num_labels = num_labels
cfg.label2id = {str(i): i for i in range(num_labels)}
cfg.id2label = {i: str(i) for i in range(num_labels)}
cfg.name_or_path = "RNNBaseline"
cfg.model_type = "rnn"
cfg.architectures = ["RNNBaseline"]
cfg.pad_token_id = getattr(tokenizer, "pad_token_id", -100)
cfg.embed_dim = embed_dim
cfg.hidden_dim = hidden_dim
cfg.num_layers = num_layers
cfg.bidirectional = bidirectional
cfg.dropout = dropout
class _Stub(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.register_buffer("_dev_tracker", torch.empty(0))
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
super().__init__(_Stub(cfg), tokenizer, num_labels=num_labels, *args, **kwargs)
vocab_size = getattr(self.tokenizer, "vocab_size", None) or len(
self.tokenizer.get_vocab()
)
_pad = self.pad_token_id
if isinstance(_pad, torch.Tensor):
_pad = int(_pad.item())
elif isinstance(_pad, int) or _pad is None:
pass
else:
_pad = -100
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=_pad)
self.lstm = nn.LSTM(
embed_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=0.0 if num_layers == 1 else dropout,
bidirectional=bidirectional,
)
self.classifier = nn.Linear(
hidden_dim * (2 if bidirectional else 1), num_labels
)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(dropout)
self.loss_fn = nn.BCELoss()
def _build_config_dict(self):
return {
"model_type": self.config.model_type,
"architectures": self.config.architectures,
"num_labels": self.config.num_labels,
"label2id": self.config.label2id,
"id2label": self.config.id2label,
"pad_token_id": self.config.pad_token_id,
"hidden_size": self.config.hidden_size,
"embed_dim": getattr(self.config, "embed_dim", None),
"hidden_dim": getattr(self.config, "hidden_dim", None),
"num_layers": getattr(self.config, "num_layers", None),
"bidirectional": getattr(self.config, "bidirectional", None),
"dropout": getattr(self.config, "dropout", None),
"vocab_size": self.embedding.num_embeddings,
"model_cls": self.__class__.__name__,
"library_name": "OMNIGENBENCH",
}
[docs]
def save_pretrained(self, save_directory: str, overwrite: bool = True):
os.makedirs(save_directory, exist_ok=True)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(self._build_config_dict(), f, ensure_ascii=False, indent=2)
if hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(save_directory)
else:
with open(os.path.join(save_directory, "tokenizer.bin"), "wb") as f:
torch.save(self.tokenizer, f)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
metadata = getattr(self, "metadata", {})
metadata["model_cls"] = self.__class__.__name__
metadata["library_name"] = metadata.get("library_name", "OMNIGENBENCH")
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
[docs]
@classmethod
def from_pretrained(
cls, save_directory: str, tokenizer=None, map_location=None, **kwargs
):
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg_dict = json.load(f)
if tokenizer is None:
if os.path.exists(
os.path.join(save_directory, "tokenizer_config.json")
) or os.path.exists(os.path.join(save_directory, "vocab.json")):
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(save_directory)
except Exception:
tokenizer = None
if tokenizer is None and os.path.exists(
os.path.join(save_directory, "tokenizer.bin")
):
with open(os.path.join(save_directory, "tokenizer.bin"), "rb") as f:
tokenizer = torch.load(f, map_location=map_location)
if tokenizer is None:
raise ValueError("Tokenizer could not be loaded; please provide one.")
model = cls(
tokenizer,
embed_dim=cfg_dict.get("embed_dim", cfg_dict.get("hidden_size", 128)),
hidden_dim=cfg_dict.get("hidden_dim", 256),
num_layers=cfg_dict.get("num_layers", 1),
bidirectional=cfg_dict.get("bidirectional", True),
dropout=cfg_dict.get("dropout", 0.1),
num_labels=cfg_dict["num_labels"],
label2id=cfg_dict.get("label2id"),
**kwargs,
)
state_dict = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state_dict, strict=False)
meta_path = os.path.join(save_directory, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf8") as f:
model.metadata = json.load(f)
return model
[docs]
def forward(self, **inputs):
"""Forward pass producing multi-label probabilities and hidden state.
Returns
-------
dict
Dictionary with keys ``logits``, ``last_hidden_state``, optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self.embedding(inputs["input_ids"])
x = self.dropout(x)
out, _ = self.lstm(x)
pooled = self._mask_mean_pool(out, inputs.get("attention_mask"))
pooled = self.dropout(pooled)
logits = self.sigmoid(self.classifier(pooled))
return {"logits": logits, "last_hidden_state": pooled, "labels": labels}
[docs]
def predict(self, sequence_or_inputs, **kwargs):
"""Return probabilities for each label (alias of forward logits)."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
return {
"predictions": out["logits"],
"logits": out["logits"],
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def inference(self, sequence_or_inputs, threshold: float = 0.5, **kwargs):
"""Return binary predictions using the specified probability threshold."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
logits = out["logits"]
preds = (logits >= threshold).to(torch.int)
if not isinstance(sequence_or_inputs, list):
return {
"predictions": preds[0].cpu(),
"logits": logits[0].cpu(),
"confidence": torch.max(logits[0]).cpu(),
"last_hidden_state": out["last_hidden_state"][0].cpu(),
}
return {
"predictions": preds.cpu(),
"logits": logits.cpu(),
"confidence": torch.max(logits, dim=-1)[0].cpu(),
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def loss_function(self, logits, labels):
"""Binary cross-entropy loss on probabilities."""
return self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
[docs]
class OmniBPNetBaseline(OmniModel):
"""A lightweight BPNet-like dilated-convolution baseline.
The model converts token ids to one-hot nucleotides (A,C,G,T), applies a
first convolution and a stack of exponentially-dilated 1D convolutions with
residual connections, averages globally, then classifies with a sigmoid.
Parameters
----------
tokenizer : Any
Tokenizer that can convert tokens to ids for "A", "C", "G", "T".
num_labels : int
Number of outputs.
n_filters : int, optional
Number of channels in convolution blocks, by default 64.
n_dilated_layers : int, optional
Number of dilated layers, by default 9.
conv1_kernel_size : int, optional
Kernel size of the first conv, by default 25.
dil_kernel_size : int, optional
Kernel size of dilated convs, by default 3.
dropout : float, optional
Dropout after global pooling, by default 0.1.
Pseudocode
----------
.. code-block:: python
X = one_hot(input_ids) # (B, 4, L)
H = relu(Conv1D(4->C, k=25)(X))
for i in range(n_layers):
R = H
H = relu(DilatedConv1D(C->C, k=3, d=2**i)(H))
H = H + R
g = GlobalAvgPool1D(H) # (B, C)
y = sigmoid(Linear(C->num_labels)(Dropout(g)))
"""
def __init__(self, tokenizer, *args, **kwargs):
n_outputs = kwargs.pop("num_labels")
n_filters = kwargs.pop("n_filters", 64)
n_dilated_layers = kwargs.pop("n_dilated_layers", 9)
conv1_kernel_size = kwargs.pop("conv1_kernel_size", 25)
dil_kernel_size = kwargs.pop("dil_kernel_size", 3)
dropout = kwargs.pop("dropout", 0.1)
class Cfg: ...
cfg = Cfg()
cfg.hidden_size = n_filters
cfg.num_labels = n_outputs
cfg.label2id = {str(i): i for i in range(n_outputs)}
cfg.id2label = {i: str(i) for i in range(n_outputs)}
cfg.name_or_path = "BPNetBaseline"
cfg.model_type = "bpnet"
cfg.architectures = ["BPNetBaseline"]
cfg.pad_token_id = getattr(tokenizer, "pad_token_id", -100)
cfg.n_filters = n_filters
cfg.n_dilated_layers = n_dilated_layers
cfg.conv1_kernel_size = conv1_kernel_size
cfg.dil_kernel_size = dil_kernel_size
cfg.dropout = dropout
class _Stub(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.register_buffer("_dev_tracker", torch.empty(0))
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
super().__init__(_Stub(cfg), tokenizer, num_labels=n_outputs, *args, **kwargs)
vocab_size = getattr(self.tokenizer, "vocab_size", None) or len(
self.tokenizer.get_vocab()
)
weight = torch.zeros(vocab_size, 4)
# Support both uppercase and lowercase nucleotide tokens
for i, toks in enumerate([("A", "a"), ("C", "c"), ("G", "g"), ("T", "t")]):
for tok in toks:
try:
tid = self.tokenizer.convert_tokens_to_ids(tok)
if tid is not None and tid >= 0 and tid < weight.size(0):
weight[tid, i] = 1.0
except Exception:
pass
self.register_buffer("_one_hot_weight", weight, persistent=False)
self.conv1 = nn.Conv1d(4, n_filters, conv1_kernel_size, padding="same")
self.dilated_convs = nn.ModuleList(
[
nn.Conv1d(
n_filters,
n_filters,
dil_kernel_size,
padding="same",
dilation=2**i,
)
for i in range(n_dilated_layers)
]
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.classifier = nn.Linear(n_filters, n_outputs)
self.sigmoid = nn.Sigmoid()
self.loss_fn = nn.BCELoss()
self.dropout_layer = nn.Dropout(dropout)
def _tokens_to_one_hot(self, input_ids: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(input_ids, self._one_hot_weight).permute(
0, 2, 1
)
def forward(self, **inputs):
"""Forward pass returning probabilities and last hidden state.
Returns
-------
dict
Keys: ``logits`` (probabilities), ``last_hidden_state`` (pooled), optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self._tokens_to_one_hot(inputs["input_ids"]) # [B,4,L]
x = torch.relu(self.conv1(x))
for layer in self.dilated_convs:
residual = x
x = torch.relu(layer(x))
x = x + residual
x = self.global_avg_pool(x).squeeze(-1)
x = self.dropout_layer(x)
logits = self.sigmoid(self.classifier(x))
return {"logits": logits, "last_hidden_state": x, "labels": labels}
def _build_config_dict(self):
return {
"model_type": self.config.model_type,
"architectures": self.config.architectures,
"num_labels": self.config.num_labels,
"label2id": self.config.label2id,
"id2label": self.config.id2label,
"pad_token_id": self.config.pad_token_id,
"hidden_size": self.config.hidden_size,
"embed_dim": getattr(self.config, "embed_dim", None),
"n_filters": getattr(self.config, "n_filters", None),
"n_dilated_layers": getattr(self.config, "n_dilated_layers", None),
"conv1_kernel_size": getattr(self.config, "conv1_kernel_size", None),
"dil_kernel_size": getattr(self.config, "dil_kernel_size", None),
"dropout": getattr(self.config, "dropout", None),
"vocab_size": self.embedding.num_embeddings,
"model_cls": self.__class__.__name__,
"library_name": "OMNIGENBENCH",
}
[docs]
def save_pretrained(self, save_directory: str, overwrite: bool = True):
os.makedirs(save_directory, exist_ok=True)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(self._build_config_dict(), f, ensure_ascii=False, indent=2)
if hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(save_directory)
else:
with open(os.path.join(save_directory, "tokenizer.bin"), "wb") as f:
torch.save(self.tokenizer, f)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
metadata = getattr(self, "metadata", {})
metadata["model_cls"] = self.__class__.__name__
metadata["library_name"] = metadata.get("library_name", "OMNIGENBENCH")
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
[docs]
@classmethod
def from_pretrained(
cls, save_directory: str, tokenizer=None, map_location=None, **kwargs
):
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg_dict = json.load(f)
if tokenizer is None:
if os.path.exists(
os.path.join(save_directory, "tokenizer_config.json")
) or os.path.exists(os.path.join(save_directory, "vocab.json")):
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(save_directory)
except Exception:
tokenizer = None
if tokenizer is None and os.path.exists(
os.path.join(save_directory, "tokenizer.bin")
):
with open(os.path.join(save_directory, "tokenizer.bin"), "rb") as f:
tokenizer = torch.load(f, map_location=map_location)
if tokenizer is None:
raise ValueError("Tokenizer could not be loaded; please provide one.")
model = cls(
tokenizer,
n_filters=cfg_dict.get("n_filters", 64),
n_dilated_layers=cfg_dict.get("n_dilated_layers", 9),
conv1_kernel_size=cfg_dict.get("conv1_kernel_size", 25),
dil_kernel_size=cfg_dict.get("dil_kernel_size", 3),
dropout=cfg_dict.get("dropout", 0.1),
num_labels=cfg_dict["num_labels"],
label2id=cfg_dict.get("label2id"),
**kwargs,
)
state_dict = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state_dict, strict=False)
meta_path = os.path.join(save_directory, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf8") as f:
model.metadata = json.load(f)
return model
[docs]
def forward(self, **inputs):
"""Forward pass returning probabilities and last hidden state.
Returns
-------
dict
Keys: ``logits`` (probabilities), ``last_hidden_state`` (pooled), optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self._tokens_to_one_hot(inputs["input_ids"]) # [B,4,L]
x = torch.relu(self.conv1(x))
for layer in self.dilated_convs:
residual = x
x = torch.relu(layer(x))
x = x + residual
x = self.global_avg_pool(x).squeeze(-1)
x = self.dropout_layer(x)
logits = self.sigmoid(self.classifier(x))
return {"logits": logits, "last_hidden_state": x, "labels": labels}
[docs]
def predict(self, sequence_or_inputs, **kwargs):
"""Return probabilities for each label."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
return {
"predictions": out["logits"],
"logits": out["logits"],
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def inference(self, sequence_or_inputs, threshold: float = 0.5, **kwargs):
"""Return thresholded predictions along with confidence scores."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
logits = out["logits"]
preds = (logits >= threshold).to(torch.int)
if not isinstance(sequence_or_inputs, list):
return {
"predictions": preds[0].cpu(),
"logits": logits[0].cpu(),
"confidence": torch.max(logits[0]).cpu(),
"last_hidden_state": out["last_hidden_state"][0].cpu(),
}
return {
"predictions": preds.cpu(),
"logits": logits.cpu(),
"confidence": torch.max(logits, dim=-1)[0].cpu(),
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def loss_function(self, logits, labels):
return self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
[docs]
class OmniBasenjiBaseline(OmniModel):
"""Basenji-like 1D CNN with dilations adapted for tokenizer-driven inputs.
This baseline maps token ids to A/C/G/T channels, stacks conv+pool blocks,
applies dilated residual blocks, then global-average-pools and classifies.
"""
def __init__(self, tokenizer, *args, **kwargs):
num_labels = kwargs.pop("num_labels")
conv1kc = kwargs.pop("conv1kc", 64)
conv1ks = kwargs.pop("conv1ks", 15)
pool1ks = kwargs.pop("pool1ks", 8)
conv2kc = kwargs.pop("conv2kc", 64)
conv2ks = kwargs.pop("conv2ks", 5)
pool2ks = kwargs.pop("pool2ks", 4)
conv3kc = kwargs.pop("conv3kc", round(64 * 1.125))
conv3ks = kwargs.pop("conv3ks", 5)
pool3ks = kwargs.pop("pool3ks", 4)
convdc = kwargs.pop("convdc", 6)
dropout = kwargs.pop("dropout", 0.1)
class Cfg: ...
cfg = Cfg()
cfg.hidden_size = 64
cfg.num_labels = num_labels
cfg.label2id = {str(i): i for i in range(num_labels)}
cfg.id2label = {i: str(i) for i in range(num_labels)}
cfg.name_or_path = "BasenjiBaseline"
cfg.model_type = "basenji"
cfg.architectures = ["BasenjiBaseline"]
cfg.pad_token_id = getattr(tokenizer, "pad_token_id", -100)
cfg.dropout = dropout
cfg.convdc = convdc
class _Stub(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.register_buffer("_dev_tracker", torch.empty(0))
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
super().__init__(_Stub(cfg), tokenizer, num_labels=num_labels, *args, **kwargs)
# Build token-id -> one-hot(A,C,G,T) projection
vocab_size = getattr(self.tokenizer, "vocab_size", None) or len(
self.tokenizer.get_vocab()
)
weight = torch.zeros(vocab_size, 4)
for i, toks in enumerate([("A", "a"), ("C", "c"), ("G", "g"), ("T", "t")]):
for tok in toks:
try:
tid = self.tokenizer.convert_tokens_to_ids(tok)
if (
tid is not None
and isinstance(tid, int)
and 0 <= tid < weight.size(0)
):
weight[tid, i] = 1.0
except Exception:
pass
self.register_buffer("_one_hot_weight", weight, persistent=False)
self.act = nn.GELU()
self.conv_block_1 = nn.Sequential(
self.act,
nn.Conv1d(
4, conv1kc, kernel_size=conv1ks, padding=conv1ks // 2, bias=False
),
nn.BatchNorm1d(conv1kc, momentum=0.9, affine=True),
nn.MaxPool1d(kernel_size=pool1ks, ceil_mode=True),
nn.Dropout(p=0.2),
)
self.conv_block_2 = nn.Sequential(
self.act,
nn.Conv1d(
conv1kc, conv2kc, kernel_size=conv2ks, padding=conv2ks // 2, bias=False
),
nn.BatchNorm1d(conv2kc, momentum=0.9, affine=True),
nn.MaxPool1d(kernel_size=pool2ks, ceil_mode=True),
nn.Dropout(p=0.2),
)
self.conv_block_3 = nn.Sequential(
self.act,
nn.Conv1d(
conv2kc, conv3kc, kernel_size=conv3ks, padding=conv3ks // 2, bias=False
),
nn.BatchNorm1d(conv3kc, momentum=0.9, affine=True),
nn.MaxPool1d(kernel_size=pool3ks, ceil_mode=True),
nn.Dropout(p=0.2),
)
self.dilations = nn.ModuleList()
for i in range(convdc):
self.dilations.append(
nn.Sequential(
self.act,
nn.Conv1d(
conv3kc,
32,
kernel_size=3,
padding=2**i,
dilation=2**i,
bias=False,
),
nn.BatchNorm1d(32, momentum=0.9, affine=True),
self.act,
nn.Conv1d(32, 72, kernel_size=1, padding=0, bias=False),
nn.BatchNorm1d(72, momentum=0.9, affine=True),
nn.Dropout(p=0.25),
)
)
self.conv_block_4 = nn.Sequential(
self.act,
nn.Conv1d(72, 64, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm1d(64, momentum=0.9, affine=True),
nn.Dropout(p=0.1),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.classifier = nn.Linear(64, num_labels)
self.sigmoid = nn.Sigmoid()
self.loss_fn = nn.BCELoss()
def _tokens_to_one_hot(self, input_ids: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(input_ids, self._one_hot_weight).permute(
0, 2, 1
)
def forward(self, **inputs):
"""Forward pass returning probabilities and last hidden state.
Returns
-------
dict
Keys: ``logits`` (probabilities), ``last_hidden_state`` (pooled), optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self._tokens_to_one_hot(inputs["input_ids"]) # [B,4,L]
x = torch.relu(self.conv1(x))
for layer in self.dilated_convs:
residual = x
x = torch.relu(layer(x))
x = x + residual
x = self.conv_block_4(x)
seq_out = x.transpose(1, 2)
pooled = self.global_avg_pool(x).squeeze(-1)
pooled = self.dropout(pooled)
logits = self.sigmoid(self.classifier(pooled))
return {"logits": logits, "last_hidden_state": x, "labels": labels}
def _build_config_dict(self):
return {
"model_type": self.config.model_type,
"architectures": self.config.architectures,
"num_labels": self.config.num_labels,
"label2id": self.config.label2id,
"id2label": self.config.id2label,
"pad_token_id": self.config.pad_token_id,
"hidden_size": self.config.hidden_size,
"embed_dim": getattr(self.config, "embed_dim", None),
"n_filters": getattr(self.config, "n_filters", None),
"n_dilated_layers": getattr(self.config, "n_dilated_layers", None),
"conv1_kernel_size": getattr(self.config, "conv1_kernel_size", None),
"dil_kernel_size": getattr(self.config, "dil_kernel_size", None),
"dropout": getattr(self.config, "dropout", None),
"vocab_size": self.embedding.num_embeddings,
"model_cls": self.__class__.__name__,
"library_name": "OMNIGENBENCH",
}
[docs]
def save_pretrained(self, save_directory: str, overwrite: bool = True):
os.makedirs(save_directory, exist_ok=True)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(self._build_config_dict(), f, ensure_ascii=False, indent=2)
if hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(save_directory)
else:
with open(os.path.join(save_directory, "tokenizer.bin"), "wb") as f:
torch.save(self.tokenizer, f)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
metadata = getattr(self, "metadata", {})
metadata["model_cls"] = self.__class__.__name__
metadata["library_name"] = metadata.get("library_name", "OMNIGENBENCH")
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
[docs]
@classmethod
def from_pretrained(
cls, save_directory: str, tokenizer=None, map_location=None, **kwargs
):
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg_dict = json.load(f)
if tokenizer is None:
if os.path.exists(
os.path.join(save_directory, "tokenizer_config.json")
) or os.path.exists(os.path.join(save_directory, "vocab.json")):
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(save_directory)
except Exception:
tokenizer = None
if tokenizer is None and os.path.exists(
os.path.join(save_directory, "tokenizer.bin")
):
with open(os.path.join(save_directory, "tokenizer.bin"), "rb") as f:
tokenizer = torch.load(f, map_location=map_location)
if tokenizer is None:
raise ValueError("Tokenizer could not be loaded; please provide one.")
model = cls(
tokenizer,
n_filters=cfg_dict.get("n_filters", 64),
n_dilated_layers=cfg_dict.get("n_dilated_layers", 9),
conv1_kernel_size=cfg_dict.get("conv1_kernel_size", 25),
dil_kernel_size=cfg_dict.get("dil_kernel_size", 3),
dropout=cfg_dict.get("dropout", 0.1),
num_labels=cfg_dict["num_labels"],
label2id=cfg_dict.get("label2id"),
**kwargs,
)
state_dict = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state_dict, strict=False)
meta_path = os.path.join(save_directory, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf8") as f:
model.metadata = json.load(f)
return model
[docs]
def forward(self, **inputs):
"""Forward pass returning probabilities and last hidden state.
Returns
-------
dict
Keys: ``logits`` (probabilities), ``last_hidden_state`` (pooled), optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self._tokens_to_one_hot(inputs["input_ids"]) # [B,4,L]
x = torch.relu(self.conv1(x))
for layer in self.dilated_convs:
residual = x
x = torch.relu(layer(x))
x = x + residual
x = self.conv_block_4(x)
seq_out = x.transpose(1, 2)
pooled = self.global_avg_pool(x).squeeze(-1)
pooled = self.dropout(pooled)
logits = self.sigmoid(self.classifier(pooled))
return {"logits": logits, "last_hidden_state": x, "labels": labels}
[docs]
def predict(self, sequence_or_inputs, **kwargs):
"""Return probabilities for each label."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
return {
"predictions": out["logits"],
"logits": out["logits"],
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def inference(self, sequence_or_inputs, threshold: float = 0.5, **kwargs):
"""Return thresholded predictions along with confidence scores."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
logits = out["logits"]
preds = (logits >= threshold).to(torch.int)
if not isinstance(sequence_or_inputs, list):
return {
"predictions": preds[0].cpu(),
"logits": logits[0].cpu(),
"confidence": torch.max(logits[0]).cpu(),
"last_hidden_state": out["last_hidden_state"][0].cpu(),
}
return {
"predictions": preds.cpu(),
"logits": logits.cpu(),
"confidence": torch.max(logits, dim=-1)[0].cpu(),
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def loss_function(self, logits, labels):
return self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
[docs]
class OmniDeepSTARRBaseline(OmniModel):
"""DeepSTARR-like CNN with global pooling and MLP head adapted for tokenizer inputs."""
def __init__(self, tokenizer, *args, **kwargs):
num_labels = kwargs.pop("num_labels")
dropout_prob = kwargs.pop("dropout_prob", 0.4)
num_filters1 = kwargs.pop("num_filters1", 256)
kernel_size1 = kwargs.pop("kernel_size1", 7)
num_filters2 = kwargs.pop("num_filters2", 60)
kernel_size2 = kwargs.pop("kernel_size2", 3)
num_filters3 = kwargs.pop("num_filters3", 60)
kernel_size3 = kwargs.pop("kernel_size3", 5)
num_filters4 = kwargs.pop("num_filters4", 120)
kernel_size4 = kwargs.pop("kernel_size4", 3)
dense_neurons1 = kwargs.pop("dense_neurons1", 256)
dense_neurons2 = kwargs.pop("dense_neurons2", 256)
class Cfg: ...
cfg = Cfg()
cfg.hidden_size = num_filters4
cfg.num_labels = num_labels
cfg.label2id = {str(i): i for i in range(num_labels)}
cfg.id2label = {i: str(i) for i in range(num_labels)}
cfg.name_or_path = "DeepSTARRBaseline"
cfg.model_type = "deepstarr"
cfg.architectures = ["DeepSTARRBaseline"]
cfg.pad_token_id = getattr(tokenizer, "pad_token_id", -100)
class _Stub(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.register_buffer("_dev_tracker", torch.empty(0))
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
super().__init__(_Stub(cfg), tokenizer, num_labels=num_labels, *args, **kwargs)
# Build token-id -> one-hot(A,C,G,T) projection
vocab_size = getattr(self.tokenizer, "vocab_size", None) or len(
self.tokenizer.get_vocab()
)
weight = torch.zeros(vocab_size, 4)
for i, toks in enumerate([("A", "a"), ("C", "c"), ("G", "g"), ("T", "t")]):
for tok in toks:
try:
tid = self.tokenizer.convert_tokens_to_ids(tok)
if (
tid is not None
and isinstance(tid, int)
and 0 <= tid < weight.size(0)
):
weight[tid, i] = 1.0
except Exception:
pass
self.register_buffer("_one_hot_weight", weight, persistent=False)
def block(in_ch, out_ch, k):
return nn.Sequential(
nn.Conv1d(in_ch, out_ch, kernel_size=k, padding=k // 2),
nn.BatchNorm1d(out_ch),
nn.ReLU(),
nn.MaxPool1d(2),
)
self.conv = nn.Sequential(
block(4, num_filters1, kernel_size1),
block(num_filters1, num_filters2, kernel_size2),
block(num_filters2, num_filters3, kernel_size3),
block(num_filters3, num_filters4, kernel_size4),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.mlp = nn.Sequential(
nn.Linear(num_filters4, dense_neurons1),
nn.BatchNorm1d(dense_neurons1),
nn.ReLU(),
nn.Dropout(dropout_prob),
nn.Linear(dense_neurons1, dense_neurons2),
nn.BatchNorm1d(dense_neurons2),
nn.ReLU(),
nn.Dropout(dropout_prob),
)
self.classifier = nn.Linear(dense_neurons2, num_labels)
self.sigmoid = nn.Sigmoid()
self.loss_fn = nn.BCELoss()
def _tokens_to_one_hot(self, input_ids: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(input_ids, self._one_hot_weight).permute(
0, 2, 1
)
def forward(self, **inputs):
"""Forward pass returning probabilities and last hidden state.
Returns
-------
dict
Keys: ``logits`` (probabilities), ``last_hidden_state`` (pooled), optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self._tokens_to_one_hot(inputs["input_ids"]) # [B,4,L]
x = self.conv(x)
x = self.global_avg_pool(x).squeeze(-1)
feats = self.mlp(x)
logits = self.sigmoid(self.classifier(feats))
return {"logits": logits, "last_hidden_state": feats, "labels": labels}
def _build_config_dict(self):
return {
"model_type": self.config.model_type,
"architectures": self.config.architectures,
"num_labels": self.config.num_labels,
"label2id": self.config.label2id,
"id2label": self.config.id2label,
"pad_token_id": self.config.pad_token_id,
"hidden_size": self.config.hidden_size,
"embed_dim": getattr(self.config, "embed_dim", None),
"n_filters": getattr(self.config, "n_filters", None),
"n_dilated_layers": getattr(self.config, "n_dilated_layers", None),
"conv1_kernel_size": getattr(self.config, "conv1_kernel_size", None),
"dil_kernel_size": getattr(self.config, "dil_kernel_size", None),
"dropout": getattr(self.config, "dropout", None),
"vocab_size": self.embedding.num_embeddings,
"model_cls": self.__class__.__name__,
"library_name": "OMNIGENBENCH",
}
[docs]
def save_pretrained(self, save_directory: str, overwrite: bool = True):
os.makedirs(save_directory, exist_ok=True)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(self._build_config_dict(), f, ensure_ascii=False, indent=2)
if hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(save_directory)
else:
with open(os.path.join(save_directory, "tokenizer.bin"), "wb") as f:
torch.save(self.tokenizer, f)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
metadata = getattr(self, "metadata", {})
metadata["model_cls"] = self.__class__.__name__
metadata["library_name"] = metadata.get("library_name", "OMNIGENBENCH")
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
[docs]
@classmethod
def from_pretrained(
cls, save_directory: str, tokenizer=None, map_location=None, **kwargs
):
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg_dict = json.load(f)
if tokenizer is None:
if os.path.exists(
os.path.join(save_directory, "tokenizer_config.json")
) or os.path.exists(os.path.join(save_directory, "vocab.json")):
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(save_directory)
except Exception:
tokenizer = None
if tokenizer is None and os.path.exists(
os.path.join(save_directory, "tokenizer.bin")
):
with open(os.path.join(save_directory, "tokenizer.bin"), "rb") as f:
tokenizer = torch.load(f, map_location=map_location)
if tokenizer is None:
raise ValueError("Tokenizer could not be loaded; please provide one.")
model = cls(
tokenizer,
n_filters=cfg_dict.get("n_filters", 64),
n_dilated_layers=cfg_dict.get("n_dilated_layers", 9),
conv1_kernel_size=cfg_dict.get("conv1_kernel_size", 25),
dil_kernel_size=cfg_dict.get("dil_kernel_size", 3),
dropout=cfg_dict.get("dropout", 0.1),
num_labels=cfg_dict["num_labels"],
label2id=cfg_dict.get("label2id"),
**kwargs,
)
state_dict = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state_dict, strict=False)
meta_path = os.path.join(save_directory, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf8") as f:
model.metadata = json.load(f)
return model
[docs]
def forward(self, **inputs):
"""Forward pass returning probabilities and last hidden state.
Returns
-------
dict
Keys: ``logits`` (probabilities), ``last_hidden_state`` (pooled), optional ``labels``.
"""
labels = inputs.pop("labels", None)
x = self._tokens_to_one_hot(inputs["input_ids"]) # [B,4,L]
x = torch.relu(self.conv1(x))
for layer in self.dilated_convs:
residual = x
x = torch.relu(layer(x))
x = x + residual
x = self.conv_block_4(x)
seq_out = x.transpose(1, 2)
pooled = self.global_avg_pool(x).squeeze(-1)
pooled = self.dropout(pooled)
logits = self.sigmoid(self.classifier(pooled))
return {"logits": logits, "last_hidden_state": x, "labels": labels}
[docs]
def predict(self, sequence_or_inputs, **kwargs):
"""Return probabilities for each label."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
return {
"predictions": out["logits"],
"logits": out["logits"],
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def inference(self, sequence_or_inputs, threshold: float = 0.5, **kwargs):
"""Return thresholded predictions along with confidence scores."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
logits = out["logits"]
preds = (logits >= threshold).to(torch.int)
if not isinstance(sequence_or_inputs, list):
return {
"predictions": preds[0].cpu(),
"logits": logits[0].cpu(),
"confidence": torch.max(logits[0]).cpu(),
"last_hidden_state": out["last_hidden_state"][0].cpu(),
}
return {
"predictions": preds.cpu(),
"logits": logits.cpu(),
"confidence": torch.max(logits, dim=-1)[0].cpu(),
"last_hidden_state": out["last_hidden_state"],
}
[docs]
def loss_function(self, logits, labels):
return self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
# ---------------- Generic Multi-Task Framework -----------------
@dataclass
class BaselineConfig:
"""Configuration for baseline backbones and heads.
Attributes
----------
backbone_type : {"cnn", "rnn", "bpnet", "deepstarr", "basenji"}
Backbone to instantiate.
task_name : {"multilabel_classification", "classification", "regression", "token_classification", "token_regression", "multilabel_token_classification"}
Head task type.
vocab_size : int
Vocabulary size for embedding layers.
hidden_size : int
Feature size exposed by the backbone for the head.
num_labels : int
Number of outputs for the head.
loss_type : str
Placeholder for future customization; currently auto-selected per task.
label2id, id2label : dict
Label mapping dictionaries.
pad_token_id : int
Padding token id used by embeddings and loss functions.
embed_dim, num_filters, kernel_sizes, dropout, hidden_dim, num_layers, bidirectional :
Hyperparameters for CNN and RNN backbones.
n_filters, n_dilated_layers, conv1_kernel_size, dil_kernel_size :
Hyperparameters for BPNet-style backbone.
"""
backbone_type: str = "cnn"
task_name: str = "multilabel_classification"
vocab_size: int = 0
hidden_size: int = 128
num_labels: int = 2
loss_type: str = "auto"
label2id: dict = field(default_factory=dict)
id2label: dict = field(default_factory=dict)
pad_token_id: int = -100
embed_dim: int = 128
num_filters: int = 128
kernel_sizes: tuple = (3, 5, 7)
dropout: float = 0.1
hidden_dim: int = 256
num_layers: int = 1
bidirectional: bool = True
n_filters: int = 64
n_dilated_layers: int = 9
conv1_kernel_size: int = 25
dil_kernel_size: int = 3
class BackboneBase(nn.Module):
"""Abstract base for simple backbones returning both token-level and pooled states."""
def forward(self, input_ids, attention_mask=None):
"""Compute forward features.
Parameters
----------
input_ids : torch.LongTensor
Shape ``(batch_size, seq_len)``.
attention_mask : torch.LongTensor, optional
Shape ``(batch_size, seq_len)``; 1 valid, 0 padding.
Returns
-------
dict
``{"sequence_output": (B, L, H), "hidden_state": (B, H)}``.
"""
raise NotImplementedError
class CNNBackbone(BackboneBase):
"""Embedding + multi-kernel 1D-CNN with masked global max-pooling."""
def __init__(self, cfg: BaselineConfig):
"""Build the CNN backbone.
Parameters
----------
cfg : BaselineConfig
Configuration with embedding and convolution hyperparameters.
"""
super().__init__()
self.embedding = nn.Embedding(
cfg.vocab_size, cfg.embed_dim, padding_idx=cfg.pad_token_id
)
self.convs = nn.ModuleList(
[
nn.Sequential(
nn.Conv1d(cfg.embed_dim, cfg.num_filters, k, padding=k // 2),
nn.ReLU(),
)
for k in cfg.kernel_sizes
]
)
self.dropout = nn.Dropout(cfg.dropout)
self.pool = _MaskedGlobalMaxPool1d()
self.out_dim = cfg.num_filters * len(cfg.kernel_sizes)
def forward(self, input_ids, attention_mask=None):
"""Return token-level and pooled features.
Returns
-------
dict
``sequence_output``: ``(B, L, H)``, ``hidden_state``: ``(B, H)``.
"""
x = self.embedding(input_ids)
x = self.dropout(x)
feats = [m(x.transpose(1, 2)) for m in self.convs]
feats = torch.cat(feats, dim=1).transpose(1, 2)
pooled = self.pool(feats, attention_mask)
return {"sequence_output": feats, "hidden_state": pooled}
class RNNBackbone(BackboneBase):
"""Embedding + LSTM (optionally bidirectional) with masked mean pooling."""
def __init__(self, cfg: BaselineConfig):
"""Build the RNN backbone with LSTM."""
super().__init__()
self.embedding = nn.Embedding(
cfg.vocab_size, cfg.embed_dim, padding_idx=cfg.pad_token_id
)
self.lstm = nn.LSTM(
cfg.embed_dim,
cfg.hidden_dim,
num_layers=cfg.num_layers,
batch_first=True,
dropout=0.0 if cfg.num_layers == 1 else cfg.dropout,
bidirectional=cfg.bidirectional,
)
self.dropout = nn.Dropout(cfg.dropout)
self.out_dim = cfg.hidden_dim * (2 if cfg.bidirectional else 1)
def forward(self, input_ids, attention_mask=None):
"""Return token-level and pooled features using mean-pooling over mask."""
x = self.embedding(input_ids)
x = self.dropout(x)
seq_out, _ = self.lstm(x)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).to(seq_out.dtype)
pooled = (seq_out * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-6)
else:
pooled = seq_out.mean(dim=1)
pooled = self.dropout(pooled)
return {"sequence_output": seq_out, "hidden_state": pooled}
class BPNetBackbone(BackboneBase):
"""Dilated-convolution backbone with residual connections (BPNet-style)."""
def __init__(self, cfg: BaselineConfig):
"""Initialize convolutional stack and one-hot projection buffer."""
super().__init__()
self.register_buffer(
"_one_hot_weight", torch.zeros(cfg.vocab_size, 4), persistent=False
)
self.conv1 = nn.Conv1d(4, cfg.n_filters, cfg.conv1_kernel_size, padding="same")
self.dilated_convs = nn.ModuleList(
[
nn.Conv1d(
cfg.n_filters,
cfg.n_filters,
cfg.dil_kernel_size,
padding="same",
dilation=2**i,
)
for i in range(cfg.n_dilated_layers)
]
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.out_dim = cfg.n_filters
self.dropout = nn.Dropout(cfg.dropout)
def set_one_hot(self, weight):
"""Set the token-id-to-one-hot projection matrix.
Parameters
----------
weight : torch.Tensor
Shape ``(vocab_size, 4)`` mapping token ids to A,C,G,T one-hots.
"""
if weight.shape == self._one_hot_weight.shape:
self._one_hot_weight = weight
def _tokens_to_one_hot(self, ids):
"""Project token ids to nucleotide one-hot channels.
Parameters
----------
ids : torch.LongTensor
Shape ``(batch_size, seq_len)``.
Returns
-------
torch.Tensor
One-hot tensor with shape ``(batch_size, 4, seq_len)``.
"""
return torch.nn.functional.embedding(ids, self._one_hot_weight).permute(0, 2, 1)
def forward(self, input_ids, attention_mask=None):
"""Return sequence and pooled outputs from dilated conv stack."""
x = self._tokens_to_one_hot(input_ids)
x = torch.relu(self.conv1(x))
for layer in self.dilated_convs:
residual = x
x = torch.relu(layer(x))
x = x + residual
seq_out = x.transpose(1, 2)
pooled = self.global_avg_pool(x).squeeze(-1)
pooled = self.dropout(pooled)
return {"sequence_output": seq_out, "hidden_state": pooled}
class DeepSTARRBackbone(BackboneBase):
"""DeepSTARR-style convolutional backbone with global average pooling.
Converts token ids to A/C/G/T one-hots, applies stacked conv+BN+ReLU+pool
blocks, then projects pooled features to a configurable hidden size.
"""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.register_buffer(
"_one_hot_weight", torch.zeros(cfg.vocab_size, 4), persistent=False
)
# Defaults aligned with DeepSTARR baseline
num_filters1, k1 = 256, 7
num_filters2, k2 = 60, 3
num_filters3, k3 = 60, 5
num_filters4, k4 = 120, 3
def block(in_ch, out_ch, k):
return nn.Sequential(
nn.Conv1d(in_ch, out_ch, kernel_size=k, padding=k // 2),
nn.BatchNorm1d(out_ch),
nn.ReLU(),
nn.MaxPool1d(2),
)
self.conv = nn.Sequential(
block(4, num_filters1, k1),
block(num_filters1, num_filters2, k2),
block(num_filters2, num_filters3, k3),
block(num_filters3, num_filters4, k4),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.dropout = nn.Dropout(cfg.dropout)
# Project to backbone hidden size for heads
self.proj = nn.Linear(num_filters4, cfg.hidden_size)
self.out_dim = cfg.hidden_size
def set_one_hot(self, weight: torch.Tensor):
if weight.shape == self._one_hot_weight.shape:
self._one_hot_weight = weight
def _tokens_to_one_hot(self, ids: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(ids, self._one_hot_weight).permute(0, 2, 1)
def forward(self, input_ids, attention_mask=None):
x = self._tokens_to_one_hot(input_ids)
x = self.conv(x)
seq_out = x.transpose(1, 2)
pooled = self.global_avg_pool(x).squeeze(-1)
pooled = self.dropout(pooled)
pooled = self.proj(pooled)
return {"sequence_output": seq_out, "hidden_state": pooled}
class BasenjiBackbone(BackboneBase):
"""Basenji-like convolutional backbone with dilations and global pooling."""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.register_buffer(
"_one_hot_weight", torch.zeros(cfg.vocab_size, 4), persistent=False
)
self.act = nn.GELU()
# Defaults aligned with Basenji baseline
conv1kc, conv1ks, pool1ks = 64, 15, 8
conv2kc, conv2ks, pool2ks = 64, 5, 4
conv3kc, conv3ks, pool3ks = round(64 * 1.125), 5, 4
convdc = 6
self.conv_block_1 = nn.Sequential(
self.act,
nn.Conv1d(
4, conv1kc, kernel_size=conv1ks, padding=conv1ks // 2, bias=False
),
nn.BatchNorm1d(conv1kc, momentum=0.9, affine=True),
nn.MaxPool1d(kernel_size=pool1ks, ceil_mode=True),
nn.Dropout(p=0.2),
)
self.conv_block_2 = nn.Sequential(
self.act,
nn.Conv1d(
conv1kc, conv2kc, kernel_size=conv2ks, padding=conv2ks // 2, bias=False
),
nn.BatchNorm1d(conv2kc, momentum=0.9, affine=True),
nn.MaxPool1d(kernel_size=pool2ks, ceil_mode=True),
nn.Dropout(p=0.2),
)
self.conv_block_3 = nn.Sequential(
self.act,
nn.Conv1d(
conv2kc, conv3kc, kernel_size=conv3ks, padding=conv3ks // 2, bias=False
),
nn.BatchNorm1d(conv3kc, momentum=0.9, affine=True),
nn.MaxPool1d(kernel_size=pool3ks, ceil_mode=True),
nn.Dropout(p=0.2),
)
self.dilations = nn.ModuleList(
[
nn.Sequential(
self.act,
nn.Conv1d(
conv3kc,
32,
kernel_size=3,
padding=2**i,
dilation=2**i,
bias=False,
),
nn.BatchNorm1d(32, momentum=0.9, affine=True),
self.act,
nn.Conv1d(32, 72, kernel_size=1, padding=0, bias=False),
nn.BatchNorm1d(72, momentum=0.9, affine=True),
nn.Dropout(p=0.25),
)
for i in range(convdc)
]
)
self.conv_block_4 = nn.Sequential(
self.act,
nn.Conv1d(72, 64, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm1d(64, momentum=0.9, affine=True),
nn.Dropout(p=0.1),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.dropout = nn.Dropout(cfg.dropout)
self.proj = nn.Linear(64, cfg.hidden_size)
self.out_dim = cfg.hidden_size
def set_one_hot(self, weight: torch.Tensor):
if weight.shape == self._one_hot_weight.shape:
self._one_hot_weight = weight
def _tokens_to_one_hot(self, ids: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(ids, self._one_hot_weight).permute(0, 2, 1)
def forward(self, input_ids, attention_mask=None):
x = self._tokens_to_one_hot(input_ids)
x = self.conv_block_1(x)
x = self.conv_block_2(x)
x = self.conv_block_3(x)
for layer in self.dilations:
x = layer(x)
x = self.conv_block_4(x)
seq_out = x.transpose(1, 2)
pooled = self.global_avg_pool(x).squeeze(-1)
pooled = self.dropout(pooled)
pooled = self.proj(pooled)
return {"sequence_output": seq_out, "hidden_state": pooled}
class DeepSEABackbone(BackboneBase):
"""DeepSEA-style 1D CNN backbone adapted for tokenizer inputs.
Architecture (adapted):
- Token ids -> A/C/G/T one-hot (B,4,L)
- Conv1(4->320, k=8) + ReLU + MaxPool(4)
- Conv2(320->480, k=8) + ReLU + MaxPool(4)
- Conv3(480->960, k=8) + ReLU
- GlobalAvgPool over length -> Dropout -> Linear proj to cfg.hidden_size
Note: We use global average pooling and a projection layer to support variable-length inputs and align with framework heads.
"""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.register_buffer(
"_one_hot_weight", torch.zeros(cfg.vocab_size, 4), persistent=False
)
self.conv_block_1 = nn.Sequential(
nn.Conv1d(4, 320, kernel_size=8, padding=8 // 2),
nn.ReLU(),
nn.MaxPool1d(kernel_size=4, stride=4),
)
self.conv_block_2 = nn.Sequential(
nn.Conv1d(320, 480, kernel_size=8, padding=8 // 2),
nn.ReLU(),
nn.MaxPool1d(kernel_size=4, stride=4),
)
self.conv_block_3 = nn.Sequential(
nn.Conv1d(480, 960, kernel_size=8, padding=8 // 2),
nn.ReLU(),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.dropout = nn.Dropout(cfg.dropout)
self.proj = nn.Linear(960, cfg.hidden_size)
self.out_dim = cfg.hidden_size
def set_one_hot(self, weight: torch.Tensor):
if (
isinstance(weight, torch.Tensor)
and weight.shape == self._one_hot_weight.shape
):
self._one_hot_weight = weight
def _tokens_to_one_hot(self, ids: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(ids, self._one_hot_weight).permute(0, 2, 1)
def forward(self, input_ids, attention_mask=None):
x = self._tokens_to_one_hot(input_ids)
x = self.conv_block_1(x)
x = self.conv_block_2(x)
x = self.conv_block_3(x)
seq_out = x.transpose(1, 2)
pooled = self.global_avg_pool(x).squeeze(-1)
pooled = self.dropout(pooled)
pooled = self.proj(pooled)
return {"sequence_output": seq_out, "hidden_state": pooled}
BACKBONE_REGISTRY = {
"cnn": CNNBackbone,
"rnn": RNNBackbone,
"bpnet": BPNetBackbone,
"deepstarr": DeepSTARRBackbone,
"basenji": BasenjiBackbone,
"deepsea": DeepSEABackbone,
}
class HeadBase(nn.Module):
"""Abstract prediction head consuming backbone features."""
def forward(self, features: dict, labels=None):
"""Compute head logits and optional loss.
Parameters
----------
features : dict
Must contain ``hidden_state`` or ``sequence_output`` depending on head type.
labels : torch.Tensor, optional
Supervision tensor; shape depends on task.
Returns
-------
dict
Keys: ``logits`` and optionally ``loss`` if labels are provided.
"""
raise NotImplementedError
def postprocess(self, logits):
"""Default identity post-processing (override in subclasses)."""
return logits
class SequenceClassificationHead(HeadBase):
"""Single-label classification head over pooled features using CrossEntropy."""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.classifier = nn.Linear(cfg.hidden_size, cfg.num_labels)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, features, labels=None):
"""Compute logits ``(B, num_labels)`` and optional loss from pooled features."""
logits = self.classifier(features["hidden_state"])
loss = self.loss_fn(logits, labels.long()) if labels is not None else None
return {"logits": logits, "loss": loss}
def postprocess(self, logits):
"""Return probabilities via softmax."""
return torch.argmax(torch.softmax(logits, dim=-1), dim=-1)
class MultiLabelClassificationHead(HeadBase):
"""Multi-label head over pooled features using BCEWithLogitsLoss."""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.classifier = nn.Linear(cfg.hidden_size, cfg.num_labels)
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(self, features, labels=None):
"""Compute raw logits and optional BCE-with-logits loss."""
logits = self.classifier(features["hidden_state"])
loss = self.loss_fn(logits, labels.float()) if labels is not None else None
return {"logits": logits, "loss": loss}
def postprocess(self, logits):
"""Return multi-label probabilities via sigmoid."""
return torch.sigmoid(logits)
class RegressionHead(HeadBase):
"""Regression head over pooled features using MSELoss."""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.regressor = nn.Linear(cfg.hidden_size, cfg.num_labels)
self.loss_fn = nn.MSELoss()
def forward(self, features, labels=None):
"""Return continuous outputs and optional MSE loss."""
logits = self.regressor(features["hidden_state"])
loss = (
self.loss_fn(logits.view(-1), labels.view(-1).float())
if labels is not None
else None
)
return {"logits": logits, "loss": loss}
class TokenClassificationHead(HeadBase):
"""Per-token classification head over ``sequence_output`` using CrossEntropy."""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.classifier = nn.Linear(cfg.hidden_size, cfg.num_labels)
self.loss_fn = nn.CrossEntropyLoss(ignore_index=cfg.pad_token_id)
def forward(self, features, labels=None):
"""Return per-token logits ``(B, L, num_labels)`` and optional loss."""
seq_out = features["sequence_output"]
logits = self.classifier(seq_out)
loss = (
self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).long())
if labels is not None
else None
)
return {"logits": logits, "loss": loss}
def postprocess(self, logits):
"""Return probabilities via softmax."""
return torch.argmax(torch.softmax(logits, dim=-1), dim=-1)
class TokenRegressionHead(HeadBase):
"""Per-token regression head over ``sequence_output`` using MSELoss."""
def __init__(self, cfg: BaselineConfig):
super().__init__()
self.regressor = nn.Linear(cfg.hidden_size, cfg.num_labels)
self.loss_fn = nn.MSELoss()
def forward(self, features, labels=None):
"""Return per-token continuous outputs and optional loss."""
seq_out = features["sequence_output"]
logits = self.regressor(seq_out)
loss = (
self.loss_fn(logits.view(-1), labels.view(-1).float())
if labels is not None
else None
)
return {"logits": logits, "loss": loss}
HEAD_REGISTRY = {
"classification": SequenceClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"regression": RegressionHead,
"token_classification": TokenClassificationHead,
"token_regression": TokenRegressionHead,
}
[docs]
class OmniGenericBaseline(OmniModel):
"""Generic baseline model wiring a simple backbone to a selected head.
This class provides a flexible way to create small baselines for different
tasks by choosing among CNN/LSTM/BPNet-style backbones and head types.
Parameters
----------
tokenizer : Any
Tokenizer instance used for vocabulary size and padding index. Must not be None.
backbone_type : str
One of {"cnn", "rnn", "bpnet", "deepstarr", "basenji"}.
task_name : str
One of {"multilabel_classification", "classification", "regression", "token_classification", "token_regression"}.
num_labels : int
Number of outputs for the head.
label2id : dict, optional
Mapping from label string to id; default builds {"0":0, ...}.
Other keyword arguments are forwarded to the chosen backbone configuration
(e.g., ``embed_dim``, ``hidden_dim``, ``n_filters``, etc.).
"""
def __init__(self, tokenizer, *args, **kwargs):
# Enforce required arguments with helpful errors
if tokenizer is None:
raise ValueError(
"tokenizer is required for OmniGenericBaseline (got None). Provide a tokenizer with vocab_size/get_vocab and pad_token_id."
)
backbone_type = kwargs.pop("backbone_type", None)
task_name = kwargs.pop("task_name", None)
if backbone_type is None:
raise ValueError(
f"backbone_type must be provided explicitly. Choices: {sorted(list(BACKBONE_REGISTRY.keys()))}"
)
if task_name is None:
raise ValueError(
f"task_name must be provided explicitly. Choices: {sorted(list(HEAD_REGISTRY.keys()))}"
)
num_labels = kwargs.pop("num_labels")
label2id = kwargs.pop("label2id", {str(i): i for i in range(num_labels)})
# --- robust pad token handling: don't treat 0 as falsy, validate range later
_pad_attr = getattr(tokenizer, "pad_token_id", None)
pad_token_id = int(_pad_attr) if _pad_attr is not None else None
class Cfg: ...
cfg = Cfg()
cfg.hidden_size = kwargs.get("hidden_size", 128)
cfg.num_labels = num_labels
cfg.label2id = label2id
cfg.id2label = {v: k for k, v in label2id.items()}
cfg.name_or_path = f"GenericBaseline-{backbone_type}-{task_name}"
cfg.model_type = "baseline"
cfg.architectures = ["OmniGenericBaseline"]
# cfg.pad_token_id will be set after we compute vocab_size and validate
class _Stub(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.register_buffer("_dev_tracker", torch.empty(0))
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
super().__init__(
_Stub(cfg),
tokenizer,
num_labels=num_labels,
label2id=label2id,
*args,
**kwargs,
)
vocab_size = getattr(self.tokenizer, "vocab_size", None) or len(
self.tokenizer.get_vocab()
)
# Finalize a safe pad_token_id: prefer tokenizer value; otherwise search common pad tokens; else default to 0
if (
pad_token_id is None
or pad_token_id < -vocab_size
or pad_token_id >= vocab_size
):
resolved = None
try:
for tok in ("<pad>", "[PAD]", "PAD", "pad"):
tid = self.tokenizer.convert_tokens_to_ids(tok)
if isinstance(tid, int) and 0 <= tid < vocab_size:
resolved = tid
break
except Exception:
resolved = None
if resolved is None:
resolved = 0 # conservative default
pad_token_id = int(resolved)
# store on config now that it's validated
self.baseline_cfg = BaselineConfig(
backbone_type=backbone_type,
task_name=task_name,
vocab_size=vocab_size,
hidden_size=cfg.hidden_size,
num_labels=num_labels,
pad_token_id=pad_token_id,
embed_dim=kwargs.get("embed_dim", 128),
num_filters=kwargs.get("num_filters", 128),
kernel_sizes=kwargs.get("kernel_sizes", (3, 5, 7)),
dropout=kwargs.get("dropout", 0.1),
hidden_dim=kwargs.get("hidden_dim", 256),
num_layers=kwargs.get("num_layers", 1),
bidirectional=kwargs.get("bidirectional", True),
n_filters=kwargs.get("n_filters", 64),
n_dilated_layers=kwargs.get("n_dilated_layers", 9),
conv1_kernel_size=kwargs.get("conv1_kernel_size", 25),
dil_kernel_size=kwargs.get("dil_kernel_size", 3),
label2id=label2id,
id2label={v: k for k, v in label2id.items()},
)
if backbone_type not in BACKBONE_REGISTRY:
raise ValueError(
f"Unknown backbone_type='{backbone_type}'. Valid options: {sorted(list(BACKBONE_REGISTRY.keys()))}"
)
backbone_cls = BACKBONE_REGISTRY[backbone_type]
self.backbone = backbone_cls(self.baseline_cfg)
# Setup nucleotide one-hot projection for backbones that support it
weight = torch.zeros(vocab_size, 4)
for i, toks in enumerate([("A", "a"), ("C", "c"), ("G", "g"), ("T", "t")]):
for tok in toks:
try:
tid = self.tokenizer.convert_tokens_to_ids(tok)
if tid is not None and tid >= 0 and tid < weight.size(0):
weight[tid, i] = 1.0
except Exception:
pass
if hasattr(self.backbone, "set_one_hot"):
try:
self.backbone.set_one_hot(weight)
except Exception:
pass
self.baseline_cfg.hidden_size = getattr(
self.backbone, "out_dim", self.baseline_cfg.hidden_size
)
self.config.hidden_size = self.baseline_cfg.hidden_size
if task_name not in HEAD_REGISTRY:
raise ValueError(
f"Unknown task_name='{task_name}'. Valid options: {sorted(list(HEAD_REGISTRY.keys()))}"
)
head_cls = HEAD_REGISTRY[task_name]
self.head = head_cls(self.baseline_cfg)
@property
def device(self):
return next(self.parameters()).device
def _build_config_dict(self):
"""Return a serializable dictionary for ``save_pretrained``.
The dictionary includes both backbone hyperparameters and head/task info.
"""
return {
"model_type": self.config.model_type,
"architectures": self.config.architectures,
"num_labels": self.config.num_labels,
"label2id": self.config.label2id,
"id2label": self.config.id2label,
"pad_token_id": self.config.pad_token_id,
"hidden_size": self.config.hidden_size,
"backbone_type": self.baseline_cfg.backbone_type,
"task_name": self.baseline_cfg.task_name,
"vocab_size": self.baseline_cfg.vocab_size,
"embed_dim": self.baseline_cfg.embed_dim,
"num_filters": self.baseline_cfg.num_filters,
"kernel_sizes": list(self.baseline_cfg.kernel_sizes),
"dropout": self.baseline_cfg.dropout,
"hidden_dim": self.baseline_cfg.hidden_dim,
"num_layers": self.baseline_cfg.num_layers,
"bidirectional": self.baseline_cfg.bidirectional,
"n_filters": self.baseline_cfg.n_filters,
"n_dilated_layers": self.baseline_cfg.n_dilated_layers,
"conv1_kernel_size": self.baseline_cfg.conv1_kernel_size,
"dil_kernel_size": self.baseline_cfg.dil_kernel_size,
"model_cls": self.__class__.__name__,
"library_name": "OMNIGENBENCH",
}
[docs]
def save_pretrained(self, save_directory: str, overwrite: bool = True):
"""Save model weights, config, tokenizer and metadata to a directory.
Files
-----
- ``config.json``: Backbone/head configuration.
- ``pytorch_model.bin``: Model weights.
- ``tokenizer``: Saved via tokenizer's ``save_pretrained`` when available.
- ``metadata.json``: Lightweight metadata (class and library).
"""
os.makedirs(save_directory, exist_ok=True)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(self._build_config_dict(), f, ensure_ascii=False, indent=2)
if hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(save_directory)
else:
with open(os.path.join(save_directory, "tokenizer.bin"), "wb") as f:
torch.save(self.tokenizer, f)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
metadata = getattr(self, "metadata", {})
metadata["model_cls"] = self.__class__.__name__
metadata["library_name"] = metadata.get("library_name", "OMNIGENBENCH")
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
[docs]
@classmethod
def from_pretrained(
cls, save_directory: str, tokenizer=None, map_location=None, **kwargs
):
"""Load model, config, and tokenizer from a directory."""
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg_dict = json.load(f)
if tokenizer is None:
if os.path.exists(
os.path.join(save_directory, "tokenizer_config.json")
) or os.path.exists(os.path.join(save_directory, "vocab.json")):
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(save_directory)
except Exception:
tokenizer = None
if tokenizer is None and os.path.exists(
os.path.join(save_directory, "tokenizer.bin")
):
with open(os.path.join(save_directory, "tokenizer.bin"), "rb") as f:
tokenizer = torch.load(f, map_location=map_location)
if tokenizer is None:
raise ValueError("Tokenizer could not be loaded; please provide one.")
# Backward-compat: warn if backbone/task missing; fall back to defaults
if "backbone_type" not in cfg_dict:
warnings.warn(
"Missing 'backbone_type' in saved config; defaulting to 'cnn'. Please resave the model to include this field.",
RuntimeWarning,
)
if "task_name" not in cfg_dict:
warnings.warn(
"Missing 'task_name' in saved config; defaulting to 'multilabel_classification'. Please resave the model to include this field.",
RuntimeWarning,
)
model = cls(
tokenizer,
backbone_type=cfg_dict.get("backbone_type", "cnn"),
task_name=cfg_dict.get("task_name", "multilabel_classification"),
num_labels=cfg_dict["num_labels"],
embed_dim=cfg_dict.get("embed_dim", 128),
num_filters=cfg_dict.get("num_filters", 128),
kernel_sizes=tuple(cfg_dict.get("kernel_sizes", (3, 5, 7))),
dropout=cfg_dict.get("dropout", 0.1),
hidden_dim=cfg_dict.get("hidden_dim", 256),
num_layers=cfg_dict.get("num_layers", 1),
bidirectional=cfg_dict.get("bidirectional", True),
n_filters=cfg_dict.get("n_filters", 64),
n_dilated_layers=cfg_dict.get("n_dilated_layers", 9),
conv1_kernel_size=cfg_dict.get("conv1_kernel_size", 25),
dil_kernel_size=cfg_dict.get("dil_kernel_size", 3),
label2id=cfg_dict.get("label2id"),
)
state_dict = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state_dict, strict=False)
meta_path = os.path.join(save_directory, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf8") as f:
model.metadata = json.load(f)
return model
[docs]
def forward(self, **inputs):
"""Forward pass through backbone and task head.
Returns
-------
dict
Always includes ``logits`` and passthrough ``labels`` if present.
Includes ``last_hidden_state`` and optionally ``sequence_output`` when
the backbone provides it (e.g., for token-level heads).
"""
labels = inputs.get("labels")
feats = self.backbone(inputs["input_ids"], inputs.get("attention_mask"))
head_out = self.head(feats, labels=labels)
out = {
"logits": head_out["logits"],
"labels": labels,
"last_hidden_state": feats.get("hidden_state"),
}
if "sequence_output" in feats:
out["sequence_output"] = feats["sequence_output"]
if head_out.get("loss") is not None:
out["loss"] = head_out["loss"]
return out
[docs]
def predict(self, sequence_or_inputs, **kwargs):
"""Return task-appropriate probabilities from head postprocess."""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
logits = out["logits"]
probs = self.head.postprocess(logits)
return {
"predictions": probs,
"logits": logits,
"last_hidden_state": out.get("last_hidden_state"),
}
[docs]
def inference(self, sequence_or_inputs, threshold: float = 0.5, **kwargs):
"""Convenience inference wrapper producing final predictions.
Behavior depends on ``task_name``:
- ``multilabel_classification``: thresholded sigmoid probabilities.
- ``classification``: argmax over softmax probabilities.
- others: returns post-processed outputs.
"""
out = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
logits = out["logits"]
probs = self.head.postprocess(logits)
if self.baseline_cfg.task_name == "multilabel_classification":
preds = (probs >= threshold).to(torch.int)
elif self.baseline_cfg.task_name == "classification":
preds = probs.argmax(dim=-1)
else:
preds = probs
if not isinstance(sequence_or_inputs, list):
return {
"predictions": preds[0].cpu(),
"logits": logits[0].cpu(),
"probabilities": probs[0].cpu(),
"last_hidden_state": (
out.get("last_hidden_state")[0].cpu()
if out.get("last_hidden_state") is not None
else None
),
}
return {
"predictions": preds.cpu(),
"logits": logits.cpu(),
"probabilities": probs.cpu(),
"last_hidden_state": out.get("last_hidden_state"),
}
[docs]
def loss_function(self, logits, labels):
"""Compute task-appropriate training loss."""
if self.baseline_cfg.task_name == "multilabel_classification":
return nn.BCEWithLogitsLoss()(logits, labels.float())
if self.baseline_cfg.task_name == "classification":
return nn.CrossEntropyLoss()(logits, labels.long())
if self.baseline_cfg.task_name == "regression":
return nn.MSELoss()(logits, labels.view(-1).float())
# Token-level tasks default
return nn.MSELoss()(logits.view(-1), labels.view(-1).float())
[docs]
def create_baseline(
tokenizer,
*,
backbone_type: str,
task_name: str,
num_labels: int,
label2id: Optional[dict] = None,
**kwargs,
):
"""Factory for building baselines via OmniGenericBaseline.
Example:
model = create_baseline(
tokenizer,
backbone_type="deepstarr",
task_name="multilabel_classification",
num_labels=8,
)
"""
return OmniGenericBaseline(
tokenizer,
backbone_type=backbone_type,
task_name=task_name,
num_labels=num_labels,
label2id=label2id,
**kwargs,
)
# ---------------- Base Backbones for OmniModel Wrappers -----------------
class _BaseConfig:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _build_label_maps(num_labels, label2id=None):
"""Build default label mappings when not provided.
Returns
-------
tuple(dict, dict)
``(label2id, id2label)`` consistent pair.
"""
if label2id is None:
label2id = {str(i): i for i in range(num_labels)}
id2label = {v: k for k, v in label2id.items()}
return label2id, id2label
return label2id, {v: k for k, v in label2id.items()}
class OmniCNNBaseModel(nn.Module):
"""Lightweight CNN backbone returning per-token hidden states, to be wrapped by OmniModelFor* wrappers.
Forward returns a dictionary with ``last_hidden_state`` of shape ``(B, L, H)``.
"""
def __init__(
self,
vocab_size,
embed_dim=128,
num_filters=128,
kernel_sizes=(3, 5, 7),
dropout=0.1,
pad_token_id=0,
num_labels=2,
label2id=None,
name_or_path="cnn-backbone",
):
super().__init__()
label2id, id2label = _build_label_maps(num_labels, label2id)
# device tracker buffer for OmniModel compatibility
self.register_buffer("_dev_tracker", torch.empty(0), persistent=False)
self.config = _BaseConfig(
hidden_size=num_filters * len(kernel_sizes),
embed_dim=embed_dim,
num_filters=num_filters,
kernel_sizes=list(kernel_sizes),
dropout=dropout,
pad_token_id=pad_token_id,
vocab_size=vocab_size,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
name_or_path=name_or_path,
model_type="cnn",
architectures=["OmniCNNBaseModel"],
)
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
self.convs = nn.ModuleList(
[
nn.Sequential(
nn.Conv1d(embed_dim, num_filters, k, padding=k // 2), nn.ReLU()
)
for k in kernel_sizes
]
)
self.dropout = nn.Dropout(dropout)
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
def forward(self, input_ids, attention_mask=None, **kwargs):
"""Return token-level features only.
Parameters
----------
input_ids : torch.LongTensor
Shape ``(B, L)``.
attention_mask : torch.LongTensor, optional
Unused here; kept for compatibility.
"""
x = self.embedding(input_ids)
x = self.dropout(x)
feats = [conv(x.transpose(1, 2)) for conv in self.convs]
feats = torch.cat(feats, dim=1).transpose(1, 2)
return {"last_hidden_state": feats}
def save_pretrained(self, save_directory, **kwargs):
"""Save backbone weights and config to ``save_directory``."""
os.makedirs(save_directory, exist_ok=True)
cfg_dict = {
k: getattr(self.config, k)
for k in self.config.__dict__
if not k.startswith("_")
}
cfg_dict.update(
{"model_cls": self.__class__.__name__, "library_name": "OMNIGENBENCH"}
)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(cfg_dict, f, ensure_ascii=False, indent=2)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(
{"model_cls": self.__class__.__name__, "library_name": "OMNIGENBENCH"},
f,
indent=2,
)
@classmethod
def from_pretrained(cls, save_directory, map_location=None, **kwargs):
"""Load backbone from a directory containing ``config.json`` and weights."""
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg = json.load(f)
model = cls(
vocab_size=cfg["vocab_size"],
embed_dim=cfg.get("embed_dim", 128),
num_filters=cfg.get("num_filters", 128),
kernel_sizes=tuple(cfg.get("kernel_sizes", (3, 5, 7))),
dropout=cfg.get("dropout", 0.1),
pad_token_id=cfg.get("pad_token_id", 0),
num_labels=cfg.get("num_labels", 2),
label2id=cfg.get("label2id"),
name_or_path=cfg.get("name_or_path", "cnn-backbone"),
)
state = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state, strict=False)
return model
class OmniRNNBaseModel(nn.Module):
"""BiLSTM backbone returning per-token hidden states for OmniModel wrappers."""
def __init__(
self,
vocab_size,
embed_dim=128,
hidden_dim=256,
num_layers=1,
bidirectional=True,
dropout=0.1,
pad_token_id=0,
num_labels=2,
label2id=None,
name_or_path="rnn-backbone",
):
super().__init__()
label2id, id2label = _build_label_maps(num_labels, label2id)
self.register_buffer("_dev_tracker", torch.empty(0), persistent=False)
self.config = _BaseConfig(
hidden_size=hidden_dim * (2 if bidirectional else 1),
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
bidirectional=bidirectional,
dropout=dropout,
pad_token_id=pad_token_id,
vocab_size=vocab_size,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
name_or_path=name_or_path,
model_type="rnn",
architectures=["OmniRNNBaseModel"],
)
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
self.lstm = nn.LSTM(
embed_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=0.0 if num_layers == 1 else dropout,
bidirectional=bidirectional,
)
self.dropout = nn.Dropout(dropout)
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
def forward(self, input_ids, attention_mask=None, **kwargs):
"""Return token-level hidden states only: ``{"last_hidden_state": (B, L, H)}``."""
x = self.embedding(input_ids)
x = self.dropout(x)
seq_out, _ = self.lstm(x)
return {"last_hidden_state": seq_out}
def save_pretrained(self, save_directory, **kwargs):
"""Save backbone weights and config to a directory."""
os.makedirs(save_directory, exist_ok=True)
cfg_dict = {
k: getattr(self.config, k)
for k in self.config.__dict__
if not k.startswith("_")
}
cfg_dict.update(
{"model_cls": self.__class__.__name__, "library_name": "OMNIGENBENCH"}
)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(cfg_dict, f, ensure_ascii=False, indent=2)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(
{"model_cls": self.__class__.__name__, "library_name": "OMNIGENBENCH"},
f,
indent=2,
)
@classmethod
def from_pretrained(cls, save_directory, map_location=None, **kwargs):
"""Load backbone from a directory containing ``config.json`` and weights."""
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg = json.load(f)
model = cls(
vocab_size=cfg["vocab_size"],
embed_dim=cfg.get("embed_dim", 128),
hidden_dim=cfg.get("hidden_dim", 256),
num_layers=cfg.get("num_layers", 1),
bidirectional=cfg.get("bidirectional", True),
dropout=cfg.get("dropout", 0.1),
pad_token_id=cfg.get("pad_token_id", 0),
num_labels=cfg.get("num_labels", 2),
label2id=cfg.get("label2id"),
name_or_path=cfg.get("name_or_path", "rnn-backbone"),
)
state = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state, strict=False)
return model
class OmniBPNetBaseModel(nn.Module):
"""BPNet-style dilated CNN backbone that returns per-token hidden states.
This lightweight base model mirrors OmniCNNBaseModel/OmniRNNBaseModel patterns and
is suitable for wrapping by OmniModelFor* tasks that expect
``{"last_hidden_state": (B, L, H)}``.
Parameters
----------
vocab_size : int
Vocabulary size for the token-to-one-hot projection.
n_filters : int, optional
Number of channels in convolution blocks, by default 64.
n_dilated_layers : int, optional
Number of dilated residual layers, by default 9.
conv1_kernel_size : int, optional
Kernel size for the first conv, by default 25.
dil_kernel_size : int, optional
Kernel size for dilated convs, by default 3.
dropout : float, optional
Dropout prob applied to features, by default 0.1.
pad_token_id : int, optional
Padding token id, by default 0.
num_labels : int, optional
Unused here; included for config parity, by default 2.
label2id : dict, optional
Optional label map stored in config.
name_or_path : str, optional
Model name for config metadata, by default "bpnet-backbone".
"""
def __init__(
self,
vocab_size,
n_filters=64,
n_dilated_layers=9,
conv1_kernel_size=25,
dil_kernel_size=3,
dropout=0.1,
pad_token_id=0,
num_labels=2,
label2id=None,
name_or_path="bpnet-backbone",
):
super().__init__()
label2id, id2label = _build_label_maps(num_labels, label2id)
self.register_buffer("_dev_tracker", torch.empty(0), persistent=False)
self.config = _BaseConfig(
hidden_size=n_filters,
n_filters=n_filters,
n_dilated_layers=n_dilated_layers,
conv1_kernel_size=conv1_kernel_size,
dil_kernel_size=dil_kernel_size,
dropout=dropout,
pad_token_id=pad_token_id,
vocab_size=vocab_size,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
name_or_path=name_or_path,
model_type="bpnet",
architectures=["OmniBPNetBaseModel"],
)
self.register_buffer(
"_one_hot_weight", torch.zeros(vocab_size, 4), persistent=False
)
self.conv1 = nn.Conv1d(4, n_filters, conv1_kernel_size, padding="same")
self.dilated_convs = nn.ModuleList(
[
nn.Conv1d(
n_filters,
n_filters,
dil_kernel_size,
padding="same",
dilation=2**i,
)
for i in range(n_dilated_layers)
]
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.out_dim = n_filters
self.dropout = nn.Dropout(dropout)
@property
def device(self):
return self._dev_tracker.device
@property
def dtype(self):
return self._dev_tracker.dtype
def set_one_hot(self, weight: torch.Tensor):
"""Optionally set the id->one-hot projection matrix of shape (vocab_size, 4)."""
if (
isinstance(weight, torch.Tensor)
and weight.shape == self._one_hot_weight.shape
):
self._one_hot_weight = weight
def _tokens_to_one_hot(self, ids: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(ids, self._one_hot_weight).permute(0, 2, 1)
def forward(self, input_ids, attention_mask=None, **inputs):
"""Return per-token hidden states only.
Returns
-------
dict
``{"last_hidden_state": (batch, seq_len, hidden_size)}``.
"""
x = self._tokens_to_one_hot(input_ids) # [B,4,L]
x = torch.relu(self.conv1(x))
for layer in self.dilated_convs:
residual = x
x = torch.relu(layer(x))
x = x + residual
seq_out = x.transpose(1, 2) # (B, L, C)
seq_out = self.dropout(seq_out)
return {"last_hidden_state": seq_out}
def _build_config_dict(self):
return {
"model_type": self.config.model_type,
"architectures": self.config.architectures,
"num_labels": self.config.num_labels,
"label2id": self.config.label2id,
"id2label": self.config.id2label,
"pad_token_id": self.config.pad_token_id,
"hidden_size": self.config.hidden_size,
"embed_dim": getattr(self.config, "embed_dim", None),
"n_filters": getattr(self.config, "n_filters", None),
"n_dilated_layers": getattr(self.config, "n_dilated_layers", None),
"conv1_kernel_size": getattr(self.config, "conv1_kernel_size", None),
"dil_kernel_size": getattr(self.config, "dil_kernel_size", None),
"dropout": getattr(self.config, "dropout", None),
"vocab_size": self.embedding.num_embeddings,
"model_cls": self.__class__.__name__,
"library_name": "OMNIGENBENCH",
}
def save_pretrained(self, save_directory: str, overwrite: bool = True):
os.makedirs(save_directory, exist_ok=True)
with open(
os.path.join(save_directory, "config.json"), "w", encoding="utf8"
) as f:
json.dump(self._build_config_dict(), f, ensure_ascii=False, indent=2)
if hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(save_directory)
else:
with open(os.path.join(save_directory, "tokenizer.bin"), "wb") as f:
torch.save(self.tokenizer, f)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
metadata = getattr(self, "metadata", {})
metadata["model_cls"] = self.__class__.__name__
metadata["library_name"] = metadata.get("library_name", "OMNIGENBENCH")
with open(
os.path.join(save_directory, "metadata.json"), "w", encoding="utf8"
) as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
@classmethod
def from_pretrained(
cls, save_directory: str, tokenizer=None, map_location=None, **kwargs
):
with open(
os.path.join(save_directory, "config.json"), "r", encoding="utf8"
) as f:
cfg_dict = json.load(f)
if tokenizer is None:
if os.path.exists(
os.path.join(save_directory, "tokenizer_config.json")
) or os.path.exists(os.path.join(save_directory, "vocab.json")):
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(save_directory)
except Exception:
tokenizer = None
if tokenizer is None and os.path.exists(
os.path.join(save_directory, "tokenizer.bin")
):
with open(os.path.join(save_directory, "tokenizer.bin"), "rb") as f:
tokenizer = torch.load(f, map_location=map_location)
if tokenizer is None:
raise ValueError("Tokenizer could not be loaded; please provide one.")
model = cls(
tokenizer,
n_filters=cfg_dict.get("n_filters", 64),
n_dilated_layers=cfg_dict.get("n_dilated_layers", 9),
conv1_kernel_size=cfg_dict.get("conv1_kernel_size", 25),
dil_kernel_size=cfg_dict.get("dil_kernel_size", 3),
dropout=cfg_dict.get("dropout", 0.1),
num_labels=cfg_dict["num_labels"],
label2id=cfg_dict.get("label2id"),
**kwargs,
)
state_dict = torch.load(
os.path.join(save_directory, "pytorch_model.bin"),
map_location=map_location or "cpu",
)
model.load_state_dict(state_dict, strict=False)
meta_path = os.path.join(save_directory, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf8") as f:
model.metadata = json.load(f)
return model