# -*- coding: utf-8 -*-
# file: auto_bench_cli.py
# time: 19:18 05/02/2025
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# Homepage: https://yangheng95.github.io
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import argparse
import os
from typing import Optional
# Handle both relative and absolute imports
try:
from ..auto_train.auto_train import AutoTrain
from ...src.misc.utils import fprint, load_module_from_path
from ..config.auto_config import AutoConfig
except ImportError:
# Fallback for direct execution
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from omnigenbench.auto.auto_train.auto_train import AutoTrain
from omnigenbench.src.misc.utils import fprint
[docs]
def train_command(args: Optional[list] = None):
"""
Entry point for the OmniGenome auto-train command-line interface.
:param args: A list of command-line arguments. If None, `sys.argv` is used.
"""
parser = create_parser()
parsed_args, unknown_args = parser.parse_known_args(args)
# Dynamically load config.py if it exists in the dataset directory
config_path = os.path.join(parsed_args.dataset, "config.py")
print(f"Loading configuration from: {config_path}")
if os.path.exists(config_path):
config = load_module_from_path("module_name", config_path)
for attr_name in dir(config):
attr = getattr(config, attr_name)
if isinstance(attr, AutoConfig): # Check if it is an instance of AutoConfig
# Process the found AutoConfig instance
for key, value in vars(
attr
).items(): # Iterate over all attributes of the instance
if not hasattr(parsed_args, key):
setattr(parsed_args, key, value)
# Convert unknown arguments into a dictionary
extra_args = {}
for arg in unknown_args:
if arg.startswith("--"):
key = arg.lstrip("--")
value = True # Default to True for flags
if "=" in key:
key, value = key.split("=", 1)
extra_args[key] = value
# Merge extra_args into parsed_args
for key, value in extra_args.items():
if not hasattr(parsed_args, key):
setattr(parsed_args, key, value)
model_path = parsed_args.model
fprint(f"\n>> Starting training for model: {model_path}")
if "multimolecule" in model_path:
from multimolecule import RnaTokenizer, AutoModelForTokenPrediction
tokenizer = RnaTokenizer.from_pretrained(model_path)
model = AutoModelForTokenPrediction.from_pretrained(
model_path, trust_remote_code=True
).base_model
else:
tokenizer = parsed_args.tokenizer
model = model_path
# Initialize AutoTraining
args = vars(parsed_args)
args.pop("model")
args.pop("tokenizer")
autotrain = AutoTrain(
dataset=args.pop("dataset"),
config_or_model=model,
tokenizer=tokenizer,
overwrite=args.pop("overwrite", False),
trainer=args.pop("trainer", "accelerate"),
**vars(parsed_args), # Pass all parsed arguments
)
# Run training
autotrain.run(**vars(parsed_args))
[docs]
def create_parser() -> argparse.ArgumentParser:
"""
Creates the argument parser for the auto-train CLI.
:return: An `argparse.ArgumentParser` instance.
"""
parser = argparse.ArgumentParser(
description="Genomic Foundation Model Training Suite",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Required arguments
parser.add_argument(
"-d",
"--dataset",
type=str,
help="Path to the dataset and training configuration file.",
)
parser.add_argument(
"-t",
"--tokenizer",
type=str,
default=None,
help="Path to the tokenizer to use (HF tokenizer ID or local path).",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
help="Path to the model to evaluate (HF model ID or local path).",
)
# Optional arguments
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing training results, otherwise resume from checkpoint.",
)
parser.add_argument(
"--trainer",
type=str,
default="native",
choices=["native", "accelerate", "hf_trainer"],
help="Trainer to use for training. Options: native, accelerate, hf_trainer.",
)
return parser
[docs]
def run_train():
"""
This function is the entry point for the 'autotrain' console script.
"""
train_command()
if __name__ == "__main__":
train_command()