takuseno/d3rlpy

View on GitHub
d3rlpy/cli.py

Summary

Maintainability
B
5 hrs
Test Coverage
# pylint: disable=redefined-builtin,exec-used
# type: ignore

import glob
import json
import os
import subprocess
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple

import click
import gym
import gymnasium
import numpy as np
from gym.wrappers import RecordVideo

from ._version import __version__
from .algos import QLearningAlgoBase, TransformerAlgoBase
from .base import load_learnable
from .metrics.utility import (
    evaluate_qlearning_with_environment,
    evaluate_transformer_with_environment,
)

if TYPE_CHECKING:
    import matplotlib.pyplot


def print_stats(path: str) -> None:
    data = np.loadtxt(path, delimiter=",")
    print("FILE NAME  : ", path)
    print("EPOCH      : ", data[-1, 0])
    print("TOTAL STEPS: ", data[-1, 1])
    print("MAX VALUE  : ", np.max(data[:, 2]))
    print("MIN VALUE  : ", np.min(data[:, 2]))
    print("STD VALUE  : ", np.std(data[:, 2]))


def get_plt() -> "matplotlib.pyplot":
    import matplotlib.pyplot as plt

    try:
        # enable seaborn style if available
        import seaborn as sns

        sns.set()
    except ImportError:
        pass
    return plt


def _compute_moving_average(values: np.ndarray, window: int) -> np.ndarray:
    assert values.ndim == 1
    results: List[float] = []
    # average over past data
    for i in range(values.shape[0]):
        start = max(0, i - window)
        results.append(float(np.mean(values[start : i + 1])))
    return np.array(results)


@click.group()
def cli() -> None:
    print(f"d3rlpy command line interface (Version {__version__})")


@cli.command(short_help="Show statistics of save metrics.")
@click.argument("path")
def stats(path: str) -> None:
    print_stats(path)


@cli.command(short_help="Plot saved metrics (requires matplotlib).")
@click.argument("path", nargs=-1)
@click.option(
    "--window", default=1, show_default=True, help="Moving average window."
)
@click.option("--show-steps", is_flag=True, help="Use iterations on x-axis.")
@click.option("--show-max", is_flag=True, help="Show maximum value.")
@click.option("--label", multiple=True, help="Label in legend.")
@click.option("--xlim", nargs=2, type=float, help="Limit on x-axis (tuple).")
@click.option("--ylim", nargs=2, type=float, help="Limit on y-axis (tuple).")
@click.option("--title", help="Title of the plot.")
@click.option("--ylabel", default="value", help="Label on y-axis.")
@click.option("--save", help="Flag to save the plot as an image.")
def plot(
    path: List[str],
    window: int,
    show_steps: bool,
    show_max: bool,
    label: Optional[Sequence[str]],
    xlim: Optional[Tuple[float, float]],
    ylim: Optional[Tuple[float, float]],
    title: Optional[str],
    ylabel: str,
    save: str,
) -> None:
    plt = get_plt()

    max_y_values = []
    min_x_values = []
    max_x_values = []

    if label:
        assert len(label) == len(
            path
        ), "--labels must be provided as many as the number of paths"

    for i, p in enumerate(path):
        data = np.loadtxt(p, delimiter=",")

        # filter to smooth data
        y_data = _compute_moving_average(data[:, 2], window)

        # create label
        if label:
            _label = label[i]
        elif len(p.split(os.sep)) > 1:
            _label = "/".join(p.split(os.sep)[-2:])
        else:
            _label = p

        if show_steps:
            x_data = data[:, 1]
        else:
            x_data = data[:, 0]

        max_y_values.append(np.max(data[:, 2]))
        min_x_values.append(np.min(x_data))
        max_x_values.append(np.max(x_data))

        # show statistics
        print("")
        print_stats(p)

        plt.plot(x_data, y_data, label=_label)

    if show_max:
        plt.plot(
            [np.min(min_x_values), np.max(max_x_values)],
            [np.max(max_y_values), np.max(max_y_values)],
            color="black",
            linestyle="dashed",
        )

    plt.xlabel("steps" if show_steps else "epochs")
    plt.ylabel(ylabel)

    if xlim:
        plt.xlim(xlim[0], xlim[1])
    if ylim:
        plt.ylim(ylim[0], ylim[1])

    if title:
        plt.title(title)

    plt.legend()
    if save:
        plt.savefig(save)
    else:
        plt.show()


@cli.command(short_help="Plot saved metrics in a grid (requires matplotlib).")
@click.argument("path")
@click.option("--title", help="Tittle of the plot.")
@click.option("--save", help="Flag to save the plot as an image.")
def plot_all(
    path: str,
    title: Optional[str],
    save: str,
) -> None:
    plt = get_plt()

    # print params.json
    if os.path.exists(os.path.join(path, "params.json")):
        with open(os.path.join(path, "params.json"), "r") as f:
            params = json.loads(f.read())
        print("")
        for k, v in params.items():
            print(f"{k}={v}")

    metrics_names = sorted(list(glob.glob(os.path.join(path, "*.csv"))))
    n_cols = int(np.ceil(len(metrics_names) ** 0.5))
    n_rows = int(np.ceil(len(metrics_names) / n_cols))

    plt.figure(figsize=(12, 7))

    for i in range(n_rows):
        for j in range(n_cols):
            index = j + n_cols * i
            if index >= len(metrics_names):
                break

            plt.subplot(n_rows, n_cols, index + 1)

            data = np.loadtxt(metrics_names[index], delimiter=",")

            plt.plot(data[:, 0], data[:, 2])
            plt.title(os.path.basename(metrics_names[index]))
            plt.xlabel("epoch")
            plt.ylabel("value")

    if title:
        plt.suptitle(title)

    plt.tight_layout()
    if save:
        plt.savefig(save)
    else:
        plt.show()


