Source code for omnigenbench.src.abc.abstract_model

# -*- coding: utf-8 -*-
# file: abstract_model.py
# time: 18:36 06/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import json
import os
import shutil
import time
import warnings
import inspect
from importlib import import_module

import dill
import findfile
import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer, BatchEncoding

from ..misc.utils import fprint, env_meta_info
from .embedding_mixin import EmbeddingMixin

warnings.filterwarnings("once")


[docs] def count_parameters(model): """ This function iterates through all parameters of a PyTorch model and counts only those that require gradients (i.e., trainable parameters). Args: model (torch.nn.Module): A PyTorch model. Returns: int: The total number of trainable parameters. Example: >>> model = OmniModelForSequenceClassification(config, tokenizer) >>> num_params = count_parameters(model) >>> print(f"Model has {num_params} trainable parameters") """ return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs] class OmniModel(EmbeddingMixin, torch.nn.Module): """ Abstract base class providing a unified interface for genomic foundation models in the OmniGenBench framework. This class handles model initialization, forward passes, loss computation, prediction interfaces, and model persistence while maintaining compatibility with HuggingFace's ecosystem. **Architectural Pattern**: This class follows the Template Method pattern, providing common infrastructure while delegating task-specific behavior to subclasses (OmniModelForSequenceClassification, OmniModelForTokenClassification, etc.). **Inherited Capabilities** (via EmbeddingMixin): - **Embedding Generation**: ``batch_encode()``, ``encode()``, ``encode_tokens()`` for extracting fixed-length sequence representations from genomic sequences - **Attention Extraction**: ``extract_attention_scores()``, ``batch_extract_attention_scores()`` for model interpretability and attention weight visualization - **Similarity Computation**: ``compute_similarity()`` for sequence comparison and relatedness analysis - **Visualization Tools**: ``visualize_attention_pattern()`` for generating attention heatmaps and understanding model focus All task-specific OmniModel subclasses automatically inherit these capabilities, enabling representation learning and interpretability without additional implementation. **Design Philosophy**: By inheriting from both EmbeddingMixin and torch.nn.Module, this class seamlessly integrates sequence embedding capabilities with PyTorch's standard training infrastructure, making it compatible with native PyTorch training loops, HuggingFace Trainer, and Accelerate-based distributed training. **Task-Specific Subclasses**: Users should instantiate concrete implementations rather than this abstract class directly: - ``OmniModelForSequenceClassification``: Sequence-level classification tasks (e.g., promoter identification, functional annotation) - ``OmniModelForMultiLabelSequenceClassification``: Multi-label classification (e.g., transcription factor binding site prediction with 919 TFs) - ``OmniModelForTokenClassification``: Per-nucleotide predictions (e.g., splice site detection, secondary structure annotation) - ``OmniModelForSequenceRegression``: Sequence-level continuous predictions (e.g., gene expression levels, binding affinity scores) - ``OmniModelForTokenRegression``: Per-nucleotide continuous predictions (e.g., chromatin accessibility profiles, conservation scores) - ``OmniModelForRNADesign``: Structure-guided RNA sequence generation (genetic algorithm + masked language model) - ``OmniModelForEmbedding``: Representation learning and feature extraction """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initializes the genomic foundation model with flexible input types. This method handles three initialization patterns: 1. **From pre-trained path** (recommended): Loads model from HuggingFace Hub or local directory. The architecture is automatically detected via ``config.json`` using the ``auto_map`` or ``architectures`` fields. 2. **From PyTorch module**: Wraps an existing ``nn.Module`` with OmniModel interface, useful for integrating custom architectures or models loaded via other means. 3. **From configuration**: Initializes a new model from AutoConfig specification, typically used for training models from scratch. The initialization process automatically detects the underlying architecture via HuggingFace's ``config.json`` (using ``auto_map`` or ``architectures`` fields), eliminating manual architecture specification for standard models. Args: config_or_model: One of the following: - **str**: Path or HuggingFace Hub identifier (e.g., "yangheng/OmniGenome-186M"). Can be a local path to a model directory or a Hub model ID. - **torch.nn.Module**: Pre-instantiated PyTorch model to wrap with OmniModel interface, enabling use of custom architectures within the framework. - **AutoConfig**: Configuration object for new model initialization, used when training models from scratch or with custom configurations. tokenizer: Tokenizer instance compatible with the model architecture. Used for sequence preprocessing during inference. Should implement either OmniTokenizer interface or HuggingFace tokenizer protocol. *args: Additional positional arguments passed to torch.nn.Module.__init__ **kwargs: Additional keyword arguments: - **label2id** (dict, optional): Mapping from class labels to integer IDs. Required for classification tasks. Example: {"negative": 0, "positive": 1}. Either this or num_labels must be provided. - **num_labels** (int, optional): Number of output classes. Alternative to label2id for when label names are not available. If both provided, they must be consistent (len(label2id) must equal num_labels). - **trust_remote_code** (bool, optional): Whether to trust remote code when loading from HuggingFace Hub. Defaults to True. Set to False for security-critical environments where only vetted models should be loaded. - **ignore_mismatched_sizes** (bool, optional): Whether to ignore size mismatches when loading pre-trained weights (e.g., different classifier head dimensions). Defaults to False. Set to True when fine-tuning for a different number of labels than the pre-trained model. - **dropout** (float, optional): Dropout probability for regularization in classification/regression heads. Defaults to 0.0. Typical values: 0.1-0.5. - **dataset_class** (type, optional): Dataset class used for preprocessing. Enables models to use the dataset's ``prepare_input`` method during inference, allowing custom field handling beyond basic tokenization. Useful when inference requires the same complex preprocessing as training. - **problem_type** (str, optional): Type of prediction problem. Common values: "single_label_classification", "multi_label_classification", "regression". Affects loss calculation and output interpretation. Raises: ValueError: If neither label2id nor num_labels is provided, or if they are inconsistent (len(label2id) != num_labels). Also raised if config_or_model is an unsupported type (not str, nn.Module, or AutoConfig). RuntimeError: If the hidden size cannot be determined from the config (model must define one of: hidden_size, n_embd, or d_model), or if the model architecture cannot be auto-detected from config.json (missing both architectures and auto_map). FileNotFoundError: If the specified model path does not exist locally and cannot be found on HuggingFace Hub. Check model path/ID spelling and internet connectivity. Example: >>> # Pattern 1: Initialize from pre-trained model (recommended) >>> from omnigenbench import OmniModelForSequenceClassification, OmniTokenizer >>> tokenizer = OmniTokenizer.from_pretrained("yangheng/OmniGenome-186M") >>> model = OmniModelForSequenceClassification( ... "yangheng/OmniGenome-186M", ... tokenizer=tokenizer, ... num_labels=2, ... problem_type="single_label_classification" ... ) >>> print(f"Model has {count_parameters(model):,} trainable parameters") >>> # Pattern 2: Initialize with label2id mapping >>> label2id = {"background": 0, "promoter": 1, "enhancer": 2} >>> model = OmniModelForSequenceClassification( ... "yangheng/OmniGenome-186M", ... tokenizer=tokenizer, ... label2id=label2id # num_labels inferred automatically as 3 ... ) >>> # Pattern 3: Initialize from configuration (for custom models) >>> from transformers import AutoConfig >>> config = AutoConfig.from_pretrained("yangheng/OmniGenome-186M") >>> config.num_labels = 10 >>> model = OmniModelForSequenceClassification(config, tokenizer) >>> # Pattern 4: Wrap existing PyTorch module >>> from transformers import AutoModel >>> base_model = AutoModel.from_pretrained("yangheng/OmniGenome-186M") >>> model = OmniModelForSequenceClassification( ... base_model, tokenizer, num_labels=2 ... ) >>> # Pattern 5: Initialize with dataset class for complex preprocessing >>> from omnigenbench import OmniDatasetForSequenceClassification >>> model = OmniModelForSequenceClassification( ... "yangheng/OmniGenome-186M", ... tokenizer=tokenizer, ... num_labels=2, ... dataset_class=OmniDatasetForSequenceClassification ... ) >>> # Now model.inference() can use dataset's prepare_input method """ self.loss_fn = None label2id = kwargs.pop("label2id", None) trust_remote_code = kwargs.pop("trust_remote_code", True) num_labels = kwargs.pop("num_labels", len(label2id) if label2id else None) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) dataset_class = kwargs.pop("dataset_class", None) if label2id and not num_labels: num_labels = len(label2id) elif num_labels and not label2id: label2id = {str(i): i for i in range(num_labels)} elif not label2id and not num_labels: raise ValueError( "Either label2id or num_labels must be provided to initialize the model." ) else: if len(label2id) != num_labels: raise ValueError( "The length of label2id does not match num_labels. " f"Expected {num_labels}, but got {len(label2id)}." ) # do not change the order of the following lines super().__init__(*args, **kwargs) if isinstance(config_or_model, str): config = AutoConfig.from_pretrained( config_or_model, num_labels=num_labels, label2id=label2id, trust_remote_code=trust_remote_code, ) # Load the model from either `architectures` or `auto_map` if hasattr(config, "auto_map") and config.auto_map: architectures = list(set(config.auto_map.keys()) - set(["AutoConfig"])) if architectures: model_cls_name = ( "AutoModel" if "AutoModel" in architectures else architectures[-1] ) if "multimolecule" in config_or_model.__repr__().lower(): model_cls = getattr( import_module(f"multimolecule"), model_cls_name ) else: model_cls = getattr( import_module(f"transformers"), model_cls_name ) model = model_cls.from_pretrained( config_or_model, config=config, trust_remote_code=trust_remote_code, ignore_mismatched_sizes=ignore_mismatched_sizes, ).base_model else: raise ValueError( f"Model cannot be instantiated from '{config_or_model}'. " f"The configuration must contain either 'architectures' or 'auto_map' field. " f"Please verify the model path/ID is correct and config.json is properly formatted." ) elif hasattr(config, "architectures") and config.architectures: model_cls_name = ( AutoModel if "AutoModel" in config.architectures else config.architectures[-1] ) if hasattr(import_module(f"multimolecule"), model_cls_name): model_cls = getattr(import_module(f"multimolecule"), model_cls_name) elif hasattr(import_module(f"transformers"), model_cls_name): model_cls = getattr(import_module(f"transformers"), model_cls_name) else: raise ValueError( f"Model class '{model_cls_name}' not found in transformers or multimolecule libraries." ) model = model_cls.from_pretrained( config_or_model, config=config, trust_remote_code=trust_remote_code, ignore_mismatched_sizes=ignore_mismatched_sizes, ).base_model else: raise ValueError( f"Model configuration from '{config_or_model}' is missing both 'architectures' and 'auto_map' fields. " f"Cannot determine the model architecture. Please ensure the model has a valid config.json file." ) self.model = model self.model.config = config del model_cls elif isinstance(config_or_model, torch.nn.Module): self.model = config_or_model self.model.config.num_labels = ( num_labels if len(label2id) == num_labels else len(label2id) ) self.model.config.label2id = label2id elif isinstance(config_or_model, AutoConfig): config = config_or_model config.num_labels = ( num_labels if len(label2id) == num_labels else len(label2id) ) config.label2id = label2id self.model = AutoModel.from_config(config) self.model.config = config else: raise ValueError( f"Invalid type for config_or_model: {type(config_or_model).__name__}. " f"Expected one of: str (model path/ID), torch.nn.Module (model instance), " f"or AutoConfig (configuration object)." ) # Update the config self.config = self.model.config if isinstance(label2id, dict): self.config.label2id = label2id self.config.id2label = {v: k for k, v in label2id.items()} if ( not hasattr(self.config, "num_labels") or len(self.config.id2label) != self.config.num_labels ): fprint( "Warning: The number of labels in the config is not equal to the number of labels in the label2id dictionary. " ) fprint( "Please check the label2id dictionary and the num_labels parameter in the config." ) self.config.num_labels = len(self.config.id2label) assert ( len(self.config.label2id) == num_labels ), f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary." # The metadata of the model self.metadata = env_meta_info() self.metadata["model_cls"] = self.__class__.__name__ # Store dataset class for data preprocessing during inference if dataset_class is not None: self.dataset_class = dataset_class self.metadata["dataset_cls"] = dataset_class.__name__ self.metadata["dataset_module"] = dataset_class.__module__ elif hasattr(self, "dataset_class"): dataset_class = self.dataset_class self.metadata["dataset_cls"] = dataset_class.__name__ self.metadata["dataset_module"] = dataset_class.__module__ else: self.dataset_class = None self.metadata["dataset_cls"] = None self.metadata["dataset_module"] = None fprint( "Warning: No dataset_class is provided for the model, please set 'dataset_class=...' " "when initializing the model if you want to use the dataset's prepare_input method during inference." ) # The config of the model if hasattr(self.config, "n_embd") and self.config.n_embd: self.config.hidden_size = self.config.n_embd elif hasattr(self.config, "d_model") and self.config.d_model: self.config.hidden_size = self.config.d_model elif hasattr(self.config, "hidden_size") and self.config.hidden_size: self.config.hidden_size = self.config.hidden_size else: raise RuntimeError( "The hidden size of the model is not found in the config." ) # The tokenizer of the model self.tokenizer = tokenizer self.metadata["tokenizer_cls"] = self.tokenizer.__class__.__name__ if hasattr(self.tokenizer, "base_tokenizer"): self.pad_token_id = self.tokenizer.base_tokenizer.pad_token_id else: self.pad_token_id = self.tokenizer.pad_token_id self.dropout = torch.nn.Dropout(kwargs.get("dropout", 0.0)) self.activation = torch.nn.Tanh() # Device management: track device but don't move yet (subclass layers not created) # Let device movement happen explicitly via .to() or automatically during forward pass try: self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") except Exception: self._device = torch.device("cpu")
[docs] def last_hidden_state_forward(self, **inputs): """ Performs a forward pass to get the last hidden state from the base model. It also handles compatibility with different model architectures by mapping input parameters appropriately. Args: **inputs: The inputs to the model, compatible with the base model's forward method. Typically includes 'input_ids', 'attention_mask', and other model-specific parameters. Returns: torch.Tensor: The last hidden state tensor. Example: >>> inputs = { ... 'input_ids': torch.tensor([[1, 2, 3, 4]]), ... 'attention_mask': torch.tensor([[1, 1, 1, 1]]) ... } >>> hidden_states = model.last_hidden_state_forward(**inputs) """ model = self.model input_mapping = {} inputs["output_hidden_states"] = True if "strippedhyena" in model.__class__.__name__.lower(): inputs["x"] = inputs["input_ids"] # For compatibility with Evo models if isinstance(inputs, BatchEncoding) or isinstance(inputs, dict): # Determine the input parameter names of the model's forward method forward_params = inspect.signature(model.forward).parameters # Map the inputs to the forward method parameters for param in forward_params: if param in inputs: input_mapping[param] = inputs[param] # 对于未在模型签名中声明的关键参数,可以给出警告或日志 ignored_keys = set(inputs.keys()) - set(input_mapping.keys()) if ignored_keys: warnings.warn(f"Warning: Ignored keys in inputs: {ignored_keys}") inputs = input_mapping elif isinstance(inputs, tuple): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else None inputs = {"input_ids": input_ids, "attention_mask": attention_mask} elif isinstance(inputs, torch.Tensor): shape = inputs.shape try: if len(shape) == 3: if shape[1] == 2: input_ids = inputs[:, 0] attention_mask = inputs[:, 1] else: input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else None elif len(shape) == 2: input_ids = inputs attention_mask = None else: raise ValueError( f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}." ) except: raise ValueError( f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}." ) inputs = {"input_ids": input_ids, "attention_mask": attention_mask} else: raise ValueError( f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}." ) # 执行模型 outputs = model(**inputs) if not hasattr(outputs, "last_hidden_state"): warnings.warn( f"last_hidden_state not found in the outputs from the {model.__class__.__name__} model." ) if hasattr(outputs, "last_hidden_state"): last_hidden_state = outputs.last_hidden_state elif isinstance(outputs, dict) and "last_hidden_state" in outputs: last_hidden_state = outputs["last_hidden_state"] elif hasattr(outputs, "hidden_states"): last_hidden_state = outputs.hidden_states[-1] elif isinstance(outputs, (list, tuple, torch.Tensor)): if len(outputs) <= 2: # For Evo models that return a tuple of (last_hidden_state, logits) last_hidden_state = outputs[0] elif len(outputs) >= 3: last_hidden_state = outputs[-1] else: raise ValueError( f"Cannot find the last hidden state in the outputs from the {model.__class__.__name__} model, " f"please check the model architecture." ) return last_hidden_state
[docs] def loss_function(self, logits, labels): """ Calculates the loss. This method should be implemented by concrete model classes to define how the loss is calculated for their specific task (classification, regression, etc.). Args: logits (torch.Tensor): The model's output logits. labels (torch.Tensor): The ground truth labels. Returns: torch.Tensor: The calculated loss. Raises: NotImplementedError: If the method is not implemented by the subclass. Example: >>> # In a classification model >>> loss = model.loss_function(logits, labels) """ raise NotImplementedError( "The loss_function() function should be implemented for your model." )
[docs] def set_loss_fn(self, loss_function): """ Sets a custom loss function for the model. The loss function should be compatible with the model's output format. Args: loss_function (callable): A callable loss function that takes logits and labels as arguments. Example: >>> import torch.nn as nn >>> model.set_loss_fn(nn.CrossEntropyLoss()) """ self.loss_fn = loss_function try: self.loss_fn.weight.to(self.model.device) except AttributeError: # If the loss function does not have a weight attribute, we assume it's not weighted pass
[docs] def predict(self, sequence_or_inputs, **kwargs): """ This method takes raw sequences or tokenized inputs and returns the raw model outputs (logits, hidden states, etc.) without post-processing. It's useful for getting the model's direct predictions for further processing. If the model has a dataset_class, this method will use the dataset's prepare_input method for data preprocessing, allowing for more complex data preparation including custom field handling. Args: sequence_or_inputs: Can be one of: - str: A single sequence (e.g., "ATCGATCG") - list: A list of sequences (e.g., ["ATCGATCG", "GCTAGCTA"]) - dict: A dictionary with 'sequence'/'seq' and optionally 'label'/'labels' and other custom fields (e.g., {"sequence": "ATCG", "label": 1}) - BatchEncoding/dict with 'input_ids': Already tokenized inputs **kwargs: Additional arguments for tokenization or dataset preparation. Common options include 'max_length', 'padding', 'truncation', etc. Returns: dict: A dictionary containing the raw model outputs, typically including `logits`, `last_hidden_state`, and other model-specific outputs. Example: >>> # Predict on a single sequence >>> outputs = model.predict("ATCGATCG") >>> >>> # Predict on multiple sequences >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"]) >>> >>> # Predict with dict input (if dataset_class is set) >>> outputs = model.predict({"sequence": "ATCGATCG", "label": 1}) """ # Please implement the predict() function for your model raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) return raw_outputs
[docs] def inference(self, sequence_or_inputs, **kwargs): """ This method takes raw sequences or tokenized inputs and returns processed predictions that are ready for human consumption. It typically includes post-processing steps like converting logits to class labels or probabilities. If the model has a dataset_class, this method will use the dataset's prepare_input method for data preprocessing, allowing for more complex data preparation including custom field handling. Args: sequence_or_inputs: Can be one of: - str: A single sequence (e.g., "ATCGATCG") - list: A list of sequences (e.g., ["ATCGATCG", "GCTAGCTA"]) - dict: A dictionary with 'sequence'/'seq' and optionally 'label'/'labels' and other custom fields (e.g., {"sequence": "ATCG", "label": 1}) - BatchEncoding/dict with 'input_ids': Already tokenized inputs **kwargs: Additional arguments for tokenization, dataset preparation, or inference. Common options include 'max_length', 'padding', 'truncation', etc. Returns: dict: A dictionary containing the processed predictions, typically including 'predictions', 'confidence', and other human-readable outputs. Example: >>> # Inference on a single sequence >>> results = model.inference("ATCGATCG") >>> print(results['predictions']) # Class labels >>> >>> # Inference on multiple sequences >>> results = model.inference(["ATCGATCG", "GCTAGCTA"]) >>> >>> # Inference with dict input (if dataset_class is set) >>> results = model.inference({"sequence": "ATCGATCG", "label": 1}) >>> print(results['predictions']) """ # Please implement the predict() function for your model raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) return raw_outputs
def __call__(self, **inputs): """ The main forward pass of the model, suitable for training loops. This method is the primary interface for model forward passes during training. It handles both tokenized inputs and raw sequences, calculates loss if labels are provided, and returns a comprehensive output dictionary. Args: **inputs: A dictionary of tokenized inputs, potentially including labels. Can also handle raw sequences that will be tokenized automatically. Returns: dict: A dictionary containing logits, last_hidden_state, labels, and loss (if labels were provided). Example: >>> # Training forward pass >>> outputs = model( ... input_ids=torch.tensor([[1, 2, 3, 4]]), ... attention_mask=torch.tensor([[1, 1, 1, 1]]), ... labels=torch.tensor([0]) ... ) >>> loss = outputs['loss'] """ # For transformer trainer integration, we need to pop the "inputs" to be a tokenized inputs object. # For native trainer, the inputs are already tokenized inputs object labels = inputs.pop("labels", None) inputs = inputs.pop("inputs", inputs) inputs["labels"] = labels if isinstance(inputs, dict): labels = inputs.get("labels", None) label = inputs.get("label", None) labels = labels if labels is not None else label # if labels is None: # warnings.warn( # "No labels are provided in the inputs, the model will not calculate the loss." # ) elif isinstance(inputs, tuple): labels = inputs[1] inputs = inputs[0] elif labels is not None: labels = labels outputs = self.forward(**inputs) if labels is not None: outputs["loss"] = self._calculate_loss(outputs, labels) else: outputs["loss"] = None return outputs def _calculate_loss(self, outputs, labels): """ Internal method to calculate loss if not already present in outputs. :param outputs: The dictionary of model outputs. :param labels: The ground truth labels. :return: The calculated loss. """ loss = outputs.get("loss", None) if loss is not None: return loss logits = outputs["logits"] if logits is not None or labels is not None: loss = self.loss_function(logits, labels) return loss else: raise RuntimeError( "The output of the forward() function should be a dictionary-like objective" " and have either 'loss', or 'logits' and 'labels' attribute." ) # ==================== Save Helper Methods ==================== def _save_base_files(self, path): """ Copy base model configuration files to save directory. This method copies essential configuration files from the original model directory to the target save directory, excluding weight files to avoid duplication. Files with extensions .bin, .json, .txt, and .py are copied. Args: path (str): Target directory path where files will be copied. Note: - pytorch_model.bin and model.safetensors are excluded as they will be saved separately in _save_weights() - Only files from the original model path are copied """ for file in findfile.find_files( self.config.name_or_path, or_key=["bin", "json", "txt", "py"], exclude_key=["pytorch_model.bin", "model.safetensors"], return_relative_path=False, ): shutil.copyfile(file, f"{path}/{os.path.basename(file)}") def _save_custom_model_class(self, path, metadata): """ Save custom model class source file if it's user-defined. This method detects and saves the source code of custom model classes that are defined outside the omnigenbench/omnigenome packages. This enables loading models without requiring the original source code. Args: path (str): Target directory path where the custom model file will be saved. metadata (dict): Metadata dictionary that will be updated with custom model file information. Side Effects: - Creates 'custom_model.py' in the target directory if model is custom - Updates metadata with 'custom_model_file' key - Prints confirmation message on success - Prints warning on failure (non-fatal) Note: - Only saves source code for models NOT in omnigenbench/omnigenome packages - Failures are logged but don't interrupt the save process """ def _copy_module_file(src_file, dst_name): dst_file = os.path.join(path, dst_name) os.makedirs(os.path.dirname(dst_file), exist_ok=True) shutil.copyfile(src_file, dst_file) metadata["custom_model_file"] = dst_name fprint(f"Saved model source to: {dst_file}") try: model_class = self.__class__ model_source_file = inspect.getfile(model_class) if ( "omnigenbench" not in model_source_file and "omnigenome" not in model_source_file ): _copy_module_file(model_source_file, "custom_model.py") return except (TypeError, OSError) as e: fprint(f"Could not save custom model source file: {e}") # If the Omni wrapper class is from the framework but the base model is remote code try: base_model = getattr(self, "model", None) if base_model is None: return base_cls = base_model.__class__ base_src = inspect.getfile(base_cls) if base_src and "omnigenbench" not in base_src and "omnigenome" not in base_src: base_dir = os.path.dirname(base_src) base_name = os.path.basename(base_src) # If the remote module sits inside a package (has __init__.py), copy the package tree if os.path.exists(os.path.join(base_dir, "__init__.py")): dst_pkg = os.path.join(path) if not os.path.exists(dst_pkg): shutil.copytree(base_dir, dst_pkg, dirs_exist_ok=True) fprint(f"Saved remote code package to: {dst_pkg}") # Ensure the specific model file is copied dst_model_file = os.path.join(dst_pkg, base_name) if not os.path.exists(dst_model_file): shutil.copyfile(base_src, dst_model_file) fprint(f"Saved model source to: {dst_model_file}") rel_file = os.path.join(os.path.basename(base_dir), base_name) metadata["custom_model_file"] = rel_file else: _copy_module_file(base_src, base_name) except (TypeError, OSError, AttributeError) as e: fprint(f"Could not save base model source file: {e}") def _save_custom_dataset_class(self, path, metadata): """ Save custom dataset class source file if available. This method saves the source code of custom dataset classes associated with the model. This ensures that data preprocessing logic is preserved and can be used during inference without requiring the original code. Args: path (str): Target directory path where the custom dataset file will be saved. metadata (dict): Metadata dictionary that will be updated with custom dataset file information. Side Effects: - Creates 'custom_dataset.py' in the target directory if dataset is custom - Updates metadata with 'custom_dataset_file' and 'custom_dataset_class' keys - Prints confirmation message on success - Prints warning on failure (non-fatal) Note: - Requires model to have 'dataset_class' attribute - Only saves source code for datasets NOT in omnigenbench/omnigenome packages - Silently returns if model has no dataset_class attribute - Failures are logged but don't interrupt the save process """ if not hasattr(self, "dataset_class"): return try: dataset_class = self.dataset_class dataset_source_file = inspect.getfile(dataset_class) # Check if it's a user-defined dataset (not from omnigenbench/omnigenome) if ( "omnigenbench" not in dataset_source_file and "omnigenome" not in dataset_source_file ): custom_dataset_path = os.path.join(path, "custom_dataset.py") shutil.copyfile(dataset_source_file, custom_dataset_path) metadata["custom_dataset_file"] = "custom_dataset.py" metadata["custom_dataset_class"] = dataset_class.__name__ fprint(f"Saved custom dataset class source to: {custom_dataset_path}") except (TypeError, OSError, AttributeError) as e: fprint(f"Could not save custom dataset source file: {e}") def _collect_metadata(self): """ Collect all metadata to be saved with the model. This method gathers comprehensive metadata about the model, including: - Loss function information (class name and module) - Model class information (name and module) - Custom attributes (num_labels, num_classes, label mappings, etc.) - Dataset metadata (if present) Returns: dict: A dictionary containing all model metadata with the following structure: { 'model_cls': str, # Model class name 'model_module': str, # Model module path 'loss_fn_class': str, # Loss function class name (optional) 'loss_fn_module': str, # Loss function module path (optional) 'custom_attrs': dict, # Custom model attributes (optional) 'dataset_metadata': dict, # Dataset metadata (optional) ... (other metadata from self.metadata) } Note: - Only serializable attributes (int, float, str, bool, list, dict) are saved - Custom attributes checked: num_labels, num_classes, threshold, label2idx, idx2label, tissue_names, tissue_columns - Base metadata is copied from self.metadata """ metadata = self.metadata.copy() # Loss function metadata if self.loss_fn is not None: metadata["loss_fn_class"] = self.loss_fn.__class__.__name__ metadata["loss_fn_module"] = self.loss_fn.__class__.__module__ # Model class metadata model_class = self.__class__ metadata["model_cls"] = model_class.__name__ metadata["model_module"] = model_class.__module__ # Custom attributes custom_attrs = {} for attr_name in [ "num_labels", "num_classes", "threshold", "label2idx", "idx2label", "tissue_names", "tissue_columns", ]: if hasattr(self, attr_name): attr_value = getattr(self, attr_name) if isinstance(attr_value, (int, float, str, bool, list, dict)): custom_attrs[attr_name] = attr_value if custom_attrs: metadata["custom_attrs"] = custom_attrs # Dataset metadata if hasattr(self, "metadata") and "dataset_metadata" in self.metadata: metadata["dataset_metadata"] = self.metadata["dataset_metadata"] return metadata def _save_weights(self, path): """ Save model weights and tokenizer to disk. This method handles the serialization of model weights and tokenizer. It attempts multiple saving strategies to ensure compatibility: 1. Save tokenizer using dill serialization 2. Try to save base model using save_pretrained() (HuggingFace style) 3. Save complete state dict as fallback Args: path (str): Target directory path where weights and tokenizer will be saved. Side Effects: - Creates 'tokenizer.bin' in the target directory - Creates model weight files (via save_pretrained if available) - Creates 'pytorch_model.bin' containing complete state dict Note: - Tokenizer is serialized using dill to preserve all attributes - Base model save_pretrained() may fail for custom models (non-fatal) - Complete state dict is always saved as backup """ # Save tokenizer with open(f"{path}/tokenizer.bin", "wb") as f: dill.dump(self.tokenizer, f) # Try to save the underlying base model try: self.model.save_pretrained(f"{path}", safe_serialization=False) except AttributeError: # Fallback: if the OmniModel subclass provides its own `save_pretrained`, use it if hasattr(self, "save_pretrained"): try: self.save_pretrained(path, overwrite=True) except Exception: pass # Save complete state dict including all components with open(f"{path}/pytorch_model.bin", "wb") as f: torch.save(self.state_dict(), f)
[docs] def save(self, path, overwrite=False, dtype=torch.float16, **kwargs): """ Save the complete model, tokenizer, and metadata to a directory. This method performs a comprehensive save operation that includes: - Base model configuration files - Custom model and dataset class source code (if applicable) - Complete metadata including loss function and custom attributes - Model weights and tokenizer The save process follows a 6-step workflow: 1. Save base configuration files 2. Collect comprehensive metadata 3. Save custom model class source (if user-defined) 4. Save custom dataset class source (if available) 5. Save metadata to JSON file 6. Save model weights and tokenizer Args: path (str): Target directory path for saving the model. overwrite (bool, optional): Whether to overwrite existing directory. If False, a timestamp will be appended to path. Defaults to False. dtype (torch.dtype, optional): Data type for saving model weights. Model is temporarily converted to this dtype during saving. Defaults to torch.float16. **kwargs: Additional keyword arguments (reserved for future use). Side Effects: - Creates target directory if it doesn't exist - Temporarily moves model to CPU and converts dtype - Restores original device and dtype after saving - Prints confirmation message on completion Example: >>> # Basic save >>> model.save("my_model") >>> # Save with overwrite >>> model.save("my_model", overwrite=True) >>> # Save with specific dtype >>> model.save("my_model", dtype=torch.float32) Note: - Model is set to eval mode before saving - Original device and dtype are preserved - Custom classes are only saved if defined outside framework packages - Failures in non-critical steps (e.g., custom class saving) are logged but don't stop the process """ self.eval() # Handle path conflicts if os.path.exists(path) and not overwrite: fprint( f"The path {path} already exists, please set overwrite=True to overwrite it. " f"Rename the path to {path}_{time.strftime('%Y%m%d_%H%M%S')} to save it with a timestamp." ) path = f"{path}_{time.strftime('%Y%m%d_%H%M%S')}" if not os.path.exists(path): os.makedirs(path) # Store original device and dtype _device = self.model.device _dtype = self.model.dtype self.model.to(dtype).to("cpu") # Save model and tokenizer config self.tokenizer.save_pretrained(path) self.model.save_pretrained(path) # Step 1: Save base files self._save_base_files(path) # Step 2: Collect metadata metadata = self._collect_metadata() # Step 3: Save custom model class self._save_custom_model_class(path, metadata) # Step 4: Save custom dataset class self._save_custom_dataset_class(path, metadata) # Step 5: Save metadata to JSON with open(f"{path}/metadata.json", "w", encoding="utf8") as f: json.dump(metadata, f, indent=2) # Step 6: Save weights and tokenizer self._save_weights(path) # Restore original device and dtype self.model.to(_dtype).to(_device) fprint(f"The model is saved to {path}.")
# ==================== Load Helper Methods ==================== def _load_metadata(self, path): """ Load and validate metadata from saved model directory. This method reads the metadata.json file and performs validation to ensure the saved model matches the current model class. Args: path (str): Directory path containing the saved model and metadata.json. Returns: dict: Loaded metadata dictionary containing model information. Raises: ValueError: If the saved model class doesn't match the current model class. FileNotFoundError: If metadata.json is not found in the directory. json.JSONDecodeError: If metadata.json is malformed. Example: >>> metadata = model._load_metadata("checkpoint") >>> print(metadata['model_cls']) # 'OmniModelForSequenceClassification' Note: - Validates that saved model class matches current class - This check ensures type safety when loading models """ with open(f"{path}/metadata.json", "r", encoding="utf8") as f: metadata = json.load(f) if metadata["model_cls"] != self.__class__.__name__: raise ValueError( f"The model class in the loaded model is {metadata['model_cls']}, " f"but the current model class is {self.__class__.__name__}." ) return metadata def _load_config(self, path, **kwargs): """ Load model configuration and check for differences with current config. This method loads the saved configuration and compares it with the current model's configuration, warning about any differences found. Args: path (str): Directory path containing the saved model configuration. **kwargs: Additional arguments passed to AutoConfig.from_pretrained(). Returns: AutoConfig: Loaded configuration object. Side Effects: - Prints warnings for any configuration differences found - Warnings include the key name and both values (saved vs current) Example: >>> config = model._load_config("checkpoint", trust_remote_code=True) Warning: The value of the key num_labels in the loaded model is 10, but the current value is 5. Note: - trust_remote_code is set to True by default to support custom model files - Configuration differences don't prevent loading but are logged - Useful for detecting model version mismatches """ config = AutoConfig.from_pretrained(path, trust_remote_code=True, **kwargs) for key, value in config.__dict__.items(): if key not in self.config.__dict__ or self.config.__dict__[key] != value: fprint( f"Warning: The value of the key {key} in the loaded model is {value}, " f"but the current value is {self.config.__dict__.get(key, None)}." ) return config def _load_dataset_class(self, path, metadata): """ Restore the dataset class from metadata. This method attempts to dynamically import the dataset class that was saved with the model. This allows the model to use the same data preprocessing logic during inference as was used during training. Args: path (str): Directory path containing the saved model. metadata (dict): Metadata dictionary containing dataset class information. Expected keys: 'dataset_cls', 'dataset_module', 'custom_dataset_file', 'custom_dataset_class' Side Effects: - Sets self.dataset_class to the restored dataset class - Prints confirmation message on success - Prints warning on failure (non-fatal) Note: - Silently returns if dataset class info is not in metadata - Tries multiple loading strategies: built-in modules, custom files - Import or attribute errors are caught and logged as warnings - Custom dataset classes are loaded from custom_dataset.py if available """ if "dataset_cls" not in metadata and "custom_dataset_class" not in metadata: return dataset_cls = None # Method 1: Try to load from custom_dataset.py file if "custom_dataset_file" in metadata and "custom_dataset_class" in metadata: try: custom_dataset_path = os.path.join( path, metadata["custom_dataset_file"] ) if os.path.exists(custom_dataset_path): import importlib.util spec = importlib.util.spec_from_file_location( "custom_dataset_module", custom_dataset_path ) custom_dataset_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(custom_dataset_module) dataset_cls = getattr( custom_dataset_module, metadata["custom_dataset_class"] ) fprint( f"Restored custom dataset class: {metadata['custom_dataset_class']} from {metadata['custom_dataset_file']}" ) except (ImportError, AttributeError, OSError) as e: warnings.warn(f"Could not restore custom dataset class from file: {e}") # Method 2: Try to load from built-in omnigenbench/omnigenome modules if ( dataset_cls is None and "dataset_module" in metadata and "dataset_cls" in metadata ): try: dataset_module = import_module(metadata["dataset_module"]) dataset_cls = getattr(dataset_module, metadata["dataset_cls"]) fprint( f"Restored dataset class: {metadata['dataset_cls']} from {metadata['dataset_module']}" ) except (ImportError, AttributeError) as e: warnings.warn(f"Could not restore dataset class from module: {e}") if dataset_cls is not None: self.dataset_class = dataset_cls def _load_loss_function(self, metadata): """ Restore saved loss function from metadata. This method attempts to restore the loss function that was used during training by dynamically importing it based on saved metadata. Args: metadata (dict): Metadata dictionary containing loss function information. Expected keys: 'loss_fn_class', 'loss_fn_module' Side Effects: - Sets self.loss_fn to the restored loss function instance - Prints confirmation message on success - Prints warning on failure (non-fatal) Example: >>> # metadata contains: {'loss_fn_class': 'CrossEntropyLoss', 'loss_fn_module': 'torch.nn.modules.loss'} >>> model._load_loss_function(metadata) Restored loss function: CrossEntropyLoss from torch.nn.modules.loss Note: - Silently returns if loss function info is not in metadata - Import or attribute errors are caught and logged as warnings - Loss function is instantiated with default parameters - Custom loss functions must be importable at load time """ if "loss_fn_class" not in metadata or "loss_fn_module" not in metadata: return try: loss_module = import_module(metadata["loss_fn_module"]) loss_class = getattr(loss_module, metadata["loss_fn_class"]) self.loss_fn = loss_class() fprint( f"Restored loss function: {metadata['loss_fn_class']} from {metadata['loss_fn_module']}" ) except (ImportError, AttributeError) as e: warnings.warn(f"Could not restore loss function: {e}") def _load_weights(self, path, **kwargs): """ Load model weights with validation and compatibility checks. This method loads the saved state dictionary and performs thorough validation by comparing saved weights with the current model structure. It reports any missing or unexpected keys. Args: path (str): Directory path containing pytorch_model.bin file. **kwargs: Additional arguments, may include: - device (str): Device to map loaded tensors to (e.g., 'cpu', 'cuda:0') Side Effects: - Loads weights into current model's state dict - Prints warnings for missing or unexpected keys - Uses strict=False to allow partial loading Warnings: - Missing keys: Parameters in current model not found in saved weights - Unexpected keys: Parameters in saved weights not found in current model Example: >>> model._load_weights("checkpoint", device="cuda:0") Warning: Missing keys in loaded weights: {'classifier.bias'} Warning: Unexpected keys in loaded weights: {'old_layer.weight'} Note: - strict=False allows loading with architecture mismatches - Missing keys will be randomly initialized - Unexpected keys are ignored - Device mapping prevents CUDA OOM when loading on different device """ with open(f"{path}/pytorch_model.bin", "rb") as f: loaded_state_dict = torch.load(f, map_location=kwargs.get("device", "cpu")) # Check if keys match between current and loaded state dict current_keys = set(self.state_dict().keys()) loaded_keys = set(loaded_state_dict.keys()) missing_keys = current_keys - loaded_keys unexpected_keys = loaded_keys - current_keys if missing_keys: warnings.warn(f"Missing keys in loaded weights: {missing_keys}") if unexpected_keys: warnings.warn(f"Unexpected keys in loaded weights: {unexpected_keys}") self.load_state_dict(loaded_state_dict, strict=False) def _load_tokenizer(self, path): """ Load saved tokenizer from binary file. This method deserializes the tokenizer that was saved using dill. The tokenizer is essential for proper text preprocessing during inference. Args: path (str): Directory path containing tokenizer.bin file. Side Effects: - Sets self.tokenizer to the loaded tokenizer instance - Silently returns if tokenizer.bin doesn't exist Example: >>> model._load_tokenizer("checkpoint") >>> print(type(model.tokenizer)) # <class 'omnigenbench.OmniTokenizer'> Note: - Tokenizer is saved/loaded using dill for complete serialization - If tokenizer.bin doesn't exist, current tokenizer is preserved - Dill is used instead of pickle to handle complex tokenizer objects """ tokenizer_path = f"{path}/tokenizer.bin" if os.path.exists(tokenizer_path): with open(tokenizer_path, "rb") as f: self.tokenizer = dill.load(f)
[docs] def load(self, path, **kwargs): """ Load a complete model from a saved directory. This method performs a comprehensive load operation that restores: - Model metadata and configuration - Loss function (if saved) - Model weights with validation - Tokenizer The load process follows a 6-step workflow: 1. Load and validate metadata 2. Load model configuration 3. Restore dataset class (if available) 4. Restore loss function (if available) 5. Load model weights with validation 6. Load tokenizer Args: path (str): Directory path containing the saved model. **kwargs: Additional keyword arguments passed to loading functions. Common options include: - device (str): Device to load model to (e.g., 'cpu', 'cuda:0') - trust_remote_code (bool): Whether to trust custom code Returns: OmniModel: The loaded model instance (self). Raises: ValueError: If saved model class doesn't match current class. FileNotFoundError: If required files (metadata.json, pytorch_model.bin) are missing. Example: >>> # Basic load >>> loaded_model = model.load("checkpoint") >>> # Load to specific device >>> loaded_model = model.load("checkpoint", device="cuda:0") >>> # Load with custom code trust >>> loaded_model = model.load("checkpoint", trust_remote_code=True) Side Effects: - Updates all model attributes (weights, config, tokenizer, loss_fn) - Prints warnings for configuration differences - Prints warnings for weight loading issues (missing/unexpected keys) - Prints confirmation messages for restored components Note: - Model class must match the saved model class - Partial loading is supported (missing keys will be randomly initialized) - Loss function restoration is optional (warnings only if fails) - Custom models require trust_remote_code=True (enabled by default) """ # Step 1: Load metadata metadata = self._load_metadata(path) # Step 2: Load configuration config = self._load_config(path, **kwargs) # Step 3: Restore dataset class self._load_dataset_class(path, metadata) # Step 4: Restore loss function self._load_loss_function(metadata) # Step 5: Load weights self._load_weights(path, **kwargs) # Step 6: Load tokenizer self._load_tokenizer(path) return self
def _forward_from_raw_input(self, sequence_or_inputs, **kwargs): """ Tokenizes raw input and performs a forward pass in no_grad mode. This method supports two preprocessing strategies: 1. If dataset_class is available, use dataset's prepare_input method for preprocessing 2. Otherwise, use direct tokenizer (backward compatible) :param sequence_or_inputs: A sequence (str), list of sequences, dict with 'sequence' key, or tokenized inputs (BatchEncoding/dict). :param kwargs: Additional arguments for tokenization or dataset preparation. :return: A dictionary containing the raw model outputs and the tokenized inputs. """ # Check if inputs are already tokenized if isinstance(sequence_or_inputs, (BatchEncoding, dict)): # If it's a dict, check if it contains 'input_ids' (tokenized) or raw data if ( isinstance(sequence_or_inputs, dict) and "input_ids" in sequence_or_inputs ): inputs = sequence_or_inputs # If it's a dict without 'input_ids', it might be raw data for dataset.prepare_input elif (isinstance(sequence_or_inputs, dict) and hasattr( self, "dataset_class") and self.dataset_class is not None): # Use dataset's prepare_input method try: # Create a temporary dataset instance for using prepare_input max_length = kwargs.pop("max_length", 1024) dataset_instance = self.dataset_class( dataset_name_or_path=None, tokenizer=self.tokenizer, max_length=max_length, **kwargs, ) inputs = dataset_instance.prepare_input( sequence_or_inputs, **kwargs ) # Remove batch dimension if present (prepare_input may add it) for key, value in inputs.items(): if ( isinstance(value, torch.Tensor) and value.dim() > 1 and value.size(0) == 1 ): inputs[key] = value.squeeze(0) except Exception as e: warnings.warn( f"Failed to use dataset.prepare_input: {e}. Falling back to tokenizer." ) # Fallback to tokenizer for dict input if "sequence" in sequence_or_inputs or "seq" in sequence_or_inputs: seq = sequence_or_inputs.get( "sequence", sequence_or_inputs.get("seq") ) inputs = self.tokenizer( seq, padding=kwargs.pop("padding", True), max_length=kwargs.pop("max_length", 1024), truncation=kwargs.pop("truncation", True), return_tensors=kwargs.pop("return_tensors", "pt"), **kwargs, ) else: inputs = sequence_or_inputs else: inputs = sequence_or_inputs # Handle string or list of strings elif isinstance(sequence_or_inputs, (str, list)): # If dataset_class is available, try to use its prepare_input method if hasattr(self, "dataset_class") and self.dataset_class is not None: try: # Prepare instance(s) for dataset.prepare_input if isinstance(sequence_or_inputs, str): instance = sequence_or_inputs else: # For list of sequences, we'll process them one by one instance = sequence_or_inputs max_length = kwargs.pop("max_length", 1024) dataset_instance = self.dataset_class( dataset_name_or_path=None, tokenizer=self.tokenizer, max_length=max_length, **kwargs, ) if isinstance(instance, list): # Process list of sequences batch_inputs = [] for seq in instance: inp = dataset_instance.prepare_input(seq, **kwargs) batch_inputs.append(inp) # Stack all inputs inputs = { key: torch.stack( [ ( inp[key].squeeze(0) if inp[key].dim() > 0 else inp[key] ) for inp in batch_inputs ] ) for key in batch_inputs[0].keys() } else: inputs = dataset_instance.prepare_input(instance, **kwargs) # Remove batch dimension if present for key, value in inputs.items(): if ( isinstance(value, torch.Tensor) and value.dim() > 1 and value.size(0) == 1 ): inputs[key] = value.squeeze(0) except Exception as e: warnings.warn( f"Failed to use dataset.prepare_input: {e}. Falling back to tokenizer." ) # Fallback to tokenizer inputs = self.tokenizer( sequence_or_inputs, padding=kwargs.pop("padding", True), max_length=kwargs.pop("max_length", 1024), truncation=kwargs.pop("truncation", True), return_tensors=kwargs.pop("return_tensors", "pt"), **kwargs, ) else: # No dataset_class, use tokenizer directly (backward compatible) inputs = self.tokenizer( sequence_or_inputs, padding=kwargs.pop("padding", True), max_length=kwargs.pop("max_length", 1024), truncation=kwargs.pop("truncation", True), return_tensors=kwargs.pop("return_tensors", "pt"), **kwargs, ) else: raise ValueError(f"Unsupported input type: {type(sequence_or_inputs)}") # Ensure inputs are on the correct device and add batch dimension if needed if not isinstance(inputs, (BatchEncoding, dict)): raise ValueError(f"Processed inputs must be a dict, got {type(inputs)}") # Add batch dimension if missing for key, value in inputs.items(): if isinstance(value, torch.Tensor) and value.dim() == 1: inputs[key] = value.unsqueeze(0) # Move tensors to the model's device target_device = self.device inputs = { k: v.to(target_device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items() } with torch.no_grad(): raw_outputs = self(**inputs) raw_outputs["inputs"] = inputs return raw_outputs @property def device(self): """Return the actual device of model parameters, not cached value.""" # Always infer from actual parameters to handle subclass layers correctly try: return next(self.parameters()).device except StopIteration: return torch.device("cpu")
[docs] def to(self, *args, **kwargs): """Move model to specified device/dtype and keep device tracking in sync.""" super().to(*args, **kwargs) self.model.to(*args, **kwargs) # Update internal device tracking if a device/dtype is specified # Try to derive device from args/kwargs or model parameters updated_device = None for arg in args: if isinstance(arg, torch.device): updated_device = arg elif isinstance(arg, str) and ("cuda" in arg or "cpu" in arg): updated_device = torch.device(arg) if "device" in kwargs: dev = kwargs.get("device") updated_device = dev if isinstance(dev, torch.device) else torch.device(dev) if updated_device is None: try: updated_device = next(self.model.parameters()).device except StopIteration: updated_device = self.device self._device = updated_device # Mirror `.device` for compatibility with existing call sites try: self.model.device = updated_device for module in self.model.modules(): try: module.device = updated_device except Exception: pass except Exception: pass return self
[docs] @staticmethod def from_pretrained(config_or_model, tokenizer, *args, **kwargs): """ Loads a pre-trained model and tokenizer. :param config_or_model: The name or path of the pre-trained model. :param tokenizer: The tokenizer to use. :param args: Additional positional arguments. :param kwargs: Additional keyword arguments. :return: An instance of `OmniModel`. """ config = kwargs.pop("config", None) if config is None: config = AutoConfig.from_pretrained(config_or_model, **kwargs) base_model = AutoModel.from_pretrained(config_or_model, **kwargs) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(base_model, **kwargs) return OmniModel(config, base_model, tokenizer, *args, **kwargs)
[docs] def model_info(self): """ Prints and returns detailed information about the model. :return: A string containing the model information. """ info = f"Model Name: {self.__class__.__name__}\n" info += f"Model Metadata: {self.metadata}\n" info += f"Base Model Name: {self.config.name_or_path}\n" info += f"Model Type: {self.config.model_type}\n" info += f"Model Architecture: {self.config.architectures}\n" info += f"Model Parameters: {count_parameters(self.model) / 1e6} M\n" info += f"Model Config: {self.config}\n" fprint(info) return info