Source code for modAL.utils.validation

from typing import Sequence

import numpy as np
from sklearn.exceptions import NotFittedError
from sklearn.base import BaseEstimator


[docs]def check_class_labels(*args: BaseEstimator) -> bool: """ Checks the known class labels for each classifier. Args: *args: Classifier objects to check the known class labels. Returns: True, if class labels match for all classifiers, False otherwise. """ try: classes_ = [estimator.classes_ for estimator in args] except AttributeError: raise NotFittedError('Not all estimators are fitted. Fit all estimators before using this method.') for classifier_idx in range(len(args) - 1): if not np.array_equal(classes_[classifier_idx], classes_[classifier_idx+1]): return False return True
[docs]def check_class_proba(proba: np.ndarray, known_labels: Sequence, all_labels: Sequence) -> np.ndarray: """ Checks the class probabilities and reshapes it if not all labels are present in the classifier. Args: proba: The class probabilities of a classifier. known_labels: The class labels known by the classifier. all_labels: All class labels. Returns: Class probabilities augmented such that the probability of all classes is present. If the classifier is unaware of a particular class, all probabilities are zero. """ # TODO: rewrite this function using numpy.insert label_idx_map = -np.ones(len(all_labels), dtype='int') for known_label_idx, known_label in enumerate(known_labels): # finds the position of label in all_labels for label_idx, label in enumerate(all_labels): if np.array_equal(label, known_label): label_idx_map[label_idx] = known_label_idx break aug_proba = np.hstack((proba, np.zeros(shape=(proba.shape[0], 1)))) return aug_proba[:, label_idx_map]