MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/core/dataset.py

Summary

Maintainability
D
2 days
Test Coverage
A
93%
import numpy as np

from collections import defaultdict

import torch

from mushroom_rl.core.serialization import Serializable
from .array_backend import ArrayBackend

from ._impl import *


class DatasetInfo(Serializable):
    def __init__(self, backend, device, horizon, gamma, state_shape, state_dtype, action_shape, action_dtype,
                 policy_state_shape, n_envs=1):
        assert backend == "torch" or device is None

        self.backend = backend
        self.device = device
        self.horizon = horizon
        self.gamma = gamma
        self.state_shape = state_shape
        self.state_dtype = state_dtype
        self.action_shape = action_shape
        self.action_dtype = action_dtype
        self.policy_state_shape = policy_state_shape
        self.n_envs = n_envs

        super().__init__()

        self._add_save_attr(
            backend='primitive',
            gamma='primitive',
            horizon='primitive',
            state_shape='primitive',
            state_dtype='primitive',
            action_shape='primitive',
            action_dtype='primitive',
            policy_state_shape='primitive',
            n_envs='primitive'
        )

    @property
    def is_agent_stateful(self):
        return self.policy_state_shape is not None

    @staticmethod
    def create_dataset_info(mdp_info, agent_info, n_envs=1, device=None):
        backend = mdp_info.backend
        horizon = mdp_info.horizon
        gamma = mdp_info.gamma
        state_shape = mdp_info.observation_space.shape
        state_dtype = mdp_info.observation_space.data_type
        action_shape = mdp_info.action_space.shape
        action_dtype = mdp_info.action_space.data_type
        policy_state_shape = agent_info.policy_state_shape

        return DatasetInfo(backend, device, horizon, gamma, state_shape, state_dtype,
                           action_shape, action_dtype, policy_state_shape, n_envs)

    @staticmethod
    def create_replay_memory_info(mdp_info, agent_info, device=None):
        backend = agent_info.backend
        horizon = mdp_info.horizon
        gamma = mdp_info.gamma
        state_shape = mdp_info.observation_space.shape
        state_dtype = mdp_info.observation_space.data_type  # FIXME: this may cause issues, needs fix
        action_shape = mdp_info.action_space.shape
        action_dtype = mdp_info.action_space.data_type  # FIXME: this may cause issues, needs fix
        policy_state_shape = agent_info.policy_state_shape

        return DatasetInfo(backend, device, horizon, gamma, state_shape, state_dtype,
                           action_shape, action_dtype, policy_state_shape)


