ptp/components/models/vision/generic_image_encoder.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__author__ = "Tomasz Kornuta"
import torch
import torchvision.models as models
from ptp.configuration.config_parsing import get_value_from_dictionary
from ptp.configuration.configuration_error import ConfigurationError
from ptp.components.models.model import Model
from ptp.data_types.data_definition import DataDefinition
class GenericImageEncoder(Model):
"""
Class
"""
def __init__(self, name, config):
"""
Initializes the ``LeNet5`` model, creates the required layers.
:param name: Name of the model (taken from the configuration file).
:param config: Parameters read from configuration file.
:type config: ``ptp.configuration.ConfigInterface``
"""
super(GenericImageEncoder, self).__init__(name, GenericImageEncoder, config)
# Get key mappings.
self.key_inputs = self.stream_keys["inputs"]
self.key_outputs = self.stream_keys["outputs"]
# Get operation modes.
self.return_feature_maps = self.config["return_feature_maps"]
pretrained = self.config["pretrained"]
# Get model type from configuration.
self.model_type = get_value_from_dictionary("model_type", self.config, "vgg16 | densenet121 | resnet152 | resnet50".split(" | "))
if(self.model_type == 'vgg16'):
# Get VGG16
self.model = models.vgg16(pretrained=pretrained)
if self.return_feature_maps:
# Use only the "feature encoder".
self.model = self.model.features
# Height of the returned features tensor (SET)
self.feature_maps_height = 7
self.globals["feature_maps_height"] = self.feature_maps_height
# Width of the returned features tensor (SET)
self.feature_maps_width = 7
self.globals["feature_maps_width"] = self.feature_maps_width
# Depth of the returned features tensor (SET)
self.feature_maps_depth = 512
self.globals["feature_maps_depth"] = self.feature_maps_depth
else:
# Use the whole model, but cut/reshape only the last layer.
self.output_size = self.globals["output_size"]
# "Replace" the last layer.
self.model.classifier._modules['6'] = torch.nn.Linear(4096, self.output_size)
elif(self.model_type == 'densenet121'):
# Get densenet121
self.model = models.densenet121(pretrained=pretrained)
if self.return_feature_maps:
raise ConfigurationError("'densenet121' doesn't support 'return_feature_maps' mode (yet)")
# Use the whole model, but cut/reshape only the last layer.
self.output_size = self.globals["output_size"]
self.model.classifier = torch.nn.Linear(1024, self.output_size)
elif(self.model_type == 'resnet152'):
# Get resnet152
self.model = models.resnet152(pretrained=pretrained)
if self.return_feature_maps:
# Get all modules exluding last (avgpool) and (fc)
modules=list(self.model.children())[:-2]
self.model=torch.nn.Sequential(*modules)
# Height of the returned features tensor (SET)
self.feature_maps_height = 7
self.globals["feature_maps_height"] = self.feature_maps_height
# Width of the returned features tensor (SET)
self.feature_maps_width = 7
self.globals["feature_maps_width"] = self.feature_maps_width
# Depth of the returned features tensor (SET)
self.feature_maps_depth = 2048
self.globals["feature_maps_depth"] = self.feature_maps_depth
else:
# Use the whole model, but cut/reshape only the last layer.
self.output_size = self.globals["output_size"]
self.model.fc = torch.nn.Linear(2048, self.output_size)
elif(self.model_type == 'resnet50'):
# Get resnet50
self.model = models.resnet50(pretrained=pretrained)
if self.return_feature_maps:
# Get all modules exluding last (avgpool) and (fc)
modules=list(self.model.children())[:-2]
self.model=torch.nn.Sequential(*modules)
# Height of the returned features tensor (SET)
self.feature_maps_height = 7
self.globals["feature_maps_height"] = self.feature_maps_height
# Width of the returned features tensor (SET)
self.feature_maps_width = 7
self.globals["feature_maps_width"] = self.feature_maps_width
# Depth of the returned features tensor (SET)
self.feature_maps_depth = 2048
self.globals["feature_maps_depth"] = self.feature_maps_depth
else:
# Use the whole model, but cut/reshape only the last layer.
self.output_size = self.globals["output_size"]
self.model.fc = torch.nn.Linear(2048, self.output_size)
def input_data_definitions(self):
"""
Function returns a dictionary with definitions of input data that are required by the component.
:return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
"""
return {
self.key_inputs: DataDefinition([-1, 3, 224, 224], [torch.Tensor], "Batch of images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE WIDTH]"),
}
def output_data_definitions(self):
"""
Function returns a dictionary with definitions of output data produced the component.
:return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
"""
if self.return_feature_maps:
return {
self.key_outputs: DataDefinition([-1, self.feature_maps_depth, self.feature_maps_height, self.feature_maps_width], [torch.Tensor], "Batch of feature maps [BATCH_SIZE x FEAT_DEPTH x FEAT_HEIGHT x FEAT_WIDTH]")
}
else:
return {
self.key_outputs: DataDefinition([-1, self.output_size], [torch.Tensor], "Batch of outputs, each represented as probability distribution over classes [BATCH_SIZE x PREDICTION_SIZE]")
}
def forward(self, data_streams):
"""
Main forward pass of the model.
:param data_streams: DataStreams({'inputs', ....}), where:
- inputs: expected stream containing images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE WIDTH]
- outpus: added stream containing outputs [BATCH_SIZE x PREDICTION_SIZE]
:type data_streams: ``ptp.data_types.DataStreams``
"""
# Unpack DataStreams.
img = data_streams[self.key_inputs]
#print("{}: input shape: {}, device: {}\n".format(self.name, img.shape, img.device))
outputs = self.model(img)
# Add outputs to datadict.
data_streams.publish({self.key_outputs: outputs})