research/seq_flow_lite/utils/tflite_utils.py
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils to convert to a TFLite model."""
import tensorflow.compat.v1 as tf
def _dump_graph_in_text_format(filename, graph_def):
"""Dump a tensorflow graph in readable text format."""
f = open(filename, 'w')
for node in graph_def.node:
f.write('Node: %s (%s)\n' % (node.name, node.op))
for input_name in node.input:
f.write('\tInput: %s\n' % input_name)
f.close()
def get_mean_stddev_values(min_value_of_features, max_value_of_features):
"""Gets Mean and Stddev values for given min/max float values."""
quant_min = 0
quant_max = 255
min_global = min_value_of_features
max_global = max_value_of_features
quant_min_float = float(quant_min)
quant_max_float = float(quant_max)
nudged_scale = (max_global - min_global) / (quant_max_float - quant_min_float)
zero_point_from_min = quant_min_float - min_global / nudged_scale
if zero_point_from_min < quant_min_float:
nudged_zero_point = int(quant_min)
elif zero_point_from_min > quant_max_float:
nudged_zero_point = int(quant_max)
else:
nudged_zero_point = int(round(zero_point_from_min))
nudged_min = (quant_min_float - nudged_zero_point) * (nudged_scale)
nudged_max = (quant_max_float - nudged_zero_point) * (nudged_scale)
zero_point = (quant_min - min_global) / (max_global - min_global) * quant_max
scale = (nudged_max - nudged_min) / 255.0
mean_value = zero_point
stddev_value = 1 / scale
return mean_value, stddev_value
class InterpreterWithCustomOps(tf.lite.Interpreter):
"""Extended tf.lite.Interpreter."""
def __init__(self,
model_content,
custom_op_registerers=None,
experimental_preserve_all_tensors=False):
self._custom_op_registerers = custom_op_registerers or []
super(InterpreterWithCustomOps, self).__init__(
model_content=model_content,
experimental_preserve_all_tensors=experimental_preserve_all_tensors)
def op_details(self):
op_details = {}
try:
op_details = self._get_ops_details() # Accessing experimental method.
except AttributeError:
print('Unable to access op details')
return op_details
def op_histogram(self):
op_hist = {}
op_list = self.op_details()
for op in op_list:
if op['op_name'] in op_hist:
op_hist[op['op_name']] += 1
else:
op_hist[op['op_name']] = 1
return op_hist
def check_op_histogram(self, expected):
passed = True
for k, v in self.op_histogram().items():
if k not in expected:
print('Unexpected key {} found {} times.'.format(k, v))
passed = False
continue
elif expected[k] != v:
print('Expected {} counts of key {} found {}.'.format(
expected[k], k, v))
passed = False
del expected[k]
for k, v in expected.items():
print('Missing expected key {} value {}.'.format(k, v))
passed = False
return passed
def set_output_quantized_for_custom_ops(graph_def, use_mlir=True):
"""Set output types/quantized flag for custom/unsupported ops."""
quantized_custom_ops = {
'SequenceStringProjection': [tf.float32.as_datatype_enum],
'SequenceStringProjectionV2': [tf.float32.as_datatype_enum],
'PoolingOp': [tf.float32.as_datatype_enum],
'ExpectedValueOp': [tf.float32.as_datatype_enum],
'LayerNorm': [tf.float32.as_datatype_enum],
'UniformCausalAttn': [tf.float32.as_datatype_enum],
'DynamicUniformCausalAttn': [tf.float32.as_datatype_enum],
'RnnDecoderReadState': [tf.float32.as_datatype_enum],
'RnnDecoderWriteState': [tf.float32.as_datatype_enum],
}
custom_op_renames = {
'SequenceStringProjection': 'SEQUENCE_STRING_PROJECTION',
'SequenceStringProjectionV2': 'SEQUENCE_STRING_PROJECTION_V2',
}
for node in graph_def.node:
if node.op in quantized_custom_ops:
if use_mlir:
node.attr['_tfl_quant_trait'].s = str.encode('fully_quantizable')
else:
node.attr['_output_quantized'].b = True
node.attr['_output_types'].list.type[:] = quantized_custom_ops[node.op]
if not use_mlir and node.op in custom_op_renames:
node.op = custom_op_renames[node.op]
def generate_tflite(session,
graph,
input_tensors,
output_tensors,
use_mlir=True):
"""Generate TFLite model from a session, graph and input/output tensors."""
output_nodes = [tensor.name.split(':')[0] for tensor in output_tensors]
graph_def = tf.graph_util.convert_variables_to_constants(
session, graph.as_graph_def(), output_nodes)
set_output_quantized_for_custom_ops(graph_def, use_mlir)
converter = tf.lite.TFLiteConverter(graph_def, input_tensors, output_tensors)
converter.inference_type = tf.uint8
converter.default_ranges_stats = (127.5, 127.5)
converter.quantized_input_stats = {
tensor.op.name: (127.5, 127.5) for tensor in input_tensors
}
converter.allow_custom_ops = True
converter.experimental_new_converter = use_mlir
return converter.convert()