class Dataset(Serializable):
    def __init__(self, dataset_info, n_steps=None, n_episodes=None):
        assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None)

        self._array_backend = ArrayBackend.get_array_backend(dataset_info.backend)

        if n_steps is not None:
            n_samples = n_steps
        else:
            horizon = dataset_info.horizon
            assert np.isfinite(horizon)

            n_samples = horizon * n_episodes

        if dataset_info.n_envs == 1:
            base_shape = (n_samples,)
            mask_shape = None
        else:
            base_shape = (n_samples, dataset_info.n_envs)
            mask_shape = base_shape

        state_shape = base_shape + dataset_info.state_shape
        action_shape = base_shape + dataset_info.action_shape
        reward_shape = base_shape

        if dataset_info.is_agent_stateful:
            policy_state_shape = base_shape + dataset_info.policy_state_shape
        else:
            policy_state_shape = None

        self._info = defaultdict(list)
        self._episode_info = defaultdict(list)
        self._theta_list = list()

        if dataset_info.backend == 'numpy':
            self._data = NumpyDataset(dataset_info.state_dtype, state_shape,
                                      dataset_info.action_dtype, action_shape,
                                      reward_shape, base_shape,
                                      policy_state_shape, mask_shape)
        elif dataset_info.backend == 'torch':
            self._data = TorchDataset(dataset_info.state_dtype, state_shape,
                                      dataset_info.action_dtype, action_shape, reward_shape, base_shape,
                                      policy_state_shape, mask_shape, device=dataset_info.device)
        else:
            self._data = ListDataset(policy_state_shape is not None, mask_shape is not None)

        self._dataset_info = dataset_info

        super().__init__()

        self._add_all_save_attr()

    @classmethod
    def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1):
        dataset_info = DatasetInfo.create_dataset_info(mdp_info, agent_info, n_envs)

        return cls(dataset_info, n_steps, n_episodes)

    @classmethod
    def create_raw_instance(cls, dataset=None):
        """
        Creates an empty instance of the Dataset and populates essential data structures

        Args:
            dataset (Dataset, None): a template dataset to be used to create the new instance.

        Returns:
            A new empty instance of the dataset.

        """
        new_dataset = cls.__new__(cls)

        if dataset is not None:
            new_dataset._array_backend = dataset._array_backend
            new_dataset._dataset_info = dataset._dataset_info
        else:
            new_dataset._dataset_info = None

        new_dataset._info = None
        new_dataset._episode_info = None
        new_dataset._data = None
        new_dataset._theta_list = None

        new_dataset._add_all_save_attr()

        return new_dataset

    @classmethod
    def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
                   policy_state=None, policy_next_state=None, info=None, episode_info=None, theta_list=None,
                   horizon=None, gamma=0.99, backend='numpy', device=None):
        """
        Creates a dataset of transitions from the provided arrays.

        Args:
            states (array): array of states;
            actions (array): array of actions;
            rewards (array): array of rewards;
            next_states (array): array of next_states;
            absorbings (array): array of absorbing flags;
            lasts (array): array of last flags;
            policy_state (array, None): array of policy internal states;
            policy_next_state (array, None): array of next policy internal states;
            info (dict, None): dictiornay of step info;
            episode_info (dict, None): dictiornary of episode info;
            theta_list (list, None): list of policy parameters;
            horizon (int, None): horizon of the mdp;
            gamma (float, 0.99): discount factor;
            backend (str, 'numpy'): backend to be used by the dataset.

        Returns:
            The list of transitions.

        """
        assert len(states) == len(actions) == len(rewards) == len(next_states) == len(absorbings) == len(lasts)

        if policy_state is not None:
            assert len(states) == len(policy_state) == len(policy_next_state)

        dataset = cls.create_raw_instance()

        if info is None:
            dataset._info = defaultdict(list)
        else:
            dataset._info = info.copy()

        if episode_info is None:
            dataset._episode_info = defaultdict(list)
        else:
            dataset._episode_info = episode_info.copy()

        if theta_list is None:
            dataset._theta_list = list()
        else:
            dataset._theta_list = theta_list

        dataset._array_backend = ArrayBackend.get_array_backend(backend)
        if backend == 'numpy':
            dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)
        elif backend == 'torch':
            dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)
        else:
            dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)

        state_shape = states.shape[1:]
        action_shape = actions.shape[1:]
        policy_state_shape = None if policy_state is None else policy_state.shape[1:]

        dataset._dataset_info = DatasetInfo(backend, device, horizon, gamma, state_shape, states.dtype,
                                            action_shape, actions.dtype, policy_state_shape)

        return dataset

    def append(self, step, info):
        self._data.append(*step)
        self._append_info(self._info, info)

    def append_episode_info(self, info):
        self._append_info(self._episode_info, info)

    def append_theta(self, theta):
        self._theta_list.append(theta)

    def get_info(self, field, index=None):
        if index is None:
            return self._info[field]
        else:
            return self._info[field][index]

    def clear(self):
        self._episode_info = defaultdict(list)
        self._theta_list = list()
        self._info = defaultdict(list)

        self._data.clear()

    def get_view(self, index, copy=False):
        dataset = self.create_raw_instance(dataset=self)

        info_slice = defaultdict(list)
        for key in self._info.keys():
            info_slice[key] = self._info[key][index]

        dataset._info = info_slice
        dataset._episode_info = defaultdict(list)
        dataset._data = self._data.get_view(index, copy)

        return dataset

    def item(self):
        assert len(self) == 1
        return self[0]

    def __getitem__(self, index):
        if isinstance(index, (slice, np.ndarray)) or isinstance(index, (slice, torch.Tensor)):
            return self.get_view(index)
        elif isinstance(index, int) and index < len(self._data):
            return self._data[index]
        else:
            raise IndexError

    def __add__(self, other):
        result = self.create_raw_instance(dataset=self)
        new_info = self._merge_info(self.info, other.info)
        new_episode_info = self._merge_info(self.episode_info, other.episode_info)

        result._info = new_info
        result._episode_info = new_episode_info
        result._theta_list = self._theta_list + other._theta_list
        result._data = self._data + other._data

        return result

    def __len__(self):
        return len(self._data)

    @property
    def state(self):
        return self._data.state

    @property
    def action(self):
        return self._data.action

    @property
    def reward(self):
        return self._data.reward

    @property
    def next_state(self):
        return self._data.next_state

    @property
    def absorbing(self):
        return self._data.absorbing

    @property
    def last(self):
        return self._data.last

    @property
    def policy_state(self):
        return self._data.policy_state

    @property
    def policy_next_state(self):
        return self._data.policy_next_state

    @property
    def info(self):
        return self._info

    @property
    def episode_info(self):
        return self._episode_info

    @property
    def theta_list(self):
        return self._theta_list

    @property
    def episodes_length(self):
        """
        Compute the length of each episode in the dataset.

        Returns:
            A list of length of each episode in the dataset.

        """
        lengths = list()
        l = 0
        for sample in self:
            l += 1
            if sample[-1] == 1:
                lengths.append(l)
                l = 0

        return self._array_backend.from_list(lengths)

    @property
    def n_episodes(self):
        return self._data.n_episodes

    @property
    def undiscounted_return(self):
        return self.compute_J()

    @property
    def discounted_return(self):
        return self.compute_J(self._dataset_info.gamma)

    @property
    def array_backend(self):
        return self._array_backend

    @property
    def is_stateful(self):
        return self._data.is_stateful

    def parse(self, to=None):
        """
        Return the dataset as set of arrays.
        Args:
            to (str, None):  the backend to be used for the returned arrays. By default, the dataset backend is used.

        Returns:
            A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last

        """
        if to is None:
            to = self._array_backend.get_backend_name()
        return self._convert(self.state, self.action, self.reward, self.next_state, self.absorbing, self.last, to=to)

    def parse_policy_state(self, to=None):
        """
        Return the dataset as set of arrays.

        Args:
            to (str, None):  the backend to be used for the returned arrays. By default, the dataset backend is used.

        Returns:
            A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last

        """
        if to is None:
            to = self._array_backend.get_backend_name()
        return self._convert(self.policy_state, self.policy_next_state, to=to)

    def select_first_episodes(self, n_episodes):
        """
        Return the first ``n_episodes`` episodes in the provided dataset.

        Args:
            n_episodes (int): the number of episodes to pick from the dataset;

        Returns:
            A subset of the dataset containing the first ``n_episodes`` episodes.

        """
        assert n_episodes > 0, 'Number of episodes must be greater than zero.'

        last_idxs = np.argwhere(self.last).ravel()
        return self[:last_idxs[n_episodes - 1] + 1]

    def select_random_samples(self, n_samples):
        """
        Return the randomly picked desired number of samples in the provided
        dataset.

        Args:
            n_samples (int): the number of samples to pick from the dataset.

        Returns:
            A subset of the dataset containing randomly picked ``n_samples``
            samples.

        """
        assert n_samples >= 0, 'Number of samples must be greater than or equal to zero.'

        if n_samples == 0:
            return np.array([[]])

        idxs = np.random.randint(len(self), size=n_samples)

        return self[idxs]

    def get_init_states(self):
        """
        Get the initial states of a dataset

        Returns:
            An array of initial states of the considered dataset.

        """
        pick = True
        x_0 = list()
        for step in self:
            if pick:
                x_0.append(step[0])
            pick = step[-1]
        return self._array_backend.from_list(x_0)

    def compute_J(self, gamma=1.):
        """
        Compute the cumulative discounted reward of each episode in the dataset.

        Args:
            gamma (float, 1.): discount factor.

        Returns:
            The cumulative discounted reward of each episode in the dataset.

        """
        js = list()

        j = 0.
        episode_steps = 0
        for i in range(len(self)):
            j += gamma ** episode_steps * self.reward[i]
            episode_steps += 1
            if self.last[i] or i == len(self) - 1:
                js.append(j)
                j = 0.
                episode_steps = 0

        if len(js) == 0:
            js = [0.]

        return self._array_backend.from_list(js)

    def compute_metrics(self, gamma=1.):
        """
        Compute the metrics of each complete episode in the dataset.

        Args:
            gamma (float, 1.): the discount factor.

        Returns:
            The minimum score reached in an episode,
            the maximum score reached in an episode,
            the mean score reached,
            the median score reached,
            the number of completed episodes.

            If no episode has been completed, it returns 0 for all values.

        """
        i = 0
        for i in reversed(range(len(self))):
            if self.last[i]:
                i += 1
                break

        dataset = self[:i]

        if len(dataset) > 0:
            J = dataset.compute_J(gamma)
            median = self._array_backend.median(J)
            return J.min(), J.max(), J.mean(), median, len(J)
        else:
            return 0, 0, 0, 0, 0

    def _convert(self, *arrays, to='numpy'):
        if to == 'numpy':
            return self._array_backend.arrays_to_numpy(*arrays)
        elif to == 'torch':
            return self._array_backend.arrays_to_torch(*arrays)
        else:
            return NotImplementedError

    def _add_all_save_attr(self):
        self._add_save_attr(
            _info='pickle',
            _episode_info='pickle',
            _theta_list='pickle',
            _data='mushroom',
            _array_backend='primitive',
            _dataset_info='mushroom'
        )

    @staticmethod
    def _append_info(info, step_info):
        for key, value in step_info.items():
            info[key].append(value)

    @staticmethod
    def _merge_info(info, other_info):
        new_info = defaultdict(list)
        for key in info.keys():
            new_info[key] = info[key] + other_info[key]
        return new_info


