SforAiDl/genrl

View on GitHub
genrl/environments/suite.py

Summary

Maintainability
A
35 mins
Test Coverage
from typing import List

import gym

from genrl.environments import (
    AtariPreprocessing,
    FireReset,
    FrameStack,
    GymWrapper,
    NoopReset,
)
from genrl.environments.time_limit import AtariTimeLimit, TimeLimit
from genrl.environments.torch import TorchWrapper
from genrl.environments.vec_env import SerialVecEnv, SubProcessVecEnv, VecEnv


def VectorEnv(
    env_id: str,
    n_envs: int = 2,
    parallel: int = False,
    env_type: str = "gym",
) -> VecEnv:
    """
        Chooses the kind of Vector Environment that is required

        :param env_id: Gym environment to be vectorised
        :param n_envs: Number of environments
        :param parallel: True if we want environments to run parallely and (
    subprocesses, False if we want environments to run serially one after the other)
        :param env_type: Type of environment. Currently, we support ["gym", "atari"]
        :type env_id: string
        :type n_envs: int
        :type parallel: False
        :type env_type: string
        :returns: Vector Environment
        :rtype: object
    """
    wrapper = AtariEnv if env_type == "atari" else GymEnv

    envs = [TorchWrapper(wrapper(env_id)) for _ in range(n_envs)]

    if parallel:
        venv = SubProcessVecEnv(envs, n_envs)
    else:
        venv = SerialVecEnv(envs, n_envs)

    return venv


def GymEnv(env_id: str) -> gym.Env:
    """
    Function to apply wrappers for all regular Gym envs by Trainer class

    :param env: Environment Name
    :type env: string
    :returns: Gym Environment
    :rtype: object
    """
    env = gym.make(env_id)
    return GymWrapper(TimeLimit(env))


def AtariEnv(
    env_id: str,
    wrapper_list: List = [
        AtariPreprocessing,
        NoopReset,
        FireReset,
        AtariTimeLimit,
        FrameStack,
    ],
) -> gym.Env:
    """
    Function to apply wrappers for all Atari envs by Trainer class

    :param env: Environment Name
    :type env: string
    :param wrapper_list: List of wrappers to use
    :type wrapper_list: list or tuple
    :returns: Gym Atari Environment
    :rtype: object
    """
    env = gym.make(env_id)
    env = GymWrapper(env)

    if "NoFrameskip" in env_id:
        frameskip = 1
    elif "Deterministic" in env_id:
        frameskip = 4
    else:
        frameskip = (2, 5)

    for wrapper in wrapper_list:
        if wrapper is AtariPreprocessing:
            env = wrapper(env, frameskip)
        else:
            env = wrapper(env)

    return env