MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/core/_impl/torch_dataset.py

Summary

Maintainability
F
1 wk
Test Coverage
B
84%
import torch

from mushroom_rl.core.serialization import Serializable
from mushroom_rl.utils.torch import TorchUtils


class TorchDataset(Serializable):
    def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, flag_shape,
                 policy_state_shape, mask_shape, device=None):

        self._device = TorchUtils.get_device(device)
        self._state_type = state_type
        self._action_type = action_type

        self._states = torch.empty(*state_shape, dtype=self._state_type, device=self._device)
        self._actions = torch.empty(*action_shape, dtype=self._action_type, device=self._device)
        self._rewards = torch.empty(*reward_shape, dtype=torch.float, device=self._device)
        self._next_states = torch.empty(*state_shape, dtype=self._state_type, device=self._device)
        self._absorbing = torch.empty(flag_shape, dtype=torch.bool, device=self._device)
        self._last = torch.empty(flag_shape, dtype=torch.bool, device=self._device)
        self._len = 0

        if policy_state_shape is None:
            self._policy_states = None
            self._policy_next_states = None
        else:
            self._policy_states = torch.empty(policy_state_shape, dtype=torch.float, device=self._device)
            self._policy_next_states = torch.empty(policy_state_shape, dtype=torch.float, device=self._device)

        if mask_shape is None:
            self._mask = None
        else:
            self._mask = torch.empty(mask_shape, dtype=torch.bool, device=self._device)

        self._add_all_save_attr()

    @classmethod
    def create_new_instance(cls, dataset=None):
        """
        Creates an empty instance of the Dataset and populates essential data structures

        Args:
            dataset (NumpyDataset, None): a template dataset to be used to create the new instance.

        Returns:
            A new empty instance of the dataset.

        """
        new_dataset = cls.__new__(cls)

        if dataset is not None:
            new_dataset._device = dataset._device
            new_dataset._state_type = dataset._state_type
            new_dataset._action_type = dataset._action_type
        else:
            new_dataset._device = None
            new_dataset._state_type = None
            new_dataset._action_type = None

        new_dataset._states = None
        new_dataset._actions = None
        new_dataset._rewards = None
        new_dataset._next_states = None
        new_dataset._absorbing = None
        new_dataset._last = None
        new_dataset._len = None
        new_dataset._policy_states = None
        new_dataset._policy_next_states = None
        new_dataset._mask = None

        new_dataset._add_all_save_attr()

        return new_dataset

    @classmethod
    def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
                   policy_states=None, policy_next_states=None):
        if not isinstance(states, torch.Tensor):
            states = torch.as_tensor(states)
            actions = torch.as_tensor(actions)
            rewards = torch.as_tensor(rewards)
            next_states = torch.as_tensor(next_states)
            absorbings = torch.as_tensor(absorbings)
            lasts = torch.as_tensor(lasts)

        dataset = cls.create_new_instance()

        dataset._state_type = states.dtype
        dataset._action_type = actions.dtype

        dataset._states = torch.as_tensor(states)
        dataset._actions = torch.as_tensor(actions)
        dataset._rewards = torch.as_tensor(rewards)
        dataset._next_states = torch.as_tensor(next_states)
        dataset._absorbing = torch.as_tensor(absorbings, dtype=torch.bool)
        dataset._last = torch.as_tensor(lasts, dtype=torch.bool)
        dataset._len = len(lasts)

        if policy_states is not None and policy_next_states is not None:
            if not isinstance(policy_states, torch.Tensor):
                policy_states = torch.as_tensor(policy_states)
                policy_next_states = torch.as_tensor(policy_next_states)

            dataset._policy_states = policy_states
            dataset._policy_next_states = policy_next_states
        else:
            dataset._policy_states = None
            dataset._policy_next_states = None

        return dataset

    def __len__(self):
        return self._len

    def append(self, state, action, reward, next_state, absorbing, last, policy_state=None, policy_next_state=None,
               mask=None):
        i = self._len   # todo: handle index out of bounds?

        self._states[i] = state
        self._actions[i] = action
        self._rewards[i] = reward
        self._next_states[i] = next_state
        self._absorbing[i] = absorbing
        self._last[i] = last

        if self.is_stateful:
            self._policy_states[i] = policy_state
            self._policy_next_states[i] = policy_next_state

        if mask is not None:
            self._mask[i] = mask

        self._len += 1

    def clear(self):
        self._states = torch.empty_like(self._states)
        self._actions = torch.empty_like(self._actions)
        self._rewards = torch.empty_like(self._rewards)
        self._next_states = torch.empty_like(self._next_states)
        self._absorbing = torch.empty_like(self._absorbing)
        self._last = torch.empty_like(self._last)

        if self.is_stateful:
            self._policy_states = torch.empty_like(self._policy_states)
            self._policy_next_states = torch.empty_like(self._policy_next_states)

        self._len = 0

    def get_view(self, index, copy=False):
        view = self.create_new_instance(self)

        if copy:
            view._states = torch.empty_like(self._states)
            view._actions = torch.empty_like(self._actions)
            view._rewards = torch.empty_like(self._rewards)
            view._next_states = torch.empty_like(self._next_states)
            view._absorbing = torch.empty_like(self._absorbing)
            view._last = torch.empty_like(self._last)

            new_states = self.state[index, ...]
            new_len = new_states.shape[0]

            view._states[:new_len] = new_states
            view._actions[:new_len] = self.action[index, ...]
            view._rewards[:new_len] = self.reward[index, ...]
            view._next_states[:new_len] = self.next_state[index, ...]
            view._absorbing[:new_len] = self.absorbing[index, ...]
            view._last[:new_len] = self.last[index, ...]
            view._len = new_len

            if self.is_stateful:
                view._policy_states = torch.empty_like(self._policy_states)
                view._policy_next_states = torch.empty_like(self._policy_next_states)

                view._policy_states[:new_len] = self._policy_states[index, ...]
                view._policy_next_states[:new_len] = self._policy_next_states[index, ...]

            if self._mask is not None:
                view._mask = torch.empty_like(self._mask)
                view._mask[:new_len] = self._mask[index, ...]
        else:
            view._states = self.state[index, ...]
            view._actions = self.action[index, ...]
            view._rewards = self.reward[index, ...]
            view._next_states = self.next_state[index, ...]
            view._absorbing = self.absorbing[index, ...]
            view._last = self.last[index, ...]
            view._len = view._states.shape[0]

            if self.is_stateful:
                view._policy_states = self._policy_states[index, ...]
                view._policy_next_states = self._policy_next_states[index, ...]

            if self._mask is not None:
                view._mask = self._mask[index, ...]

        return view

    def __getitem__(self, index):
        return self._states[index], self._actions[index], self._rewards[index], self._next_states[index], \
               self._absorbing[index], self._last[index]

    def __add__(self, other):
        result = self.create_new_instance(self)

        result._states = torch.concatenate((self.state, other.state))
        result._actions = torch.concatenate((self.action, other.action))
        result._rewards = torch.concatenate((self.reward, other.reward))
        result._next_states = torch.concatenate((self.next_state, other.next_state))
        result._absorbing = torch.concatenate((self.absorbing, other.absorbing))
        result._last = torch.concatenate((self.last, other.last))
        result._last[len(self) - 1] = True
        result._len = len(self) + len(other)

        if self.is_stateful:
            result._policy_states = torch.concatenate((self.policy_state, other.policy_state))
            result._policy_next_states = torch.concatenate((self.policy_next_state, other.policy_next_state))

        return result

    @property
    def state(self):
        return self._states[:len(self)]

    @property
    def action(self):
        return self._actions[:len(self)]

    @property
    def reward(self):
        return self._rewards[:len(self)]

    @property
    def next_state(self):
        return self._next_states[:len(self)]

    @property
    def absorbing(self):
        return self._absorbing[:len(self)]

    @property
    def last(self):
        return self._last[:len(self)]

    @property
    def policy_state(self):
        return self._policy_states[:len(self)]

    @property
    def policy_next_state(self):
        return self._policy_next_states[:len(self)]

    @property
    def is_stateful(self):
        return self._policy_states is not None

    @property
    def mask(self):
        return self._mask[:len(self)]

    @mask.setter
    def mask(self, new_mask):
        self._mask[:len(self)] = new_mask

    @property
    def n_episodes(self):
        n_episodes = self.last.sum()

        if not self.last[-1]:
            n_episodes += 1

        return n_episodes

    def _add_all_save_attr(self):
        self._add_save_attr(
            _state_type='primitive',
            _action_type='primitive',
            _device='primitive',
            _states='torch',
            _actions='torch',
            _rewards='torch',
            _next_states='torch',
            _absorbing='torch',
            _last='torch',
            _policy_states='numpy',
            _policy_next_states='numpy',
            _len='primitive'
        )