official/projects/lra/exponential_moving_average.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based MegaEncoder block layer."""
from typing import Optional
import tensorflow as tf, tf_keras
class MultiHeadEMA(tf_keras.layers.Layer):
"""Exponential Moving Average Layer.
See "https://arxiv.org/abs/2209.10655" for more details.
"""
def __init__(
self, embed_dim, ndim=2, bidirectional=False, truncation=None, **kwargs
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.ndim = ndim
self.bidirectional = bidirectional
self.truncation = truncation
self.scale = tf.math.sqrt(1.0 / self.ndim)
self.kernel_dim = 2 * embed_dim if self.bidirectional else embed_dim
self._kernel = None
self._coeffs = None
def build(self, input_shape):
self.damping_factor = self.add_weight(
shape=(self.kernel_dim, self.ndim, 1),
initializer="random_normal",
trainable=True,
name="damping_factor",
dtype=tf.float32,
)
self.decay_factor = self.add_weight(
shape=(self.kernel_dim, self.ndim, 1),
initializer="random_normal",
trainable=True,
name="decay_factor",
dtype=tf.float32,
)
self.ema_expansion_matrix = self.add_weight(
shape=(self.kernel_dim, self.ndim, 1),
initializer="random_normal",
trainable=True,
name="ema_expansion_matrix",
dtype=tf.float32,
)
self.kernel_projection_matrix = self.add_weight(
shape=(self.kernel_dim, self.ndim),
initializer="random_normal",
trainable=True,
name="kernel_projection_matrix",
dtype=tf.float32,
)
self.residual_weight = self.add_weight(
shape=(self.embed_dim,),
initializer="ones",
trainable=True,
name="residual_weight",
dtype=tf.float32,
)
super().build(input_shape)
def _calc_coeffs(self):
self._coeffs = None
# D x N x 1
damping_factor = tf.math.sigmoid(self.damping_factor)
decay_factor = tf.math.sigmoid(self.decay_factor)
previous_timestep_weight = 1.0 - damping_factor * decay_factor
return damping_factor, previous_timestep_weight
def _compute_kernel(self, length: int):
self._kernel = None
# D x N x 1
damping_factor, previous_timestep_weight = self._calc_coeffs()
# D x N x L
vander = tf.cast(
tf.reshape(tf.range(length), shape=(1, 1, length)),
dtype=damping_factor.dtype,
) * tf.math.log(previous_timestep_weight)
kernel = (damping_factor * self.ema_expansion_matrix) * tf.math.exp(vander)
# D x L
return tf.einsum(
"dnl,dn->dl", kernel, self.kernel_projection_matrix * self.scale
)
def coeffs(self):
if self.training:
return self._calc_coeffs()
else:
if self._coeffs is None:
self._coeffs = self._calc_coeffs()
return self._coeffs
def kernel(self, length: int):
assert self.truncation is None, "WEIRD!"
kernel_size = (
length if self.truncation is None else min(self.truncation, length)
)
return self._compute_kernel(kernel_size)
def call(self, x, padding_mask: Optional[tf.Tensor] = None) -> tf.Tensor:
"""Input shape: Time x Batch x Channel.
Args:
x: Tensor input.
padding_mask (ByteTensor, optional): mask to exclude keys that are pads,
of shape `(batch, src_len)`, where padding elements are indicated by
1s.
Returns:
transformed: transformed Tensor.
"""
seq_len, _, embed_dim = x.shape
assert embed_dim == self.embed_dim
if seq_len is None:
seq_len = 1
# L x B x D
residual = x * self.residual_weight
# L x B x D -> B x D x L
x = tf.transpose(x, perm=(1, 2, 0))
# Masking of the tensor
if padding_mask is not None:
x = x * tf.cast(tf.expand_dims(padding_mask, axis=1), x.dtype)
k = self.kernel(seq_len)
kernel_size = k.shape[1]
fft_len = seq_len
s = 0
if self.bidirectional:
k1, k2 = tf.split(k, [self.embed_dim, self.embed_dim], axis=0)
# D x 2*L-1
padding_l = tf.constant([[0, 0], [kernel_size - 1, 0]])
padding_r = tf.constant([[0, 0], [0, kernel_size - 1]])
padding_x = tf.constant([[0, 0], [0, 0], [kernel_size - 1, 0]])
k = tf.pad(k1, padding_l) + tf.pad(tf.reverse(k2, axis=[-1]), padding_r)
x = tf.pad(x, padding_x)
fft_len = fft_len + kernel_size - 1
s = 2 * kernel_size - 2
k_f = tf.signal.rfft(
k, fft_length=tf.constant([2 * fft_len], dtype=tf.int32)
)
x_f = tf.signal.rfft(
x, fft_length=tf.constant([2 * fft_len], dtype=tf.int32)
)
# B x D x L
out = tf.signal.irfft(
x_f * k_f, fft_length=tf.constant([2 * fft_len], dtype=tf.int32)
)[..., s : s + seq_len]
# B x D x L -> L x B x D
out = tf.nn.silu(tf.transpose(out, perm=(2, 0, 1)) + residual)
return out