tensorflow/models

View on GitHub
official/recommendation/create_ncf_data.py

Summary

Maintainability
A
55 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Binary to generate training/evaluation dataset for NCF model."""

import json

# pylint: disable=g-bad-import-order
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf, tf_keras
# pylint: enable=g-bad-import-order

from official.recommendation import movielens
from official.recommendation import data_preprocessing

flags.DEFINE_string(
    "data_dir", None,
    "The input data dir at which training and evaluation tf record files "
    "will be saved.")
flags.DEFINE_string("meta_data_file_path", None,
                    "The path in which input meta data will be written.")
flags.DEFINE_enum("dataset", "ml-20m", ["ml-1m", "ml-20m"],
                  "Dataset to be trained/evaluated.")
flags.DEFINE_enum(
    "constructor_type", "bisection", ["bisection", "materialized"],
    "Strategy to use for generating false negatives. materialized has a "
    "precompute that scales badly, but a faster per-epoch construction "
    "time and can be faster on very large systems.")
flags.DEFINE_integer("num_train_epochs", 14,
                     "Total number of training epochs to generate.")
flags.DEFINE_integer(
    "num_negative_samples", 4,
    "Number of negative instances to pair with positive instance.")
flags.DEFINE_integer(
    "train_prebatch_size", 99000,
    "Batch size to be used for prebatching the dataset "
    "for training.")
flags.DEFINE_integer(
    "eval_prebatch_size", 99000,
    "Batch size to be used for prebatching the dataset "
    "for training.")

FLAGS = flags.FLAGS


def prepare_raw_data(flag_obj):
  """Downloads and prepares raw data for data generation."""
  movielens.download(flag_obj.dataset, flag_obj.data_dir)

  data_processing_params = {
      "train_epochs": flag_obj.num_train_epochs,
      "batch_size": flag_obj.train_prebatch_size,
      "eval_batch_size": flag_obj.eval_prebatch_size,
      "batches_per_step": 1,
      "stream_files": True,
      "num_neg": flag_obj.num_negative_samples,
  }

  num_users, num_items, producer = data_preprocessing.instantiate_pipeline(
      dataset=flag_obj.dataset,
      data_dir=flag_obj.data_dir,
      params=data_processing_params,
      constructor_type=flag_obj.constructor_type,
      epoch_dir=flag_obj.data_dir,
      generate_data_offline=True)

  # pylint: disable=protected-access
  input_metadata = {
      "num_users": num_users,
      "num_items": num_items,
      "constructor_type": flag_obj.constructor_type,
      "num_train_elements": producer._elements_in_epoch,
      "num_eval_elements": producer._eval_elements_in_epoch,
      "num_train_epochs": flag_obj.num_train_epochs,
      "train_prebatch_size": flag_obj.train_prebatch_size,
      "eval_prebatch_size": flag_obj.eval_prebatch_size,
      "num_train_steps": producer.train_batches_per_epoch,
      "num_eval_steps": producer.eval_batches_per_epoch,
  }
  # pylint: enable=protected-access

  return producer, input_metadata


def generate_data():
  """Creates NCF train/eval dataset and writes input metadata as a file."""
  producer, input_metadata = prepare_raw_data(FLAGS)
  producer.run()

  with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
    writer.write(json.dumps(input_metadata, indent=4) + "\n")


def main(_):
  generate_data()


if __name__ == "__main__":
  flags.mark_flag_as_required("data_dir")
  flags.mark_flag_as_required("meta_data_file_path")
  app.run(main)