from typing import Sequence
import numpy as np
from sklearn.exceptions import NotFittedError
from sklearn.base import BaseEstimator
[docs]def check_class_labels(*args: BaseEstimator) -> bool:
"""
Checks the known class labels for each classifier.
Args:
*args: Classifier objects to check the known class labels.
Returns:
True, if class labels match for all classifiers, False otherwise.
"""
try:
classes_ = [estimator.classes_ for estimator in args]
except AttributeError:
raise NotFittedError('Not all estimators are fitted. Fit all estimators before using this method.')
for classifier_idx in range(len(args) - 1):
if not np.array_equal(classes_[classifier_idx], classes_[classifier_idx+1]):
return False
return True
[docs]def check_class_proba(proba: np.ndarray, known_labels: Sequence, all_labels: Sequence) -> np.ndarray:
"""
Checks the class probabilities and reshapes it if not all labels are present in the classifier.
Args:
proba: The class probabilities of a classifier.
known_labels: The class labels known by the classifier.
all_labels: All class labels.
Returns:
Class probabilities augmented such that the probability of all classes is present. If the classifier is unaware
of a particular class, all probabilities are zero.
"""
# TODO: rewrite this function using numpy.insert
label_idx_map = -np.ones(len(all_labels), dtype='int')
for known_label_idx, known_label in enumerate(known_labels):
# finds the position of label in all_labels
for label_idx, label in enumerate(all_labels):
if np.array_equal(label, known_label):
label_idx_map[label_idx] = known_label_idx
break
aug_proba = np.hstack((proba, np.zeros(shape=(proba.shape[0], 1))))
return aug_proba[:, label_idx_map]