tensorflow/models

View on GitHub
official/projects/perceiver/modeling/layers/decoder.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Perceiver basic decoder."""

import collections

import tensorflow as tf, tf_keras

from official.nlp.modeling import layers
from official.projects.perceiver.modeling.layers import utils


class Decoder(tf_keras.layers.Layer):
  """Perceiver Decoder layer.

  Uses cross attention decoder layer.
  This layer implements a Perceiver Decoder from
  "Perceiver: General Perception with Iterative Attention".
  (https://arxiv.org/abs/2103.03206)

  References:
    [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
    [Perceiver: General Perception with Iterative
    Attention](https://arxiv.org/abs/2103.03206)
    (https://github.com/deepmind/deepmind-research/blob/master/perceiver/perceiver.py)
    (https://github.com/tensorflow/models/blob/871c4e0a393ef4385534bee55354a5df8aa1ccf4/official/nlp/modeling/layers/transformer_encoder_block.py)
  """

  def __init__(self,
               output_last_dim,
               qk_last_dim=None,
               v_last_dim=None,
               use_query_residual=False,
               output_w_init=None,
               num_heads=1,
               name="decoder",
               **kwargs):
    """Init.

    Args:
      output_last_dim:
        Last dim size for output.
      qk_last_dim:
        When set, determines the last dimension of the attention score output.
        Check `qk_last_dim` doc in `utils.build_cross_attention_block_args`.
      v_last_dim:
        When set, determines the value's last dimension in the multi-head
        attention.
        Check `v_last_dim` doc in `utils._build_transformer_encoder_block_args`.
      use_query_residual:
        Toggle to execute residual connection after attention.
      output_w_init:
        Ouptut layer kernel initializer.
      num_heads:
        Number of attention heads for the `TransformerEncoderBlock`.
      name:
        Sets the `tf_keras.layers.Layer` name.
      **kwargs:
        Any keyword arguments to pass through to `tf_keras.layers.Layer`.
    """
    super().__init__(name=name, **kwargs)

    self._output_last_dim = output_last_dim
    self._output_w_init = output_w_init
    self._use_query_residual = use_query_residual
    self._qk_last_dim = qk_last_dim
    self._v_last_dim = v_last_dim
    self._final_project = False  # Make variable if needed
    self._num_heads = num_heads

    # Omitted `concat_preprocessed_input` for MLM use-case.

  def build(self, input_shape):
    """Build layers using `input_shape`.

    Args:
      input_shape:
        Input shape(s) of the layer call.
    """
    decoder_query_shape = input_shape[0]
    z_shape = input_shape[1]
    self._decoding_cross_attn = layers.TransformerEncoderBlock(
        **utils.build_cross_attention_block_args(
            (decoder_query_shape, z_shape),
            widening_factor=1,
            dropout_prob=0.0,
            num_heads=self._num_heads,
            shape_for_attn="kv",
            qk_last_dim=self._qk_last_dim,
            v_last_dim=self._v_last_dim,
            use_query_residual=self._use_query_residual))

  def call(self, inputs, training=None, query_mask=None):
    """Return decoded output of latent vector via the query.

    Args:
      inputs:
        Expect inputs to be a tuple of perceiver's decoder query tensor and
        latent tensor (z). For the cross attention block, `z` is the key-value
        tensor and decoder query is the query tensor.
        Latent tensor comes from the self-attention processing blocks and
        decoder query comes from users to query for the desired output.
      training:
        Flag to indicate training status.
      query_mask:
        mask used to create the attention mask for the query tensor in the
        cross attention block.

    Returns:
      `tf.Tensor` decoded output of latent vector via the query.
    """
    if not isinstance(inputs, collections.abc.Sequence):
      raise ValueError("`inputs` must be a sequence.")
    if len(inputs) != 2:
      raise ValueError("`inputs` must have two elements.")

    query, z = inputs
    # Cross-attention decoding.
    # key, value: B x N x K; query: B x M x K
    # Attention maps -> B x N x M
    # Output -> B x M x K
    # Construct cross attention and linear layer lazily, in case we don't need
    # them.
    if query_mask is None:
      attention_mask = None
    else:
      attention_mask = utils.make_cross_attention_mask(
          query_mask=query_mask,
          kv_mask=tf.ones(tf.shape(z)[:2], dtype=tf.int32))

    output = self._decoding_cross_attn(
        (query, z, attention_mask),
        training=training)

    return output