takuseno/d3rlpy

View on GitHub
tutorials/cartpole.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup rendering dependencies for Google Colaboratory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt-get install -y xvfb ffmpeg > /dev/null 2>&1\n",
    "!pip install pyvirtualdisplay pygame moviepy > /dev/null 2>&1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install d3rlpy!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install d3rlpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup cartpole dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d3rlpy\n",
    "\n",
    "# get CartPole dataset\n",
    "dataset, env = d3rlpy.datasets.get_cartpole()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup data-driven deep reinforcement learning algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup CQL algorithm\n",
    "cql = d3rlpy.algos.DiscreteCQLConfig().create(device='cuda:0')\n",
    "\n",
    "# start training\n",
    "cql.fit(\n",
    "    dataset,\n",
    "    n_steps=10000,\n",
    "    n_steps_per_epoch=1000,\n",
    "    evaluators={\n",
    "        'environment': d3rlpy.metrics.EnvironmentEvaluator(env), # evaluate with CartPole-v1 environment\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Record video!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "from gym.wrappers import RecordVideo\n",
    "\n",
    "# start virtual display\n",
    "d3rlpy.notebook_utils.start_virtual_display()\n",
    "\n",
    "# wrap RecordVideo wrapper\n",
    "env = RecordVideo(gym.make(\"CartPole-v1\", render_mode=\"rgb_array\"), './video')\n",
    "\n",
    "# evaluate\n",
    "d3rlpy.metrics.evaluate_qlearning_with_environment(cql, env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how it works!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d3rlpy.notebook_utils.render_video(\"video/rl-video-episode-0.mp4\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "colab": {
    "provenance": [],
    "gpuType": "T4"
  },
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 4
}