takuseno/d3rlpy

View on GitHub
d3rlpy/models/builders.py

Summary

Maintainability
A
50 mins
Test Coverage
from typing import Sequence, Tuple, cast

import torch
from torch import nn

from ..constants import PositionEncodingType
from ..types import Shape
from .encoders import EncoderFactory
from .q_functions import QFunctionFactory
from .torch import (
    CategoricalPolicy,
    ContinuousDecisionTransformer,
    ContinuousEnsembleQFunctionForwarder,
    DeterministicPolicy,
    DeterministicResidualPolicy,
    DiscreteDecisionTransformer,
    DiscreteEnsembleQFunctionForwarder,
    GlobalPositionEncoding,
    NormalPolicy,
    Parameter,
    PositionEncoding,
    SimplePositionEncoding,
    VAEDecoder,
    VAEEncoder,
    ValueFunction,
    compute_output_size,
)
from .utility import create_activation

__all__ = [
    "create_discrete_q_function",
    "create_continuous_q_function",
    "create_deterministic_policy",
    "create_deterministic_residual_policy",
    "create_categorical_policy",
    "create_normal_policy",
    "create_vae_encoder",
    "create_vae_decoder",
    "create_value_function",
    "create_parameter",
    "create_continuous_decision_transformer",
    "create_discrete_decision_transformer",
]


def create_discrete_q_function(
    observation_shape: Shape,
    action_size: int,
    encoder_factory: EncoderFactory,
    q_func_factory: QFunctionFactory,
    device: str,
    n_ensembles: int = 1,
) -> Tuple[nn.ModuleList, DiscreteEnsembleQFunctionForwarder]:
    if q_func_factory.share_encoder:
        encoder = encoder_factory.create(observation_shape)
        hidden_size = compute_output_size([observation_shape], encoder)
        # normalize gradient scale by ensemble size
        for p in cast(nn.Module, encoder).parameters():
            p.register_hook(lambda grad: grad / n_ensembles)

    q_funcs = []
    forwarders = []
    for _ in range(n_ensembles):
        if not q_func_factory.share_encoder:
            encoder = encoder_factory.create(observation_shape)
            hidden_size = compute_output_size([observation_shape], encoder)
        q_func, forwarder = q_func_factory.create_discrete(
            encoder, hidden_size, action_size
        )
        q_funcs.append(q_func)
        forwarders.append(forwarder)
    q_func_modules = nn.ModuleList(q_funcs)
    q_func_modules.to(device)
    ensemble_forwarder = DiscreteEnsembleQFunctionForwarder(
        forwarders, action_size
    )
    return q_func_modules, ensemble_forwarder


def create_continuous_q_function(
    observation_shape: Shape,
    action_size: int,
    encoder_factory: EncoderFactory,
    q_func_factory: QFunctionFactory,
    device: str,
    n_ensembles: int = 1,
) -> Tuple[nn.ModuleList, ContinuousEnsembleQFunctionForwarder]:
    if q_func_factory.share_encoder:
        encoder = encoder_factory.create_with_action(
            observation_shape, action_size
        )
        hidden_size = compute_output_size(
            [observation_shape, (action_size,)], encoder
        )
        # normalize gradient scale by ensemble size
        for p in cast(nn.Module, encoder).parameters():
            p.register_hook(lambda grad: grad / n_ensembles)

    q_funcs = []
    forwarders = []
    for _ in range(n_ensembles):
        if not q_func_factory.share_encoder:
            encoder = encoder_factory.create_with_action(
                observation_shape, action_size
            )
            hidden_size = compute_output_size(
                [observation_shape, (action_size,)], encoder
            )
        q_func, forwarder = q_func_factory.create_continuous(
            encoder, hidden_size
        )
        q_funcs.append(q_func)
        forwarders.append(forwarder)
    q_func_modules = nn.ModuleList(q_funcs)
    q_func_modules.to(device)
    ensemble_forwarder = ContinuousEnsembleQFunctionForwarder(
        forwarders, action_size
    )
    return q_func_modules, ensemble_forwarder


def create_deterministic_policy(
    observation_shape: Shape,
    action_size: int,
    encoder_factory: EncoderFactory,
    device: str,
) -> DeterministicPolicy:
    encoder = encoder_factory.create(observation_shape)
    hidden_size = compute_output_size([observation_shape], encoder)
    policy = DeterministicPolicy(
        encoder=encoder,
        hidden_size=hidden_size,
        action_size=action_size,
    )
    policy.to(device)
    return policy


def create_deterministic_residual_policy(
    observation_shape: Shape,
    action_size: int,
    scale: float,
    encoder_factory: EncoderFactory,
    device: str,
) -> DeterministicResidualPolicy:
    encoder = encoder_factory.create_with_action(observation_shape, action_size)
    hidden_size = compute_output_size(
        [observation_shape, (action_size,)], encoder
    )
    policy = DeterministicResidualPolicy(
        encoder=encoder,
        hidden_size=hidden_size,
        action_size=action_size,
        scale=scale,
    )
    policy.to(device)
    return policy


def create_normal_policy(
    observation_shape: Shape,
    action_size: int,
    encoder_factory: EncoderFactory,
    device: str,
    min_logstd: float = -20.0,
    max_logstd: float = 2.0,
    use_std_parameter: bool = False,
) -> NormalPolicy:
    encoder = encoder_factory.create(observation_shape)
    hidden_size = compute_output_size([observation_shape], encoder)
    policy = NormalPolicy(
        encoder=encoder,
        hidden_size=hidden_size,
        action_size=action_size,
        min_logstd=min_logstd,
        max_logstd=max_logstd,
        use_std_parameter=use_std_parameter,
    )
    policy.to(device)
    return policy


