takuseno/d3rlpy

View on GitHub
tutorials/atari.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install d3rlpy and d4rl-atari."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/takuseno/d4rl-atari d3rlpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d3rlpy\n",
    "\n",
    "# prepare 1% Breakout dataset\n",
    "dataset, env = d3rlpy.datasets.get_atari_transitions(\n",
    "    \"breakout\",\n",
    "    fraction=0.01,\n",
    "    index=0,\n",
    "    num_stack=4,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup data-driven deep reinforcement learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cql = d3rlpy.algos.DiscreteCQLConfig(\n",
    "    learning_rate=5e-5,\n",
    "    optim_factory=d3rlpy.models.optimizers.AdamFactory(eps=1e-2 / 32),\n",
    "    batch_size=32,\n",
    "    alpha=4.0,\n",
    "    q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(\n",
    "        n_quantiles=200\n",
    "    ),\n",
    "    observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),\n",
    "    target_update_interval=2000,\n",
    "    reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0),\n",
    ").create(device=\"cuda:0\")\n",
    "\n",
    "env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001)\n",
    "\n",
    "cql.fit(\n",
    "    dataset,\n",
    "    n_steps=1000000,\n",
    "    n_steps_per_epoch=10000,\n",
    "    evaluators={\"environment\": env_scorer},\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}