Source code for omnigenbench.src.metric.ranking_metric
# -*- coding: utf-8 -*-
# file: ranking_metric.py
# time: 13:27 09/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import types
import warnings
import numpy as np
import sklearn.metrics as metrics
from ..abc.abstract_metric import OmniMetric
[docs]
class RankingMetric(OmniMetric):
"""
A specialized metric class for ranking tasks and evaluation.
This class provides access to ranking-specific metrics from scikit-learn
and handles different input formats including HuggingFace trainer outputs.
It dynamically wraps scikit-learn metrics and provides a unified interface
for computing various ranking evaluation metrics.
Attributes:
metric_func: Custom metric function if provided
ignore_y: Value to ignore in predictions and true values
Example:
>>> from omnigenbench import RankingMetric
>>> metric = RankingMetric(ignore_y=-100)
>>> y_true = [0, 1, 2, 0, 1]
>>> y_pred = [0.1, 0.9, 0.8, 0.2, 0.7]
>>> result = metric.roc_auc_score(y_true, y_pred)
>>> print(result)
{'roc_auc_score': 0.8}
"""
def __init__(self, *args, **kwargs):
"""
Initialize the RankingMetric class.
Args:
*args: Additional positional arguments passed to parent class
**kwargs: Additional keyword arguments passed to parent class
"""
super().__init__(*args, **kwargs)
def __getattr__(self, name):
"""
Dynamically create ranking metric computation methods.
This method intercepts attribute access and creates wrapper functions
for scikit-learn ranking metrics, handling different input formats and
preprocessing the data appropriately.
Args:
name (str): Name of the ranking metric to access
Returns:
callable: Wrapper function for the requested ranking metric
Raises:
AttributeError: If the requested metric is not found
"""
# Get the metric function
metric_func = getattr(metrics, name, None)
if metric_func and isinstance(metric_func, types.FunctionType):
# If the metric function exists, return a wrapper function
def wrapper(y_true=None, y_score=None, *args, **kwargs):
"""
Compute the ranking metric, based on the true and predicted values.
This wrapper handles different input formats including HuggingFace
trainer outputs and performs necessary preprocessing for ranking tasks.
Args:
y_true: The true values or HuggingFace EvalPrediction object
y_score: The predicted values (scores for ranking)
ignore_y: The value to ignore in the predictions and true values in corresponding positions
*args: Additional positional arguments for the metric
**kwargs: Additional keyword arguments for the metric
Returns:
dict: Dictionary containing the metric name and computed value
"""
# for huggingface trainers
if y_true.__class__.__name__ == "EvalPrediction":
eval_prediction = y_true
if hasattr(eval_prediction, "label_ids"):
y_true = eval_prediction.label_ids
if hasattr(eval_prediction, "labels"):
y_true = eval_prediction.labels
predictions = eval_prediction.predictions
for i in range(len(predictions)):
if predictions[i].shape == y_true.shape and not np.all(
predictions[i] == y_true
):
y_score = predictions[i]
break
y_true, y_score = RankingMetric.flatten(y_true, y_score)
y_true_mask_idx = np.where(y_true != self.ignore_y)
if self.ignore_y is not None:
y_true = y_true[y_true_mask_idx]
try:
y_score = y_score[y_true_mask_idx]
except Exception as e:
warnings.warn(str(e))
return {name: self.compute(y_true, y_score, *args, **kwargs)}
return wrapper
raise AttributeError(f"'CustomMetrics' object has no attribute '{name}'")
[docs]
def compute(self, y_true, y_score, *args, **kwargs):
"""
Compute the ranking metric, based on the true and predicted values.
This method should be implemented by subclasses to provide specific
ranking metric computation logic.
Args:
y_true: The true values
y_score: The predicted values (scores for ranking)
*args: Additional positional arguments for the metric
**kwargs: Additional keyword arguments for the metric
Returns:
The computed ranking metric value
Raises:
NotImplementedError: If compute method is not implemented in the child class
"""
raise NotImplementedError(
"Method compute() is not implemented in the child class."
)