official/projects/pix2seq/modeling/pix2seq_model.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements A Language Modeling Framework for Object Detection.
Model paper: https://arxiv.org/abs/2109.10852
This module does not support Keras de/serialization. Please use
tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
for graph serialization.
"""
import math
from typing import Any, List, Mapping, Optional, Union
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.projects.pix2seq.modeling import transformer
def get_shape(x):
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_variable_initializer(name=None):
if name is None:
return tf_keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02)
def add_seq_pos_emb(
self, pos_encoding, max_seq_len, dim, name_prefix=None, initializer=None
):
"""Add seq_pos_emb variable/tensor to model instance referenced by `self`."""
if name_prefix is None:
name_prefix = self.name
if initializer is None:
initializer = get_variable_initializer()
if pos_encoding == "learned":
self.seq_pos_emb = self.add_weight(
shape=(max_seq_len + 1, dim),
initializer=initializer,
name="%s/seq_pos_embedding" % name_prefix,
)
# (gunho) currently only 'learned' positional encoding is supported
elif pos_encoding == "sin_cos":
self.seq_pos_emb = None
else:
raise ValueError("Unknown pos encoding %s" % pos_encoding)
def add_vocab_token_emb(
self,
vocab_size,
dim,
shared_embedding,
output_bias,
name_prefix=None,
initializer=None,
):
"""Add token_embedding variable to model instance referenced by `self`."""
if name_prefix is None:
name_prefix = self.name
if initializer is None:
initializer = get_variable_initializer()
if shared_embedding:
self.token_embedding = self.add_weight(
shape=[vocab_size, dim],
initializer=initializer,
name="%s/token_embedding" % name_prefix,
)
else:
self.inp_token_embedding = self.add_weight(
shape=[vocab_size, dim],
initializer=initializer,
name="%s/inp_token_embedding" % name_prefix,
)
self.outp_token_embedding = self.add_weight(
shape=[vocab_size, dim],
initializer=initializer,
name="%s/outp_token_embedding" % name_prefix,
)
if output_bias:
self.outp_bias = self.add_weight(
shape=[vocab_size],
initializer=initializer,
name="%s/outp_bias" % name_prefix,
)
def get_ar_mask(seq_len, dtype=tf.float32):
"""Get autoregressive causal mask so the model cannot attends to the future.
Args:
seq_len: a `int` or `int` tensor specifying the sequence length.
dtype: tf data type for the return tensor.
Returns:
tensor of shape [1, 1, seq_len, seq_len] with ones for locations to be
masked out.
"""
valid_locs = tf.linalg.band_part(
tf.ones([seq_len, seq_len], dtype=dtype), -1, 0
)
valid_locs = tf.reshape(valid_locs, [1, 1, seq_len, seq_len])
return 1.0 - valid_locs
def position_embedding_sine(
attention_mask,
num_pos_features=256,
temperature=10000.0,
normalize=True,
scale=2 * math.pi,
):
"""Sine-based positional embeddings for 2D images.
Args:
attention_mask: a `bool` Tensor specifying the size of the input image to
the Transformer and which elements are padded, of size [batch_size,
height, width]
num_pos_features: a `int` specifying the number of positional features,
should be equal to the hidden size of the Transformer network
temperature: a `float` specifying the temperature of the positional
embedding. Any type that is converted to a `float` can also be accepted.
normalize: a `bool` determining whether the positional embeddings should be
normalized between [0, scale] before application of the sine and cos
functions.
scale: a `float` if normalize is True specifying the scale embeddings before
application of the embedding function.
Returns:
embeddings: a `float` tensor of the same shape as input_tensor specifying
the positional embeddings based on sine features.
"""
if num_pos_features % 2 != 0:
raise ValueError(
"Number of embedding features (num_pos_features) must be even when "
"column and row embeddings are concatenated."
)
num_pos_features = num_pos_features // 2
# Produce row and column embeddings based on total size of the image
# <tf.float>[batch_size, height, width]
attention_mask = tf.cast(attention_mask, tf.float32)
row_embedding = tf.cumsum(attention_mask, 1)
col_embedding = tf.cumsum(attention_mask, 2)
if normalize:
eps = 1e-6
row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale
col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale
dim_t = tf.range(num_pos_features, dtype=row_embedding.dtype)
dim_t = tf.pow(temperature, 2 * (dim_t // 2) / num_pos_features)
# Creates positional embeddings for each row and column position
# <tf.float>[batch_size, height, width, num_pos_features]
pos_row = tf.expand_dims(row_embedding, -1) / dim_t
pos_col = tf.expand_dims(col_embedding, -1) / dim_t
pos_row = tf.stack(
[tf.sin(pos_row[:, :, :, 0::2]), tf.cos(pos_row[:, :, :, 1::2])], axis=4
)
pos_col = tf.stack(
[tf.sin(pos_col[:, :, :, 0::2]), tf.cos(pos_col[:, :, :, 1::2])], axis=4
)
final_shape = tf_utils.get_shape_list(pos_row)[:3] + [-1]
pos_row = tf.reshape(pos_row, final_shape)
pos_col = tf.reshape(pos_col, final_shape)
output = tf.concat([pos_row, pos_col], -1)
embeddings = tf.cast(output, tf.float32)
return embeddings
def top_logits(
logits: tf.Tensor, k: int = 0, p: float = 1.0, mask: float = -1e10
) -> tf.Tensor:
"""Remove low probability logits via masking.
Args:
logits: class logits in shape of (batch size, total_classes).
k: specifying top k largest logits to keep.
p: specifying a probability for finding a minimum set of largest logits to
keep, where their cumulative probability is no less than p (actually in
the following version, it is "...cumulative probability is the largest but
no more than p").
mask: a value that's used to replace logits that don't satisfy the keep
conditions.
Returns:
logits where low probability ones are replaced with mask.
"""
mask = tf.ones_like(logits) * mask
if k > 0:
min_logits = tf.nn.top_k(logits, k=k)[0][:, -1:]
logits = tf.where(logits < min_logits, mask, logits)
if p < 1.0:
sorted_logits = tf.sort(logits, direction="DESCENDING", axis=-1)
cum_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
min_logits = -tf.reduce_max(
tf.where(cum_probs <= p, -sorted_logits, mask), -1, keepdims=True
)
min_logits = tf.minimum(min_logits, sorted_logits[:, :1])
logits = tf.where(logits < min_logits, mask, logits)
return logits
class Pix2Seq(tf_keras.Model):
"""Pix2Seq model with Keras.
Pix2Seq consists of backbone, input token embedding, Pix2SeqTransformer.
"""
def __init__(
self,
backbone,
backbone_endpoint_name,
max_seq_len,
vocab_size,
hidden_size,
num_heads,
num_encoder_layers=6,
num_decoder_layers=6,
drop_path=0.1,
drop_units=0.1,
drop_att=0.0,
temperature=1.0,
top_k=0,
top_p=0.4,
eos_token: int | None = None,
**kwargs
):
super().__init__(**kwargs)
self._backbone = backbone
self._backbone_endpoint_name = backbone_endpoint_name
self._max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._num_heads = num_heads
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
self._drop_path = drop_path
self._drop_units = drop_units
self._drop_att = drop_att
if hidden_size % 2 != 0:
raise ValueError("hidden_size must be a multiple of 2.")
self._dropout = tf_keras.layers.Dropout(self._drop_units)
self._stem_projection = tf_keras.layers.Dense(
self._hidden_size, name="stem_projection"
)
self._stem_ln = tf_keras.layers.LayerNormalization(
epsilon=1e-6, name="stem_ln"
)
self._transformer = Pix2SeqTransformer(
max_seq_len=self._max_seq_len,
vocab_size=self._vocab_size,
hidden_size=self._hidden_size,
pos_encoding="learned",
num_encoder_layers=self._num_encoder_layers,
num_decoder_layers=self._num_decoder_layers,
drop_path=self._drop_path,
drop_units=self._drop_units,
drop_att=self._drop_att,
num_heads=self._num_heads,
)
self._temperature = temperature
self._top_k = top_k
self._top_p = top_p
self._eos_token = eos_token
@property
def backbone(self) -> tf_keras.Model:
return self._backbone
@property
def transformer(self) -> tf_keras.Model:
return self._transformer
def get_config(self):
return {
"backbone": self._backbone,
"backbone_endpoint_name": self._backbone_endpoint_name,
"max_seq_len": self._max_seq_len,
"vocab_size": self._vocab_size,
"hidden_size": self._hidden_size,
"num_encoder_layers": self._num_encoder_layers,
"num_decoder_layers": self._num_decoder_layers,
"drop_path": self._drop_path,
"drop_units": self._drop_units,
"drop_att": self._drop_att,
"temperature": self._temperature,
"top_k": self._top_k,
"top_p": self._top_p,
"eos_token": self._eos_token,
"num_heads": self._num_heads,
}
@classmethod
def from_config(cls, config):
return cls(**config)
@property
def checkpoint_items(
self,
) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone, transformer=self.transformer)
return items
def _generate_image_mask(
self, inputs: tf.Tensor, target_shape: tf.Tensor
) -> tf.Tensor:
"""Generates image mask from input image."""
mask = tf.expand_dims(
tf.cast(
tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0.3), inputs.dtype
),
axis=-1,
)
mask = tf.image.resize(
mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
)
return mask
def call(
self,
inputs: tf.Tensor,
targets: Optional[tf.Tensor] = None,
training: bool = None,
use_teacher_forcing_for_eval: bool = False,
use_input_as_backbone_features=False,
) -> List[Any]:
if use_input_as_backbone_features:
features = inputs
else:
features = self._backbone(inputs)[self._backbone_endpoint_name]
mask = tf.ones_like(features)
batch_size, h, w, num_channels = get_shape(features)
features = tf.reshape(features, [batch_size, h * w, num_channels])
features = self._stem_ln(
self._stem_projection(self._dropout(features, training))
)
pos_emb = position_embedding_sine(
mask[:, :, :, 0], num_pos_features=self._hidden_size
)
pos_emb = tf.reshape(pos_emb, [batch_size, -1, self._hidden_size])
pos_emb = tf.cast(pos_emb, features.dtype)
tokens = None
inputs = {
"inputs": features,
"tokens": targets,
"pos_emb": pos_emb,
}
if training:
logits = self._transformer(inputs, training=True)
elif use_teacher_forcing_for_eval:
logits = self._transformer(inputs, training=False)
else:
tokens, logits = self._transformer.infer(
inputs,
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p,
eos_token=self._eos_token,
)
return [tokens, logits]
def _create_cond_fn(seq_len: int, eos_token: int | None, prompt_len: int):
"""Returns a loop condition for decoder.
Args:
seq_len: the maximum sequence length.
eos_token: if not None, enable early termination based on end-of-sequence
token.
prompt_len: the length of prompt sequence.
"""
def cond(step, caches, tokens, logits):
del caches
del logits
within_seq_len = (seq_len > prompt_len) & (step < seq_len - 1)
if eos_token is None:
return within_seq_len
else:
tokens = tokens[prompt_len:step]
reached_eos = tf.reduce_all(tf.reduce_any(tokens == eos_token, axis=0))
return within_seq_len & tf.logical_not(reached_eos)
return cond
class Pix2SeqTransformer(tf_keras.layers.Layer):
"""Encoder and Decoder of Pix2Seq."""
def __init__(
self,
max_seq_len,
vocab_size,
hidden_size,
pos_encoding="learned",
num_encoder_layers=6,
num_decoder_layers=6,
drop_path=0.1,
drop_units=0.1,
drop_att=0.0,
shared_embedding=True,
output_bias=True,
num_heads=8,
**kwargs
):
super().__init__(**kwargs)
self._max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._hidden_size = hidden_size
self._pos_encoding = pos_encoding
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
self._drop_path = drop_path
self._drop_units = drop_units
self._drop_att = drop_att
self._shared_embedding = shared_embedding
self._output_bias = output_bias
self._num_heads = num_heads
add_seq_pos_emb(
self, self._pos_encoding, self._max_seq_len, self._hidden_size
)
add_vocab_token_emb(
self,
self._vocab_size,
self._hidden_size,
self._shared_embedding,
self._output_bias,
)
if self._num_encoder_layers > 0:
self._encoder = transformer.TransformerEncoder(
num_layers=self._num_encoder_layers,
dim=self._hidden_size,
mlp_ratio=4,
num_heads=self._num_heads,
drop_path=self._drop_path,
drop_units=self._drop_units,
drop_att=self._drop_att,
)
else:
self._encoder = None
self._output_ln_enc = tf_keras.layers.LayerNormalization(
epsilon=1e-6, name="output_ln_enc"
)
self._proj = tf_keras.layers.Dense(self._hidden_size, name="proj/linear")
self._proj_ln = tf_keras.layers.LayerNormalization(
epsilon=1e-6, name="proj/ln"
)
self._proj_mlp = transformer.MLP(
num_layers=1,
dim=self._hidden_size,
mlp_ratio=4,
drop_path=self._drop_path,
drop_units=self._drop_units,
name="proj/mlp",
)
self._decoder = transformer.TransformerDecoder(
num_layers=self._num_decoder_layers,
dim=self._hidden_size,
mlp_ratio=4,
num_heads=self._num_heads,
drop_path=self._drop_path,
drop_units=self._drop_units,
drop_att=self._drop_att,
)
self._output_ln_dec = tf_keras.layers.LayerNormalization(
epsilon=1e-6, name="output_ln_dec"
)
def get_config(self):
return {
"max_seq_len": self._max_seq_len,
"vocab_size": self._vocab_size,
"hidden_size": self._hidden_size,
"pos_encoding": self._pos_encoding,
"num_encoder_layers": self._num_encoder_layers,
"num_decoder_layers": self._num_decoder_layers,
"drop_path": self._drop_path,
"drop_units": self._drop_units,
"drop_att": self._drop_att,
"shared_embedding": self._shared_embedding,
"output_bias": self._output_bias,
"num_heads": self._num_heads,
}
def call(self, inputs: tf.Tensor, training: bool = None):
sources = inputs["inputs"]
targets = inputs["tokens"]
mem_pos_embed = inputs["pos_emb"]
sources = sources + mem_pos_embed
if self._encoder is not None:
encoded = self._encoder(sources, None, training=training, ret_list=False)
else:
encoded = sources
encoded = self._output_ln_enc(encoded)
encoded = self._proj_ln(self._proj(encoded))
encoded = encoded + mem_pos_embed
encoded = self._proj_mlp(encoded, training=training)
seq_len = tf.shape(targets)[1]
seq_pos_emb = tf.expand_dims(self.seq_pos_emb[:seq_len], 0)
inp_embedding = outp_embedding = self.token_embedding
target_emb = tf.gather(inp_embedding, targets) + seq_pos_emb
self_attention_mask = 1.0 - get_ar_mask(seq_len, target_emb.dtype)
decoded, _ = self._decoder(
target_emb, encoded, None, self_attention_mask, None, training
)
decoded = self._output_ln_dec(decoded)
decoded = tf.cast(decoded, seq_pos_emb.dtype)
outp_embedding = tf.cast(outp_embedding, seq_pos_emb.dtype)
logits = tf.matmul(decoded, outp_embedding, transpose_b=True)
if self._output_bias:
logits = tf.nn.bias_add(logits, self.outp_bias)
return logits
def infer(
self,
inputs: tf.Tensor,
max_seq_len=None,
temperature=1.0,
top_k=0,
top_p=0.4,
sampling_callback=None,
eos_token: int | None = None,
):
"""Autoregressive (without teacher-forcing) prediction.
Note: the autoregressive sampling/inference time can be further optimized by
caching *transformed* key / value inside multi-head attention for the
`encoded` and previously generated tokens, but this may make the code less
readable.
Args:
inputs: prompt - `int` tokens with shape of (bsz, prompt_len). encoded -
`float` encoded representations for conditioning with shape of (bsz,
size, dim). This can be optional in case of pure decoder.
max_seq_len: `int` of max generated sequence length (including prompt).
temperature: `float` scalar for scaling the logits before sampling.
top_k: `int` scalar for truncating top-k tokens according to logits before
token sampling.
top_p: `float` scalar specifying the threshold of cumulative probability
for truncating tokens before token sampling.
sampling_callback: a callbak `function` that take `next_logits`, and
return `next_token`. This is used when users need a specific logic for
sampling. Default to `None` with standard free-form sampling.
eos_token: if not None, stop inference early based on this end-of-sequence
(EOS) token. This won't change sequence length. However, for each
sequence, the tokens after the EOS token will be set to the EOS token
and logit values will have undefined behavior based on implementation
detail.
Returns:
sampled tokens with shape of (bsz, max_seq_len-prompt_len).
logits (temperature-scaled) associated with sampled token, in shape of
(bsz, max_seq_len-prompt_len, vocab_size).
"""
sources = inputs["inputs"]
prompt = inputs["tokens"]
mem_pos_embed = inputs["pos_emb"]
sources = sources + mem_pos_embed
if self._encoder is not None:
encoded = self._encoder(sources, None, training=False, ret_list=False)
else:
encoded = sources
encoded = self._output_ln_enc(encoded)
encoded = self._proj_ln(self._proj(encoded))
encoded = encoded + mem_pos_embed
encoded = self._proj_mlp(encoded, training=False)
bsz = tf.shape(prompt)[0]
prompt_len = tf.shape(prompt)[1]
seq_len = self._max_seq_len if max_seq_len is None else max_seq_len
# (gunho) 500 (self._max_seq_len) -> 501 for prompt seq
seq_len = seq_len + 1
seq_pos_emb = tf.expand_dims(self.seq_pos_emb, 0)
inp_embedding = self.token_embedding
outp_embedding = inp_embedding
# Each step reads caches[:step] and tokens[step:next_step] and updates
# tokens[next_step], logits[next_step] and caches[step:next_step].
# On the first step, step=0, next_step=prompt_len. On subsequent steps
# next_step = step + 1.
def loop_body(step, caches, tokens, logits, is_prompt=False):
if is_prompt:
assert step == 0
x = tf.gather(inp_embedding, tf.transpose(tokens[:prompt_len]))
input_pos_embed = seq_pos_emb[:, :prompt_len]
x += input_pos_embed
self_attention_mask = 1.0 - get_ar_mask(prompt_len, x.dtype)
caches_in = None
else:
x = tf.gather(inp_embedding, tf.transpose(tokens[step]))
input_pos_embed = seq_pos_emb[:, step]
x += input_pos_embed
x = tf.expand_dims(x, 1) # (bsz, 1, d)
self_attention_mask = tf.ones([1, 1, 1, 1])
caches_in = tf.transpose(caches[:step], [1, 2, 0, 3])
decoded, caches_out = self._decoder(
x, encoded, caches_in, self_attention_mask, None, training=False
)
decoded = self._output_ln_dec(decoded)
# (gunho) transformer.py uses tf.float32 for numeric stability.
decoded = tf.cast(decoded, seq_pos_emb.dtype)
next_logits = tf.matmul( # only take the last for sampling next token.
decoded, outp_embedding, transpose_b=True
)[:, -1]
if self._output_bias:
next_logits = tf.nn.bias_add(next_logits, self.outp_bias)
# Scale and trunctate logits and sample next token.
if sampling_callback:
next_token = sampling_callback(
next_logits, step, temperature, top_k, top_p
)
else:
sampling_logits = next_logits / tf.cast(temperature, tf.float32)
sampling_logits = top_logits(sampling_logits, k=top_k, p=top_p)
next_token = tf.random.categorical(
sampling_logits, num_samples=1, dtype=tf.int32
)[:, 0]
# Update internal states.
next_step = step + (prompt_len if is_prompt else 1)
caches_out = tf.transpose(caches_out, [2, 0, 1, 3])
if is_prompt:
caches = tf.tensor_scatter_nd_update(
caches,
tf.range(prompt_len)[:, tf.newaxis],
caches_out,
)
else:
caches = tf.tensor_scatter_nd_update(caches, [[step]], caches_out)
tokens = tf.tensor_scatter_nd_update(tokens, [[next_step]], [next_token])
logits = tf.tensor_scatter_nd_update(logits, [[next_step]], [next_logits])
return (next_step, caches, tokens, logits)
caches_var = tf.zeros(
[seq_len - 1, self._num_decoder_layers, bsz, self._hidden_size]
)
tokens_var = tf.zeros([seq_len, bsz], dtype=tf.int64)
logits_var = tf.zeros([seq_len, bsz, self._vocab_size], dtype=tf.float32)
indices = tf.expand_dims(tf.range(prompt_len), -1)
tokens_var = tf.tensor_scatter_nd_update(
tokens_var, indices, tf.transpose(prompt, [1, 0])
)
step = 0
step, caches_var, tokens_var, logits_var = loop_body(
step, caches_var, tokens_var, logits_var, is_prompt=True
)
step, _, tokens_var, logits_var = tf.while_loop(
cond=_create_cond_fn(
seq_len=seq_len, eos_token=eos_token, prompt_len=prompt_len
),
body=loop_body,
loop_vars=[step, caches_var, tokens_var, logits_var],
)
# If stopping early based on eos_token, assign eos_token to all tokens after
# stopping occurs.
if eos_token is not None:
tokens_var = tf.where(
tf.range(seq_len)[:, tf.newaxis] >= step,
tf.cast(eos_token, tokens_var.dtype),
tokens_var,
)
sampled_tokens = tf.transpose(tokens_var[prompt_len:], [1, 0])
sampled_tokens_logits = tf.transpose(logits_var[prompt_len:], [1, 0, 2])
return sampled_tokens, sampled_tokens_logits