tensorflow/models

View on GitHub
official/recommendation/ranking/common.py

Summary

Maintainability
A
45 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flags and common definitions for Ranking Models."""

from absl import flags
import tensorflow as tf, tf_keras

from official.common import flags as tfm_flags

FLAGS = flags.FLAGS


def define_flags() -> None:
  """Defines flags for training the Ranking model."""
  tfm_flags.define_flags()

  FLAGS.set_default(name='experiment', value='dlrm_criteo')
  FLAGS.set_default(name='mode', value='train_and_eval')

  flags.DEFINE_integer(
      name='seed',
      default=None,
      help='This value will be used to seed both NumPy and TensorFlow.')
  flags.DEFINE_string(
      name='profile_steps',
      default='20,40',
      help='Save profiling data to model dir at given range of global steps. '
      'The value must be a comma separated pair of positive integers, '
      'specifying the first and last step to profile. For example, '
      '"--profile_steps=2,4" triggers the profiler to process 3 steps, starting'
      ' from the 2nd step. Note that profiler has a non-trivial performance '
      'overhead, and the output file can be gigantic if profiling many steps.')


@tf_keras.utils.register_keras_serializable(package='RANKING')
class WarmUpAndPolyDecay(tf_keras.optimizers.schedules.LearningRateSchedule):
  """Learning rate callable for the embeddings.

  Linear warmup on [0, warmup_steps] then
  Constant on [warmup_steps, decay_start_steps]
  And polynomial decay on [decay_start_steps, decay_start_steps + decay_steps].
  """

  def __init__(self,
               batch_size: int,
               decay_exp: float = 2.0,
               learning_rate: float = 40.0,
               warmup_steps: int = 8000,
               decay_steps: int = 12000,
               decay_start_steps: int = 10000):
    super(WarmUpAndPolyDecay, self).__init__()
    self.batch_size = batch_size
    self.decay_exp = decay_exp
    self.learning_rate = learning_rate
    self.warmup_steps = warmup_steps
    self.decay_steps = decay_steps
    self.decay_start_steps = decay_start_steps

  def __call__(self, step):
    decay_exp = self.decay_exp
    learning_rate = self.learning_rate
    warmup_steps = self.warmup_steps
    decay_steps = self.decay_steps
    decay_start_steps = self.decay_start_steps

    scal = self.batch_size / 2048

    adj_lr = learning_rate * scal
    if warmup_steps == 0:
      return adj_lr

    warmup_lr = step / warmup_steps * adj_lr
    global_step = tf.cast(step, tf.float32)
    decay_steps = tf.cast(decay_steps, tf.float32)
    decay_start_step = tf.cast(decay_start_steps, tf.float32)
    warmup_lr = tf.cast(warmup_lr, tf.float32)

    steps_since_decay_start = global_step - decay_start_step
    already_decayed_steps = tf.minimum(steps_since_decay_start, decay_steps)
    decay_lr = adj_lr * (
        (decay_steps - already_decayed_steps) / decay_steps)**decay_exp
    decay_lr = tf.maximum(0.0001, decay_lr)

    lr = tf.where(
        global_step < warmup_steps, warmup_lr,
        tf.where(
            tf.logical_and(decay_steps > 0, global_step > decay_start_step),
            decay_lr, adj_lr))

    lr = tf.maximum(0.01, lr)
    return lr

  def get_config(self):
    return {
        'batch_size': self.batch_size,
        'decay_exp': self.decay_exp,
        'learning_rate': self.learning_rate,
        'warmup_steps': self.warmup_steps,
        'decay_steps': self.decay_steps,
        'decay_start_steps': self.decay_start_steps
    }