@cli.command(
    short_help="Export saved model as inference model format (ONNX or TorchScript)."
)
@click.argument("model_path")
@click.argument("output_path")
def export(model_path: str, output_path: str) -> None:
    # load saved model
    print(f"Loading {model_path}...")
    algo = load_learnable(model_path)
    assert isinstance(
        algo, QLearningAlgoBase
    ), "Currently, only Q-learning algorithms are supported."

    # export inference model
    print(f"Exporting to {output_path}...")
    algo.save_policy(output_path)


def _exec_to_create_env(code: str) -> gym.Env[Any, Any]:
    print(f"Executing '{code}'")
    variables: Dict[str, Any] = {}
    exec(code, globals(), variables)
    if "env" not in variables:
        raise RuntimeError("env must be defined in env_header.")
    return variables["env"]  # type: ignore


@cli.command(short_help="Record episodes with the saved model.")
@click.argument("model_path")
@click.option("--env-id", default=None, help="Gym environment id.")
@click.option(
    "--env-header", default=None, help="One-liner to create environment."
)
@click.option("--out", default="videos", help="Output directory path.")
@click.option("--n-episodes", default=3, help="Number of episodes to record.")
@click.option(
    "--target-return",
    default=None,
    help="Target return for Decision Transformer variants.",
)
@click.option("--use-gymnasium", is_flag=True, help="Flag to use Gymnasium.")
def record(
    model_path: str,
    env_id: Optional[str],
    env_header: Optional[str],
    out: str,
    n_episodes: int,
    target_return: Optional[float],
    use_gymnasium: bool,
) -> None:
    # load saved model
    print(f"Loading {model_path}...")
    algo = load_learnable(model_path)

    # wrap environment with Monitor
    env: gym.Env[Any, Any]
    if env_id is not None:
        if use_gymnasium:
            env = gymnasium.make(env_id, render_mode="rgb_array")
        else:
            env = gym.make(env_id, render_mode="rgb_array")
    elif env_header is not None:
        env = _exec_to_create_env(env_header)
    else:
        raise ValueError("env_id or env_header must be provided.")

    wrapped_env = RecordVideo(
        env,
        out,
        episode_trigger=lambda ep: True,
    )

    # run episodes
    if isinstance(algo, QLearningAlgoBase):
        evaluate_qlearning_with_environment(algo, wrapped_env, n_episodes)
    elif isinstance(algo, TransformerAlgoBase):
        assert target_return is not None, "--target-return must be specified."
        evaluate_transformer_with_environment(
            algo.as_stateful_wrapper(float(target_return)),
            wrapped_env,
            n_episodes,
        )
    else:
        raise ValueError("invalid algo type.")


@cli.command(short_help="Run evaluation episodes with rendering.")
@click.argument("model_path")
@click.option("--env-id", default=None, help="Gym environment id.")
@click.option(
    "--env-header", default=None, help="One-liner to create environment."
)
@click.option("--n-episodes", default=3, help="Number of episodes to run.")
@click.option(
    "--target-return",
    default=None,
    help="Target return for Decision Transformer variants.",
)
@click.option("--use-gymnasium", is_flag=True, help="Flag to use Gymnasium.")
def play(
    model_path: str,
    env_id: Optional[str],
    env_header: Optional[str],
    n_episodes: int,
    target_return: Optional[float],
    use_gymnasium: bool,
) -> None:
    # load saved model
    print(f"Loading {model_path}...")
    algo = load_learnable(model_path)

    # wrap environment with Monitor
    env: gym.Env[Any, Any]
    if env_id is not None:
        if use_gymnasium:
            env = gymnasium.make(env_id, render_mode="human")
        else:
            env = gym.make(env_id, render_mode="human")
    elif env_header is not None:
        env = _exec_to_create_env(env_header)
    else:
        raise ValueError("env_id or env_header must be provided.")

    # run episodes
    if isinstance(algo, QLearningAlgoBase):
        score = evaluate_qlearning_with_environment(algo, env, n_episodes)
    elif isinstance(algo, TransformerAlgoBase):
        assert target_return is not None, "--target-return must be specified."
        score = evaluate_transformer_with_environment(
            algo.as_stateful_wrapper(float(target_return)),
            env,
            n_episodes,
        )
    else:
        raise ValueError("invalid algo type.")
    print(f"Score: {score}")


def _install_module(
    name: list[str], upgrade: bool = False, check: bool = True
) -> None:
    name = ["-U", *name] if upgrade else name
    subprocess.run(["pip3", "install", *name], check=check)


def _uninstall_module(name: list[str], check: bool = True) -> None:
    subprocess.run(["pip3", "uninstall", "-y", *name], check=check)


@cli.command(short_help="Install additional packages.")
@click.argument("name")
def install(name: str) -> None:
    if name == "atari":
        _install_module(["gym[atari,accept-rom-license]"], upgrade=True)
    elif name == "d4rl_atari":
        install("atari")
        _install_module(["git+https://github.com/takuseno/d4rl-atari"])
    elif name == "d4rl":
        _install_module(["git+https://github.com/takuseno/D4RL"])
        _install_module(["gym"], upgrade=True)
        _uninstall_module(["pybullet"])
    elif name == "minari":
        _install_module(["minari==0.4.2", "gymnasium_robotics"], upgrade=True)
    else:
        raise ValueError(f"Unsupported command: {name}")