tensorflow/models

View on GitHub
official/recommendation/ranking/data/data_pipeline_multi_hot.py

Summary

Maintainability
C
7 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data pipeline for the Ranking model.

This module defines various input datasets for the Ranking model.
"""

from typing import List
import tensorflow as tf, tf_keras

from official.recommendation.ranking.configs import config


class CriteoTsvReaderMultiHot:
  """Input reader callable for pre-processed Multi Hot Criteo data.

  Raw Criteo data is assumed to be preprocessed in the following way:
  1. Missing values are replaced with zeros.
  2. Negative values are replaced with zeros.
  3. Integer features are transformed by log(x+1) and are hence tf.float32.
  4. Categorical data is bucketized and are hence tf.int32.
  
  Implements a TsvReaderMultiHot for reading data from a criteo dataset and 
  generate multi hot synthetic data using the provided vocab_sizes and
  multi_hot_sizes, also includes a complete synthetic data generator as well as 
  a TFRecordReader to read data from pre materialized multi hot synthetic 
  dataset that converted to TFRecords
  """

  def __init__(self,
               file_pattern: str,
               params: config.DataConfig,
               num_dense_features: int,
               vocab_sizes: List[int],
               multi_hot_sizes: List[int],
               use_synthetic_data: bool = False):
    self._file_pattern = file_pattern
    self._params = params
    self._num_dense_features = num_dense_features
    self._vocab_sizes = vocab_sizes
    self._use_synthetic_data = use_synthetic_data
    self._multi_hot_sizes = multi_hot_sizes

  def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
    params = self._params
    # Per replica batch size.
    batch_size = ctx.get_per_replica_batch_size(
        params.global_batch_size) if ctx else params.global_batch_size
    if self._use_synthetic_data:
      return self._generate_synthetic_data(ctx, batch_size)

    @tf.function
    def _parse_fn(example: tf.Tensor):
      """Parser function for pre-processed Criteo TSV records."""
      label_defaults = [[0.0]]
      dense_defaults = [
          [0.0] for _ in range(self._num_dense_features)
      ]
      num_sparse_features = len(self._vocab_sizes)
      categorical_defaults = [
          [0] for _ in range(num_sparse_features)
      ]
      record_defaults = label_defaults + dense_defaults + categorical_defaults
      fields = tf.io.decode_csv(
          example, record_defaults, field_delim='\t', na_value='-1')

      num_labels = 1
      label = tf.reshape(fields[0], [batch_size, 1])

      features = {}
      num_dense = len(dense_defaults)

      dense_features = []
      offset = num_labels
      for idx in range(num_dense):
        dense_features.append(fields[idx + offset])
      features['dense_features'] = tf.stack(dense_features, axis=1)

      offset += num_dense
      features['sparse_features'] = {}

      sparse_tensors = []
      for idx, (vocab_size, multi_hot_size) in enumerate(
          zip(self._vocab_sizes, self._multi_hot_sizes)
      ):
        sparse_tensor = tf.reshape(fields[idx + offset], [batch_size, 1])
        sparse_tensor_synthetic = tf.random.uniform(
            shape=(batch_size, multi_hot_size - 1),
            maxval=int(vocab_size),
            dtype=tf.int32,
        )
        sparse_tensors.append(
            tf.sparse.from_dense(
                tf.concat([sparse_tensor, sparse_tensor_synthetic], axis=1)
            )
        )

      sparse_tensor_elements = {
          str(i): sparse_tensors[i] for i in range(len(sparse_tensors))
      }

      features['sparse_features'] = sparse_tensor_elements

      return features, label

    filenames = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)

    # Shard the full dataset according to host number.
    # Each host will get 1 / num_of_hosts portion of the data.
    if params.sharding and ctx and ctx.num_input_pipelines > 1:
      filenames = filenames.shard(ctx.num_input_pipelines,
                                  ctx.input_pipeline_id)

    num_shards_per_host = 1
    if params.sharding:
      num_shards_per_host = params.num_shards_per_host

    def make_dataset(shard_index):
      filenames_for_shard = filenames.shard(num_shards_per_host, shard_index)
      dataset = tf.data.TextLineDataset(filenames_for_shard)
      if params.is_training:
        dataset = dataset.repeat()
      dataset = dataset.batch(batch_size, drop_remainder=True)
      dataset = dataset.map(_parse_fn,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

    indices = tf.data.Dataset.range(num_shards_per_host)
    dataset = indices.interleave(
        map_func=make_dataset,
        cycle_length=params.cycle_length,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    if self._params.use_cached_data:
      dataset = dataset.take(1).cache().repeat()

    return dataset

  def _generate_synthetic_data(self, ctx: tf.distribute.InputContext,
                               batch_size: int) -> tf.data.Dataset:
    """Creates synthetic data based on the parameter batch size.

    Args:
      ctx: Input Context
      batch_size: per replica batch size.

    Returns:
      The synthetic dataset.
    """
    params = self._params
    num_dense = self._num_dense_features
    num_replicas = ctx.num_replicas_in_sync if ctx else 1

    if params.is_training:
      dataset_size = 50 * batch_size * num_replicas
    else:
      dataset_size = 50 * batch_size * num_replicas
    dense_tensor = tf.random.uniform(
        shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32
    )

    sparse_tensors = []
    for vocab_size, multi_hot_size in zip(
        self._vocab_sizes, self._multi_hot_sizes
    ):
      sparse_tensors.append(
          tf.sparse.from_dense(
              tf.random.uniform(
                  shape=(dataset_size, multi_hot_size),
                  maxval=int(vocab_size),
                  dtype=tf.int32,
              )
          )
      )

    sparse_tensor_elements = {
        str(i): sparse_tensors[i] for i in range(len(sparse_tensors))
    }

    # the mean is in [0, 1] interval.
    dense_tensor_mean = tf.math.reduce_mean(dense_tensor, axis=1)

    # the label is in [0, 1] interval.
    label_tensor = (dense_tensor_mean)
    # Using the threshold 0.5 to convert to 0/1 labels.
    label_tensor = tf.cast(label_tensor + 0.5, tf.int32)

    input_elem = {'dense_features': dense_tensor,
                  'sparse_features': sparse_tensor_elements}, label_tensor

    dataset = tf.data.Dataset.from_tensor_slices(input_elem)
    dataset = dataset.cache()
    if params.is_training:
      dataset = dataset.repeat()

    return dataset.batch(batch_size, drop_remainder=True)


class CriteoTFRecordReader(object):
  """Input reader fn for TFRecords that have been serialized in batched form."""

  def __init__(self,
               file_pattern: str,
               params: config.DataConfig,
               num_dense_features: int,
               vocab_sizes: List[int],
               multi_hot_sizes: List[int],
               use_cached_data: bool = False):
    self._file_pattern = file_pattern
    self._params = params
    self._num_dense_features = num_dense_features
    self._vocab_sizes = vocab_sizes
    self._multi_hot_sizes = multi_hot_sizes
    self._use_cached_data = use_cached_data

    self.label_features = 'label'
    self.dense_features = ['dense-feature-%d' % x for x in range(1, 14)]
    self.sparse_features = ['sparse-feature-%d' % x for x in range(14, 40)]

  def __call__(self, ctx: tf.distribute.InputContext):
    params = self._params
    # Per replica batch size.
    batch_size = (
        ctx.get_per_replica_batch_size(params.global_batch_size)
        if ctx
        else params.global_batch_size
    )

    def _get_feature_spec():
      feature_spec = {}
      feature_spec[self.label_features] = tf.io.FixedLenFeature(
          [], dtype=tf.int64
      )
      for dense_feat in self.dense_features:
        feature_spec[dense_feat] = tf.io.FixedLenFeature(
            [],
            dtype=tf.float32,
        )
      for i, sparse_feat in enumerate(self.sparse_features):
        feature_spec[sparse_feat] = tf.io.FixedLenFeature(
            [self._multi_hot_sizes[i]], dtype=tf.int64
        )
      return feature_spec

    def _parse_fn(serialized_example):
      feature_spec = _get_feature_spec()
      parsed_features = tf.io.parse_single_example(
          serialized_example, feature_spec
      )
      label = parsed_features[self.label_features]
      features = {}
      int_features = []
      for dense_ft in self.dense_features:
        int_features.append(parsed_features[dense_ft])
      features['dense_features'] = tf.stack(int_features)

      features['sparse_features'] = {}
      for i, sparse_ft in enumerate(self.sparse_features):
        features['sparse_features'][str(i)] = tf.sparse.from_dense(
            parsed_features[sparse_ft]
        )

      return features, label

    filenames = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)
    # Shard the full dataset according to host number.
    # Each host will get 1 / num_of_hosts portion of the data.
    if params.sharding and ctx and ctx.num_input_pipelines > 1:
      filenames = filenames.shard(ctx.num_input_pipelines,
                                  ctx.input_pipeline_id)

    num_shards_per_host = 1
    if params.sharding:
      num_shards_per_host = params.num_shards_per_host

    def make_dataset(shard_index):
      filenames_for_shard = filenames.shard(num_shards_per_host, shard_index)
      dataset = tf.data.TFRecordDataset(
          filenames_for_shard
      )
      if params.is_training:
        dataset = dataset.repeat()
      dataset = dataset.map(
          _parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
      )
      return dataset

    indices = tf.data.Dataset.range(num_shards_per_host)
    dataset = indices.interleave(
        map_func=make_dataset,
        cycle_length=params.cycle_length,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )

    dataset = dataset.batch(
        batch_size,
        drop_remainder=True,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    if self._use_cached_data:
      dataset = dataset.take(1).cache().repeat()

    return dataset


def train_input_fn(params: config.Task) -> CriteoTsvReaderMultiHot:
  """Returns callable object of batched training examples.

  Args:
    params: hyperparams to create input pipelines.

  Returns:
    CriteoTsvReader callable for training dataset.
  """
  return CriteoTsvReaderMultiHot(
      file_pattern=params.train_data.input_path,
      params=params.train_data,
      vocab_sizes=params.model.vocab_sizes,
      num_dense_features=params.model.num_dense_features,
      multi_hot_sizes=params.model.multi_hot_sizes,
      use_synthetic_data=params.use_synthetic_data)


def eval_input_fn(params: config.Task) -> CriteoTsvReaderMultiHot:
  """Returns callable object of batched eval examples.

  Args:
    params: hyperparams to create input pipelines.

  Returns:
    CriteoTsvReader callable for eval dataset.
  """

  return CriteoTsvReaderMultiHot(
      file_pattern=params.validation_data.input_path,
      params=params.validation_data,
      vocab_sizes=params.model.vocab_sizes,
      num_dense_features=params.model.num_dense_features,
      multi_hot_sizes=params.model.multi_hot_sizes,
      use_synthetic_data=params.use_synthetic_data)