
View on GitHub


2 days
Test Coverage
import numpy as np

from collections import defaultdict

import torch

from mushroom_rl.core.serialization import Serializable
from .array_backend import ArrayBackend

from ._impl import *

class DatasetInfo(Serializable):
    def __init__(self, backend, device, horizon, gamma, state_shape, state_dtype, action_shape, action_dtype,
                 policy_state_shape, n_envs=1):
        assert backend == "torch" or device is None

        self.backend = backend
        self.device = device
        self.horizon = horizon
        self.gamma = gamma
        self.state_shape = state_shape
        self.state_dtype = state_dtype
        self.action_shape = action_shape
        self.action_dtype = action_dtype
        self.policy_state_shape = policy_state_shape
        self.n_envs = n_envs



    def is_agent_stateful(self):
        return self.policy_state_shape is not None

    def create_dataset_info(mdp_info, agent_info, n_envs=1, device=None):
        backend = mdp_info.backend
        horizon = mdp_info.horizon
        gamma = mdp_info.gamma
        state_shape = mdp_info.observation_space.shape
        state_dtype = mdp_info.observation_space.data_type
        action_shape = mdp_info.action_space.shape
        action_dtype = mdp_info.action_space.data_type
        policy_state_shape = agent_info.policy_state_shape

        return DatasetInfo(backend, device, horizon, gamma, state_shape, state_dtype,
                           action_shape, action_dtype, policy_state_shape, n_envs)

    def create_replay_memory_info(mdp_info, agent_info, device=None):
        backend = agent_info.backend
        horizon = mdp_info.horizon
        gamma = mdp_info.gamma
        state_shape = mdp_info.observation_space.shape
        state_dtype = mdp_info.observation_space.data_type  # FIXME: this may cause issues, needs fix
        action_shape = mdp_info.action_space.shape
        action_dtype = mdp_info.action_space.data_type  # FIXME: this may cause issues, needs fix
        policy_state_shape = agent_info.policy_state_shape

        return DatasetInfo(backend, device, horizon, gamma, state_shape, state_dtype,
                           action_shape, action_dtype, policy_state_shape)

