nigroup/nideep

View on GitHub
nideep/datasets/balance.py

Summary

Maintainability
C
1 day
Test Coverage
import numpy as np

CLNAME_OTHER = 'other'

class Balancer(object):
    '''
    Balance class counts
    '''
    def get_class_count(self, other_clname=CLNAME_OTHER):
        """
        Return per-class occurrence count

        Keyword argumented:
        other_clname -- a name for an overall negative class (all inactive)
        """
        self.has_other_cl = other_clname is not None and other_clname != ''

        if self.l.ndim == 2:
            class_count = np.sum(self.l, axis=0)
            other_count = int(len(self.l) - np.sum(class_count))
        d = {}
        for i, x in enumerate(class_count):
            d[i] = int(x)
        if self.has_other_cl:
            d[other_clname] = other_count
        return d

    def oversample_idxs_to_balance_class_count(self, class_counts):
        """
        Determine indices with with which we can sample from the dataset
        and get a balanced inter-class distribution
        """
        num_examples, num_classes = self.l.shape
        mx = np.max(class_counts)
        # include each sample at least once
        idxs = np.arange(num_examples).reshape(-1, 1)  # reshape to column vector

        for i, c in enumerate(class_counts):
            delta_ = mx - c
            if delta_ > 0:
                if i < num_classes:
                    rows = np.where(self.l[:, i] == 1)[0]
                else:  # cannot index other class
                    rows = np.where(np.sum(self.l, axis=-1) <= 0)[0]
                rows_to_sample = rows[np.random.randint(0, high=rows.size,
                                                        size=(delta_, 1))
                                      ]
                idxs = np.vstack((idxs, rows_to_sample))
        return idxs.ravel()

    def sample_idxs_to_target_count(self, class_counts, target_count):
        """
        Determine indices with with which we can sample from the dataset
        and get a balanced inter-class distribution
        Classes with count < target_count will sub-sampled without replacement.
        Classes with count > target_count will get over-sampled.
        Classes with count equal to target_count will be copied.
        """
        _, num_classes = self.l.shape
        idxs = None  # some classes might get sub-sampled

        for i, _ in enumerate(class_counts):
            if i < num_classes:
                rows = np.where(self.l[:, i] == 1)[0]
            else:  # cannot index 'other' class / negative class, all absent
                rows = np.where(np.sum(self.l, axis=-1) <= 0)[0]
            # over-sample OR sub-sample?
            if rows.size <= target_count:

                # include each once and reshape to column vector
                rows_idxs = np.arange(rows.size).reshape(-1, 1)
                tile_factor = int(target_count / rows.size)
                rows_idxs = np.tile(rows_idxs, (tile_factor, 1))
                rows_to_sample = rows[rows_idxs]
                if rows_idxs.size < target_count:

                    remain = target_count - rows_idxs.size
                    # sample without replacement
                    rows_remain = np.random.choice(rows, size=(remain, 1),
                                                   replace=False)
                    rows_to_sample = np.vstack((rows_to_sample, rows_remain))
            else:
                rows_to_sample = np.random.choice(rows, size=(target_count, 1),
                                                  replace=False)
            if i > 0:
                idxs = np.vstack((idxs, rows_to_sample))
            else:
                idxs = rows_to_sample  # initial assignment
        return idxs.ravel()

    def get_idxs_to_balance_class_count(self, class_counts, target_count=None):
        """
        Determine indices with with which we can sample from the dataset
        and get a balanced inter-class distribution
        """
        if target_count is None:
            return self.oversample_idxs_to_balance_class_count(class_counts)
        else:
            return self.sample_idxs_to_target_count(class_counts, target_count)

    def __init__(self, labls):
        '''
        Constructor
        labls -- ground truth
        '''
        self.l = labls
        self.has_other_cl = True