tensorflow/models

View on GitHub
official/core/registry.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Registry utility."""


def register(registered_collection, reg_key):
  """Register decorated function or class to collection.

  Register decorated function or class into registered_collection, in a
  hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
  the decorated function or class is stored under
  registered_collection["my_model"]["my_exp"]["my_config_0"].
  This decorator is supposed to be used together with the lookup() function in
  this file.

  Args:
    registered_collection: a dictionary. The decorated function or class will be
      put into this collection.
    reg_key: The key for retrieving the registered function or class. If reg_key
      is a string, it can be hierarchical like my_model/my_exp/my_config_0
  Returns:
    A decorator function
  Raises:
    KeyError: when function or class to register already exists.
  """
  def decorator(fn_or_cls):
    """Put fn_or_cls in the dictionary."""
    if isinstance(reg_key, str):
      hierarchy = reg_key.split("/")
      collection = registered_collection
      for h_idx, entry_name in enumerate(hierarchy[:-1]):
        if entry_name not in collection:
          collection[entry_name] = {}
        collection = collection[entry_name]
        if not isinstance(collection, dict):
          raise KeyError(
              "Collection path {} at position {} already registered as "
              "a function or class.".format(entry_name, h_idx))
      leaf_reg_key = hierarchy[-1]
    else:
      collection = registered_collection
      leaf_reg_key = reg_key

    if leaf_reg_key in collection:
      raise KeyError("Function or class {} registered multiple times.".format(
          leaf_reg_key))

    collection[leaf_reg_key] = fn_or_cls
    return fn_or_cls
  return decorator


def lookup(registered_collection, reg_key):
  """Lookup and return decorated function or class in the collection.

  Lookup decorated function or class in registered_collection, in a
  hierarchical order. For example, when
  reg_key="my_model/my_exp/my_config_0",
  this function will return
  registered_collection["my_model"]["my_exp"]["my_config_0"].

  Args:
    registered_collection: a dictionary. The decorated function or class will be
      retrieved from this collection.
    reg_key: The key for retrieving the registered function or class. If reg_key
      is a string, it can be hierarchical like my_model/my_exp/my_config_0
  Returns:
    The registered function or class.
  Raises:
    LookupError: when reg_key cannot be found.
  """
  if isinstance(reg_key, str):
    hierarchy = reg_key.split("/")
    collection = registered_collection
    for h_idx, entry_name in enumerate(hierarchy):
      if entry_name not in collection:
        raise LookupError(
            f"collection path {entry_name} at position {h_idx} is never "
            f"registered. Please make sure the {entry_name} and its library is "
            "imported and linked to the trainer binary.")
      collection = collection[entry_name]
    return collection
  else:
    if reg_key not in registered_collection:
      raise LookupError(
          f"registration key {reg_key} is never "
          f"registered. Please make sure the {reg_key} and its library is "
          "imported and linked to the trainer binary.")
    return registered_collection[reg_key]