tensorflow/models

View on GitHub
official/modeling/hyperparams/params_dict.py

Summary

Maintainability
D
2 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A parameter dictionary class which supports the nest structure."""

import collections
import copy
import re

import six
import tensorflow as tf, tf_keras
import yaml

# regex pattern that matches on key-value pairs in a comma-separated
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(
    r"""
  (?P<name>[a-zA-Z][\w\.]*)(?P<bracketed_index>\[?[0-9]*\]?)  # variable name: "var" or "x" followed by optional index: "[0]" or "[23]"
  \s*=\s*
  ((?P<val>\'(.*?)\'           # single quote
  |
  \"(.*?)\"                    # double quote
  |
  [^,\[]*                      # single value
  |
  \[[^\]]*\]))                 # list of values
  ($|,\s*)""", re.VERBOSE)

_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')

# Yaml LOADER with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.

_LOADER = yaml.FullLoader
_LOADER.add_implicit_resolver(
    'tag:yaml.org,2002:float',
    re.compile(r'''
    ^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
    |
    [-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
    |
    \\.[0-9_]+(?:[eE][-+][0-9]+)?
    |
    [-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
    list('-+0123456789.'))


class ParamsDict(object):
  """A hyperparameter container class."""

  RESERVED_ATTR = ['_locked', '_restrictions']

  def __init__(self, default_params=None, restrictions=None):
    """Instantiate a ParamsDict.

    Instantiate a ParamsDict given a set of default parameters and a list of
    restrictions. Upon initialization, it validates itself by checking all the
    defined restrictions, and raise error if it finds inconsistency.

    Args:
      default_params: a Python dict or another ParamsDict object including the
        default parameters to initialize.
      restrictions: a list of strings, which define a list of restrictions to
        ensure the consistency of different parameters internally. Each
        restriction string is defined as a binary relation with a set of
        operators, including {'==', '!=',  '<', '<=', '>', '>='}.
    """
    self._locked = False
    self._restrictions = []
    if restrictions:
      self._restrictions = restrictions
    if default_params is None:
      default_params = {}
    self.override(default_params, is_strict=False)

  def _set(self, k, v):
    if isinstance(v, dict):
      self.__dict__[k] = ParamsDict(v)
    else:
      self.__dict__[k] = copy.deepcopy(v)

  def __setattr__(self, k, v):
    """Sets the value of the existing key.

    Note that this does not allow directly defining a new key. Use the
    `override` method with `is_strict=False` instead.

    Args:
      k: the key string.
      v: the value to be used to set the key `k`.

    Raises:
      KeyError: if k is not defined in the ParamsDict.
    """
    if k not in ParamsDict.RESERVED_ATTR:
      if k not in self.__dict__.keys():
        raise KeyError('The key `%{}` does not exist. '
                       'To extend the existing keys, use '
                       '`override` with `is_strict` = True.'.format(k))
      if self._locked:
        raise ValueError('The ParamsDict has been locked. '
                         'No change is allowed.')
    self._set(k, v)

  def __getattr__(self, k):
    """Gets the value of the existing key.

    Args:
      k: the key string.

    Returns:
      the value of the key.

    Raises:
      AttributeError: if k is not defined in the ParamsDict.
    """
    if k not in self.__dict__.keys():
      raise AttributeError('The key `{}` does not exist. '.format(k))
    return self.__dict__[k]

  def __contains__(self, key):
    """Implements the membership test operator."""
    return key in self.__dict__

  def get(self, key, value=None):
    """Accesses through built-in dictionary get method."""
    return self.__dict__.get(key, value)

  def __delattr__(self, k):
    """Deletes the key and removes its values.

    Args:
      k: the key string.

    Raises:
      AttributeError: if k is reserverd or not defined in the ParamsDict.
      ValueError: if the ParamsDict instance has been locked.
    """
    if k in ParamsDict.RESERVED_ATTR:
      raise AttributeError(
          'The key `{}` is reserved. No change is allowes. '.format(k))
    if k not in self.__dict__.keys():
      raise AttributeError('The key `{}` does not exist. '.format(k))
    if self._locked:
      raise ValueError('The ParamsDict has been locked. No change is allowed.')
    del self.__dict__[k]

  def override(self, override_params, is_strict=True):
    """Override the ParamsDict with a set of given params.

    Args:
      override_params: a dict or a ParamsDict specifying the parameters to be
        overridden.
      is_strict: a boolean specifying whether override is strict or not. If
        True, keys in `override_params` must be present in the ParamsDict. If
        False, keys in `override_params` can be different from what is currently
        defined in the ParamsDict. In this case, the ParamsDict will be extended
        to include the new keys.
    """
    if self._locked:
      raise ValueError('The ParamsDict has been locked. No change is allowed.')
    if isinstance(override_params, ParamsDict):
      override_params = override_params.as_dict()
    self._override(override_params, is_strict)  # pylint: disable=protected-access

  def _override(self, override_dict, is_strict=True):
    """The implementation of `override`."""
    for k, v in six.iteritems(override_dict):
      if k in ParamsDict.RESERVED_ATTR:
        raise KeyError('The key `%{}` is internally reserved. '
                       'Can not be overridden.')
      if k not in self.__dict__.keys():
        if is_strict:
          raise KeyError('The key `{}` does not exist. '
                         'To extend the existing keys, use '
                         '`override` with `is_strict` = False.'.format(k))
        else:
          self._set(k, v)
      else:
        if isinstance(v, dict):
          self.__dict__[k]._override(v, is_strict)  # pylint: disable=protected-access
        elif isinstance(v, ParamsDict):
          self.__dict__[k]._override(v.as_dict(), is_strict)  # pylint: disable=protected-access
        else:
          self.__dict__[k] = copy.deepcopy(v)

  def lock(self):
    """Makes the ParamsDict immutable."""
    self._locked = True

  def as_dict(self):
    """Returns a dict representation of ParamsDict.

    For the nested ParamsDict, a nested dict will be returned.
    """
    params_dict = {}
    for k, v in six.iteritems(self.__dict__):
      if k not in ParamsDict.RESERVED_ATTR:
        if isinstance(v, ParamsDict):
          params_dict[k] = v.as_dict()
        else:
          params_dict[k] = copy.deepcopy(v)
    return params_dict

  def validate(self):
    """Validate the parameters consistency based on the restrictions.

    This method validates the internal consistency using the pre-defined list of
    restrictions. A restriction is defined as a string which specifies a binary
    operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
    '>='}. Note that the meaning of these operators are consistent with the
    underlying Python immplementation. Users should make sure the define
    restrictions on their type make sense.

    For example, for a ParamsDict like the following
    ```
    a:
      a1: 1
      a2: 2
    b:
      bb:
        bb1: 10
        bb2: 20
      ccc:
        a1: 1
        a3: 3
    ```
    one can define two restrictions like this
    ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']

    What it enforces are:
     - a.a1 = 1 == b.ccc.a1 = 1
     - a.a2 = 2 <= b.bb.bb2 = 20

    Raises:
      KeyError: if any of the following happens
        (1) any of parameters in any of restrictions is not defined in
            ParamsDict,
        (2) any inconsistency violating the restriction is found.
      ValueError: if the restriction defined in the string is not supported.
    """

    def _get_kv(dotted_string, params_dict):
      """Get keys and values indicated by dotted_string."""
      if _CONST_VALUE_RE.match(dotted_string) is not None:
        const_str = dotted_string
        if const_str == 'None':
          constant = None
        else:
          constant = float(const_str)
        return None, constant
      else:
        tokenized_params = dotted_string.split('.')
        v = params_dict
        for t in tokenized_params:
          v = v[t]
        return tokenized_params[-1], v

    def _get_kvs(tokens, params_dict):
      if len(tokens) != 2:
        raise ValueError('Only support binary relation in restriction.')
      stripped_tokens = [t.strip() for t in tokens]
      left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
      right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
      return left_k, left_v, right_k, right_v

    params_dict = self.as_dict()
    for restriction in self._restrictions:
      if '==' in restriction:
        tokens = restriction.split('==')
        _, left_v, _, right_v = _get_kvs(tokens, params_dict)
        if left_v != right_v:
          raise KeyError(
              'Found inconsistency between key `{}` and key `{}`.'.format(
                  tokens[0], tokens[1]))
      elif '!=' in restriction:
        tokens = restriction.split('!=')
        _, left_v, _, right_v = _get_kvs(tokens, params_dict)
        if left_v == right_v:
          raise KeyError(
              'Found inconsistency between key `{}` and key `{}`.'.format(
                  tokens[0], tokens[1]))
      elif '<=' in restriction:
        tokens = restriction.split('<=')
        _, left_v, _, right_v = _get_kvs(tokens, params_dict)
        if left_v > right_v:
          raise KeyError(
              'Found inconsistency between key `{}` and key `{}`.'.format(
                  tokens[0], tokens[1]))
      elif '<' in restriction:
        tokens = restriction.split('<')
        _, left_v, _, right_v = _get_kvs(tokens, params_dict)
        if left_v >= right_v:
          raise KeyError(
              'Found inconsistency between key `{}` and key `{}`.'.format(
                  tokens[0], tokens[1]))
      elif '>=' in restriction:
        tokens = restriction.split('>=')
        _, left_v, _, right_v = _get_kvs(tokens, params_dict)
        if left_v < right_v:
          raise KeyError(
              'Found inconsistency between key `{}` and key `{}`.'.format(
                  tokens[0], tokens[1]))
      elif '>' in restriction:
        tokens = restriction.split('>')
        _, left_v, _, right_v = _get_kvs(tokens, params_dict)
        if left_v <= right_v:
          raise KeyError(
              'Found inconsistency between key `{}` and key `{}`.'.format(
                  tokens[0], tokens[1]))
      else:
        raise ValueError('Unsupported relation in restriction.')


def read_yaml_to_params_dict(file_path: str):
  """Reads a YAML file to a ParamsDict."""
  with tf.io.gfile.GFile(file_path, 'r') as f:
    params_dict = yaml.load(f, Loader=_LOADER)
    return ParamsDict(params_dict)


def save_params_dict_to_yaml(params, file_path):
  """Saves the input ParamsDict to a YAML file."""
  with tf.io.gfile.GFile(file_path, 'w') as f:

    def _my_list_rep(dumper, data):
      # u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
      return dumper.represent_sequence(
          u'tag:yaml.org,2002:seq', data, flow_style=True)

    yaml.add_representer(list, _my_list_rep)
    yaml.dump(params.as_dict(), f, default_flow_style=False)


def nested_csv_str_to_json_str(csv_str):
  """Converts a nested (using '.') comma-separated k=v string to a JSON string.

  Converts a comma-separated string of key/value pairs that supports
  nesting of keys to a JSON string. Nesting is implemented using
  '.' between levels for a given key.

  Spacing between commas and = is supported (e.g. there is no difference between
  "a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
  keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).

  Note that this will only support values supported by CSV, meaning
  values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
  supported. Strings are supported as well, e.g. "a='hello'".

  An example conversion would be:

  "a=1, b=2, c.a=2, c.b=3, d.a.a=5"

  to

  "{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"

  Args:
    csv_str: the comma separated string.

  Returns:
    the converted JSON string.

  Raises:
    ValueError: If csv_str is not in a comma separated string or
      if the string is formatted incorrectly.
  """
  if not csv_str:
    return ''

  array_param_map = collections.defaultdict(str)
  max_index_map = collections.defaultdict(str)
  formatted_entries = []
  nested_map = collections.defaultdict(list)
  pos = 0
  while pos < len(csv_str):
    m = _PARAM_RE.match(csv_str, pos)
    if not m:
      raise ValueError('Malformed hyperparameter value while parsing '
                       'CSV string: %s' % csv_str[pos:])
    pos = m.end()
    # Parse the values.
    m_dict = m.groupdict()
    name = m_dict['name']
    v = m_dict['val']
    bracketed_index = m_dict['bracketed_index']
    # If we reach the name of the array.
    if bracketed_index and '.' not in name:
      # Extract the array's index by removing '[' and ']'
      index = int(bracketed_index[1:-1])
      if '.' in v:
        numeric_val = float(v)
      else:
        numeric_val = int(v)
      # Add the value to the array.
      if name not in array_param_map:
        max_index_map[name] = index
        array_param_map[name] = [None] * (index + 1)
        array_param_map[name][index] = numeric_val
      elif index < max_index_map[name]:
        array_param_map[name][index] = numeric_val
      else:
        array_param_map[name] += [None] * (index - max_index_map[name])
        array_param_map[name][index] = numeric_val
        max_index_map[name] = index
      continue

    # If a GCS path (e.g. gs://...) is provided, wrap this in quotes
    # as yaml.load would otherwise throw an exception
    if re.match(r'(?=[^\"\'])(?=[gs://])', v):
      v = '\'{}\''.format(v)

    name_nested = name.split('.')
    if len(name_nested) > 1:
      grouping = name_nested[0]
      if bracketed_index:
        value = '.'.join(name_nested[1:]) + bracketed_index + '=' + v
      else:
        value = '.'.join(name_nested[1:]) + '=' + v
      nested_map[grouping].append(value)
    else:
      formatted_entries.append('%s : %s' % (name, v))

  for grouping, value in nested_map.items():
    value = ','.join(value)
    value = nested_csv_str_to_json_str(value)
    formatted_entries.append('%s : %s' % (grouping, value))

  # Add array parameters and check that the array is fully initialized.
  for name in array_param_map:
    if any(v is None for v in array_param_map[name]):
      raise ValueError('Did not pass all values of array: %s' % name)
    formatted_entries.append('%s : %s' % (name, array_param_map[name]))

  return '{' + ', '.join(formatted_entries) + '}'


def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
  """Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.

  The logic of the function is outlined below:
  1. Test that the input is a dict. If not, proceed to 2.
  2. Tests that the input is a string. If not, raise unknown ValueError
  2.1. Test if the string is in a CSV format. If so, parse.
  If not, proceed to 2.2.
  2.2. Try loading the string as a YAML/JSON. If successful, parse to
  dict and use it to override. If not, proceed to 2.3.
  2.3. Try using the string as a file path and load the YAML file.

  Args:
    params: a ParamsDict object to be overridden.
    dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
      a YAML file specifying the parameters to be overridden.
    is_strict: a boolean specifying whether override is strict or not.

  Returns:
    params: the overridden ParamsDict object.

  Raises:
    ValueError: if failed to override the parameters.
  """
  if not dict_or_string_or_yaml_file:
    return params
  if isinstance(dict_or_string_or_yaml_file, dict):
    params.override(dict_or_string_or_yaml_file, is_strict)
  elif isinstance(dict_or_string_or_yaml_file, six.string_types):
    try:
      dict_or_string_or_yaml_file = (
          nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
    except ValueError:
      pass
    params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=_LOADER)
    if isinstance(params_dict, dict):
      params.override(params_dict, is_strict)
    else:
      with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
        params.override(yaml.load(f, Loader=_LOADER), is_strict)
  else:
    raise ValueError('Unknown input type to parse.')
  return params