MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/utils/mujoco/viewer.py

Summary

Maintainability
D
2 days
Test Coverage
F
9%
import os
import glfw
import mujoco
import time
import collections
from itertools import cycle
import numpy as np


def _import_egl(width, height):
    from mujoco.egl import GLContext

    return GLContext(width, height)


def _import_glfw(width, height):
    from mujoco.glfw import GLContext

    return GLContext(width, height)


def _import_osmesa(width, height):
    from mujoco.osmesa import GLContext

    return GLContext(width, height)


_ALL_RENDERERS = collections.OrderedDict(
    [
        ("glfw", _import_glfw),
        ("egl", _import_egl),
        ("osmesa", _import_osmesa),
    ]
)


class MujocoViewer:
    """
    Class that creates a viewer for mujoco environments.

    """

    def __init__(self, model, dt, width=1920, height=1080, start_paused=False,
                 custom_render_callback=None, record=False, camera_params=None,
                 default_camera_mode="static", hide_menu_on_startup=None,
                 geom_group_visualization_on_startup=None, headless=False):
        """
        Constructor.

        Args:
            model: Mujoco model.
            dt (float): Timestep of the environment, (not the simulation).
            width (int): Width of the viewer window.
            height (int): Height of the viewer window.
            start_paused (bool): If True, the rendering is paused in the beginning of the simulation.
            custom_render_callback (func): Custom render callback function, which is supposed to be called
                during rendering.
            record (bool): If true, frames are returned during rendering.
            camera_params (dict): Dictionary of dictionaries including custom parameterization of the three cameras.
                Checkout the function get_default_camera_params() to know what parameters are expected. Is some camera
                type specification or parameter is missing, the default one is used.
            hide_menu_on_startup (bool): If True, the menu is hidden on startup.
            geom_group_visualization_on_startup (int/list): int or list defining which geom group_ids should be
                visualized on startup. If None, all are visualized.
            headless (bool): If True, render will be done in headless mode.

        """

        if hide_menu_on_startup is None and headless:
            hide_menu_on_startup = True
        elif hide_menu_on_startup is None and not headless:
            hide_menu_on_startup = False

        self.button_left = False
        self.button_right = False
        self.button_middle = False
        self.last_x = 0
        self.last_y = 0
        self.dt = dt

        self.frames = 0
        self.start_time = time.time()

        self._headless = headless
        self._model = model
        self._font_scale = 100
        
        if headless:
            # use the OpenGL render that is available on the machine
            self._opengl_context = self.setup_opengl_backend_headless(width, height)
            self._opengl_context.make_current()
            self._width, self._height = self.update_headless_size(width, height)
        else:
            # use glfw
            self._width, self._height = width, height
            glfw.init()
            glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, 0)
            self._window = glfw.create_window(width=self._width, height=self._height,
                                              title="MuJoCo", monitor=None, share=None)
            glfw.make_context_current(self._window)
            glfw.set_mouse_button_callback(self._window, self.mouse_button)
            glfw.set_cursor_pos_callback(self._window, self.mouse_move)
            glfw.set_key_callback(self._window, self.keyboard)
            glfw.set_scroll_callback(self._window, self.scroll)

        self._set_mujoco_buffers()
        
        if record and not headless:
            # dont allow to change the window size to have equal frame size during recording
            glfw.window_hint(glfw.RESIZABLE, False)

        self._viewport = mujoco.MjrRect(0, 0, self._width, self._height)
        self._loop_count = 0
        self._time_per_render = 1 / 60.
        self._run_speed_factor = 1.0
        self._paused = start_paused

        # Disable v_sync, so swap_buffers does not block
        # glfw.swap_interval(0)

        self._scene = mujoco.MjvScene(self._model, 1000)
        self._scene_option = mujoco.MjvOption()
        self._camera = mujoco.MjvCamera()
        mujoco.mjv_defaultFreeCamera(model, self._camera)
        if camera_params is None:
            self._camera_params = self.get_default_camera_params()
        else:
            self._camera_params = self._assert_camera_params(camera_params)
        self._all_camera_modes = ("static", "follow", "top_static")
        self._camera_mode_iter = cycle(self._all_camera_modes)
        self._camera_mode = None
        self._camera_mode_target = next(self._camera_mode_iter)
        assert default_camera_mode in self._all_camera_modes
        while self._camera_mode_target != default_camera_mode:
            self._camera_mode_target = next(self._camera_mode_iter)
        self._set_camera()

        self.custom_render_callback = custom_render_callback

        self._overlay = {}
        self._hide_menu = hide_menu_on_startup

        if geom_group_visualization_on_startup is not None:
            assert type(geom_group_visualization_on_startup) == list or type(geom_group_visualization_on_startup) == int
            if type(geom_group_visualization_on_startup) is not list:
                geom_group_visualization_on_startup = [geom_group_visualization_on_startup]
            for group_id, _ in enumerate(self._scene_option.geomgroup):
                if group_id not in geom_group_visualization_on_startup:
                    self._scene_option.geomgroup[group_id] = False

    def load_new_model(self, model):
        """
        Loads a new model to the viewer, and resets the scene and context.
        This is used in MultiMujoco environments.

        Args:
            model: Mujoco model.

        """

        self._model = model
        self._scene = mujoco.MjvScene(model, 1000)
        self._context = mujoco.MjrContext(model, mujoco.mjtFontScale(self._font_scale))


    def mouse_button(self, window, button, act, mods):
        """
        Mouse button callback for glfw.

        Args:
            window: glfw window.
            button: glfw button id.
            act: glfw action.
            mods: glfw mods.

        """

        self.button_left = glfw.get_mouse_button(self._window, glfw.MOUSE_BUTTON_LEFT) == glfw.PRESS
        self.button_right = glfw.get_mouse_button(self._window, glfw.MOUSE_BUTTON_RIGHT) == glfw.PRESS
        self.button_middle = glfw.get_mouse_button(self._window, glfw.MOUSE_BUTTON_MIDDLE) == glfw.PRESS

        self.last_x, self.last_y = glfw.get_cursor_pos(self._window)

    def mouse_move(self, window, x_pos, y_pos):
        """
        Mouse mode callback for glfw.

        Args:
            window:  glfw window.
            x_pos: Current mouse x position.
            y_pos: Current mouse y position.

        """

        if not self.button_left and not self.button_right and not self.button_middle:
            return

        dx = x_pos - self.last_x
        dy = y_pos - self.last_y
        self.last_x = x_pos
        self.last_y = y_pos

        width, height = glfw.get_window_size(self._window)

        mod_shift = glfw.get_key(self._window, glfw.KEY_LEFT_SHIFT) == glfw.PRESS or glfw.get_key(self._window,
                                                                                                  glfw.KEY_RIGHT_SHIFT) == glfw.PRESS

        if self.button_right:
            action = mujoco.mjtMouse.mjMOUSE_MOVE_H if mod_shift else mujoco.mjtMouse.mjMOUSE_MOVE_V
        elif self.button_left:
            action = mujoco.mjtMouse.mjMOUSE_ROTATE_H if mod_shift else mujoco.mjtMouse.mjMOUSE_ROTATE_V
        else:
            action = mujoco.mjtMouse.mjMOUSE_ZOOM

        mujoco.mjv_moveCamera(self._model, action, dx / width, dy / height, self._scene, self._camera)

    def keyboard(self, window, key, scancode, act, mods):
        """
        Keyboard callback for glfw.

        Args:
            window: glfw window.
            key: glfw key event.
            scancode: glfw scancode.
            act: glfw action.
            mods: glfw mods.

        """

        if act != glfw.RELEASE:
            return

        if key == glfw.KEY_SPACE:
            self._paused = not self._paused

        if key == glfw.KEY_C:
            self._scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = not self._scene_option.flags[
                mujoco.mjtVisFlag.mjVIS_CONTACTFORCE]
            self._scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONSTRAINT] = not self._scene_option.flags[
                mujoco.mjtVisFlag.mjVIS_CONSTRAINT]

        if key == glfw.KEY_T:
            self._scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = not self._scene_option.flags[
                mujoco.mjtVisFlag.mjVIS_TRANSPARENT]

        if key == glfw.KEY_0:
            self._scene_option.geomgroup[0] = not self._scene_option.geomgroup[0]

        if key == glfw.KEY_1:
            self._scene_option.geomgroup[1] = not self._scene_option.geomgroup[1]

        if key == glfw.KEY_2:
            self._scene_option.geomgroup[2] = not self._scene_option.geomgroup[2]

        if key == glfw.KEY_3:
            self._scene_option.geomgroup[3] = not self._scene_option.geomgroup[3]

        if key == glfw.KEY_4:
            self._scene_option.geomgroup[4] = not self._scene_option.geomgroup[4]

        if key == glfw.KEY_5:
            self._scene_option.geomgroup[5] = not self._scene_option.geomgroup[5]

        if key == glfw.KEY_6:
            self._scene_option.geomgroup[6] = not self._scene_option.geomgroup[6]

        if key == glfw.KEY_7:
            self._scene_option.geomgroup[7] = not self._scene_option.geomgroup[7]

        if key == glfw.KEY_8:
            self._scene_option.geomgroup[8] = not self._scene_option.geomgroup[8]

        if key == glfw.KEY_9:
            self._scene_option.geomgroup[9] = not self._scene_option.geomgroup[9]

        if key == glfw.KEY_TAB:
            self._camera_mode_target = next(self._camera_mode_iter)

        if key == glfw.KEY_S:
            self._run_speed_factor /= 2.0

        if key == glfw.KEY_F:
            self._run_speed_factor *= 2.0

        if key == glfw.KEY_E:
            self._scene_option.frame = not self._scene_option.frame

        if key == glfw.KEY_H:
            if self._hide_menu:
                self._hide_menu = False
            else:
                self._hide_menu = True

    def scroll(self, window, x_offset, y_offset):
        """
        Scrolling callback for glfw.

        Args:
            window: glfw window.
            x_offset: x scrolling offset.
            y_offset: y scrolling offset.

        """

        mujoco.mjv_moveCamera(self._model, mujoco.mjtMouse.mjMOUSE_ZOOM, 0, 0.05 * y_offset, self._scene, self._camera)

    def _set_mujoco_buffers(self):
        self._context = mujoco.MjrContext(self._model, mujoco.mjtFontScale(self._font_scale))
        if self._headless:
            mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_OFFSCREEN, self._context)
            if self._context.currentBuffer != mujoco.mjtFramebuffer.mjFB_OFFSCREEN:
                raise RuntimeError("Offscreen rendering not supported")
        else:
            mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_WINDOW, self._context)
            if self._context.currentBuffer != mujoco.mjtFramebuffer.mjFB_WINDOW:
                raise RuntimeError("Window rendering not supported")

    def update_headless_size(self, width, height):
        _context = mujoco.MjrContext(self._model, mujoco.mjtFontScale(self._font_scale))
        if width > _context.offWidth or height > _context.offHeight:
            width = max(width, self._model.vis.global_.offwidth)
            height = max(height, self._model.vis.global_.offheight)
            
            if width != _context.offWidth or height != _context.offHeight:
                self._model.vis.global_.offwidth = width
                self._model.vis.global_.offheight = height

        return width, height

    def render(self, data, record):
        """
        Main rendering function.

        Args:
            data: Mujoco data structure.
            record (bool): If true, frames are returned during rendering.

        Returns:
            If record is True, frames are returned during rendering, else None.

        """

        def render_inner_loop(self):

            if not self._headless:
                self._create_overlay()

            render_start = time.time()

            mujoco.mjv_updateScene(self._model, data, self._scene_option, None, self._camera,
                                   mujoco.mjtCatBit.mjCAT_ALL,
                                   self._scene)

            if not self._headless:
                self._viewport.width, self._viewport.height = glfw.get_window_size(self._window)

            mujoco.mjr_render(self._viewport, self._scene, self._context)

            for gridpos, [t1, t2] in self._overlay.items():

                if self._hide_menu:
                    continue

                mujoco.mjr_overlay(
                    mujoco.mjtFont.mjFONT_SHADOW,
                    gridpos,
                    self._viewport,
                    t1,
                    t2,
                    self._context)

            if self.custom_render_callback is not None:
                self.custom_render_callback(self._viewport, self._context)

            if not self._headless:
                glfw.swap_buffers(self._window)
                glfw.poll_events()
                if glfw.window_should_close(self._window):
                    self.stop()
                    exit(0)

            self.frames += 1
            self._overlay.clear()
            self._time_per_render = 0.9 * self._time_per_render + 0.1 * (time.time() - render_start)

        if self._paused:
            while self._paused:
                render_inner_loop(self)

        if record:
            self._loop_count = 1
        else:
            self._loop_count += self.dt / (self._time_per_render * self._run_speed_factor)
        while self._loop_count > 0:
            render_inner_loop(self)
            self._set_camera()
            self._loop_count -= 1

        if record:
            return self.read_pixels()

    def read_pixels(self, depth=False):
        """
        Reads the pixels from the glfw viewer.

        Args:
            depth (bool): If True, depth map is also returned.

        Returns:
            If depth is True, tuple of np.arrays (rgb and depth), else just a single
            np.array for the rgb image.

        """

        if self._headless:
            shape = (self._width, self._height)
        else:
            shape = glfw.get_framebuffer_size(self._window)

        if depth:
            rgb_img = np.zeros((shape[1], shape[0], 3), dtype=np.uint8)
            depth_img = np.zeros((shape[1], shape[0], 1), dtype=np.float32)
            mujoco.mjr_readPixels(rgb_img, depth_img, self._viewport, self._context)
            return (np.flipud(rgb_img), np.flipud(depth_img))
        else:
            img = np.zeros((shape[1], shape[0], 3), dtype=np.uint8)
            mujoco.mjr_readPixels(img, None, self._viewport, self._context)
            return np.flipud(img)

    def stop(self):
        """
        Destroys the glfw image.

        """
        if not self._headless:
            glfw.destroy_window(self._window)

    def _create_overlay(self):
        """
        This function creates and adds all overlays used in the viewer.

        """

        topleft = mujoco.mjtGridPos.mjGRID_TOPLEFT
        topright = mujoco.mjtGridPos.mjGRID_TOPRIGHT
        bottomleft = mujoco.mjtGridPos.mjGRID_BOTTOMLEFT
        bottomright = mujoco.mjtGridPos.mjGRID_BOTTOMRIGHT

        def add_overlay(gridpos, text1, text2="", make_new_line=True):
            if gridpos not in self._overlay:
                self._overlay[gridpos] = ["", ""]
            if make_new_line:
                self._overlay[gridpos][0] += text1 + "\n"
                self._overlay[gridpos][1] += text2 + "\n"
            else:
                self._overlay[gridpos][0] += text1
                self._overlay[gridpos][1] += text2

        add_overlay(
            bottomright,
            "Framerate:",
            str(int(1/self._time_per_render * self._run_speed_factor)), make_new_line=False)

        add_overlay(
            topleft,
            "Press SPACE to pause.")

        add_overlay(
            topleft,
            "Press H to hide the menu.")

        add_overlay(
            topleft,
            "Press TAB to switch cameras.")

        add_overlay(
            topleft,
            "Press T to make the model transparent.")

        add_overlay(
            topleft,
            "Press E to toggle reference frames.")

        add_overlay(
            topleft,
            "Press 0-9 to disable/enable geom group visualization.")

        visualize_contact = "On" if self._scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] else "Off"
        add_overlay(
            topleft,
            "Contact force visualization (Press C):", visualize_contact)

        add_overlay(
            topleft,
            "Camera mode:",
            self._camera_mode)
        
        add_overlay(
            topleft,
            "Run speed = %.3f x real time" %
            self._run_speed_factor,
            "[S]lower, [F]aster", make_new_line=False)

    def _set_camera(self):
        """
        Sets the camera mode to the current camera mode target. Allowed camera
        modes are "follow" in which the model is tracked, "static" that is a static
        camera at the default camera positon, and "top_static" that is a static
        camera on top of the model.

        """

        if self._camera_mode_target == "follow":
            if self._camera_mode != "follow":
                self._camera.fixedcamid = -1
                self._camera.type = mujoco.mjtCamera.mjCAMERA_TRACKING
                self._camera.trackbodyid = 0
                self._set_camera_properties(self._camera_mode_target)
        elif self._camera_mode_target == "static":
            if self._camera_mode != "static":
                self._camera.fixedcamid = 0
                self._camera.type = mujoco.mjtCamera.mjCAMERA_FREE
                self._camera.trackbodyid = -1
                self._set_camera_properties(self._camera_mode_target)
        elif self._camera_mode_target == "top_static":
            if self._camera_mode != "top_static":
                self._camera.fixedcamid = 0
                self._camera.type = mujoco.mjtCamera.mjCAMERA_FREE
                self._camera.trackbodyid = -1
                self._set_camera_properties(self._camera_mode_target)

    def _set_camera_properties(self, mode):
        """
        Sets the camera properties "distance", "elevation", and "azimuth"
        as well as the camera mode based on the provided mode.

        Args:
            mode (str): Camera mode. (either "follow", "static", or "top_static")

        """

        cam_params = self._camera_params[mode]
        self._camera.distance = cam_params["distance"]
        self._camera.elevation = cam_params["elevation"]
        self._camera.azimuth = cam_params["azimuth"]
        if "lookat" in cam_params:
            self._camera.lookat = np.array(cam_params["lookat"])
        self._camera_mode = mode

    def _assert_camera_params(self, camera_params):
        """
        Asserts if the provided camera parameters are valid or not. Also, if
        properties of some camera types are not specified, the default parameters
        are used.

        Args:
            camera_params (dict): Dictionary of dictionaries containig parameters for each camera type.

        Returns:
            Dictionary of dictionaries with parameters for each camera type.

        """

        default_camera_params = self.get_default_camera_params()

        # check if the provided camera types and parameters are valid
        for cam_type in camera_params.keys():
            assert cam_type in default_camera_params.keys(), f"Camera type \"{cam_type}\" is unknown. Allowed " \
                                                             f"camera types are {list(default_camera_params.keys())}."
            for param in camera_params[cam_type].keys():
                assert param in default_camera_params[cam_type].keys(), f"Parameter \"{param}\" of camera type " \
                                                                        f"\"{cam_type}\" is unknown. Allowed " \
                                                                        f"parameters are" \
                                                                        f" {list(default_camera_params[cam_type].keys())}"

        # add default parameters if not specified
        for cam_type in default_camera_params.keys():
            if cam_type not in camera_params.keys():
                camera_params[cam_type] = default_camera_params[cam_type]
            else:
                for param in default_camera_params[cam_type].keys():
                    if param not in camera_params[cam_type].keys():
                        camera_params[cam_type][param] = default_camera_params[cam_type][param]

        return camera_params

    @staticmethod
    def get_default_camera_params():
        """
        Getter for default camera paramterization.

        Returns:
            Dictionary of dictionaries with default parameters for each camera type.

        """

        return dict(static=dict(distance=15.0, elevation=-45.0, azimuth=90.0, lookat=np.array([0.0, 0.0, 0.0])),
                    follow=dict(distance=3.5, elevation=0.0, azimuth=90.0),
                    top_static=dict(distance=5.0, elevation=-90.0, azimuth=90.0, lookat=np.array([0.0, 0.0, 0.0])))


    def setup_opengl_backend_headless(self, width, height):

        backend = os.environ.get("MUJOCO_GL")
        if backend is not None:
            try:
                opengl_context = _ALL_RENDERERS[backend](width, height)
            except KeyError:
                raise RuntimeError(
                    "Environment variable {} must be one of {!r}: got {!r}.".format(
                        "MUJOCO_GL", _ALL_RENDERERS.keys(), backend
                    )
                )

        else:
            # iterate through all OpenGL backends to see which one is available
            for name, _ in _ALL_RENDERERS.items():
                try:
                    opengl_context = _ALL_RENDERERS[name](width, height)
                    backend = name
                    break
                except:  # noqa:E722
                    pass
            if backend is None:
                raise RuntimeError(
                    "No OpenGL backend could be imported. Attempting to create a "
                    "rendering context will result in a RuntimeError."
                )

        return opengl_context