takuseno/d3rlpy

View on GitHub
d3rlpy/models/optimizers.py

Summary

Maintainability
A
2 hrs
Test Coverage
import dataclasses
from typing import Iterable, Sequence, Tuple

from torch import nn
from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop

from ..serializable_config import DynamicConfig, generate_config_registration

__all__ = [
    "OptimizerFactory",
    "SGDFactory",
    "AdamFactory",
    "AdamWFactory",
    "RMSpropFactory",
    "GPTAdamWFactory",
    "register_optimizer_factory",
    "make_optimizer_field",
]


def _get_parameters_from_named_modules(
    named_modules: Iterable[Tuple[str, nn.Module]]
) -> Sequence[nn.Parameter]:
    # retrieve unique set of parameters
    params_dict = {}
    for _, module in named_modules:
        for param in module.parameters():
            if param not in params_dict:
                params_dict[param] = param
    return list(params_dict.values())


@dataclasses.dataclass()
class OptimizerFactory(DynamicConfig):
    """A factory class that creates an optimizer object in a lazy way.

    The optimizers in algorithms can be configured through this factory class.
    """

    def create(
        self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float
    ) -> Optimizer:
        """Returns an optimizer object.

        Args:
            named_modules (list): List of tuples of module names and modules.
            lr (float): Learning rate.

        Returns:
            torch.optim.Optimizer: an optimizer object.
        """
        raise NotImplementedError


@dataclasses.dataclass()
class SGDFactory(OptimizerFactory):
    """An alias for SGD optimizer.

    .. code-block:: python

        from d3rlpy.optimizers import SGDFactory

        factory = SGDFactory(weight_decay=1e-4)

    Args:
        momentum: momentum factor.
        dampening: dampening for momentum.
        weight_decay: weight decay (L2 penalty).
        nesterov: flag to enable Nesterov momentum.
    """

    momentum: float = 0.0
    dampening: float = 0.0
    weight_decay: float = 0.0
    nesterov: bool = False

    def create(
        self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float
    ) -> SGD:
        return SGD(
            _get_parameters_from_named_modules(named_modules),
            lr=lr,
            momentum=self.momentum,
            dampening=self.dampening,
            weight_decay=self.weight_decay,
            nesterov=self.nesterov,
        )

    @staticmethod
    def get_type() -> str:
        return "sgd"


@dataclasses.dataclass()
class AdamFactory(OptimizerFactory):
    """An alias for Adam optimizer.

    .. code-block:: python

        from d3rlpy.optimizers import AdamFactory

        factory = AdamFactory(weight_decay=1e-4)

    Args:
        betas: coefficients used for computing running averages of
            gradient and its square.
        eps: term added to the denominator to improve numerical stability.
        weight_decay: weight decay (L2 penalty).
        amsgrad: flag to use the AMSGrad variant of this algorithm.
    """

    betas: Tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
    weight_decay: float = 0
    amsgrad: bool = False

    def create(
        self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float
    ) -> Adam:
        return Adam(
            _get_parameters_from_named_modules(named_modules),
            lr=lr,
            betas=self.betas,
            eps=self.eps,
            weight_decay=self.weight_decay,
            amsgrad=self.amsgrad,
        )

    @staticmethod
    def get_type() -> str:
        return "adam"


@dataclasses.dataclass()
class AdamWFactory(OptimizerFactory):
    """An alias for AdamW optimizer.

    .. code-block:: python

        from d3rlpy.optimizers import AdamWFactory

        factory = AdamWFactory(weight_decay=1e-4)

    Args:
        betas: coefficients used for computing running averages of
            gradient and its square.
        eps: term added to the denominator to improve numerical stability.
        weight_decay: weight decay (L2 penalty).
        amsgrad: flag to use the AMSGrad variant of this algorithm.
    """

    betas: Tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
    weight_decay: float = 0
    amsgrad: bool = False

    def create(
        self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float
    ) -> AdamW:
        return AdamW(
            _get_parameters_from_named_modules(named_modules),
            lr=lr,
            betas=self.betas,
            eps=self.eps,
            weight_decay=self.weight_decay,
            amsgrad=self.amsgrad,
        )

    @staticmethod
    def get_type() -> str:
        return "adam_w"


