Source code for perturb_lib.evaluators.collection.standard_ones

"""Copyright (C) 2025  GlaxoSmithKline plc

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Collection of standard evaluators of perturbation models.
"""

from numpy.typing import NDArray
from scipy.stats import pearsonr
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

from perturb_lib import PlibData
from perturb_lib.evaluators.access import register_evaluator
from perturb_lib.evaluators.base import PlibEvaluatorMixin
from perturb_lib.utils import inherit_docstring


[docs] @register_evaluator @inherit_docstring class RMSE(PlibEvaluatorMixin): """Root-mean-square error (RMSE)."""
[docs] def evaluate(self, predictions: NDArray, true_values: PlibData) -> float: # noqa: D102 return root_mean_squared_error(true_values[:]["value"], predictions)
[docs] @register_evaluator @inherit_docstring class MAE(PlibEvaluatorMixin): """Mean absolute error (MAE)."""
[docs] def evaluate(self, predictions: NDArray, true_values: PlibData) -> float: # noqa: D102\ return mean_absolute_error(true_values[:]["value"], predictions)
[docs] @register_evaluator @inherit_docstring class R2(PlibEvaluatorMixin): """R2 score function. Represents the proportion of variance (of y) that has been explained by the independent variables in the model. """
[docs] def evaluate(self, predictions: NDArray, true_values: PlibData) -> float: # noqa: D102 return r2_score(true_values[:]["value"], predictions)
[docs] @register_evaluator @inherit_docstring class Pearson(PlibEvaluatorMixin): """Pearson correlation coefficient. Measures the linear relationship predictions and ground truth. Strictly speaking, Pearson’s correlation assumes that outputs be normally distributed. """
[docs] def evaluate(self, predictions: NDArray, true_values: PlibData) -> float: # noqa: D102 return pearsonr(true_values[:]["value"], predictions).statistic