Source code for omnigenbench.src.metric.classification_metric
# -*- coding: utf-8 -*-
# file: classification_metric.py
# time: 12:57 09/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import types
import warnings
import numpy as np
import sklearn.metrics as metrics
from ..abc.abstract_metric import OmniMetric
[docs]
class ClassificationMetric(OmniMetric):
"""
This class provides a comprehensive interface for classification metrics
in the OmniGenome framework. It integrates with scikit-learn's classification
metrics and provides additional functionality for handling genomic classification
tasks.
The class automatically exposes all scikit-learn classification metrics as
callable attributes, making them easily accessible for evaluation. It also
handles special cases like Hugging Face's EvalPrediction objects and
provides proper handling of ignored labels.
Attributes:
metric_func (callable): A callable metric function from sklearn.metrics.
ignore_y (any): A value in the ground truth labels to be ignored during
metric computation. Defaults to -100.
kwargs (dict): Additional keyword arguments for metric computation.
"""
def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
"""
Initializes the classification metric.
Args:
metric_func (callable, optional): A callable metric function from
sklearn.metrics. If None, subclasses
should implement their own compute method.
ignore_y (any, optional): A value in the ground truth labels to be
ignored during metric computation. Defaults to -100.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Example:
>>> # Initialize with a specific metric function
>>> metric = ClassificationMetric(metrics.accuracy_score)
>>> # Initialize with ignore value
>>> metric = ClassificationMetric(ignore_y=-100)
"""
super().__init__(metric_func, ignore_y, *args, **kwargs)
self.kwargs = kwargs
# def __getattr__(self, name):
def __getattribute__(self, name):
"""
Custom attribute getter that provides dynamic access to scikit-learn metrics.
This method provides transparent access to all scikit-learn classification
metrics. When a metric function is accessed, it returns a callable wrapper
that handles the metric computation with proper preprocessing.
Args:
name (str): The attribute name to get.
Returns:
callable: A wrapper function for the requested metric, or the original
attribute if it's not a metric function.
Example:
>>> metric = ClassificationMetric()
>>> # Access any scikit-learn metric
>>> accuracy_fn = metric.accuracy_score
>>> result = accuracy_fn(y_true, y_pred)
"""
# Get the metric function
metric_func = getattr(metrics, name, None)
if metric_func and isinstance(metric_func, types.FunctionType):
setattr(self, "compute", metric_func)
# If the metric function exists, return a wrapper function
def wrapper(y_true=None, y_pred=None, *args, **kwargs):
"""
Compute the metric, based on the true and predicted values.
This wrapper function handles various input formats including
Hugging Face's EvalPrediction objects and provides proper
preprocessing for metric computation.
Args:
y_true: The true values (ground truth labels).
y_pred: The predicted values (model predictions).
ignore_y: The value to ignore in the predictions and true
values in corresponding positions.
*args: Additional positional arguments for the metric function.
**kwargs: Additional keyword arguments for the metric function.
Returns:
dict: A dictionary with the metric name as key and its value.
Example:
>>> # Standard usage
>>> result = accuracy_fn(y_true, y_pred)
>>> print(result) # {'accuracy_score': 0.85}
>>> # With Hugging Face EvalPrediction
>>> result = accuracy_fn(eval_prediction)
>>> print(result) # {'accuracy_score': 0.85}
"""
# This is an ugly method to handle the case when the predictions are in the form of a tuple
# for huggingface trainers
if y_true.__class__.__name__ == "EvalPrediction":
eval_prediction = y_true
if hasattr(eval_prediction, "label_ids"):
y_true = eval_prediction.label_ids
if hasattr(eval_prediction, "labels"):
y_true = eval_prediction.labels
predictions = eval_prediction.predictions
for i in range(len(predictions)):
if predictions[i].shape == y_true.shape and not np.all(
predictions[i] == y_true
):
y_score = predictions[i]
break
y_true, y_pred = ClassificationMetric.flatten(y_true, y_pred)
y_true_mask_idx = np.where(y_true != self.ignore_y)
if self.ignore_y is not None:
y_true = y_true[y_true_mask_idx]
try:
y_pred = y_pred[y_true_mask_idx]
except Exception as e:
warnings.warn(str(e))
kwargs.update(self.kwargs)
return {name: self.compute(y_true, y_pred, *args, **kwargs)}
return wrapper
else:
return super().__getattribute__(name)
[docs]
def compute(self, y_true, y_pred, *args, **kwargs):
"""
Compute the metric, based on the true and predicted values.
This method computes the classification metric using the provided
metric function. It handles preprocessing and applies any additional
keyword arguments.
Args:
y_true: The true values (ground truth labels).
y_pred: The predicted values (model predictions).
*args: Additional positional arguments for the metric function.
**kwargs: Additional keyword arguments for the metric function.
Returns:
dict: A dictionary with the metric name as key and its value.
Raises:
NotImplementedError: If no metric function is provided and the method
is not implemented by the subclass.
Example:
>>> metric = ClassificationMetric(metrics.accuracy_score)
>>> result = metric.compute(y_true, y_pred)
>>> print(result) # {'accuracy_score': 0.85}
"""
if self.metric_func is not None:
kwargs.update(self.kwargs)
return self.metric_func(y_true, y_pred, *args, **kwargs)
else:
raise NotImplementedError(
"Method compute() is not implemented in the child class."
)