d3rlpy/models/torch/transformers.py
import math
from abc import ABCMeta, abstractmethod
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
from ...torch_utility import GEGLU
from ...types import TorchObservation
from .encoders import Encoder
from .parameters import Parameter, get_parameter
__all__ = [
"ContinuousDecisionTransformer",
"DiscreteDecisionTransformer",
"PositionEncoding",
"SimplePositionEncoding",
"GlobalPositionEncoding",
"GatoTransformer",
]
def create_attention_mask(context_size: int) -> torch.Tensor:
mask = torch.ones(context_size, context_size, dtype=torch.float32)
return torch.tril(mask).view(1, 1, context_size, context_size)
class CausalSelfAttention(nn.Module): # type: ignore
_num_heads: int
_context_size: int
_k: nn.Linear
_q: nn.Linear
_v: nn.Linear
_proj: nn.Linear
_attn_dropout: nn.Dropout
_proj_dropout: nn.Dropout
_mask: torch.Tensor
def __init__(
self,
embed_size: int,
num_heads: int,
context_size: int,
attn_dropout: float,
resid_dropout: float,
):
super().__init__()
self._num_heads = num_heads
self._context_size = context_size
self._k = nn.Linear(embed_size, embed_size)
self._q = nn.Linear(embed_size, embed_size)
self._v = nn.Linear(embed_size, embed_size)
self._proj = nn.Linear(embed_size, embed_size)
self._attn_dropout = nn.Dropout(attn_dropout)
self._proj_dropout = nn.Dropout(resid_dropout)
mask = create_attention_mask(context_size)
self.register_buffer("_mask", mask)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 3, f"Expects (B, T, N), but got {x.shape}"
batch_size, context_size, _ = x.shape
assert context_size <= self._context_size, "Exceeds context_size"
# (B, T, N) -> (B, T, H, N / H) -> (B, H, T, N / H)
shape = (batch_size, context_size, self._num_heads, -1)
k = self._k(x).view(shape).transpose(1, 2)
q = self._q(x).view(shape).transpose(1, 2)
v = self._v(x).view(shape).transpose(1, 2)
# (B, H, T, N / H) -> (B, H, T, T)
qkT = torch.matmul(q, k.transpose(2, 3))
attention = qkT / math.sqrt(k.shape[-1])
attention = attention.masked_fill(
self._mask[..., :context_size, :context_size] == 0, float("-inf")
)
attention = F.softmax(attention, dim=-1)
attention = self._attn_dropout(attention)
# (B, H, T, T) x (B, H, T, N / H) -> (B, H, T, N / H)
output = torch.matmul(attention, v)
# (B, H, T, N / H) -> (B, T, N)
output = output.transpose(1, 2).reshape(batch_size, context_size, -1)
return self._proj_dropout(self._proj(output))
class MLP(nn.Module): # type: ignore
_l1: nn.Linear
_l2: nn.Linear
_dropout: nn.Dropout
_activation: nn.Module
def __init__(
self,
in_size: int,
out_size: int,
pre_activation_hidden_size: int,
post_activation_hidden_size: int,
dropout: float,
activation: nn.Module,
):
super().__init__()
self._l1 = nn.Linear(in_size, pre_activation_hidden_size)
self._l2 = nn.Linear(post_activation_hidden_size, out_size)
self._dropout = nn.Dropout(dropout)
self._activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self._activation(self._l1(x))
h = self._dropout(self._l2(h))
return h
class Block(nn.Module): # type: ignore
_attention: CausalSelfAttention
_mlp: MLP
_layer_norm1: nn.LayerNorm
_layer_norm2: nn.LayerNorm
def __init__(
self,
layer_width: int,
pre_activation_ff_hidden_size: int,
post_activation_ff_hidden_size: int,
num_heads: int,
context_size: int,
attn_dropout: float,
resid_dropout: float,
activation: nn.Module,
):
super().__init__()
self._attention = CausalSelfAttention(
embed_size=layer_width,
num_heads=num_heads,
context_size=context_size,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
)
self._mlp = MLP(
in_size=layer_width,
out_size=layer_width,
pre_activation_hidden_size=pre_activation_ff_hidden_size,
post_activation_hidden_size=post_activation_ff_hidden_size,
dropout=resid_dropout,
activation=activation,
)
self._layer_norm1 = nn.LayerNorm(layer_width, eps=0.003)
self._layer_norm2 = nn.LayerNorm(layer_width, eps=0.003)
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm_x = self._layer_norm1(x)
x = x + self._attention(norm_x)
norm_x = self._layer_norm2(x)
x = x + self._mlp(norm_x)
return x
class PositionEncoding(nn.Module, metaclass=ABCMeta): # type: ignore
@abstractmethod
def forward(self, t: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class SimplePositionEncoding(PositionEncoding):
def __init__(self, embed_dim: int, max_timestep: int):
super().__init__()
self._embed = nn.Embedding(max_timestep, embed_dim)
def forward(self, t: torch.Tensor) -> torch.Tensor:
assert t.dim() == 2, "Expects (B, T)"
# (B, T) -> (B, T, N)
return self._embed(t)
class GlobalPositionEncoding(PositionEncoding):
def __init__(self, embed_dim: int, max_timestep: int, context_size: int):
super().__init__()
self._embed_dim = embed_dim
self._global_position_embedding = nn.Embedding(max_timestep, embed_dim)
self._block_position_embedding = Parameter(
torch.zeros(1, 3 * context_size, embed_dim, dtype=torch.float32)
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
assert t.dim() == 2, "Expects (B, T)"
_, context_size = t.shape
# (B, 1) -> (B, 1, N)
global_embedding = self._global_position_embedding(t[:, -1:])
# (1, 3 * Cmax, N) -> (1, T, N)
block_embedding = get_parameter(self._block_position_embedding)[
:, :context_size, :
]
# (B, 1, N) + (1, T, N) -> (B, T, N)
return global_embedding + block_embedding
class GPT2(nn.Module): # type: ignore
_transformer: nn.Sequential
_layer_norm: nn.LayerNorm
_dropout: nn.Dropout
def __init__(
self,
layer_width: int,
pre_activation_ff_hidden_size: int,
post_activation_ff_hidden_size: int,
num_heads: int,
context_size: int,
num_layers: int,
attn_dropout: float,
resid_dropout: float,
embed_dropout: float,
activation: nn.Module,
):
super().__init__()
blocks = [
Block(
layer_width=layer_width,
pre_activation_ff_hidden_size=pre_activation_ff_hidden_size,
post_activation_ff_hidden_size=post_activation_ff_hidden_size,
num_heads=num_heads,
context_size=context_size,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
activation=activation,
)
for _ in range(num_layers)
]
self._transformer = nn.Sequential(*blocks)
self._layer_norm = nn.LayerNorm(layer_width, eps=0.003)
self._dropout = nn.Dropout(embed_dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self._dropout(x)
h = self._transformer(h)
h = self._layer_norm(h)
return h
def _init_weights(module: nn.Module) -> None:
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class ContinuousDecisionTransformer(nn.Module): # type: ignore
_encoder: Encoder
_position_encoding: PositionEncoding
_action_embed: nn.Linear
_rtg_embed: nn.Linear
_gpt2: GPT2
_output: nn.Linear
def __init__(
self,
encoder: Encoder,
embed_size: int,
position_encoding: PositionEncoding,
action_size: int,
num_heads: int,
context_size: int,
num_layers: int,
attn_dropout: float,
resid_dropout: float,
embed_dropout: float,
activation: nn.Module,
):
super().__init__()
self._position_encoding = position_encoding
self._embed_ln = nn.LayerNorm(embed_size)
self._gpt2 = GPT2(
layer_width=embed_size,
pre_activation_ff_hidden_size=4 * embed_size,
post_activation_ff_hidden_size=4 * embed_size,
num_heads=num_heads,
context_size=3 * context_size,
num_layers=num_layers,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
embed_dropout=embed_dropout,
activation=activation,
)
self.apply(_init_weights)
self._encoder = encoder
self._rtg_embed = nn.Linear(1, embed_size)
self._action_embed = nn.Linear(action_size, embed_size)
self._output = nn.Linear(embed_size, action_size)
def forward(
self,
x: TorchObservation,
action: torch.Tensor,
return_to_go: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
batch_size, context_size, _ = return_to_go.shape
position_embedding = self._position_encoding(timesteps)
if isinstance(x, torch.Tensor):
flat_x = x.view(-1, *x.shape[2:])
else:
flat_x = [_x.view(-1, *_x.shape[2:]) for _x in x]
flat_state_embedding = self._encoder(flat_x)
state_embedding = flat_state_embedding.view(
batch_size, context_size, -1
)
state_embedding = state_embedding + position_embedding
action_embedding = self._action_embed(action) + position_embedding
rtg_embedding = self._rtg_embed(return_to_go) + position_embedding
# (B, T, N) -> (B, 3, T, N)
h = torch.stack(
[rtg_embedding, state_embedding, action_embedding], dim=1
)
# (B, 3, T, N) -> (B, T, 3, N) -> (B, T * 3, N)
h = h.transpose(1, 2).reshape(batch_size, 3 * context_size, -1)
# for inference, drop the last step action to prevent copy
if not self.training:
h = h[:, :-1, :]
h = self._gpt2(self._embed_ln(h))
return torch.tanh(self._output(h[:, 1::3, :]))
class DiscreteDecisionTransformer(nn.Module): # type: ignore
_encoder: Encoder
_position_encoding: PositionEncoding
_action_embed: nn.Embedding
_rtg_embed: nn.Linear
_gpt2: GPT2
_output: nn.Linear
_embed_activation: nn.Module
def __init__(
self,
encoder: Encoder,
embed_size: int,
position_encoding: PositionEncoding,
action_size: int,
num_heads: int,
context_size: int,
num_layers: int,
attn_dropout: float,
resid_dropout: float,
embed_dropout: float,
activation: nn.Module,
embed_activation: nn.Module,
):
super().__init__()
self._position_encoding = position_encoding
self._gpt2 = GPT2(
layer_width=embed_size,
pre_activation_ff_hidden_size=4 * embed_size,
post_activation_ff_hidden_size=4 * embed_size,
num_heads=num_heads,
context_size=3 * context_size,
num_layers=num_layers,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
embed_dropout=embed_dropout,
activation=activation,
)
self._output = nn.Linear(embed_size, action_size, bias=False)
self._action_embed = nn.Embedding(action_size, embed_size)
self.apply(_init_weights)
self._encoder = encoder
self._rtg_embed = nn.Linear(1, embed_size)
self._embed_activation = embed_activation
def forward(
self,
x: TorchObservation,
action: torch.Tensor,
return_to_go: torch.Tensor,
timesteps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, context_size, _ = return_to_go.shape
position_embedding = self._position_encoding(timesteps)
if isinstance(x, torch.Tensor):
flat_x = x.reshape(-1, *x.shape[2:])
else:
flat_x = [_x.reshape(-1, *_x.shape[2:]) for _x in x]
flat_state_embedding = self._encoder(flat_x)
state_embedding = flat_state_embedding.view(
batch_size, context_size, -1
)
flat_action = action.view(batch_size, context_size).long()
action_embedding = self._action_embed(flat_action)
rtg_embedding = self._rtg_embed(return_to_go)
# (B, T, N) -> (B, 3, T, N)
h = torch.stack(
[rtg_embedding, state_embedding, action_embedding], dim=1
)
h = self._embed_activation(h)
h = h + position_embedding.view(batch_size, 1, context_size, -1)
# (B, 3, T, N) -> (B, T, 3, N) -> (B, T * 3, N)
h = h.transpose(1, 2).reshape(batch_size, 3 * context_size, -1)
# for inference, drop the last step action to prevent copy
if not self.training:
h = h[:, :-1, :]
h = self._gpt2(h)
# use state embeddings as input
logits = self._output(h[:, 1::3, :])
return F.softmax(logits, dim=-1), logits
class GatoTransformer(nn.Module): # type: ignore
_gpt2: GPT2
_token_embed: nn.Embedding
_observation_pos_embed: nn.Embedding
_action_pos_embed: Parameter
_output: nn.Linear
_embed_activation: nn.Module
def __init__(
self,
layer_width: int,
max_observation_length: int,
vocab_size: int,
num_heads: int,
context_size: int,
num_layers: int,
attn_dropout: float,
resid_dropout: float,
embed_dropout: float,
embed_activation: nn.Module,
):
super().__init__()
self._gpt2 = GPT2(
layer_width=layer_width,
pre_activation_ff_hidden_size=2 * 4 * layer_width,
post_activation_ff_hidden_size=4 * layer_width,
num_heads=num_heads,
context_size=context_size,
num_layers=num_layers,
attn_dropout=attn_dropout,
resid_dropout=resid_dropout,
embed_dropout=embed_dropout,
activation=GEGLU(),
)
self._output = nn.Linear(layer_width, vocab_size, bias=False)
# +1 for separator token
self._token_embed = nn.Embedding(vocab_size + 1, layer_width)
self._observation_pos_embed = nn.Embedding(
max_observation_length, layer_width
)
self._action_pos_embed = Parameter(
torch.zeros(1, 1, layer_width, dtype=torch.float32)
)
self.apply(_init_weights)
self._embed_activation = embed_activation
def forward(
self,
tokens: torch.Tensor,
observation_masks: torch.Tensor,
observation_positions: torch.Tensor,
action_masks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: Support text and patch tokens
assert tokens.ndim == 2
batch_size, context_size = tokens.shape
assert observation_masks.shape == (batch_size, context_size, 1)
assert observation_positions.shape == (batch_size, context_size)
assert action_masks.shape == (batch_size, context_size, 1)
# (B, T, N)
embeddings = self._embed_activation(self._token_embed(tokens))
# add local observation embedding
embeddings = (
embeddings
+ observation_masks
* self._observation_pos_embed(observation_positions)
)
# add action embedding
embeddings = embeddings + action_masks * get_parameter(
self._action_pos_embed
)
# (B, T, N) -> (B, T, N)
h = self._gpt2(embeddings)
# (B, T, N) -> (B, T, vocab)
logits = self._output(h)
return F.softmax(logits, dim=-1), logits