tensorflow/python/training/warm_starting_util_test.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for warm_starting_util."""
import os
import numpy as np
from tensorflow.python.checkpoint import checkpoint as tracking_util
from tensorflow.python.feature_column import feature_column_lib as fc
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import warm_starting_util as ws_util
ones = init_ops.ones_initializer
norms = init_ops.truncated_normal_initializer
rand = init_ops.random_uniform_initializer
zeros = init_ops.zeros_initializer
class WarmStartingUtilTest(test.TestCase):
def _write_vocab(self, string_values, file_name):
vocab_file = os.path.join(self.get_temp_dir(), file_name)
with open(vocab_file, "w") as f:
f.write("\n".join(string_values))
return vocab_file
def _write_checkpoint(self, sess):
self.evaluate(variables.global_variables_initializer())
saver = saver_lib.Saver()
ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
saver.save(sess, ckpt_prefix, global_step=0)
def _create_prev_run_var(self,
var_name,
shape=None,
initializer=None,
partitioner=None):
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
var = variable_scope.get_variable(
var_name,
shape=shape,
initializer=initializer,
partitioner=partitioner)
self._write_checkpoint(sess)
if partitioner:
self.assertTrue(isinstance(var, variables.PartitionedVariable))
var = var._get_variable_list()
return var, self.evaluate(var)
def _create_prev_run_vars(self,
var_names,
shapes,
initializers):
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
all_vars = []
for var_name, shape, initializer in zip(var_names, shapes,
initializers):
all_vars.append(variable_scope.get_variable(
var_name,
shape=shape,
initializer=initializer))
self._write_checkpoint(sess)
return [self.evaluate(var) for var in all_vars]
def _create_dummy_inputs(self):
return {
"sc_int": array_ops.sparse_placeholder(dtypes.int32),
"sc_hash": array_ops.sparse_placeholder(dtypes.string),
"sc_keys": array_ops.sparse_placeholder(dtypes.string),
"sc_vocab": array_ops.sparse_placeholder(dtypes.string),
"real": array_ops.placeholder(dtypes.float32)
}
def _create_linear_model(self, feature_cols, partitioner):
cols_to_vars = {}
with variable_scope.variable_scope("", partitioner=partitioner):
# Create the variables.
fc.linear_model(
features=self._create_dummy_inputs(),
feature_columns=feature_cols,
units=1,
cols_to_vars=cols_to_vars)
# Return a dictionary mapping each column to its variable.
return cols_to_vars
def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess):
for col, expected_values in cols_to_expected_values.items():
for i, var in enumerate(cols_to_vars[col]):
self.assertAllClose(expected_values[i], var.eval(sess))
def testWarmStartVar(self):
_, prev_val = self._create_prev_run_var(
"fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
{prev_tensor_name: var})
self.evaluate(variables.global_variables_initializer())
self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarPrevVarPartitioned(self):
_, weights = self._create_prev_run_var(
"fruit_weights",
shape=[4, 1],
initializer=[[0.5], [1.], [1.5], [2.]],
partitioner=lambda shape, dtype: [2, 1])
prev_val = np.concatenate([weights[0], weights[1]], axis=0)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
{prev_tensor_name: var})
self.evaluate(variables.global_variables_initializer())
self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarCurrentVarPartitioned(self):
_, prev_val = self._create_prev_run_var(
"fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights",
shape=[4, 1],
initializer=[[0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
{prev_tensor_name: var})
self.evaluate(variables.global_variables_initializer())
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
self.assertAllClose(prev_val, new_val)
def testWarmStartVarBothVarsPartitioned(self):
_, weights = self._create_prev_run_var(
"old_scope/fruit_weights",
shape=[4, 1],
initializer=[[0.5], [1.], [1.5], [2.]],
partitioner=lambda shape, dtype: [2, 1])
prev_val = np.concatenate([weights[0], weights[1]], axis=0)
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"new_scope/fruit_weights",
shape=[4, 1],
initializer=[[0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
prev_tensor_name, var = ws_util._get_var_info(
fruit_weights, prev_tensor_name="old_scope/fruit_weights")
checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
{prev_tensor_name: var})
self.evaluate(variables.global_variables_initializer())
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
self.assertAllClose(prev_val, new_val)
def testWarmStartVarWithVocab(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
self._create_prev_run_var(
"fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
self.evaluate(variables.global_variables_initializer())
self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithColumnVocab(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
self.evaluate(variables.global_variables_initializer())
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
[2.3, 2., 0.]], fruit_output_layer.eval(sess))
def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
self._create_prev_run_var(
"fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
ws_util._warm_start_var_with_vocab(
fruit_weights,
new_vocab_path,
5,
self.get_temp_dir(),
prev_vocab_path,
previous_vocab_size=2)
self.evaluate(variables.global_variables_initializer())
# Old vocabulary limited to ['apple', 'banana'].
self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithVocabPrevVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
self._create_prev_run_var(
"fruit_weights",
shape=[4, 1],
initializer=[[0.5], [1.], [1.5], [2.]],
partitioner=lambda shape, dtype: [2, 1])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
self.evaluate(variables.global_variables_initializer())
self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
shape=[4, 2],
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
partitioner=lambda shape, dtype: [2, 1])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
self.evaluate(variables.global_variables_initializer())
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
[2.3, 2., 0.]], fruit_output_layer.eval(sess))
def testWarmStartVarWithVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
self._create_prev_run_var(
"fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights",
shape=[6, 1],
initializer=[[0.], [0.], [0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(
fruit_weights,
new_vocab_path,
5,
self.get_temp_dir(),
prev_vocab_path,
current_oov_buckets=1)
self.evaluate(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
self.evaluate(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_output_layer, variables.PartitionedVariable))
fruit_output_layer_vars = fruit_output_layer._get_variable_list()
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
fruit_output_layer_vars[0].eval(sess))
self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
fruit_output_layer_vars[1].eval(sess))
def testWarmStartVarWithVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
self._create_prev_run_var(
"fruit_weights",
shape=[4, 1],
initializer=[[0.5], [1.], [1.5], [2.]],
partitioner=lambda shape, dtype: [2, 1])
# New vocab with elements in reverse order and two new elements.
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry",
"blueberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights",
shape=[6, 1],
initializer=[[0.], [0.], [0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 6,
self.get_temp_dir(), prev_vocab_path)
self.evaluate(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
shape=[4, 2],
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
partitioner=lambda shape, dtype: [2, 1])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
self.evaluate(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_output_layer, variables.PartitionedVariable))
fruit_output_layer_vars = fruit_output_layer._get_variable_list()
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
fruit_output_layer_vars[0].eval(sess))
self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
fruit_output_layer_vars[1].eval(sess))
def testWarmStart_ListOfVariables(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
initializer=ones())
# Verify we initialized the values correctly.
self.assertAllEqual(np.ones([10, 1]), prev_int_val)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Initialize with zeros.
var = variable_scope.get_variable(
"v1",
shape=[10, 1],
initializer=zeros())
ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var])
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started (init overridden to ones).
self.assertAllEqual(var, prev_int_val)
def testWarmStart_ListOfStrings(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
initializer=ones())
# Verify we initialized the values correctly.
self.assertAllEqual(np.ones([10, 1]), prev_int_val)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Initialize with zeros.
var = variable_scope.get_variable(
"v1",
shape=[10, 1],
initializer=zeros())
ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"])
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started (init overridden to ones).
self.assertAllEqual(var, prev_int_val)
def testWarmStart_TwoVarsFromTheSamePrevVar(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
initializer=ones())
# Verify we initialized the values correctly.
self.assertAllEqual(np.ones([10, 1]), prev_int_val)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g):
# Initialize with zeros.
var = variable_scope.get_variable(
"v1",
shape=[10, 1],
initializer=zeros())
var2 = variable_scope.get_variable(
"v2",
shape=[10, 1],
initializer=zeros())
ws_util.warm_start(self.get_temp_dir(),
vars_to_warm_start=["v1", "v2"],
var_name_to_prev_var_name=dict(v2="v1"))
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started (init overridden to ones).
self.assertAllEqual(var, prev_int_val)
self.assertAllEqual(var2, prev_int_val)
def testWarmStart_ListOfRegexes(self):
# Save checkpoint from which to warm-start.
[prev_v1_val, prev_v1_momentum_val,
prev_v2_val, _] = self._create_prev_run_vars(
var_names=["v1", "v1/Momentum", "v2", "v2/Momentum"],
shapes=[[10, 1]] * 4,
initializers=[ones()] * 4)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Initialize with zeros.
v1 = variable_scope.get_variable(
"v1",
shape=[10, 1],
initializer=zeros())
v1_momentum = variable_scope.get_variable(
"v1/Momentum",
shape=[10, 1],
initializer=zeros())
v2 = variable_scope.get_variable(
"v2",
shape=[10, 1],
initializer=zeros())
v2_momentum = variable_scope.get_variable(
"v2/Momentum",
shape=[10, 1],
initializer=zeros())
ws_util.warm_start(self.get_temp_dir(),
# This warm-starts both v1 and v1/Momentum, but only
# v2 (and not v2/Momentum).
vars_to_warm_start=["v1", "v2[^/]"])
self.evaluate(variables.global_variables_initializer())
# Verify the selection of weights were correctly warm-started (init
# overridden to ones).
self.assertAllEqual(v1, prev_v1_val)
self.assertAllEqual(v1_momentum, prev_v1_momentum_val)
self.assertAllEqual(v2, prev_v2_val)
self.assertAllEqual(v2_momentum, np.zeros([10, 1]))
def testWarmStart_SparseColumnIntegerized(self):
# Create feature column.
sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var(
"linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
# Verify we initialized the values correctly.
self.assertAllEqual(np.ones([10, 1]), prev_int_val)
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_int], partitioner)
self.evaluate(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {sc_int: [np.zeros([10, 1])]},
sess)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_int], partitioner)
ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess)
def testWarmStart_SparseColumnHashed(self):
# Create feature column.
sc_hash = fc.categorical_column_with_hash_bucket(
"sc_hash", hash_bucket_size=15)
# Save checkpoint from which to warm-start.
_, prev_hash_val = self._create_prev_run_var(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_hash], partitioner)
self.evaluate(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {sc_hash: [np.zeros([15, 1])]},
sess)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_hash], partitioner)
ws_util.warm_start(
self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]},
sess)
def testWarmStart_SparseColumnVocabulary(self):
# Create vocab for sparse column "sc_vocab".
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"vocab")
# Create feature column.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
# Save checkpoint from which to warm-start.
_, prev_vocab_val = self._create_prev_run_var(
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
self.evaluate(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
sess)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
ws_util.warm_start(
self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*")
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
sess)
def testWarmStart_ExplicitCheckpointFile(self):
# Create vocab for sparse column "sc_vocab".
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"vocab")
# Create feature column.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
# Save checkpoint from which to warm-start.
_, prev_vocab_val = self._create_prev_run_var(
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
self.evaluate(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
sess)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
ws_util.warm_start(
# Explicitly provide the file prefix instead of just the dir.
os.path.join(self.get_temp_dir(), "model-0"),
vars_to_warm_start=".*sc_vocab.*")
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
sess)
def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
# Create old vocabulary, and use a size smaller than the total number of
# entries.
old_vocab_path = self._write_vocab(["apple", "guava", "banana"],
"old_vocab")
old_vocab_size = 2 # ['apple', 'guava']
# Create new vocab for sparse column "sc_vocab".
current_vocab_path = self._write_vocab(
["apple", "banana", "guava", "orange"], "current_vocab")
# Create feature column. Only use 2 of the actual entries, resulting in
# ['apple', 'banana'] for the new vocabulary.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=current_vocab_path, vocabulary_size=2)
# Save checkpoint from which to warm-start.
self._create_prev_run_var(
"linear_model/sc_vocab/weights", shape=[2, 1], initializer=ones())
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
self.evaluate(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([2, 1])]},
sess)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=old_vocab_path,
old_vocab_size=old_vocab_size)
ws_util.warm_start(
ckpt_to_initialize_from=self.get_temp_dir(),
vars_to_warm_start=".*sc_vocab.*",
var_name_to_vocab_info={
"linear_model/sc_vocab/weights": vocab_info
})
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started. 'banana' isn't in the
# first two entries of the old vocabulary, so it's newly initialized.
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [[[1], [0]]]}, sess)
def testWarmStart_BucketizedColumn(self):
# Create feature column.
real = fc.numeric_column("real")
real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])
# Save checkpoint from which to warm-start.
_, prev_bucket_val = self._create_prev_run_var(
"linear_model/real_bucketized/weights",
shape=[5, 1],
initializer=norms())
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([real_bucket], partitioner)
self.evaluate(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars,
{real_bucket: [np.zeros([5, 1])]}, sess)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([real_bucket], partitioner)
ws_util.warm_start(
self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars,
{real_bucket: [prev_bucket_val]}, sess)
def testWarmStart_MultipleCols(self):
# Create vocab for sparse column "sc_vocab".
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"vocab")
# Create feature columns.
sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
sc_hash = fc.categorical_column_with_hash_bucket(
"sc_hash", hash_bucket_size=15)
sc_keys = fc.categorical_column_with_vocabulary_list(
"sc_keys", vocabulary_list=["a", "b", "c", "e"])
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
real = fc.numeric_column("real")
real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])
cross = fc.crossed_column([sc_keys, sc_vocab], hash_bucket_size=20)
all_linear_cols = [sc_int, sc_hash, sc_keys, sc_vocab, real_bucket, cross]
# Save checkpoint from which to warm-start. Also create a bias variable,
# so we can check that it's also warm-started.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
sc_int_weights = variable_scope.get_variable(
"linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
sc_hash_weights = variable_scope.get_variable(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
sc_keys_weights = variable_scope.get_variable(
"linear_model/sc_keys/weights", shape=[4, 1], initializer=rand())
sc_vocab_weights = variable_scope.get_variable(
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
real_bucket_weights = variable_scope.get_variable(
"linear_model/real_bucketized/weights",
shape=[5, 1],
initializer=norms())
cross_weights = variable_scope.get_variable(
"linear_model/sc_keys_X_sc_vocab/weights",
shape=[20, 1],
initializer=rand())
bias = variable_scope.get_variable(
"linear_model/bias_weights",
shape=[1],
initializer=rand())
self._write_checkpoint(sess)
(prev_int_val, prev_hash_val, prev_keys_val, prev_vocab_val,
prev_bucket_val, prev_cross_val, prev_bias_val) = sess.run([
sc_int_weights, sc_hash_weights, sc_keys_weights, sc_vocab_weights,
real_bucket_weights, cross_weights, bias
])
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
self.evaluate(variables.global_variables_initializer())
# Without warm-starting, all weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {
sc_int: [np.zeros([10, 1])],
sc_hash: [np.zeros([15, 1])],
sc_keys: [np.zeros([4, 1])],
sc_vocab: [np.zeros([4, 1])],
real_bucket: [np.zeros([5, 1])],
cross: [np.zeros([20, 1])],
}, sess)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=vocab_path)
ws_util.warm_start(
self.get_temp_dir(),
var_name_to_vocab_info={
"linear_model/sc_vocab/weights": vocab_info
})
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {
sc_int: [prev_int_val],
sc_hash: [prev_hash_val],
sc_keys: [prev_keys_val],
sc_vocab: [prev_vocab_val],
real_bucket: [prev_bucket_val],
cross: [prev_cross_val],
"bias": [prev_bias_val],
}, sess)
def testWarmStartMoreSettings(self):
# Create old and new vocabs for sparse column "sc_vocab".
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry",
"blueberry"], "new_vocab")
# Create feature columns.
sc_hash = fc.categorical_column_with_hash_bucket(
"sc_hash", hash_bucket_size=15)
sc_keys = fc.categorical_column_with_vocabulary_list(
"sc_keys", vocabulary_list=["a", "b", "c", "e"])
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
all_linear_cols = [sc_hash, sc_keys, sc_vocab]
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
sc_keys_weights = variable_scope.get_variable(
"some_other_name", shape=[4, 1], initializer=rand())
variable_scope.get_variable(
"linear_model/sc_vocab/weights",
initializer=[[0.5], [1.], [2.], [3.]])
self._write_checkpoint(sess)
prev_keys_val = self.evaluate(sc_keys_weights)
def _partitioner(shape, dtype): # pylint:disable=unused-argument
# Partition each var into 2 equal slices.
partitions = [1] * len(shape)
partitions[0] = min(2, shape.dims[0].value)
return partitions
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path)
ws_util.warm_start(
self.get_temp_dir(),
vars_to_warm_start=".*(sc_keys|sc_vocab).*",
var_name_to_vocab_info={
ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
},
var_name_to_prev_var_name={
ws_util._infer_var_name(cols_to_vars[sc_keys]):
"some_other_name"
})
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# sc_hash should not be warm-started. Var corresponding to sc_vocab
# should be correctly warm-started after vocab remapping.
self._assert_cols_to_vars(cols_to_vars, {
sc_keys:
np.split(prev_keys_val, 2),
sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])],
sc_vocab: [
np.array([[3.], [2.], [1.]]),
np.array([[0.5], [0.], [0.]])
]
}, sess)
def testWarmStartMoreSettingsNoPartitioning(self):
# Create old and new vocabs for sparse column "sc_vocab".
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry",
"blueberry"], "new_vocab")
# Create feature columns.
sc_hash = fc.categorical_column_with_hash_bucket(
"sc_hash", hash_bucket_size=15)
sc_keys = fc.categorical_column_with_vocabulary_list(
"sc_keys", vocabulary_list=["a", "b", "c", "e"])
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
all_linear_cols = [sc_hash, sc_keys, sc_vocab]
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
sc_keys_weights = variable_scope.get_variable(
"some_other_name", shape=[4, 1], initializer=rand())
variable_scope.get_variable(
"linear_model/sc_vocab/weights",
initializer=[[0.5], [1.], [2.], [3.]])
self._write_checkpoint(sess)
prev_keys_val = self.evaluate(sc_keys_weights)
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols,
partitioner=None)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path)
ws_util.warm_start(
self.get_temp_dir(),
vars_to_warm_start=".*(sc_keys|sc_vocab).*",
var_name_to_vocab_info={
ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
},
var_name_to_prev_var_name={
ws_util._infer_var_name(cols_to_vars[sc_keys]):
"some_other_name"
})
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# sc_hash should not be warm-started. Var corresponding to sc_vocab
# should be correctly warm-started after vocab remapping.
self._assert_cols_to_vars(cols_to_vars, {
sc_keys: [prev_keys_val],
sc_hash: [np.zeros([15, 1])],
sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])]
}, sess)
def testWarmStartVarsToWarmstartIsNone(self):
# Create old and new vocabs for sparse column "sc_vocab".
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry",
"blueberry"], "new_vocab")
# Create feature columns.
sc_hash = fc.categorical_column_with_hash_bucket(
"sc_hash", hash_bucket_size=15)
sc_keys = fc.categorical_column_with_vocabulary_list(
"sc_keys", vocabulary_list=["a", "b", "c", "e"])
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
all_linear_cols = [sc_hash, sc_keys, sc_vocab]
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
variable_scope.get_variable(
"some_other_name", shape=[4, 1], initializer=rand())
variable_scope.get_variable(
"linear_model/sc_vocab/weights",
initializer=[[0.5], [1.], [2.], [3.]])
self._write_checkpoint(sess)
def _partitioner(shape, dtype): # pylint:disable=unused-argument
# Partition each var into 2 equal slices.
partitions = [1] * len(shape)
partitions[0] = min(2, shape.dims[0].value)
return partitions
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path)
ws_util.warm_start(
self.get_temp_dir(),
# The special value of None here will ensure that only the variable
# specified in var_name_to_vocab_info (sc_vocab embedding) is
# warm-started.
vars_to_warm_start=None,
var_name_to_vocab_info={
ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
},
# Even though this is provided, the None value for
# vars_to_warm_start overrides the logic, and this will not be
# warm-started.
var_name_to_prev_var_name={
ws_util._infer_var_name(cols_to_vars[sc_keys]):
"some_other_name"
})
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# sc_vocab should be correctly warm-started after vocab remapping,
# and neither of the other two should be warm-started..
self._assert_cols_to_vars(cols_to_vars, {
sc_keys: [np.zeros([2, 1]), np.zeros([2, 1])],
sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])],
sc_vocab: [
np.array([[3.], [2.], [1.]]),
np.array([[0.5], [0.], [0.]])
]
}, sess)
def testWarmStartEmbeddingColumn(self):
# Create old and new vocabs for embedding column "sc_vocab".
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
"new_vocab")
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
variable_scope.get_variable(
"input_layer/sc_vocab_embedding/embedding_weights",
initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
self._write_checkpoint(sess)
def _partitioner(shape, dtype): # pylint:disable=unused-argument
# Partition each var into 2 equal slices.
partitions = [1] * len(shape)
partitions[0] = min(2, shape.dims[0].value)
return partitions
# Create feature columns.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
emb_vocab_column = fc.embedding_column(
categorical_column=sc_vocab,
dimension=2)
all_deep_cols = [emb_vocab_column]
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = {}
with variable_scope.variable_scope("", partitioner=_partitioner):
# Create the variables.
fc.input_layer(
features=self._create_dummy_inputs(),
feature_columns=all_deep_cols,
cols_to_vars=cols_to_vars)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path,
# Can't use constant_initializer with load_and_remap. In practice,
# use a truncated normal initializer.
backup_initializer=init_ops.random_uniform_initializer(
minval=0.42, maxval=0.42))
ws_util.warm_start(
self.get_temp_dir(),
var_name_to_vocab_info={
ws_util._infer_var_name(cols_to_vars[emb_vocab_column]):
vocab_info
})
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# emb_vocab_column should be correctly warm-started after vocab
# remapping. Missing values are filled in with the EmbeddingColumn's
# initializer.
self._assert_cols_to_vars(
cols_to_vars, {
emb_vocab_column: [
np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
]
}, sess)
def testWarmStartEmbeddingColumnLinearModel(self):
# Create old and new vocabs for embedding column "sc_vocab".
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
"new_vocab")
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_vocab_embedding/embedding_weights",
initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
variable_scope.get_variable(
"linear_model/sc_vocab_embedding/weights",
initializer=[[0.69], [0.71]])
self._write_checkpoint(sess)
def _partitioner(shape, dtype): # pylint:disable=unused-argument
# Partition each var into 2 equal slices.
partitions = [1] * len(shape)
partitions[0] = min(2, shape.dims[0].value)
return partitions
# Create feature columns.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
emb_vocab = fc.embedding_column(
categorical_column=sc_vocab,
dimension=2)
all_deep_cols = [emb_vocab]
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
cols_to_vars = {}
with variable_scope.variable_scope("", partitioner=_partitioner):
# Create the variables.
fc.linear_model(
features=self._create_dummy_inputs(),
feature_columns=all_deep_cols,
cols_to_vars=cols_to_vars)
# Construct the vocab_info for the embedding weight.
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path,
# Can't use constant_initializer with load_and_remap. In practice,
# use a truncated normal initializer.
backup_initializer=init_ops.random_uniform_initializer(
minval=0.42, maxval=0.42))
ws_util.warm_start(
self.get_temp_dir(),
vars_to_warm_start=".*sc_vocab.*",
var_name_to_vocab_info={
"linear_model/sc_vocab_embedding/embedding_weights": vocab_info
})
self.evaluate(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# emb_vocab should be correctly warm-started after vocab remapping.
# Missing values are filled in with the EmbeddingColumn's initializer.
self._assert_cols_to_vars(
cols_to_vars,
{
emb_vocab: [
# linear weights part 0.
np.array([[0.69]]),
# linear weights part 1.
np.array([[0.71]]),
# embedding_weights part 0.
np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
# embedding_weights part 1.
np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
]
},
sess)
def testErrorConditions(self):
x = variable_scope.get_variable(
"x",
shape=[4, 1],
initializer=ones(),
partitioner=lambda shape, dtype: [2, 1])
# List of PartitionedVariable is invalid type when warm-starting with vocab.
self.assertRaises(TypeError, ws_util._warm_start_var_with_vocab, [x],
"/tmp", 5, "/tmp", "/tmp")
# Unused variable names raises ValueError.
with ops.Graph().as_default():
with self.cached_session() as sess:
x = variable_scope.get_variable(
"x",
shape=[4, 1],
initializer=ones(),
partitioner=lambda shape, dtype: [2, 1])
self._write_checkpoint(sess)
self.assertRaises(
ValueError,
ws_util.warm_start,
self.get_temp_dir(),
var_name_to_vocab_info={"y": ws_util.VocabInfo("", 1, 0, "")})
self.assertRaises(
ValueError,
ws_util.warm_start,
self.get_temp_dir(),
var_name_to_prev_var_name={"y": "y2"})
def testWarmStartFromObjectBasedCheckpoint(self):
prev_val = [[0.5], [1.], [1.5], [2.]]
with ops.Graph().as_default() as g:
with self.session(graph=g):
prev_var = variable_scope.get_variable(
"fruit_weights",
initializer=prev_val)
self.evaluate(variables.global_variables_initializer())
# Save object-based checkpoint.
tracking_util.Checkpoint(v=prev_var).save(
os.path.join(self.get_temp_dir(), "checkpoint"))
with ops.Graph().as_default() as g:
with self.session(graph=g):
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util.warm_start(self.get_temp_dir())
self.evaluate(variables.global_variables_initializer())
self.assertAllClose(prev_val, self.evaluate(fruit_weights))
if __name__ == "__main__":
test.main()