takuseno/d3rlpy

View on GitHub
docs/tutorials/customize_neural_network.rst

Summary

Maintainability
Test Coverage
************************
Customize Neural Network
************************

In this tutorial, you can learn how to integrate your own neural network models
to d3rlpy.
Please check :doc:`../references/network_architectures` for more information.

Prepare PyTorch Model
---------------------

If you're familiar with PyTorch, this step should be easy for you.

.. code-block:: python

  import torch
  import torch.nn as nn
  import d3rlpy

  class CustomEncoder(nn.Module):
      def __init__(self, observation_shape, feature_size):
          super().__init__()
          self.feature_size = feature_size
          self.fc1 = nn.Linear(observation_shape[0], feature_size)
          self.fc2 = nn.Linear(feature_size, feature_size)

      def forward(self, x):
          h = torch.relu(self.fc1(x))
          h = torch.relu(self.fc2(h))
          return h


Setup EncoderFactory
--------------------

Once you setup your PyTorch model, you need to setup ``EncoderFactory`` as a
dataclass class. In your ``EncoderFactory`` class, you need to define
``create`` and ``get_type``.
``get_type`` method is used to serialize your customized neural network
configuration.

.. code-block:: python

  import dataclasses

  @dataclasses.dataclass()
  class CustomEncoderFactory(d3rlpy.models.EncoderFactory):
      feature_size: int

      def create(self, observation_shape):
          return CustomEncoder(observation_shape, self.feature_size)

      @staticmethod
      def get_type() -> str:
          return "custom"


Now, you can use your model with d3rlpy.

.. code-block:: python

  # integrate your model into d3rlpy algorithm
  dqn = d3rlpy.algos.DQNConfig(encoder_factory=CustomEncoderFactory(64)).create()


Support Q-function for Actor-Critic
-----------------------------------

In the above example, your original model is designed for the network that
takes an observation as an input.
However, if you customize a Q-function of actor-critic algorithm (e.g. SAC),
you need to prepare an action-conditioned model.

.. code-block:: python

  class CustomEncoderWithAction(nn.Module):
      def __init__(self, observation_shape, action_size, feature_size):
          super().__init__()
          self.feature_size = feature_size
          self.fc1 = nn.Linear(observation_shape[0] + action_size, feature_size)
          self.fc2 = nn.Linear(feature_size, feature_size)

        def forward(self, x, action):
            h = torch.cat([x, action], dim=1)
            h = torch.relu(self.fc1(h))
            h = torch.relu(self.fc2(h))
            return h

Finally, you can update your ``CustomEncoderFactory`` as follows.

.. code-block:: python

  @dataclasses.dataclass()
  class CustomEncoderFactory(d3rlpy.models.EncoderFactory):
      feature_size: int

      def create(self, observation_shape):
          return CustomEncoder(observation_shape, self.feature_size)

      def create_with_action(self, observation_shape, action_size, discrete_action):
          return CustomEncoderWithAction(observation_shape, action_size, self.feature_size)

      @staticmethod
      def get_type() -> str:
          return "custom"

Now, you can customize actor-critic algorithms.

.. code-block:: python

  encoder_factory = CustomEncoderFactory(64)

  sac = d3rlpy.algos.SACConfig(
      actor_encoder_factory=encoder_factory,
      critic_encoder_factory=encoder_factory,
  ).create()


Make your models loadable
-------------------------

If you want ``load_learnable`` method to load the algorithm configuration including
your encoder configuration, you need to register your encoder factory.

.. code-block:: python

   from d3rlpy.models.encoders import register_encoder_factory

   # register your own encoder factory
   register_encoder_factory(CustomEncoderFactory)

   # load algorithm from d3
   dqn = d3rlpy.load_learnable("model.d3")