Source code for cardinal.base

from typing import List
from abc import ABC, abstractmethod
from warnings import warn

import numpy as np

from .typeutils import (RandomStateType, check_random_state,
                        NotEnoughSamplesWarning)


[docs]class BaseQuerySampler(ABC): """Abstract Base Class for query samplers A query sampler is an object that takes as input labeled and/or unlabeled samples and use knowledge from them to selected the most informative ones. Args: batch_size: Numbers of samples to select. """ def __init__(self, batch_size: int): self.batch_size = batch_size
[docs] @abstractmethod def fit(self, X: np.ndarray, y: np.ndarray = None): """Fit the model on labeled samples. Args: X: Labeled samples of shape (n_samples, n_features). y: Labels of shape (n_samples). Returns: The object itself """ pass
[docs] @abstractmethod def select_samples(self, X: np.array) -> np.array: """Selects the samples to annotate from unlabeled data. Args: X: Pool of unlabeled samples of shape (n_samples, n_features). Returns: Indices of the selected samples of shape (batch_size). """ pass
def _not_enough_samples(self, X: np.array) -> bool: cond = X.shape[0] < self.batch_size if cond: warn(f'''Requested {self.batch_size} samples but data only has {X.shape[0]}. All available data will be returned''', NotEnoughSamplesWarning) return cond
[docs]class ScoredQuerySampler(BaseQuerySampler): """Abstract Base Class handling query samplers relying on a total order. Query sampling methods often scores all the samples and then pick samples using these scores. This base class handles the selection system, only a scoring method is then required. Args: batch_size: Numbers of samples to select. strategy: Describes how to select the samples based on scores. Can be "top", "weighted". random_state: Random seeding """ def __init__(self, batch_size: int, strategy: str = 'top', random_state: RandomStateType = None): super().__init__(batch_size) self.strategy = strategy self.random_state = check_random_state(random_state)
[docs] @abstractmethod def score_samples(self, X: np.array) -> np.array: """Give an informativeness score to unlabeled samples. Args: X: Samples to evaluate. Returns: Scores of the samples. """ pass
[docs] def select_samples(self, X: np.array) -> np.array: """Selects the samples from unlabeled data using the internal scoring. Args: X: Pool of unlabeled samples of shape (n_samples, n_features). strategy: Strategy to use to select queries. Can be one of top, linear_choice, or squared_choice. Returns: Indices of the selected samples of shape (batch_size). """ if self._not_enough_samples(X): return np.arange(X.shape[0]) sample_scores = self.score_samples(X) self.sample_scores_ = sample_scores if self.strategy == 'top': index = np.argsort(sample_scores)[-self.batch_size:] elif self.strategy == 'weighted': index = self.random_state.choice( np.arange(X.shape[0]), size=self.batch_size, replace=False, p=sample_scores / np.sum(sample_scores)) else: raise ValueError('Unknown sample selection strategy {}' .format(self.strategy)) return index