# -*- coding: utf-8 -*-
# file: auto_train.py
# time: 11:54 14/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import os
import time
import warnings
import findfile
import torch
from metric_visualizer import MetricVisualizer
from transformers import TrainingArguments, Trainer as HFTrainer
from ...auto.config.auto_config import AutoConfig
from ...src.lora.lora_model import OmniLoraModel
from ...src.abc.abstract_tokenizer import OmniTokenizer
from ...src.abc.abstract_dataset import OmniDataset
from ...src.misc.utils import (
seed_everything,
fprint,
load_module_from_path,
clean_temp_checkpoint,
)
from ...src.trainer.accelerate_trainer import AccelerateTrainer
from ...src.trainer.trainer import Trainer
autotrain_evaluations = "./autotrain_evaluations"
[docs]
class AutoTrain:
"""
This class provides a comprehensive framework for training genomic models
on various datasets with minimal configuration. It handles dataset loading,
model initialization, training configuration, and result tracking.
AutoTrain supports various training scenarios including:
- Single dataset training with multiple seeds
- Different trainer backends (native, accelerate, huggingface)
- Automatic metric visualization and result tracking
- Configurable training parameters
Attributes:
dataset (str): The name or path of the dataset to use for training.
config_or_model (str): The name or path of the model to train.
tokenizer: The tokenizer to use for training.
autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
overwrite (bool): Whether to overwrite existing training results.
trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
mv_path (str): Path to the metric visualizer file.
mv (MetricVisualizer): The metric visualizer instance.
"""
def __init__(
self,
dataset,
config_or_model,
tokenizer=None,
**kwargs,
):
"""
Initialize the AutoTrain instance.
Args:
dataset (str): The name or path of the dataset to use for training.
Can be a local path or a HuggingFace Hub dataset name.
For hub datasets, it will be automatically downloaded.
config_or_model (str): The model instance, model name or model path of the model to train.
tokenizer: The tokenizer to use. If None, it will be loaded from the model path.
**kwargs: Additional keyword arguments.
- autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
Defaults to 'fp16'.
- overwrite (bool): Whether to overwrite existing training results.
Defaults to False.
- trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
Defaults to 'accelerate'.
- cache_dir (str): Directory to cache downloaded datasets from hub.
Defaults to './__OMNIGENBENCH_DATA__/datasets/'.
Example:
>>> # Initialize with a local dataset path
>>> trainer = AutoTrain("/path/to/dataset", "yangheng/OmniGenome-186M")
>>> # Initialize with a HuggingFace Hub dataset name (auto-downloads)
>>> trainer = AutoTrain("translation_efficiency_prediction",
... "yangheng/OmniGenome-186M")
>>> # Initialize with custom settings
>>> trainer = AutoTrain("dataset_name", "model_name",
... autocast="bf16", trainer="accelerate")
"""
self.dataset_name_or_path = (
dataset.rstrip("/") if isinstance(dataset, str) else dataset
)
self.autocast = kwargs.pop("autocast", "fp16")
self.overwrite = kwargs.pop("overwrite", False)
self.trainer = kwargs.pop("trainer", "accelerate")
self.cache_dir = kwargs.pop("cache_dir", None)
# Check if dataset is a hub name or local path
self.is_hub_dataset = not os.path.exists(self.dataset_name_or_path)
if self.is_hub_dataset:
fprint(f"Detected HuggingFace Hub dataset: {self.dataset_name_or_path}")
fprint("Downloading dataset from hub...")
# Download dataset from hub
if self.cache_dir is None:
self.cache_dir = os.path.join(
os.getcwd(),
f"__OMNIGENBENCH_DATA__/datasets/{self.dataset_name_or_path}",
)
# Use OmniDataset's download method
OmniDataset._download_dataset_from_hub(
self.dataset_name_or_path, self.cache_dir
)
self.dataset = self.cache_dir
fprint(f"Dataset downloaded to: {self.dataset}")
else:
self.dataset = self.dataset_name_or_path
fprint(f"Using local dataset: {self.dataset}")
self.config_or_model = config_or_model
self.tokenizer = tokenizer
if isinstance(config_or_model, str):
self.config_or_model = config_or_model.rstrip("/")
self.model_name = config_or_model.split("/")[-1]
else:
self.model_name = config_or_model.__class__.__name__
if isinstance(tokenizer, str):
self.tokenizer = tokenizer.rstrip("/")
os.makedirs(autotrain_evaluations, exist_ok=True)
time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
# Use dataset name for mv_name (not full path)
dataset_name = os.path.basename(self.dataset_name_or_path)
mv_name = f"{dataset_name}-{self.model_name}"
self.mv_path = f"{autotrain_evaluations}/{mv_name}-{time_str}.mv"
mv_paths = findfile.find_files(
autotrain_evaluations,
[dataset_name, self.model_name, ".mv"],
)
if mv_paths and not self.overwrite:
self.mv = MetricVisualizer.load(mv_paths[-1])
self.mv.summary(round=4)
else:
self.mv = MetricVisualizer(self.mv_path)
self.train_info()
[docs]
def train_info(self):
"""
Print and return information about the current training setup.
Returns:
str: A string containing training setup information.
Example:
>>> info = trainer.train_info()
>>> print(info)
"""
info = f"Dataset Root: {self.dataset}\n"
info += f"Model Name or Path: {self.model_name}\n"
info += f"Tokenizer: {self.tokenizer}\n"
info += f"Metric Visualizer Path: {self.mv_path}\n"
fprint(info)
return info
[docs]
def run(self, **kwargs):
"""
This method loads the dataset configuration, initializes the model and
tokenizer, and runs training across multiple seeds. It supports various
training backends and automatic result tracking.
The method now supports loading configs from both local datasets and
HuggingFace Hub datasets automatically.
Args:
**kwargs: Additional keyword arguments that will override the default
parameters in the dataset configuration.
Special kwargs:
- config (AutoConfig): Provide a pre-loaded config instead of
auto-loading from the dataset directory.
Example:
>>> # Run training with default settings (auto-loads config)
>>> trainer = AutoTrain("translation_efficiency_prediction", "yangheng/OmniGenome-186M")
>>> trainer.run()
>>> # Run with custom parameters
>>> trainer.run(learning_rate=1e-4, batch_size=16)
>>> # Run with custom config
>>> custom_config = AutoConfig.from_hub("my_dataset")
>>> trainer.run(config=custom_config)
"""
clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day
_kwargs = kwargs.copy()
# Check if config is provided directly
if "config" in _kwargs:
train_config = _kwargs.pop("config")
if not isinstance(train_config, AutoConfig):
raise ValueError("Provided config must be an AutoConfig instance")
fprint("Using provided config")
else:
# Try to load config from dataset directory
try:
fprint(f"Loading config from dataset directory: {self.dataset}")
# Search for config.py file
train_config_path = findfile.find_file(
self.dataset,
["config", ".py"],
)
if not train_config_path:
raise FileNotFoundError(
f"Could not find config.py in {self.dataset}"
)
fprint(f"Found config file: {train_config_path}")
# Load the config module
config = load_module_from_path("config", train_config_path)
train_config = None
# Find AutoConfig instance in the module
for attr_name in dir(config):
attr = getattr(config, attr_name)
if isinstance(attr, AutoConfig):
train_config = attr
break
if train_config is None:
raise ValueError(
f"Could not find AutoConfig instance in {train_config_path}"
)
fprint(f"Loaded config for {self.dataset} from {train_config_path}")
except Exception as e:
fprint(f"Warning: Failed to load config from dataset: {e}")
fprint(
"You may need to provide a config manually or ensure the dataset contains a valid config.py file"
)
raise
fprint(train_config.args)
# Init Tokenizer and Model
if not self.tokenizer:
tokenizer = OmniTokenizer.from_pretrained(
self.config_or_model, trust_remote_code=True
)
else:
tokenizer = self.tokenizer
for key, value in _kwargs.items():
if key in train_config:
fprint("Override", key, "with", value, "according to the input kwargs")
train_config.update({key: value})
else:
warnings.warn(
f"kwarg: {key} not found in train_config while setting {key} = {value}"
)
train_config.update({key: value})
for key, value in train_config.items():
if key in train_config and key in _kwargs:
_kwargs.pop(key)
fprint(
f"Autotrain Config for {self.dataset}:",
"\n".join([f"{k}: {v}" for k, v in train_config.items()]),
)
if not isinstance(train_config["seeds"], list):
train_config["seeds"] = [train_config["seeds"]]
random_seeds = train_config["seeds"]
for seed in random_seeds:
batch_size = (
train_config["batch_size"] if "batch_size" in train_config else 8
)
record_name = f"{os.path.basename(self.dataset)}-{self.model_name}".split(
"/"
)[-1]
# check if the record exists
if record_name in self.mv.transpose() and len(
list(self.mv.transpose()[record_name].values())[0]
) >= len(random_seeds):
continue
seed_everything(seed)
if self.config_or_model:
model_cls = train_config["model_cls"]
model = model_cls(
self.config_or_model,
tokenizer=tokenizer,
label2id=train_config.label2id,
num_labels=train_config["num_labels"],
trust_remote_code=True,
ignore_mismatched_sizes=True,
)
else:
raise ValueError(
"config_or_model is not specified. Please provide a valid model name or path."
)
if kwargs.get("lora_config", None) is not None:
fprint("Applying LoRA to the model with config:", kwargs["lora_config"])
model = OmniLoraModel(model, **kwargs.get("lora_config", {}))
# Init Trainer
dataset_cls = train_config["dataset_cls"]
if hasattr(model.config, "max_position_embeddings"):
max_length = min(
train_config["max_length"],
model.config.max_position_embeddings,
)
else:
max_length = train_config["max_length"]
train_set = dataset_cls(
dataset_name_or_path=train_config["train_file"],
tokenizer=tokenizer,
label2id=train_config["label2id"],
max_length=max_length,
structure_in=train_config.get("structure_in", False),
max_examples=train_config.get("max_examples", None),
shuffle=train_config.get("shuffle", True),
drop_long_seq=train_config.get("drop_long_seq", False),
**_kwargs,
)
test_set = dataset_cls(
dataset_name_or_path=train_config["test_file"],
tokenizer=tokenizer,
label2id=train_config["label2id"],
max_length=max_length,
structure_in=train_config.get("structure_in", False),
max_examples=train_config.get("max_examples", None),
shuffle=False,
drop_long_seq=train_config.get("drop_long_seq", False),
**_kwargs,
)
valid_set = dataset_cls(
dataset_name_or_path=train_config["valid_file"],
tokenizer=tokenizer,
label2id=train_config["label2id"],
max_length=max_length,
structure_in=train_config.get("structure_in", False),
max_examples=train_config.get("max_examples", None),
shuffle=False,
drop_long_seq=train_config.get("drop_long_seq", False),
**_kwargs,
)
if self.trainer == "hf_trainer":
# Set up HuggingFace Trainer
hf_kwargs = {
k: v
for k, v in kwargs.items()
if hasattr(TrainingArguments, k) and k != "output_dir"
}
training_args = TrainingArguments(
output_dir=f"./autotrain_evaluations/{self.model_name}",
num_train_epochs=hf_kwargs.pop(
"num_train_epochs", train_config["epochs"]
),
per_device_train_batch_size=hf_kwargs.pop("batch_size", batch_size),
per_device_eval_batch_size=hf_kwargs.pop("batch_size", batch_size),
gradient_accumulation_steps=hf_kwargs.pop(
"gradient_accumulation_steps", 1
),
learning_rate=hf_kwargs.pop("learning_rate", 2e-5),
weight_decay=hf_kwargs.pop("weight_decay", 0),
eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"),
save_strategy=hf_kwargs.pop("save_strategy", "epoch"),
fp16=hf_kwargs.pop("fp16", True),
remove_unused_columns=False,
label_names=["labels"],
**hf_kwargs,
)
valid_set = valid_set if len(valid_set) else test_set
if len(train_config["compute_metrics"]) > 1:
fprint(
"Multiple metrics not supported by HFTrainer, using the first one metric only."
)
trainer = HFTrainer(
model=model,
args=training_args,
train_dataset=train_set,
eval_dataset=valid_set,
compute_metrics=(
train_config["compute_metrics"][0]
if isinstance(train_config["compute_metrics"], list)
else train_config["compute_metrics"]
),
)
# Train and evaluate
eval_result = trainer.evaluate(
valid_set if len(valid_set) else test_set
)
print(eval_result)
train_result = trainer.train()
eval_result = trainer.evaluate()
test_result = trainer.evaluate(test_set if len(test_set) else valid_set)
metrics = {
"train": train_result.metrics,
"eval": eval_result,
"test": test_result,
}
fprint(metrics)
else:
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=(
train_config["learning_rate"]
if "learning_rate" in train_config
else 2e-5
),
weight_decay=(
train_config["weight_decay"]
if "weight_decay" in train_config
else 0
),
)
if self.trainer == "accelerate":
trainer_cls = AccelerateTrainer
else:
trainer_cls = Trainer
fprint(f"Using Trainer: {trainer_cls}")
trainer = trainer_cls(
model=model,
train_dataset=train_set,
eval_dataset=valid_set,
test_dataset=test_set,
batch_size=batch_size,
patience=(
train_config["patience"] if "patience" in train_config else 3
),
epochs=train_config["epochs"],
gradient_accumulation_steps=train_config.get(
"gradient_accumulation_steps", 1
),
optimizer=optimizer,
loss_fn=(
train_config["loss_fn"] if "loss_fn" in train_config else None
),
compute_metrics=train_config["compute_metrics"],
seed=seed,
autocast=self.autocast,
**_kwargs,
)
metrics = trainer.train()
print(_kwargs)
if _kwargs.get("save_model", True):
fprint(
f"Saving model to {autotrain_evaluations}/{self.dataset}/{self.model_name}"
)
save_path = os.path.join(
autotrain_evaluations, self.dataset, self.model_name
)
os.makedirs(save_path, exist_ok=True)
trainer.save_model(save_path, overwrite=True)
if metrics:
for key, value in metrics["test"][-1].items():
try:
value = float(value)
except:
pass # ignore non-float values
self.mv.log(f"{record_name}", f"{key}", value)
# for key, value in metrics['test'][-1].items():
# self.mv.log(f'{record_name}', f'test_{key}', value)
# for i, valid_metrics in enumerate(metrics["valid"]):
# for key, value in valid_metrics.items():
# self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value)
self.mv.summary(round=4)
self.mv.dump(self.mv_path)
self.mv.to_csv(self.mv_path.replace(".mv", ".csv"))
del model, trainer, optimizer
torch.cuda.empty_cache()