Source code for modAL.utils.combination

from typing import Callable, Optional, Sequence, Tuple

import numpy as np
from sklearn.base import BaseEstimator

from modAL.utils.data import modALinput


[docs]def make_linear_combination(*functions: Callable, weights: Optional[Sequence] = None) -> Callable: """ Takes the given functions and makes a function which returns the linear combination of the output of original functions. It works well with functions returning numpy arrays of the same shape. Args: *functions: Base functions for the linear combination.The functions shall have the same argument and if they return numpy arrays, the returned arrays shall have the same shape. weights: Coefficients of the functions in the linear combination. The i-th given function will be multiplied with weights[i]. Todo: Doesn't it better to accept functions as a Collection explicitly? Returns: A function which returns the linear combination of the given functions output. """ if weights is None: weights = np.ones(shape=(len(functions))) else: assert len(functions) == len(weights), 'the length of weights must be the ' \ 'same as the number of given functions' def linear_combination(*args, **kwargs): return sum((weights[i]*functions[i](*args, **kwargs) for i in range(len(weights)))) return linear_combination
[docs]def make_product(*functions: Callable, exponents: Optional[Sequence] = None) -> Callable: """ Takes the given functions and makes a function which returns the product of the output of original functions. It works well with functions returning numpy arrays of the same shape. Args: *functions: Base functions for the product. The functions shall have the same argument and if they return numpy arrays, the returned arrays shall have the same shape. exponents: Exponents of the functions in the product. The i-th given function in the product will be raised to the power of exponents[i]. Returns: A function which returns the product function of the given functions output. """ if exponents is None: exponents = np.ones(shape=(len(functions))) else: assert len(functions) == len(exponents), 'the length of exponents must be the ' \ 'same as the number of given functions' def product_function(*args, **kwargs): return np.prod([functions[i](*args, **kwargs)**exponents[i] for i in range(len(exponents))], axis=0) return product_function
[docs]def make_query_strategy(utility_measure: Callable, selector: Callable) -> Callable: """ Takes the given utility measure and selector functions and makes a query strategy by combining them. Args: utility_measure: Utility measure, for instance :func:`~modAL.disagreement.vote_entropy`, but it can be a custom function as well. Should take a classifier and the unlabelled data and should return an array containing the utility scores. selector: Function selecting instances for query. Should take an array of utility scores and should return an array containing the queried items. Returns: A function which returns queried instances given a classifier and an unlabelled pool. """ def query_strategy(classifier: BaseEstimator, X: modALinput) -> Tuple: utility = utility_measure(classifier, X) query_idx = selector(utility) return query_idx, X[query_idx] return query_strategy