official/recommendation/ranking/data/data_pipeline.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data pipeline for the Ranking model.
This module defines various input datasets for the Ranking model.
"""
from typing import List
import tensorflow as tf, tf_keras
from official.recommendation.ranking.configs import config
class CriteoTsvReader:
"""Input reader callable for pre-processed Criteo data.
Raw Criteo data is assumed to be preprocessed in the following way:
1. Missing values are replaced with zeros.
2. Negative values are replaced with zeros.
3. Integer features are transformed by log(x+1) and are hence tf.float32.
4. Categorical data is bucketized and are hence tf.int32.
"""
def __init__(self,
file_pattern: str,
params: config.DataConfig,
num_dense_features: int,
vocab_sizes: List[int],
use_synthetic_data: bool = False,
use_cached_data: bool = False):
self._file_pattern = file_pattern
self._params = params
self._num_dense_features = num_dense_features
self._vocab_sizes = vocab_sizes
self._use_synthetic_data = use_synthetic_data
self._use_cached_data = use_cached_data
def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
params = self._params
# Per replica batch size.
batch_size = ctx.get_per_replica_batch_size(
params.global_batch_size) if ctx else params.global_batch_size
if self._use_synthetic_data:
return self._generate_synthetic_data(ctx, batch_size)
@tf.function
def _parse_fn(example: tf.Tensor):
"""Parser function for pre-processed Criteo TSV records."""
label_defaults = [[0.0]]
dense_defaults = [
[0.0] for _ in range(self._num_dense_features)
]
num_sparse_features = len(self._vocab_sizes)
categorical_defaults = [
[0] for _ in range(num_sparse_features)
]
record_defaults = label_defaults + dense_defaults + categorical_defaults
fields = tf.io.decode_csv(
example, record_defaults, field_delim='\t', na_value='-1')
num_labels = 1
label = tf.reshape(fields[0], [batch_size, 1])
features = {}
num_dense = len(dense_defaults)
dense_features = []
offset = num_labels
for idx in range(num_dense):
dense_features.append(fields[idx + offset])
features['dense_features'] = tf.stack(dense_features, axis=1)
offset += num_dense
features['sparse_features'] = {}
for idx in range(num_sparse_features):
features['sparse_features'][str(idx)] = fields[idx + offset]
return features, label
filenames = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)
# Shard the full dataset according to host number.
# Each host will get 1 / num_of_hosts portion of the data.
if params.sharding and ctx and ctx.num_input_pipelines > 1:
filenames = filenames.shard(ctx.num_input_pipelines,
ctx.input_pipeline_id)
num_shards_per_host = 1
if params.sharding:
num_shards_per_host = params.num_shards_per_host
def make_dataset(shard_index):
filenames_for_shard = filenames.shard(num_shards_per_host, shard_index)
dataset = tf.data.TextLineDataset(filenames_for_shard)
if params.is_training:
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(_parse_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
indices = tf.data.Dataset.range(num_shards_per_host)
dataset = indices.interleave(
map_func=make_dataset,
cycle_length=params.cycle_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
if self._use_cached_data:
dataset = dataset.take(1).cache().repeat()
return dataset
def _generate_synthetic_data(self, ctx: tf.distribute.InputContext,
batch_size: int) -> tf.data.Dataset:
"""Creates synthetic data based on the parameter batch size.
Args:
ctx: Input Context
batch_size: per replica batch size.
Returns:
The synthetic dataset.
"""
params = self._params
num_dense = self._num_dense_features
num_replicas = ctx.num_replicas_in_sync if ctx else 1
if params.is_training:
dataset_size = 1000 * batch_size * num_replicas
else:
dataset_size = 1000 * batch_size * num_replicas
dense_tensor = tf.random.uniform(
shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32)
sparse_tensors = []
for size in self._vocab_sizes:
sparse_tensors.append(
tf.random.uniform(
shape=(dataset_size,), maxval=int(size), dtype=tf.int32))
sparse_tensor_elements = {
str(i): sparse_tensors[i] for i in range(len(sparse_tensors))
}
# the mean is in [0, 1] interval.
dense_tensor_mean = tf.math.reduce_mean(dense_tensor, axis=1)
sparse_tensors = tf.stack(sparse_tensors, axis=-1)
sparse_tensors_mean = tf.math.reduce_sum(sparse_tensors, axis=1)
# the mean is in [0, 1] interval.
sparse_tensors_mean = tf.cast(sparse_tensors_mean, dtype=tf.float32)
sparse_tensors_mean /= sum(self._vocab_sizes)
# the label is in [0, 1] interval.
label_tensor = (dense_tensor_mean + sparse_tensors_mean) / 2.0
# Using the threshold 0.5 to convert to 0/1 labels.
label_tensor = tf.cast(label_tensor + 0.5, tf.int32)
input_elem = {'dense_features': dense_tensor,
'sparse_features': sparse_tensor_elements}, label_tensor
dataset = tf.data.Dataset.from_tensor_slices(input_elem)
dataset = dataset.cache()
if params.is_training:
dataset = dataset.repeat()
if self._use_cached_data:
dataset = dataset.take(1).cache().repeat()
return dataset.batch(batch_size, drop_remainder=True)
def train_input_fn(params: config.Task) -> CriteoTsvReader:
"""Returns callable object of batched training examples.
Args:
params: hyperparams to create input pipelines.
Returns:
CriteoTsvReader callable for training dataset.
"""
return CriteoTsvReader(
file_pattern=params.train_data.input_path,
params=params.train_data,
vocab_sizes=params.model.vocab_sizes,
num_dense_features=params.model.num_dense_features,
use_synthetic_data=params.use_synthetic_data)
def eval_input_fn(params: config.Task) -> CriteoTsvReader:
"""Returns callable object of batched eval examples.
Args:
params: hyperparams to create input pipelines.
Returns:
CriteoTsvReader callable for eval dataset.
"""
return CriteoTsvReader(
file_pattern=params.validation_data.input_path,
params=params.validation_data,
vocab_sizes=params.model.vocab_sizes,
num_dense_features=params.model.num_dense_features,
use_synthetic_data=params.use_synthetic_data)