takuseno/d3rlpy

View on GitHub
tutorials/online.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup rendering dependencies for Google Colaboratory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt-get install -y xvfb ffmpeg > /dev/null 2>&1\n",
    "!pip install pyvirtualdisplay pygame moviepy > /dev/null 2>&1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install d3rlpy!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install d3rlpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup cartpole environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "env = gym.make('CartPole-v1')\n",
    "eval_env = gym.make('CartPole-v1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup data-driven deep reinforcement learning algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d3rlpy\n",
    "\n",
    "# setup DQN algorithm\n",
    "dqn = d3rlpy.algos.DQNConfig(\n",
    "    learning_rate=1e-3,\n",
    "    target_update_interval=100,\n",
    ").create(device='cuda:0')\n",
    "\n",
    "# setup explorer\n",
    "explorer = d3rlpy.algos.ConstantEpsilonGreedy(epsilon=0.3)\n",
    "\n",
    "# setup replay buffer\n",
    "buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=50000, env=env)\n",
    "\n",
    "# start training\n",
    "dqn.fit_online(\n",
    "    env,\n",
    "    buffer,\n",
    "    explorer,\n",
    "    eval_env=eval_env,\n",
    "    n_steps=50000,\n",
    "    n_steps_per_epoch=10000,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Record video!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gym.wrappers import RecordVideo\n",
    "\n",
    "# start virtual display\n",
    "d3rlpy.notebook_utils.start_virtual_display()\n",
    "\n",
    "# wrap Monitor wrapper\n",
    "env = RecordVideo(gym.make(\"CartPole-v1\", render_mode=\"rgb_array\"), './video')\n",
    "\n",
    "# evaluate\n",
    "d3rlpy.metrics.evaluate_qlearning_with_environment(dqn, env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how it works!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d3rlpy.notebook_utils.render_video(\"video/rl-video-episode-0.mp4\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "colab": {
    "provenance": [],
    "gpuType": "T4"
  },
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 4
}