SforAiDl/genrl

View on GitHub
docs/source/usage/tutorials/Deep/DQN.rst

Summary

Maintainability
Test Coverage
=====================
Deep Q-Networks (DQN)
=====================

For background on Deep RL, its core definitions and problem formulations refer to :ref:`Deep RL Background<Background>`

Objective
=========

The DQN uses the concept of Q-learning. When the state space is too huge, it require a large number of epochs to explore and update the Q-value of every state even at least once.
Hence, we make use of function approximators. DQN uses a neural network as a function approximator and objective is to get as close to the Bellman Expectation of the Q-value function as possible. This is done by minimising the loss function which is defined as

.. math::

    E_{(s, a, s', r) \sim D}[r + \gamma max_{a'} Q(s', a';\theta_{i}^{-}) - Q(s, a; \theta_i)]^2

Unlike in regular Q-learning, DQNs need more stability while updating so we often use a second neural network which we call our target model.

Algorithm Details
=================

Epsilon-Greedy Action Selection
-------------------------------

.. literalinclude:: ../../../../../genrl/agents/deep/dqn/dqn.py
    :lines: 97-129
    :lineno-start: 97

We choose the greedy action with a probability of :math:`1 - \epsilon` and the rest of the time, we sample the action randomly. During evaluation, we use only greedy actions to judge how well the agent performs.

Experience Replay
-----------------

Whenever an experience is played through (during the training loop), the experience is stored in what we call a Replay Buffer.

.. literalinclude:: ../../../../../genrl/trainers/offpolicy.py
    :lines: 91-104
    :lineno-start: 91

The transitions are later sampled in batches from the replay buffer for updating the network.

Update Q-value Network
----------------------

Once our Replay Buffer has enough experiences, we start updating the Q-value networks in the following code according to the above objective.

.. literalinclude:: ../../../../../genrl/trainers/offpolicy.py
    :lines: 145-203
    :lineno-start: 145

The function `get_q_values` calculates the Q-values of the states sampled from the replay buffer. The `get_target_q_values` function will get the Q-values of the same states as calculated by the target network.
The `update_params` function is used to calculate the MSE Loss between the Q-values and the Target Q-values and updated using Stochastic Gradient Descent.

Training through the API
========================

.. code-block:: python

    from genrl.agents import DQN
    from genrl.environments import VectorEnv
    from genrl.trainers import OffPolicyTrainer

    env = VectorEnv("CartPole-v0")
    agent = DQN("mlp", env)
    trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
    trainer.train()
    trainer.evaluate()

Variants of DQN
===============

Some of the other variants of DQN that we have implemented in the repo are:
- Double DQN 
- Dueling DQN
- Prioritized Replay DQN
- Noisy DQN
- Categorical DQN

For some extensions of the DQN (like DoubleDQN) we have provided the methods in a file under genrl/agents/dqn/utils.py

.. code-block:: python

    class DuelingDQN(DQN):
        def __init__(self, *args, **kwargs):
            super(DuelingDQN, self).__init__(*args, **kwargs)
            self.dqn_type = "dueling"  # You can choose "noisy" for NoisyDQN and "categorical" for CategoricalDQN
            self._create_model()

        def get_target_q_values(self, *args):
            return ddqn_q_target(self, *args)

The above two snippets define the same class. You can find similar APIs for the other variants in the `genrl/deep/agents/dqn` folder.