def create_categorical_policy(
    observation_shape: Shape,
    action_size: int,
    encoder_factory: EncoderFactory,
    device: str,
) -> CategoricalPolicy:
    encoder = encoder_factory.create(observation_shape)
    hidden_size = compute_output_size([observation_shape], encoder)
    policy = CategoricalPolicy(
        encoder=encoder, hidden_size=hidden_size, action_size=action_size
    )
    policy.to(device)
    return policy


def create_vae_encoder(
    observation_shape: Shape,
    action_size: int,
    latent_size: int,
    encoder_factory: EncoderFactory,
    device: str,
    min_logstd: float = -20.0,
    max_logstd: float = 2.0,
) -> VAEEncoder:
    encoder = encoder_factory.create_with_action(observation_shape, action_size)
    encoder_hidden_size = compute_output_size(
        [observation_shape, (action_size,)], encoder
    )
    vae_encoder = VAEEncoder(
        encoder=encoder,
        hidden_size=encoder_hidden_size,
        latent_size=latent_size,
        min_logstd=min_logstd,
        max_logstd=max_logstd,
    )
    vae_encoder.to(device)
    return vae_encoder


def create_vae_decoder(
    observation_shape: Shape,
    action_size: int,
    latent_size: int,
    encoder_factory: EncoderFactory,
    device: str,
) -> VAEDecoder:
    encoder = encoder_factory.create_with_action(observation_shape, latent_size)
    decoder_hidden_size = compute_output_size(
        [observation_shape, (latent_size,)], encoder
    )
    decoder = VAEDecoder(
        encoder=encoder,
        hidden_size=decoder_hidden_size,
        action_size=action_size,
    )
    decoder.to(device)
    return decoder


def create_value_function(
    observation_shape: Shape, encoder_factory: EncoderFactory, device: str
) -> ValueFunction:
    encoder = encoder_factory.create(observation_shape)
    hidden_size = compute_output_size([observation_shape], encoder)
    value_func = ValueFunction(encoder, hidden_size)
    value_func.to(device)
    return value_func


def create_parameter(
    shape: Sequence[int], initial_value: float, device: str
) -> Parameter:
    data = torch.full(shape, initial_value, dtype=torch.float32)
    parameter = Parameter(data)
    parameter.to(device)
    return parameter


def _create_position_encoding(
    position_encoding_type: PositionEncodingType,
    embed_dim: int,
    max_timestep: int,
    context_size: int,
) -> PositionEncoding:
    if position_encoding_type == PositionEncodingType.SIMPLE:
        position_encoding = SimplePositionEncoding(embed_dim, max_timestep + 1)
    elif position_encoding_type == PositionEncodingType.GLOBAL:
        position_encoding = GlobalPositionEncoding(
            embed_dim, max_timestep + 1, context_size
        )
    else:
        raise ValueError(
            f"invalid position_encoding_type: {position_encoding_type}"
        )
    return position_encoding


def create_continuous_decision_transformer(
    observation_shape: Shape,
    action_size: int,
    encoder_factory: EncoderFactory,
    num_heads: int,
    max_timestep: int,
    num_layers: int,
    context_size: int,
    attn_dropout: float,
    resid_dropout: float,
    embed_dropout: float,
    activation_type: str,
    position_encoding_type: PositionEncodingType,
    device: str,
) -> ContinuousDecisionTransformer:
    encoder = encoder_factory.create(observation_shape)
    hidden_size = compute_output_size([observation_shape], encoder)

    position_encoding = _create_position_encoding(
        position_encoding_type=position_encoding_type,
        embed_dim=hidden_size,
        max_timestep=max_timestep + 1,
        context_size=context_size,
    )

    transformer = ContinuousDecisionTransformer(
        encoder=encoder,
        embed_size=hidden_size,
        position_encoding=position_encoding,
        action_size=action_size,
        num_heads=num_heads,
        context_size=context_size,
        num_layers=num_layers,
        attn_dropout=attn_dropout,
        resid_dropout=resid_dropout,
        embed_dropout=embed_dropout,
        activation=create_activation(activation_type),
    )
    transformer.to(device)
    return transformer


def create_discrete_decision_transformer(
    observation_shape: Shape,
    action_size: int,
    encoder_factory: EncoderFactory,
    num_heads: int,
    max_timestep: int,
    num_layers: int,
    context_size: int,
    attn_dropout: float,
    resid_dropout: float,
    embed_dropout: float,
    activation_type: str,
    embed_activation_type: str,
    position_encoding_type: PositionEncodingType,
    device: str,
) -> DiscreteDecisionTransformer:
    encoder = encoder_factory.create(observation_shape)
    hidden_size = compute_output_size([observation_shape], encoder)

    position_encoding = _create_position_encoding(
        position_encoding_type=position_encoding_type,
        embed_dim=hidden_size,
        max_timestep=max_timestep + 1,
        context_size=context_size,
    )

    transformer = DiscreteDecisionTransformer(
        encoder=encoder,
        embed_size=hidden_size,
        position_encoding=position_encoding,
        action_size=action_size,
        num_heads=num_heads,
        context_size=context_size,
        num_layers=num_layers,
        attn_dropout=attn_dropout,
        resid_dropout=resid_dropout,
        embed_dropout=embed_dropout,
        activation=create_activation(activation_type),
        embed_activation=create_activation(embed_activation_type),
    )
    transformer.to(device)
    return transformer