class VectorizedDataset(Dataset):
    def __init__(self, dataset_info, n_steps=None, n_episodes=None):
        super().__init__(dataset_info, n_steps, n_episodes)

        self._initialize_theta_list(self._dataset_info.n_envs)

    def append(self, step, info):
        raise RuntimeError("Trying to use append on a vectorized dataset")

    def append_vectorized(self, step, info, mask):
        self._data.append(*step, mask=mask)
        self._append_info(self._info, {})  # FIXME: handle properly info

    def append_theta_vectorized(self, theta, mask):
        for i in range(len(theta)):
            if mask[i]:
                self._theta_list[i].append(theta[i])

    def clear(self, n_steps_per_fit=None):
        n_envs = len(self._theta_list)

        residual_data = None
        if n_steps_per_fit is not None:
            n_steps_dataset = self._data.mask.sum().item()

            if n_steps_dataset > n_steps_per_fit:
                n_extra_steps = n_steps_dataset - n_steps_per_fit
                n_parallel_steps = int(np.ceil(n_extra_steps / self._dataset_info.n_envs))
                view_size = slice(-n_parallel_steps, None)
                residual_data = self._data.get_view(view_size, copy=True)
                mask = residual_data.mask
                original_shape = mask.shape
                mask.flatten()[n_extra_steps:] = False
                residual_data.mask = mask.reshape(original_shape)

        super().clear()
        self._initialize_theta_list(n_envs)

        if n_steps_per_fit is not None and residual_data is not None:
            self._data = residual_data

    def flatten(self, n_steps_per_fit=None):
        if len(self) == 0:
            return None

        states = self._array_backend.pack_padded_sequence(self._data.state, self._data.mask)
        actions = self._array_backend.pack_padded_sequence(self._data.action, self._data.mask)
        rewards = self._array_backend.pack_padded_sequence(self._data.reward, self._data.mask)
        next_states = self._array_backend.pack_padded_sequence(self._data.next_state, self._data.mask)
        absorbings = self._array_backend.pack_padded_sequence(self._data.absorbing, self._data.mask)

        last_padded = self._data.last
        last_padded[-1, :] = True
        lasts = self._array_backend.pack_padded_sequence(last_padded, self._data.mask)

        policy_state = None
        policy_next_state = None

        if self._data.is_stateful:
            policy_state = self._array_backend.pack_padded_sequence(self._data.policy_state, self._data.mask)
            policy_next_state = self._array_backend.pack_padded_sequence(self._data.policy_next_state, self._data.mask)

        if n_steps_per_fit is not None:
            states = states[:n_steps_per_fit]
            actions = actions[:n_steps_per_fit]
            rewards = rewards[:n_steps_per_fit]
            next_states = next_states[:n_steps_per_fit]
            absorbings = absorbings[:n_steps_per_fit]
            lasts = lasts[:n_steps_per_fit]

            if self._data.is_stateful:
                policy_state = policy_state[:n_steps_per_fit]
                policy_next_state = policy_next_state[:n_steps_per_fit]

        flat_theta_list = self._flatten_theta_list()

        return Dataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
                                  policy_state=policy_state, policy_next_state=policy_next_state,
                                  info=None, episode_info=None, theta_list=flat_theta_list,  # FIXME: handle properly info
                                  horizon=self._dataset_info.horizon, gamma=self._dataset_info.gamma,
                                  backend=self._array_backend.get_backend_name())

    def _flatten_theta_list(self):
        flat_theta_list = list()

        for env_theta_list in self._theta_list:
            flat_theta_list += env_theta_list

        return flat_theta_list

    def _initialize_theta_list(self, n_envs):
        self._theta_list = list()
        for i in range(n_envs):
            self._theta_list.append(list())

    @property
    def mask(self):
        return self._data.mask