@dataclasses.dataclass()
class RMSpropFactory(OptimizerFactory):
    """An alias for RMSprop optimizer.

    .. code-block:: python

        from d3rlpy.optimizers import RMSpropFactory

        factory = RMSpropFactory(weight_decay=1e-4)

    Args:
        alpha: smoothing constant.
        eps: term added to the denominator to improve numerical stability.
        weight_decay: weight decay (L2 penalty).
        momentum: momentum factor.
        centered: flag to compute the centered RMSProp, the gradient is
            normalized by an estimation of its variance.
    """

    alpha: float = 0.95
    eps: float = 1e-2
    weight_decay: float = 0.0
    momentum: float = 0.0
    centered: bool = True

    def create(
        self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float
    ) -> RMSprop:
        return RMSprop(
            _get_parameters_from_named_modules(named_modules),
            lr=lr,
            alpha=self.alpha,
            eps=self.eps,
            weight_decay=self.weight_decay,
            momentum=self.momentum,
            centered=self.centered,
        )

    @staticmethod
    def get_type() -> str:
        return "rmsprop"


@dataclasses.dataclass()
class GPTAdamWFactory(OptimizerFactory):
    """AdamW optimizer for Decision Transformer architectures.

    .. code-block:: python

        from d3rlpy.optimizers import GPTAdamWFactory

        factory = GPTAdamWFactory(weight_decay=1e-4)

    Args:
        betas: coefficients used for computing running averages of
            gradient and its square.
        eps: term added to the denominator to improve numerical stability.
        weight_decay: weight decay (L2 penalty).
        amsgrad: flag to use the AMSGrad variant of this algorithm.
    """

    betas: Tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
    weight_decay: float = 0
    amsgrad: bool = False

    def create(
        self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float
    ) -> AdamW:
        named_modules = list(named_modules)
        params_dict = {}
        decay = set()
        no_decay = set()
        for module_name, module in named_modules:
            for param_name, param in module.named_parameters():
                full_name = (
                    f"{module_name}.{param_name}" if module_name else param_name
                )

                if full_name not in params_dict:
                    params_dict[full_name] = param

                if param_name.endswith("bias"):
                    no_decay.add(full_name)
                elif param_name.endswith("weight") and isinstance(
                    module, (nn.Linear, nn.Conv2d)
                ):
                    decay.add(full_name)
                elif param_name.endswith("weight") and isinstance(
                    module, (nn.LayerNorm, nn.Embedding)
                ):
                    no_decay.add(full_name)

        # add non-catched parameters to no_decay
        all_names = set(params_dict.keys())
        remainings = all_names.difference(decay | no_decay)
        no_decay.update(remainings)
        assert len(decay | no_decay) == len(
            _get_parameters_from_named_modules(named_modules)
        )
        assert len(decay & no_decay) == 0

        optim_groups = [
            {
                "params": [params_dict[name] for name in sorted(list(decay))],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [
                    params_dict[name] for name in sorted(list(no_decay))
                ],
                "weight_decay": 0.0,
            },
        ]

        return AdamW(
            optim_groups,
            lr=lr,
            betas=self.betas,
            eps=self.eps,
            amsgrad=self.amsgrad,
        )

    @staticmethod
    def get_type() -> str:
        return "gpt_adam_w"


register_optimizer_factory, make_optimizer_field = generate_config_registration(
    OptimizerFactory, lambda: AdamFactory()
)


register_optimizer_factory(SGDFactory)
register_optimizer_factory(AdamFactory)
register_optimizer_factory(AdamWFactory)
register_optimizer_factory(RMSpropFactory)
register_optimizer_factory(GPTAdamWFactory)