official/projects/pruning/tasks/image_classification.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classification task definition."""
from absl import logging
import tensorflow as tf, tf_keras
import tensorflow_model_optimization as tfmot
from official.core import task_factory
from official.projects.pruning.configs import image_classification as exp_cfg
from official.vision.modeling.backbones import mobilenet
from official.vision.modeling.layers import nn_blocks
from official.vision.tasks import image_classification
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
class ImageClassificationTask(image_classification.ImageClassificationTask):
"""A task for image classification with pruning."""
_BLOCK_LAYER_SUFFIX_MAP = {
mobilenet.Conv2DBNBlock: ('conv2d/kernel:0',),
nn_blocks.BottleneckBlock: (
'conv2d/kernel:0',
'conv2d_1/kernel:0',
'conv2d_2/kernel:0',
'conv2d_3/kernel:0',
),
nn_blocks.InvertedBottleneckBlock: (
'conv2d/kernel:0',
'conv2d_1/kernel:0',
'conv2d_2/kernel:0',
'conv2d_3/kernel:0',
'depthwise_conv2d/depthwise_kernel:0',
),
nn_blocks.ResidualBlock: (
'conv2d/kernel:0',
'conv2d_1/kernel:0',
'conv2d_2/kernel:0',
),
}
def build_model(self) -> tf_keras.Model:
"""Builds classification model with pruning."""
model = super(ImageClassificationTask, self).build_model()
if self.task_config.pruning is None:
return model
pruning_cfg = self.task_config.pruning
prunable_model = tf_keras.models.clone_model(
model,
clone_function=self._make_block_prunable,
)
original_checkpoint = pruning_cfg.pretrained_original_checkpoint
if original_checkpoint is not None:
ckpt = tf.train.Checkpoint(model=prunable_model, **model.checkpoint_items)
status = ckpt.read(original_checkpoint)
status.expect_partial().assert_existing_objects_matched()
pruning_params = {}
if pruning_cfg.sparsity_m_by_n is not None:
pruning_params['sparsity_m_by_n'] = pruning_cfg.sparsity_m_by_n
if pruning_cfg.pruning_schedule == 'PolynomialDecay':
pruning_params['pruning_schedule'] = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=pruning_cfg.initial_sparsity,
final_sparsity=pruning_cfg.final_sparsity,
begin_step=pruning_cfg.begin_step,
end_step=pruning_cfg.end_step,
frequency=pruning_cfg.frequency)
elif pruning_cfg.pruning_schedule == 'ConstantSparsity':
pruning_params[
'pruning_schedule'] = tfmot.sparsity.keras.ConstantSparsity(
target_sparsity=pruning_cfg.final_sparsity,
begin_step=pruning_cfg.begin_step,
frequency=pruning_cfg.frequency)
else:
raise NotImplementedError(
'Only PolynomialDecay and ConstantSparsity are currently supported. Not support %s'
% pruning_cfg.pruning_schedule)
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
prunable_model, **pruning_params)
# Print out prunable weights for debugging purpose.
prunable_layers = collect_prunable_layers(pruned_model)
pruned_weights = []
for layer in prunable_layers:
pruned_weights += [weight.name for weight, _, _ in layer.pruning_vars]
unpruned_weights = [
weight.name
for weight in pruned_model.weights
if weight.name not in pruned_weights
]
logging.info(
'%d / %d weights are pruned.\nPruned weights: [ \n%s \n],\n'
'Unpruned weights: [ \n%s \n],',
len(pruned_weights), len(model.weights), ', '.join(pruned_weights),
', '.join(unpruned_weights))
return pruned_model
def _make_block_prunable(
self, layer: tf_keras.layers.Layer) -> tf_keras.layers.Layer:
if isinstance(layer, tf_keras.Model):
return tf_keras.models.clone_model(
layer, input_tensors=None, clone_function=self._make_block_prunable)
if layer.__class__ not in self._BLOCK_LAYER_SUFFIX_MAP:
return layer
prunable_weights = []
for layer_suffix in self._BLOCK_LAYER_SUFFIX_MAP[layer.__class__]:
for weight in layer.weights:
if weight.name.endswith(layer_suffix):
prunable_weights.append(weight)
def get_prunable_weights():
return prunable_weights
layer.get_prunable_weights = get_prunable_weights
return layer
def collect_prunable_layers(model):
"""Recursively collect the prunable layers in the model."""
prunable_layers = []
for layer in model.layers:
if isinstance(layer, tf_keras.Model):
prunable_layers += collect_prunable_layers(layer)
if layer.__class__.__name__ == 'PruneLowMagnitude':
prunable_layers.append(layer)
return prunable_layers