class Dataset(Serializable):
    def __init__(self, dataset_info, n_steps=None, n_episodes=None):
        assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None)

        self._array_backend = ArrayBackend.get_array_backend(dataset_info.backend)

        if n_steps is not None:
            n_samples = n_steps
            horizon = dataset_info.horizon
            assert np.isfinite(horizon)

            n_samples = horizon * n_episodes

        if dataset_info.n_envs == 1:
            base_shape = (n_samples,)
            mask_shape = None
            base_shape = (n_samples, dataset_info.n_envs)
            mask_shape = base_shape

        state_shape = base_shape + dataset_info.state_shape
        action_shape = base_shape + dataset_info.action_shape
        reward_shape = base_shape

        if dataset_info.is_agent_stateful:
            policy_state_shape = base_shape + dataset_info.policy_state_shape
            policy_state_shape = None

        self._info = defaultdict(list)
        self._episode_info = defaultdict(list)
        self._theta_list = list()

        if dataset_info.backend == 'numpy':
            self._data = NumpyDataset(dataset_info.state_dtype, state_shape,
                                      dataset_info.action_dtype, action_shape,
                                      reward_shape, base_shape,
                                      policy_state_shape, mask_shape)
        elif dataset_info.backend == 'torch':
            self._data = TorchDataset(dataset_info.state_dtype, state_shape,
                                      dataset_info.action_dtype, action_shape, reward_shape, base_shape,
                                      policy_state_shape, mask_shape, device=dataset_info.device)
            self._data = ListDataset(policy_state_shape is not None, mask_shape is not None)

        self._dataset_info = dataset_info



    def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1):
        dataset_info = DatasetInfo.create_dataset_info(mdp_info, agent_info, n_envs)

        return cls(dataset_info, n_steps, n_episodes)

    def create_raw_instance(cls, dataset=None):
        Creates an empty instance of the Dataset and populates essential data structures

            dataset (Dataset, None): a template dataset to be used to create the new instance.

            A new empty instance of the dataset.

        new_dataset = cls.__new__(cls)

        if dataset is not None:
            new_dataset._array_backend = dataset._array_backend
            new_dataset._dataset_info = dataset._dataset_info
            new_dataset._dataset_info = None

        new_dataset._info = None
        new_dataset._episode_info = None
        new_dataset._data = None
        new_dataset._theta_list = None


        return new_dataset

    def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
                   policy_state=None, policy_next_state=None, info=None, episode_info=None, theta_list=None,
                   horizon=None, gamma=0.99, backend='numpy', device=None):
        Creates a dataset of transitions from the provided arrays.

            states (array): array of states;
            actions (array): array of actions;
            rewards (array): array of rewards;
            next_states (array): array of next_states;
            absorbings (array): array of absorbing flags;
            lasts (array): array of last flags;
            policy_state (array, None): array of policy internal states;
            policy_next_state (array, None): array of next policy internal states;
            info (dict, None): dictiornay of step info;
            episode_info (dict, None): dictiornary of episode info;
            theta_list (list, None): list of policy parameters;
            horizon (int, None): horizon of the mdp;
            gamma (float, 0.99): discount factor;
            backend (str, 'numpy'): backend to be used by the dataset.

            The list of transitions.

        assert len(states) == len(actions) == len(rewards) == len(next_states) == len(absorbings) == len(lasts)

        if policy_state is not None:
            assert len(states) == len(policy_state) == len(policy_next_state)

        dataset = cls.create_raw_instance()

        if info is None:
            dataset._info = defaultdict(list)
            dataset._info = info.copy()

        if episode_info is None:
            dataset._episode_info = defaultdict(list)
            dataset._episode_info = episode_info.copy()

        if theta_list is None:
            dataset._theta_list = list()
            dataset._theta_list = theta_list

        dataset._array_backend = ArrayBackend.get_array_backend(backend)
        if backend == 'numpy':
            dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)
        elif backend == 'torch':
            dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)
            dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)

        state_shape = states.shape[1:]
        action_shape = actions.shape[1:]
        policy_state_shape = None if policy_state is None else policy_state.shape[1:]

        dataset._dataset_info = DatasetInfo(backend, device, horizon, gamma, state_shape, states.dtype,
                                            action_shape, actions.dtype, policy_state_shape)

        return dataset

    def append(self, step, info):
        self._append_info(self._info, info)

    def append_episode_info(self, info):
        self._append_info(self._episode_info, info)

    def append_theta(self, theta):

    def get_info(self, field, index=None):
        if index is None:
            return self._info[field]
            return self._info[field][index]

    def clear(self):
        self._episode_info = defaultdict(list)
        self._theta_list = list()
        self._info = defaultdict(list)


    def get_view(self, index, copy=False):
        dataset = self.create_raw_instance(dataset=self)

        info_slice = defaultdict(list)
        for key in self._info.keys():
            info_slice[key] = self._info[key][index]

        dataset._info = info_slice
        dataset._episode_info = defaultdict(list)
        dataset._data = self._data.get_view(index, copy)

        return dataset

    def item(self):
        assert len(self) == 1
        return self[0]

    def __getitem__(self, index):
        if isinstance(index, (slice, np.ndarray)) or isinstance(index, (slice, torch.Tensor)):
            return self.get_view(index)
        elif isinstance(index, int) and index < len(self._data):
            return self._data[index]
            raise IndexError

    def __add__(self, other):
        result = self.create_raw_instance(dataset=self)
        new_info = self._merge_info(,
        new_episode_info = self._merge_info(self.episode_info, other.episode_info)

        result._info = new_info
        result._episode_info = new_episode_info
        result._theta_list = self._theta_list + other._theta_list
        result._data = self._data + other._data

        return result

    def __len__(self):
        return len(self._data)

    def state(self):
        return self._data.state

    def action(self):
        return self._data.action

    def reward(self):
        return self._data.reward

    def next_state(self):
        return self._data.next_state

    def absorbing(self):
        return self._data.absorbing

    def last(self):
        return self._data.last

    def policy_state(self):
        return self._data.policy_state

    def policy_next_state(self):
        return self._data.policy_next_state

    def info(self):
        return self._info

    def episode_info(self):
        return self._episode_info

    def theta_list(self):
        return self._theta_list

    def episodes_length(self):
        Compute the length of each episode in the dataset.

            A list of length of each episode in the dataset.

        lengths = list()
        l = 0
        for sample in self:
            l += 1
            if sample[-1] == 1:
                l = 0

        return self._array_backend.from_list(lengths)

    def n_episodes(self):
        return self._data.n_episodes

    def undiscounted_return(self):
        return self.compute_J()

    def discounted_return(self):
        return self.compute_J(self._dataset_info.gamma)

    def array_backend(self):
        return self._array_backend

    def is_stateful(self):
        return self._data.is_stateful

    def parse(self, to=None):
        Return the dataset as set of arrays.
            to (str, None):  the backend to be used for the returned arrays. By default, the dataset backend is used.

            A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last

        if to is None:
            to = self._array_backend.get_backend_name()
        return self._convert(self.state, self.action, self.reward, self.next_state, self.absorbing, self.last, to=to)

    def parse_policy_state(self, to=None):
        Return the dataset as set of arrays.

            to (str, None):  the backend to be used for the returned arrays. By default, the dataset backend is used.

            A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last

        if to is None:
            to = self._array_backend.get_backend_name()
        return self._convert(self.policy_state, self.policy_next_state, to=to)

    def select_first_episodes(self, n_episodes):
        Return the first ``n_episodes`` episodes in the provided dataset.

            n_episodes (int): the number of episodes to pick from the dataset;

            A subset of the dataset containing the first ``n_episodes`` episodes.

        assert n_episodes > 0, 'Number of episodes must be greater than zero.'

        last_idxs = np.argwhere(self.last).ravel()
        return self[:last_idxs[n_episodes - 1] + 1]

    def select_random_samples(self, n_samples):
        Return the randomly picked desired number of samples in the provided

            n_samples (int): the number of samples to pick from the dataset.

            A subset of the dataset containing randomly picked ``n_samples``

        assert n_samples >= 0, 'Number of samples must be greater than or equal to zero.'

        if n_samples == 0:
            return np.array([[]])

        idxs = np.random.randint(len(self), size=n_samples)

        return self[idxs]

    def get_init_states(self):
        Get the initial states of a dataset

            An array of initial states of the considered dataset.

        pick = True
        x_0 = list()
        for step in self:
            if pick:
            pick = step[-1]
        return self._array_backend.from_list(x_0)

    def compute_J(self, gamma=1.):
        Compute the cumulative discounted reward of each episode in the dataset.

            gamma (float, 1.): discount factor.

            The cumulative discounted reward of each episode in the dataset.

        js = list()

        j = 0.
        episode_steps = 0
        for i in range(len(self)):
            j += gamma ** episode_steps * self.reward[i]
            episode_steps += 1
            if self.last[i] or i == len(self) - 1:
                j = 0.
                episode_steps = 0

        if len(js) == 0:
            js = [0.]

        return self._array_backend.from_list(js)

    def compute_metrics(self, gamma=1.):
        Compute the metrics of each complete episode in the dataset.

            gamma (float, 1.): the discount factor.

            The minimum score reached in an episode,
            the maximum score reached in an episode,
            the mean score reached,
            the median score reached,
            the number of completed episodes.

            If no episode has been completed, it returns 0 for all values.

        i = 0
        for i in reversed(range(len(self))):
            if self.last[i]:
                i += 1

        dataset = self[:i]

        if len(dataset) > 0:
            J = dataset.compute_J(gamma)
            median = self._array_backend.median(J)
            return J.min(), J.max(), J.mean(), median, len(J)
            return 0, 0, 0, 0, 0

    def _convert(self, *arrays, to='numpy'):
        if to == 'numpy':
            return self._array_backend.arrays_to_numpy(*arrays)
        elif to == 'torch':
            return self._array_backend.arrays_to_torch(*arrays)
            return NotImplementedError

    def _add_all_save_attr(self):

    def _append_info(info, step_info):
        for key, value in step_info.items():

    def _merge_info(info, other_info):
        new_info = defaultdict(list)
        for key in info.keys():
            new_info[key] = info[key] + other_info[key]
        return new_info

class VectorizedDataset(Dataset):
    def __init__(self, dataset_info, n_steps=None, n_episodes=None):
        super().__init__(dataset_info, n_steps, n_episodes)


    def append(self, step, info):
        raise RuntimeError("Trying to use append on a vectorized dataset")

    def append_vectorized(self, step, info, mask):
        self._data.append(*step, mask=mask)
        self._append_info(self._info, {})  # FIXME: handle properly info

    def append_theta_vectorized(self, theta, mask):
        for i in range(len(theta)):
            if mask[i]:

    def clear(self, n_steps_per_fit=None):
        n_envs = len(self._theta_list)

        residual_data = None
        if n_steps_per_fit is not None:
            n_steps_dataset = self._data.mask.sum().item()

            if n_steps_dataset > n_steps_per_fit:
                n_extra_steps = n_steps_dataset - n_steps_per_fit
                n_parallel_steps = int(np.ceil(n_extra_steps / self._dataset_info.n_envs))
                view_size = slice(-n_parallel_steps, None)
                residual_data = self._data.get_view(view_size, copy=True)
                mask = residual_data.mask
                original_shape = mask.shape
                mask.flatten()[n_extra_steps:] = False
                residual_data.mask = mask.reshape(original_shape)


        if n_steps_per_fit is not None and residual_data is not None:
            self._data = residual_data

    def flatten(self, n_steps_per_fit=None):
        if len(self) == 0:
            return None

        states = self._array_backend.pack_padded_sequence(self._data.state, self._data.mask)
        actions = self._array_backend.pack_padded_sequence(self._data.action, self._data.mask)
        rewards = self._array_backend.pack_padded_sequence(self._data.reward, self._data.mask)
        next_states = self._array_backend.pack_padded_sequence(self._data.next_state, self._data.mask)
        absorbings = self._array_backend.pack_padded_sequence(self._data.absorbing, self._data.mask)

        last_padded = self._data.last
        last_padded[-1, :] = True
        lasts = self._array_backend.pack_padded_sequence(last_padded, self._data.mask)

        policy_state = None
        policy_next_state = None

        if self._data.is_stateful:
            policy_state = self._array_backend.pack_padded_sequence(self._data.policy_state, self._data.mask)
            policy_next_state = self._array_backend.pack_padded_sequence(self._data.policy_next_state, self._data.mask)

        if n_steps_per_fit is not None:
            states = states[:n_steps_per_fit]
            actions = actions[:n_steps_per_fit]
            rewards = rewards[:n_steps_per_fit]
            next_states = next_states[:n_steps_per_fit]
            absorbings = absorbings[:n_steps_per_fit]
            lasts = lasts[:n_steps_per_fit]

            if self._data.is_stateful:
                policy_state = policy_state[:n_steps_per_fit]
                policy_next_state = policy_next_state[:n_steps_per_fit]

        flat_theta_list = self._flatten_theta_list()

        return Dataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
                                  policy_state=policy_state, policy_next_state=policy_next_state,
                                  info=None, episode_info=None, theta_list=flat_theta_list,  # FIXME: handle properly info
                                  horizon=self._dataset_info.horizon, gamma=self._dataset_info.gamma,

    def _flatten_theta_list(self):
        flat_theta_list = list()

        for env_theta_list in self._theta_list:
            flat_theta_list += env_theta_list

        return flat_theta_list

    def _initialize_theta_list(self, n_envs):
        self._theta_list = list()
        for i in range(n_envs):

    def mask(self):
        return self._data.mask