tensorflow/models

View on GitHub
official/modeling/hyperparams/base_config.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base configurations to standardize experiments."""

import copy
import dataclasses
import functools
import inspect
import typing
from typing import Any, List, Mapping, Optional, Type, Union

from absl import logging
import tensorflow as tf, tf_keras
import yaml

from official.modeling.hyperparams import params_dict


_BOUND = set()


def bind(config_cls):
  """Bind a class to config cls."""
  if not inspect.isclass(config_cls):
    raise ValueError('The bind decorator is supposed to apply on the class '
                     f'attribute. Received {config_cls}, not a class.')

  def decorator(builder):
    if config_cls in _BOUND:
      raise ValueError('Inside a program, we should not bind the config with a'
                       ' class twice.')
    if inspect.isclass(builder):
      config_cls._BUILDER = builder  # pylint: disable=protected-access
    elif inspect.isfunction(builder):

      def _wrapper(self, *args, **kwargs):  # pylint: disable=unused-argument
        return builder(*args, **kwargs)

      config_cls._BUILDER = _wrapper  # pylint: disable=protected-access
    else:
      raise ValueError(f'The `BUILDER` type is not supported: {builder}')
    _BOUND.add(config_cls)
    return builder

  return decorator


def _is_optional(field):
  return typing.get_origin(field) is Union and type(None) in typing.get_args(
      field)


@dataclasses.dataclass
class Config(params_dict.ParamsDict):
  """The base configuration class that supports YAML/JSON based overrides.

  Because of YAML/JSON serialization limitations, some semantics of dataclass
  are not supported:
  * It recursively enforces a allowlist of basic types and container types, so
    it avoids surprises with copy and reuse caused by unanticipated types.
  * Warning: it converts Dict to `Config` even within sequences,
    e.g. for config = Config({'key': [([{'a': 42}],)]),
         type(config.key[0][0][0]) is Config rather than dict.
    If you define/annotate some field as Dict, the field will convert to a
    `Config` instance and lose the dictionary type.
  """
  # The class or method to bind with the params class.
  _BUILDER = None
  # It's safe to add bytes and other immutable types here.
  IMMUTABLE_TYPES = (str, int, float, bool, type(None))
  # It's safe to add set, frozenset and other collections here.
  SEQUENCE_TYPES = (list, tuple)

  default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
  restrictions: dataclasses.InitVar[Optional[List[str]]] = None

  def __post_init__(self, default_params, restrictions):
    super().__init__(
        default_params=default_params,
        restrictions=restrictions)

  @property
  def BUILDER(self):
    return self._BUILDER

  @classmethod
  def _get_annotations(cls):
    """Returns valid annotations.

    Note: this is similar to dataclasses.__annotations__ except it also includes
      annotations from its parent classes.
    """
    all_annotations = typing.get_type_hints(cls)
    # Removes Config class annotation from the value, e.g., default_params,
    # restrictions, etc.
    for k in Config.__annotations__:
      del all_annotations[k]
    return all_annotations

  @classmethod
  def _isvalidsequence(cls, v):
    """Check if the input values are valid sequences.

    Args:
      v: Input sequence.

    Returns:
      True if the sequence is valid. Valid sequence includes the sequence
      type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
      is dict or ParamsDict.
    """
    if not isinstance(v, cls.SEQUENCE_TYPES):
      return False
    return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
            all(isinstance(e, dict) for e in v) or
            all(isinstance(e, params_dict.ParamsDict) for e in v))

  @classmethod
  def _import_config(cls, v, subconfig_type):
    """Returns v with dicts converted to Configs, recursively."""
    if not issubclass(subconfig_type, params_dict.ParamsDict):
      raise TypeError(
          'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
              subconfig_type))
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      # Only support one layer of sequence.
      if not cls._isvalidsequence(v):
        raise TypeError(
            'Invalid sequence: only supports single level {!r} of {!r} or '
            'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
                                                    cls.IMMUTABLE_TYPES, v))
      import_fn = functools.partial(
          cls._import_config, subconfig_type=subconfig_type)
      return type(v)(map(import_fn, v))
    elif isinstance(v, params_dict.ParamsDict):
      # Deepcopy here is a temporary solution for preserving type in nested
      # Config object.
      return copy.deepcopy(v)
    elif isinstance(v, dict):
      return subconfig_type(v)
    else:
      raise TypeError('Unknown type: {!r}'.format(type(v)))

  @classmethod
  def _export_config(cls, v):
    """Returns v with Configs converted to dicts, recursively."""
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      return type(v)(map(cls._export_config, v))
    elif isinstance(v, params_dict.ParamsDict):
      return v.as_dict()
    elif isinstance(v, dict):
      raise TypeError('dict value not supported in converting.')
    else:
      raise TypeError('Unknown type: {!r}'.format(type(v)))

  @classmethod
  def _get_subconfig_type(
      cls, k, subconfig_type=None
  ) -> Type[params_dict.ParamsDict]:
    """Get element type by the field name.

    Args:
      k: the key/name of the field.
      subconfig_type: default subconfig_type. If None, it is set to
        Config.

    Returns:
      Config as default. If a type annotation is found for `k`,
      1) returns the type of the annotation if it is subtype of ParamsDict;
      2) returns the element type if the annotation of `k` is List[SubType]
         or Tuple[SubType].
    """
    if not subconfig_type:
      subconfig_type = Config

    annotations = cls._get_annotations()
    if k in annotations:
      # Directly Config subtype.
      type_annotation = annotations[k]
      i = 0
      # Loop for striping the Optional annotation.
      traverse_in = True
      while traverse_in:
        i += 1
        if (isinstance(type_annotation, type) and
            issubclass(type_annotation, Config)):
          subconfig_type = type_annotation
          break
        else:
          # Check if the field is a sequence of subtypes.
          field_type = typing.get_origin(type_annotation)
          if (isinstance(field_type, type) and
              issubclass(field_type, cls.SEQUENCE_TYPES)):
            element_type = typing.get_args(type_annotation)[0]
            subconfig_type = (
                element_type if issubclass(element_type, params_dict.ParamsDict)
                else subconfig_type)
            break
          elif _is_optional(type_annotation):
            # Strip the `Optional` annotation and process the subtype.
            type_annotation = typing.get_args(type_annotation)[0]
            continue
        traverse_in = False
    return subconfig_type

  def _set(self, k, v):
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      k: key to set.
      v: value.

    Raises:
      RuntimeError
    """
    subconfig_type = self._get_subconfig_type(k)

    def is_null(k):
      if k not in self.__dict__ or not self.__dict__[k]:
        return True
      return False

    if isinstance(v, dict):
      if is_null(k):
        # If the key not exist or the value is None, a new Config-family object
        # sould be created for the key.
        self.__dict__[k] = subconfig_type(v)
      else:
        self.__dict__[k].override(v)
    elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
        [not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
      if len(self.__dict__[k]) == len(v):
        for i in range(len(v)):
          self.__dict__[k][i].override(v[i])
      elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
        logging.warning(
            "The list/tuple don't match the value dictionaries provided. Thus, "
            'the list/tuple is determined by the type annotation and '
            'values provided. This is error-prone.')
        self.__dict__[k] = self._import_config(v, subconfig_type)
      else:
        self.__dict__[k] = self._import_config(v, subconfig_type)
    else:
      self.__dict__[k] = self._import_config(v, subconfig_type)

  def __setattr__(self, k, v):
    if k == 'BUILDER' or k == '_BUILDER':
      raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
                           'reserved class attribute. We should only assign '
                           '`_BUILDER` at the class level.')

    if k not in self.RESERVED_ATTR:
      if getattr(self, '_locked', False):
        raise ValueError('The Config has been locked. ' 'No change is allowed.')
    self._set(k, v)

  def _override(self, override_dict, is_strict=True):
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      override_dict: dictionary to write to .
      is_strict: If True, not allows to add new keys.

    Raises:
      KeyError: overriding reserved keys or keys not exist (is_strict=True).
    """
    for k, v in sorted(override_dict.items()):
      if k in self.RESERVED_ATTR:
        raise KeyError('The key {!r} is internally reserved. '
                       'Can not be overridden.'.format(k))
      if k not in self.__dict__:
        if is_strict:
          raise KeyError('The key {!r} does not exist in {!r}. '
                         'To extend the existing keys, use '
                         '`override` with `is_strict` = False.'.format(
                             k, type(self)))
        else:
          self._set(k, v)
      else:
        if isinstance(v, dict) and self.__dict__[k]:
          self.__dict__[k]._override(v, is_strict)  # pylint: disable=protected-access
        elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
          self.__dict__[k]._override(v.as_dict(), is_strict)  # pylint: disable=protected-access
        else:
          self._set(k, v)

  def as_dict(self):
    """Returns a dict representation of params_dict.ParamsDict.

    For the nested params_dict.ParamsDict, a nested dict will be returned.
    """
    return {
        k: self._export_config(v)
        for k, v in self.__dict__.items()
        if k not in self.RESERVED_ATTR
    }

  def replace(self, **kwargs):
    """Overrides/returns a unlocked copy with the current config unchanged."""
    # pylint: disable=protected-access
    params = copy.deepcopy(self)
    params._locked = False
    params._override(kwargs, is_strict=True)
    # pylint: enable=protected-access
    return params

  @classmethod
  def from_yaml(cls, file_path: str):
    # Note: This only works if the Config has all default values.
    with tf.io.gfile.GFile(file_path, 'r') as f:
      loaded = yaml.load(f, Loader=yaml.FullLoader)
      config = cls()
      config.override(loaded)
      return config

  @classmethod
  def from_json(cls, file_path: str):
    """Wrapper for `from_yaml`."""
    return cls.from_yaml(file_path)

  @classmethod
  def from_args(cls, *args, **kwargs):
    """Builds a config from the given list of arguments."""
    # Note we intend to keep `__annotations__` instead of `_get_annotations`.
    # Assuming a parent class of (a, b) with the sub-class of (c, d), the
    # sub-class will take (c, d) for args, rather than starting from (a, b).
    attributes = list(cls.__annotations__.keys())
    default_params = {a: p for a, p in zip(attributes, args)}
    default_params.update(kwargs)
    return cls(default_params=default_params)