eager-dev/eagerx

View on GitHub
eagerx/engines/openai_gym/objects.py

Summary

Maintainability
A
45 mins
Test Coverage
A
98%
# OTHER IMPORTS
from typing import Optional, List

# ROS IMPORTS
import gymnasium as gym

# EAGERx IMPORTS
from eagerx.core.space import Space
from eagerx.engines.openai_gym.engine import GymEngine
from eagerx.core.entities import Object
from eagerx.core.specs import ObjectSpec
from eagerx.core.graph_engine import EngineGraph
import eagerx.core.register as register


class GymObject(Object):
    @classmethod
    @register.sensors(
        observation=Space(dtype="float32"),
        reward=Space(dtype="float32"),
        terminated=Space(low=0, high=1, shape=(), dtype="int64"),
        truncated=Space(low=0, high=1, shape=(), dtype="int64"),
        image=Space(dtype="uint8"),
    )
    @register.actuators(action=Space(dtype="float32"))
    @register.engine_states()
    def make(
        cls,
        name: str,
        sensors: Optional[List[str]] = None,
        env_id: str = "Pendulum-v1",
        rate: float = 20.0,
        always_render: bool = False,
        default_action=None,
        render_shape: Optional[List[int]] = None,
    ):
        """Object spec of GymObject"""
        spec = cls.get_specification()

        # Set default node params
        spec.config.name = name
        spec.config.sensors = sensors if isinstance(sensors, list) else ["observation", "reward", "truncated", "terminated"]
        spec.config.actuators = ["action"]

        # Set custom node params
        spec.config.render_shape = render_shape if render_shape else [200, 200]
        spec.config.env_id = env_id
        spec.config.always_render = always_render
        spec.config.default_action = default_action

        # Set spaces
        env = gym.make(spec.config.env_id)
        obs_space = env.observation_space
        if isinstance(obs_space, gym.spaces.Discrete):
            spec.sensors.observation.space = Space(low=0, high=obs_space.n - 1, shape=(), dtype="int64")
        elif isinstance(obs_space, gym.spaces.Box):
            spec.sensors.observation.space = Space(obs_space.low, obs_space.high, shape=obs_space.shape, dtype=obs_space.dtype)
        else:
            raise NotImplementedError("Space not compatible with this object.")
        act_space = env.action_space
        if isinstance(act_space, gym.spaces.Discrete):
            spec.actuators.action.space = Space(low=0, high=act_space.n - 1, shape=(), dtype="int64")
        elif isinstance(act_space, gym.spaces.Box):
            spec.actuators.action.space = Space(act_space.low, act_space.high, shape=act_space.shape, dtype=act_space.dtype)
        else:
            raise NotImplementedError("Space not compatible with this object.")

        spec.sensors.reward.space = Space(low=env.reward_range[0], high=env.reward_range[1], shape=tuple(), dtype="float32")
        shape = (spec.config.render_shape[0], spec.config.render_shape[1], 3)
        spec.sensors.image.space = Space(low=0, high=255, shape=shape, dtype="uint8")

        # Set rates
        spec.sensors.observation.rate = rate
        spec.sensors.reward.rate = rate
        spec.sensors.truncated.rate = rate
        spec.sensors.terminated.rate = rate
        spec.sensors.image.rate = rate
        spec.actuators.action.rate = rate
        return spec

    @staticmethod
    @register.engine(GymEngine)  # This decorator pre-initializes engine implementation with default object_params
    def openai_gym(spec: ObjectSpec, graph: EngineGraph):
        """Engine-specific implementation (GymEngine) of the object."""
        # Set engine arguments (nothing to set here in this case)
        spec.engine.env_id = spec.config.env_id

        # Create sensor engine nodes
        # Rate=None, because we will connect them to sensors (thus uses the rate set in the agnostic specification)
        from eagerx.engines.openai_gym.enginenodes import (
            ObservationSensor,
            RewardSensor,
            TruncatedSensor,
            TerminatedSensor,
            ActionActuator,
            GymImage,
        )

        obs = ObservationSensor.make("obs", rate=spec.sensors.observation.rate, process=2)
        obs.outputs.observation.space = spec.sensors.observation.space
        rwd = RewardSensor.make("rwd", rate=spec.sensors.reward.rate, process=2)
        truncated = TruncatedSensor.make("truncated", rate=spec.sensors.truncated.rate, process=2)
        terminated = TerminatedSensor.make("terminated", rate=spec.sensors.terminated.rate, process=2)
        image = GymImage.make(
            "image",
            shape=spec.config.render_shape,
            always_render=spec.config.always_render,
            rate=spec.sensors.image.rate,
            process=2,
        )

        # Create actuator engine nodes
        # Rate=None, because we will connect it to an actuator (thus uses the rate set in the agnostic specification)
        action = ActionActuator.make(
            "action", rate=spec.actuators.action.rate, process=2, zero_action=spec.config.default_action
        )
        action.inputs.action.space = spec.actuators.action.space
        action.outputs.action_applied.space = spec.actuators.action.space

        # Connect all engine nodes
        graph.add([obs, rwd, truncated, terminated, image, action])
        graph.connect(source=obs.outputs.observation, sensor="observation")
        graph.connect(source=rwd.outputs.reward, sensor="reward")
        graph.connect(source=truncated.outputs.truncated, sensor="truncated")
        graph.connect(source=terminated.outputs.terminated, sensor="terminated")
        graph.connect(source=image.outputs.image, sensor="image")
        graph.connect(actuator="action", target=action.inputs.action)

        # Check graph validity (comment out)
        # graph.is_valid(plot=True)