research/object_detection/utils/variables_helper.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for manipulating collections of variables during training.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import re
import tensorflow.compat.v1 as tf
import tf_slim as slim
from tensorflow.python.ops import variables as tf_variables
# Maps checkpoint types to variable name prefixes that are no longer
# supported
DETECTION_FEATURE_EXTRACTOR_MSG = """\
The checkpoint type 'detection' is not supported when it contains variable
names with 'feature_extractor'. Please download the new checkpoint file
from model zoo.
"""
DEPRECATED_CHECKPOINT_MAP = {
'detection': ('feature_extractor', DETECTION_FEATURE_EXTRACTOR_MSG)
}
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False):
"""Filters out the variables matching the filter_regex.
Filter out the variables whose name matches the any of the regular
expressions in filter_regex_list and returns the remaining variables.
Optionally, if invert=True, the complement set is returned.
Args:
variables: a list of tensorflow variables.
filter_regex_list: a list of string regular expressions.
invert: (boolean). If True, returns the complement of the filter set; that
is, all variables matching filter_regex are kept and all others discarded.
Returns:
a list of filtered variables.
"""
kept_vars = []
variables_to_ignore_patterns = list([fre for fre in filter_regex_list if fre])
for var in variables:
add = True
for pattern in variables_to_ignore_patterns:
if re.match(pattern, var.op.name):
add = False
break
if add != invert:
kept_vars.append(var)
return kept_vars
def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
"""Multiply gradients whose variable names match a regular expression.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
regex_list: A list of string regular expressions.
multiplier: A (float) multiplier to apply to each gradient matching the
regular expression.
Returns:
grads_and_vars: A list of gradient to variable pairs (tuples).
"""
variables = [pair[1] for pair in grads_and_vars]
matching_vars = filter_variables(variables, regex_list, invert=True)
for var in matching_vars:
logging.info('Applying multiplier %f to variable [%s]',
multiplier, var.op.name)
grad_multipliers = {var: float(multiplier) for var in matching_vars}
return slim.learning.multiply_gradients(grads_and_vars,
grad_multipliers)
def freeze_gradients_matching_regex(grads_and_vars, regex_list):
"""Freeze gradients whose variable names match a regular expression.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
regex_list: A list of string regular expressions.
Returns:
grads_and_vars: A list of gradient to variable pairs (tuples) that do not
contain the variables and gradients matching the regex.
"""
variables = [pair[1] for pair in grads_and_vars]
matching_vars = filter_variables(variables, regex_list, invert=True)
kept_grads_and_vars = [pair for pair in grads_and_vars
if pair[1] not in matching_vars]
for var in matching_vars:
logging.info('Freezing variable [%s]', var.op.name)
return kept_grads_and_vars
def get_variables_available_in_checkpoint(variables,
checkpoint_path,
include_global_step=True):
"""Returns the subset of variables available in the checkpoint.
Inspects given checkpoint and returns the subset of variables that are
available in it.
TODO(rathodv): force input and output to be a dictionary.
Args:
variables: a list or dictionary of variables to find in checkpoint.
checkpoint_path: path to the checkpoint to restore variables from.
include_global_step: whether to include `global_step` variable, if it
exists. Default True.
Returns:
A list or dictionary of variables.
Raises:
ValueError: if `variables` is not a list or dict.
"""
if isinstance(variables, list):
variable_names_map = {}
for variable in variables:
if isinstance(variable, tf_variables.PartitionedVariable):
name = variable.name
else:
name = variable.op.name
variable_names_map[name] = variable
elif isinstance(variables, dict):
variable_names_map = variables
else:
raise ValueError('`variables` is expected to be a list or dict.')
ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
if not include_global_step:
ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
vars_in_ckpt = {}
for variable_name, variable in sorted(variable_names_map.items()):
if variable_name in ckpt_vars_to_shape_map:
if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
vars_in_ckpt[variable_name] = variable
else:
logging.warning('Variable [%s] is available in checkpoint, but has an '
'incompatible shape with model variable. Checkpoint '
'shape: [%s], model variable shape: [%s]. This '
'variable will not be initialized from the checkpoint.',
variable_name, ckpt_vars_to_shape_map[variable_name],
variable.shape.as_list())
else:
logging.warning('Variable [%s] is not available in checkpoint',
variable_name)
if isinstance(variables, list):
return list(vars_in_ckpt.values())
return vars_in_ckpt
def get_global_variables_safely():
"""If not executing eagerly, returns tf.global_variables().
Raises a ValueError if eager execution is enabled,
because the variables are not tracked when executing eagerly.
If executing eagerly, use a Keras model's .variables property instead.
Returns:
The result of tf.global_variables()
"""
with tf.init_scope():
if tf.executing_eagerly():
raise ValueError("Global variables collection is not tracked when "
"executing eagerly. Use a Keras model's `.variables` "
"attribute instead.")
return tf.global_variables()
def ensure_checkpoint_supported(checkpoint_path, checkpoint_type, model_dir):
"""Ensures that the given checkpoint can be properly loaded.
Performs the following checks
1. Raises an error if checkpoint_path and model_dir are same.
2. Checks that checkpoint_path does not contain a deprecated checkpoint file
by inspecting its variables.
Args:
checkpoint_path: str, path to checkpoint.
checkpoint_type: str, denotes the type of checkpoint.
model_dir: The model directory to store intermediate training checkpoints.
Raises:
RuntimeError: If
1. We detect an deprecated checkpoint file.
2. model_dir and checkpoint_path are in the same directory.
"""
variables = tf.train.list_variables(checkpoint_path)
if checkpoint_type in DEPRECATED_CHECKPOINT_MAP:
blocked_prefix, msg = DEPRECATED_CHECKPOINT_MAP[checkpoint_type]
for var_name, _ in variables:
if var_name.startswith(blocked_prefix):
tf.logging.error('Found variable name - %s with prefix %s', var_name,
blocked_prefix)
raise RuntimeError(msg)
checkpoint_path_dir = os.path.abspath(os.path.dirname(checkpoint_path))
model_dir = os.path.abspath(model_dir)
if model_dir == checkpoint_path_dir:
raise RuntimeError(
('Checkpoint dir ({}) and model_dir ({}) cannot be same.'.format(
checkpoint_path_dir, model_dir) +
(' Please set model_dir to a different path.')))