IBM/pytorchpipe

View on GitHub
ptp/utils/data_streams_parallel.py

Summary

Maintainability
D
2 days
Test Coverage
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

from torch.nn.parallel._functions import Scatter, Gather
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply


from ptp.data_types.data_streams import DataStreams


def data_streams_scatter(inputs, target_gpus, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, None, dim, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        if isinstance(obj, DataStreams) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))    
        # Return "unscattered" object for all GPUs.
        # This seems to be the cause of the issue for SentenceEmbeddings!
        # TODO: further investigate.
        return [obj for _ in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def data_streams_scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
    r"""Scatter with support for kwargs dictionary"""
    inputs = data_streams_scatter(inputs, target_gpus, dim) if inputs else []
    kwargs = data_streams_scatter(kwargs, target_gpus, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs


def data_streams_gather(outputs, target_device, dim=0):
    r"""
    Gathers tensors from different GPUs on a specified device
      (-1 means the CPU).
    """
    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, torch.Tensor):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None

        if isinstance(out, DataStreams):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            return type(out)(((k, gather_map([d[k] for d in outputs]))
                              for k in out))

        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            return type(out)(((k, gather_map([d[k] for d in outputs]))
                              for k in out))

        return type(out)(map(gather_map, zip(*outputs)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
    try:
        return gather_map(outputs)
    finally:
        gather_map = None


class DataStreamsParallel(torch.nn.DataParallel):
    """
    Modified DataParallel wrapper enabling operation on DataStreamss.
    
    .. warning:
        Compatible with PyTorch v1.0.1 !!

    """
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataStreamsParallel, self).__init__(module, device_ids, output_device, dim)

    def forward(self, *inputs, **kwargs):
        """
        Performs "parallelized forward" pass by scattering batch into several batches, distributing models on different GPUs, performing parallel pass and gathering results into a single (returned) DataStreams.

        ..warning:
            As the "external" operations are changing inputs to tuple of DataStreamss, extension of main DataStreams must be done "outside" of this method.
        """

        # Simple processing.
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        # One device - also easy.
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])

        # Preprocessing: get only the inputs important for to the wrapped model (optimization).
        inputs_tuple = []
        for i, item in enumerate(inputs):
            input_dict = DataStreams({key: value for key,value in item.items() if key in self.module.input_data_definitions().keys()})
            inputs_tuple.append(input_dict)
        # Convert to tuple.
        inputs_tuple = tuple(inputs_tuple)

        # Scatter inputs into several tuples.
        inputs_tuple, kwargs = self.scatter(inputs_tuple, kwargs, self.device_ids)

        # Create replicas of the module on all devices.
        replicas = self.replicate(self.module, self.device_ids[:len(inputs_tuple)])

        # Pass scattered inputs throught those replicas.
        self.parallel_apply(replicas, inputs_tuple, kwargs)

        # Gather tuple. This cannot be done "in place"!
        gathered_tuple = self.gather(inputs_tuple, self.output_device)

        # Return 0-th tuple, i.e. a single DataStreams on device 0.
        return gathered_tuple[0]


    def replicate(self, module, device_ids):
        return replicate(module, device_ids)

    def scatter(self, inputs, kwargs, device_ids):
        return data_streams_scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def parallel_apply(self, replicas, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

    def gather(self, outputs, output_device):
        return data_streams_gather(outputs, output_device, dim=self.dim)

    def add_statistics(self, stat_col):
        """
        Adds statistics for the wrapped model.

        :param stat_col: ``StatisticsCollector``.
        """
        self.module.add_statistics(stat_col)


    def collect_statistics(self, stat_col, data_streams):
        """
        Collects statistics for the wrapped model.

        :param stat_col: :py:class:`ptp.utils.StatisticsCollector`.

        :param data_streams: ``DataStreams`` containing inputs, targets etc.
        :type data_streams: :py:class:`ptp.data_types.DataStreams`
        """
        self.module.collect_statistics(stat_col, data_streams)


    def add_aggregators(self, stat_agg):
        """
        Aggregates statistics for the wrapped model.

        :param stat_agg: ``StatisticsAggregator``.
        """
        self.module.add_aggregators(stat_agg)


    def aggregate_statistics(self, stat_col, stat_agg):
        """
        Aggregates statistics for the wrapped model.

        :param stat_col: ``StatisticsCollector``

        :param stat_agg: ``StatisticsAggregator``
        """
        self.module.aggregate_statistics(stat_col, stat_agg)