MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/environments/gymnasium_atari.py

Summary

Maintainability
C
1 day
Test Coverage
F
3%
from copy import deepcopy
from collections import deque

import gymnasium as gym

from mushroom_rl.core import Environment, MDPInfo
from mushroom_rl.rl_utils.spaces import *
from mushroom_rl.utils.frames import LazyFrames, preprocess_frame
from mushroom_rl.utils.viewer import ImageViewer

class MaxAndSkip(gym.Wrapper):
    def __init__(self, env, skip, max_pooling=True):
        gym.Wrapper.__init__(self, env)
        self._obs_buffer = np.zeros((2,) + env.observation_space.shape,
                                    dtype=np.uint8)
        self._skip = skip
        self._max_pooling = max_pooling

    def reset(self):
        return self.env.reset()

    def step(self, action):
        total_reward = 0.
        for i in range(self._skip):
            obs, reward, absorbing, _, info = self.env.step(action)
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if absorbing:
                break
        if self._max_pooling:
            frame = self._obs_buffer.max(axis=0)
        else:
            frame = self._obs_buffer.mean(axis=0)

        return frame, total_reward, absorbing, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class GymnasiumAtari(Environment):
    """
    The Atari environment as presented in:
    "Human-level control through deep reinforcement learning". Mnih et. al..
    2015.

    """
    def __init__(self, name, width=84, height=84, ends_at_life=False,
                 max_pooling=True, history_length=4, max_no_op_actions=30, headless = False):
        """
        Constructor.

        Args:
            name (str): id name of the Atari game in Gym;
            width (int, 84): width of the screen;
            height (int, 84): height of the screen;
            ends_at_life (bool, False): whether the episode ends when a life is
               lost or not;
            max_pooling (bool, True): whether to do max-pooling or
                average-pooling of the last two frames when using NoFrameskip;
            history_length (int, 4): number of frames to form a state;
            max_no_op_actions (int, 30): maximum number of no-op action to
                execute at the beginning of an episode.
            headless (bool, False): If True, the rendering is forced to be headless.

        """
        # MPD creation
        if 'NoFrameskip' in name:
            self.env = MaxAndSkip(gym.make(name, render_mode='rgb_array'), history_length, max_pooling)
        else:
            self.env = gym.make(name, render_mode='rgb_array')

        # MDP parameters
        self._headless = headless
        self._img_size = (width, height)
        self._episode_ends_at_life = ends_at_life
        self._max_lives = self.env.unwrapped.ale.lives()
        self._lives = self._max_lives
        self._force_fire = None
        self._real_reset = True
        self._max_no_op_actions = max_no_op_actions
        self._history_length = history_length
        self._current_no_op = None

        assert self.env.unwrapped.get_action_meanings()[0] == 'NOOP'

        # MDP properties
        action_space = Discrete(self.env.action_space.n)
        observation_space = Box(
            low=0., high=255., shape=(history_length, self._img_size[1], self._img_size[0]))
        horizon = 1e4 # instead of np.inf
        gamma = .99
        dt = 1/60
        mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt)

        # Viewer
        self._viewer =  ImageViewer((self._img_size[1], self._img_size[0]), dt, headless=self._headless)

        super().__init__(mdp_info)

    def reset(self, state=None):
        if self._real_reset:
            state, info = self.env.reset()
            self._state = preprocess_frame(state, self._img_size)
            self._state = deque([deepcopy(
                self._state) for _ in range(self._history_length)],
                maxlen=self._history_length
            )
            self._lives = self._max_lives

        self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE'

        self._current_no_op = np.random.randint(self._max_no_op_actions + 1)

        return LazyFrames(list(self._state), self._history_length), info

    def step(self, action):
        action = action[0]

        # Force FIRE action to start episodes in games with lives
        if self._force_fire:
            obs, _, _, _, _ = self.env.env.step(1)
            self._force_fire = False
        while self._current_no_op > 0:
            obs, _, _, _, _ = self.env.env.step(0)
            self._current_no_op -= 1

        obs, reward, absorbing, _, info = self.env.step(action)
        self._real_reset = absorbing
        
        if info['lives'] != self._lives:
            if self._episode_ends_at_life:
                absorbing = True
            self._lives = info['lives']
            self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE'

        self._state.append(preprocess_frame(obs, self._img_size))

        return LazyFrames(list(self._state), self._history_length), reward, absorbing, info

    def render(self, record=False):
        img = self.env.render()

        self._viewer.display(img)

        if record:
            return img
        else:
            return None

    def stop(self):
        self.env.close()
        self._viewer.close()
        self._real_reset = True

    def set_episode_end(self, ends_at_life):
        """
        Setter.

        Args:
            ends_at_life (bool): whether the episode ends when a life is
                lost or not.

        """
        self._episode_ends_at_life = ends_at_life