MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/rl_utils/running_stats.py

Summary

Maintainability
C
1 day
Test Coverage
F
56%
import numpy as np
from collections import deque
from mushroom_rl.core import Serializable, ArrayBackend


class RunningStandardization(Serializable):
    """
    Compute a running standardization of values according to Welford's online
    algorithm: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm

    """
    def __init__(self, shape, backend, alpha=1e-32):
        """
        Constructor.

        Args:
            shape (tuple): shape of the data to standardize;
            backend (str): name of the backend to be used;
            alpha (float, 1e-32): minimum learning rate.

        """
        assert backend in ["numpy", "torch"]
        self._shape = shape

        assert 0. < alpha < 1.
        self._alpha = alpha

        self._array_backend = ArrayBackend.get_array_backend(backend)

        self._n = 1
        self._m = self._array_backend.zeros(*self._shape)
        self._s = self._array_backend.ones(*self._shape)

        self._add_save_attr(
            _shape='primitive',
            _alpha='primitive',
            _array_backend='pickle',
            _n='primitive',
            _m='primitive',
            _s='primitive'
        )

    def reset(self):
        """
        Reset the mean and standard deviation.

        """
        self._n = 1
        self._m = self._array_backend.zeros(1, *self._shape)
        self._s = self._array_backend.ones(1, *self._shape)

    def update_stats(self, value):
        """
        Update the statistics with the current data value.

        Args:
            value (Array): current data value to use for the update.

        """
        value = self._array_backend.atleast_2d(value)
        batch_size = len(value)
        self._n += batch_size
        alpha = max(batch_size / self._n, self._alpha)
        new_m = (1 - alpha) * self._m + alpha * value.mean(0)
        new_s = self._s + (value.mean(0) - self._m) * (value.mean(0) - new_m)
        self._m, self._s = new_m, new_s

    @property
    def mean(self):
        """
        Returns:
            The estimated mean value.

        """
        return self._array_backend.squeeze(self._m)

    @property
    def std(self):
        """
        Returns:
            The estimated standard deviation value.

        """
        return self._array_backend.squeeze(self._array_backend.sqrt(self._s / self._n))


class RunningExpWeightedAverage(Serializable):
    """
    Compute an exponentially weighted moving average.

    """
    def __init__(self, shape, alpha, backend, init_value=None):
        """
        Constructor.

        Args:
            shape (tuple): shape of the data to standardize;
            alpha (float): learning rate;
            backend (str): name of the backend to be used;
            init_value (np.ndarray): initial value of the filter.

        """
        assert backend in ["numpy", "torch"]
        self._shape = shape
        self._alpha = alpha
        self._array_backend = ArrayBackend.get_array_backend(backend)
        self.reset(init_value)

        self._add_save_attr(
            _shape='primitive',
            _alpha='primitive',
            _array_backend="pickle",
            _avg_value='primitive',
        )

    def reset(self, init_value=None):
        """
        Reset the mean and standard deviation.

        Args:
            init_value (Array): initial value of the filter.

        """
        if init_value is None:
            self._avg_value = self._array_backend.zeros(1, *self._shape)
        else:
            self._avg_value = self._array_backend.atleast_2d(self._array_backend.convert(init_value))

    def update_stats(self, value):
        """
        Update the statistics with the current data value.

        Args:
            value (Array): current data value to use for the update.

        """
        value = self._array_backend.atleast_2d(value)
        batch_size = len(value)
        for i in range(batch_size):
            self._avg_value = (1. - self._alpha) * self._avg_value + self._alpha * value[i]

    @property
    def mean(self):
        """
        Returns:
            The estimated mean value.

        """
        return self._avg_value


class RunningAveragedWindow(Serializable):
    """
    Compute the running average using a window of fixed size.

    """
    def __init__(self, shape, window_size, backend, init_value=None):
        """
        Constructor.

        Args:
            shape (tuple): shape of the data to standardize;
            window_size (int): size of the windows;
            backend (str): name of the backend to be used;
            init_value (np.ndarray): initial value of the filter.

        """
        assert backend in ["numpy", "torch"]
        self._shape = shape
        self._window_size = window_size
        self._array_backend = ArrayBackend.get_array_backend(backend)
        self.reset(init_value)

        self._add_save_attr(
            _shape='primitive',
            _window_size='primitive',
            _array_backend='pickle',
            _avg_buffer='primitive',
        )

    def reset(self, init_value=None):
        """
        Reset the window.

        Args:
            init_value (np.ndarray): initial value of the filter.

        """
        if init_value is None:
            self._avg_buffer = deque(self._array_backend.zeros(1, *self._shape),
                                     maxlen=self._window_size)
        else:
            self._avg_buffer = deque([self._array_backend.convert(init_value)], maxlen=self._window_size)

    def update_stats(self, value):
        """
        Update the statistics with the current data value.

        Args:
            value (np.ndarray): current data value to use for the update.

        """
        value = self._array_backend.atleast_2d(value)
        batch_size = len(value)
        for i in range(batch_size):
            self._avg_buffer.append(value[i])

    @property
    def mean(self):
        """
        Returns:
            The estimated mean value.

        """
        return self._array_backend.convert(self._avg_buffer).mean(0)