tensorflow/tensorflow

View on GitHub
tensorflow/tools/proto_splitter/split.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic interface for Python-based Splitter."""

import abc
from collections.abc import Sequence
import time
from typing import Optional, Union

from absl import logging
import riegeli

from google.protobuf import message
from tensorflow.python.lib.io import file_io
from tensorflow.tools.proto_splitter import chunk_pb2
from tensorflow.tools.proto_splitter import util
from tensorflow.tools.proto_splitter import version as version_lib
from tensorflow.tools.proto_splitter import versions_pb2


class Splitter(abc.ABC):
  """An abstract class for splitting and writing protos that are > 2GB.

  See the README on how to use or subclass this class.
  """

  @property
  @abc.abstractmethod
  def version_def(self) -> versions_pb2.VersionDef:
    """Version info about the splitter and merge implementation required."""

  @abc.abstractmethod
  def split(
      self,
  ) -> tuple[Sequence[Union[message.Message, bytes]], chunk_pb2.ChunkedMessage]:
    """Splits proto message into a Sequence of protos/bytes."""

  @abc.abstractmethod
  def write(self, file_prefix: str) -> str:
    """Serializes proto to disk.

    Args:
      file_prefix: string prefix of the filepath.

    Returns:
      The actual path the proto is written to.
    """


class ComposableSplitter(Splitter):
  """A Splitter that can be composed with other splitters.

  This Splitter writes to the riegeli file format.

  See README for details.
  """

  def __init__(
      self,
      proto,
      *,
      proto_as_initial_chunk: bool = True,
      parent_splitter: Optional["ComposableSplitter"] = None,
      fields_in_parent: Optional[util.FieldTypes] = None,
  ):
    """Initializes ComposableSplitter.

    Args:
      proto: Proto message to split.
      proto_as_initial_chunk: Whether to initialize chunks with the
        user-provided proto as the initial chunk.
      parent_splitter: The parent `ComposableSplitter` object.
      fields_in_parent: Fields to access `proto` from the parent splitter's
        proto.
    """
    self._proto = proto
    self._parent_splitter = parent_splitter
    self._fields_in_parent = fields_in_parent

    # Whether chunks have been created. See `build_chunks()`.
    self._built = False

    # Keep a list of chunk ids in the order in which they were added to the
    # list.
    self._add_chunk_order = []
    self._fix_chunk_order = False

    # Initialize chunks and ChunkedMessage (optionally with the first chunk as
    # the user-provided proto.
    if parent_splitter is not None:
      # If this is not the root Splitter class, skip the initialization of
      # the chunks/message since the parent's will be updated instead.
      self._chunks = None
      self._chunked_message = None
    elif proto_as_initial_chunk:
      self._chunks = [self._proto]
      self._chunked_message = chunk_pb2.ChunkedMessage(chunk_index=0)
      self._add_chunk_order.append(id(self._proto))
    else:
      self._chunks = []
      self._chunked_message = chunk_pb2.ChunkedMessage()

  def build_chunks(self) -> None:
    """Builds the Splitter object by generating chunks from the proto.

    Subclasses of `ComposableChunks` should only need to override this method.

    This method should be called once per Splitter to create the chunks.
    Users should call the methods `split` or `write` instead.
    """

  @property
  def version_def(self) -> versions_pb2.VersionDef:
    """Version info about the splitter and join implementation required."""
    return versions_pb2.VersionDef(
        splitter_version=1,
        join_version=0,
        bad_consumers=version_lib.get_bad_versions(),
    )

  def split(
      self,
  ) -> tuple[Sequence[Union[message.Message, bytes]], chunk_pb2.ChunkedMessage]:
    """Splits a proto message into a Sequence of protos/bytes."""
    if self._parent_splitter:
      raise ValueError(
          "A child ComposableSplitter's `split` method should not be called "
          "directly, since it inherit chunks from a parent object. Please call "
          "the parent's `split()` method instead."
      )

    assert self._chunks is not None
    assert self._chunked_message is not None

    if not self._built:
      self.build_chunks()
      self._fix_chunks()
      self._built = True
    return self._chunks, self._chunked_message

  def write(self, file_prefix: str) -> str:
    """Serializes a proto to disk.

    The writer writes all chunks into a riegeli file. The chunk metadata
    (ChunkMetadata) is written at the very end.

    Args:
      file_prefix: string prefix of the filepath. The writer will automatically
        attach a `.pb` or `.cpb` (chunked pb) suffix depending on whether the
        proto is split.

    Returns:
      The actual filepath the proto is written to. The filepath will be
      different depending on whether the proto is split, i.e., whether it will
      be a pb or not.
    """
    if self._parent_splitter is not None:
      raise ValueError(
          "A child ComposableSplitter's `write` method should not be called "
          "directly, since it inherits unrelated chunks from a parent object. "
          "Please call the parent's `write()` method instead."
      )

    start_time = time.time()
    chunks, chunked_message = self.split()

    if not chunked_message.chunked_fields:
      path = f"{file_prefix}.pb"
      file_io.atomic_write_string_to_file(
          path, self._proto.SerializeToString(deterministic=True)
      )
      logging.info("Unchunked file exported to %s", path)
      return path

    path = f"{file_prefix}.cpb"
    with riegeli.RecordWriter(file_io.FileIO(path, "wb")) as f:
      metadata = chunk_pb2.ChunkMetadata(
          message=chunked_message, version=self.version_def
      )
      for chunk in chunks:
        if isinstance(chunk, message.Message):
          f.write_message(chunk)
          chunk_type = chunk_pb2.ChunkInfo.Type.MESSAGE
          size = chunk.ByteSize()
        else:
          f.write_record(chunk)
          chunk_type = chunk_pb2.ChunkInfo.Type.BYTES
          size = len(chunk)
        metadata.chunks.add(
            type=chunk_type, size=size, offset=f.last_pos.numeric
        )
      f.write_message(metadata)

    end = time.time()

    logging.info("Chunked file exported to %s", path)
    logging.info(
        "Total time spent splitting and writing the message: %s",
        end - start_time,
    )
    logging.info(
        "Number of chunks created (including initial message): %s",
        len(chunks),
    )
    return path

  def add_chunk(
      self,
      chunk: Union[message.Message, bytes],
      field_tags: util.FieldTypes,
      index=None,
  ) -> None:
    """Adds a new chunk and updates the ChunkedMessage proto.

    Args:
      chunk: Proto message or bytes.
      field_tags: Field information about the placement of the chunked data
        within self._proto.
      index: Optional index at which to insert the chunk. The chunk ordering is
        important for merging.
    """
    if self._parent_splitter is not None:
      self._parent_splitter.add_chunk(
          chunk, self._fields_in_parent + field_tags, index
      )
    else:
      assert self._chunks is not None
      assert self._chunked_message is not None
      field = self._chunked_message.chunked_fields.add(
          field_tag=util.get_field_tag(self._proto, field_tags)
      )
      new_chunk_index = len(self._chunks)
      field.message.chunk_index = new_chunk_index
      self._add_chunk_order.append(id(chunk))

      if index is None:
        self._chunks.append(chunk)
      else:
        self._chunks.insert(index, chunk)
        self._fix_chunk_order = True

  def _fix_chunks(self) -> None:
    """Fixes chunk indices in the ChunkedMessage."""
    if not self._fix_chunk_order:
      return

    # The chunk_index of each nested ChunkedMessage is set to the length of the
    # list when the chunk was added. This would be fine if the chunks were
    # always added to the end of the list. However, this is not always the case
    # the indices must be updated.

    # Use the address of each chunk (python `id`) as lookup keys to the
    # ordered chunk indices.
    chunk_indices = {id(chunk): i for i, chunk in enumerate(self._chunks)}

    to_fix = [self._chunked_message]
    while to_fix:
      for field in to_fix.pop().chunked_fields:
        if field.message.chunked_fields:
          to_fix.append(field.message)
        if not field.message.HasField("chunk_index"):
          continue
        chunk_addr = self._add_chunk_order[field.message.chunk_index]
        assert (
            chunk_addr in chunk_indices
        ), f"Found unexpected chunk {chunk_addr}"
        new_chunk_index = chunk_indices[chunk_addr]
        field.message.chunk_index = new_chunk_index

    self._add_chunk_order = [id(chunk) for chunk in self._chunks]
    self._fix_chunk_order = False