LiberTEM/LiberTEM

View on GitHub
src/libertem/io/dataset/dask.py

Summary

Maintainability
A
2 hrs
Test Coverage
import logging
import itertools
import numpy as np
import dask.array as da

from libertem.common import Shape, Slice
from libertem.io.dataset.base import (
    DataSet, DataSetMeta, BasePartition, File, FileSet, DataSetException
)
from libertem.io.dataset.base.backend_mmap import MMapFile, MMapBackend, MMapBackendImpl
from libertem.common.messageconverter import MessageConverter

log = logging.getLogger(__name__)


class DaskDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/DaskDatasetParams.schema.json",
        "title": "DaskDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "DASK"},
            "sig_dims": {"type": "number", "minimum": 1},
            "preserve_dimensions": {"type": "boolean"},
            "min_size": {"type": "number", "minimum": 1},
        },
        "required": ["type"],
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["sig_dims", "preserve_dimensions", "min_size"]
            if k in raw_data
        }
        return data


class FakeDaskMMapFile(MMapFile):
    """
    Implementing the same interface as MMapFile, without filesystem backing
    """
    def open(self):
        # scheduler='threads' ensures that upstream computation for this array
        # chunk happens completely on this worker and not elsewhere
        self._arr = self.desc._array.compute(scheduler='threads')
        # need to be aware that Dask can create Fortran-ordered arrays
        # when .compute is called, which can lead to downstream issues when
        # np.frombuffer is called on self._mmap in the backend. Currently it seems
        # like np.frombuffer cannot handle Fortran ordering and throws a ValueError
        self._mmap = self._arr
        return self

    def close(self):
        del self._arr
        del self._mmap


class DaskBackend(MMapBackend):
    def get_impl(self):
        return DaskBackendImpl()


class DaskBackendImpl(MMapBackendImpl):
    FILE_CLS = FakeDaskMMapFile


class DaskDataSet(DataSet):
    """
    .. versionadded:: 0.9.0

    Wraps a Dask.array.array such that it can be processed by LiberTEM.
    Partitions are created to be aligned with the array chunking. When
    the array chunking is not compatible with LiberTEM the wrapper
    merges chunks until compatibility is achieved.

    The best-case scenario is for the original array to be chunked in
    the leftmost navigation dimension. If instead another navigation
    dimension is chunked then the user can set `preserve_dimension=False`
    to re-order the navigation shape to achieve better chunking for LiberTEM.
    If more than one navigation dimension is chunked, the class will do
    its best to merge chunks without creating partitions which are too large.

    LiberTEM requires that a partition contains only whole signal frames,
    so any signal dimension chunking is immediately merged by this class.

    This wrapper is most useful when the Dask array was created using
    lazy I/O via `dask.delayed`, or via `dask.array` operations.
    The major assumption is that the chunks in the array can each be
    individually evaluated without having to read or compute more data
    than the chunk itself contains. If this is not the case then this class
    could perform very poorly due to read amplification, or even crash the Dask
    workers.

    As the class performs rechunking using a merge-only strategy it will never
    split chunks which were present in the original array. If the array
    is originally very lightly chunked, then the corresponding LiberTEM partitions
    will be very large. In addition, overly-chunked arrays (for example one chunk per
    frame) can incurr excessive Dask task graph overheads and should be avoided
    where possible.

    Parameters
    ----------

    dask_array: dask.array.array
        A Dask array

    sig_dims: int
        Number of dimensions in dask_array.shape counting from the right
        to treat as signal dimensions

    preserve_dimensions: bool, optional
        If False, allow optimization of the dask_arry chunking by
        re-ordering the nav_shape to put the most chunked dimensions first.
        This can help when more than one nav dimension is chunked.

    min_size: float, optional
        The minimum partition size in bytes if the array chunking allows
        an order-preserving merge strategy. The default min_size is 128 MiB.

    io_backend: bool, optional
        For compatibility, accept an unused io_backend argument.

    Example
    --------

    >>> import dask.array as da
    >>>
    >>> d_arr = da.ones((4, 4, 64, 64), chunks=(2, -1, -1, -1))
    >>> ds = ctx.load('dask', dask_array=d_arr, sig_dims=2)

    Will create a dataset with 5 partitions split along the zeroth dimension.
    """
    # TODO add mechanism to re-order the dimensions of results automatically
    # if preserve_dimensions is set to False
    def __init__(self, dask_array, *, sig_dims, preserve_dimensions=True,
                 min_size=None, io_backend=None):
        super().__init__(io_backend=io_backend)
        if io_backend is not None:
            raise DataSetException("DaskDataSet currently doesn't support alternative I/O backends")

        self._check_array(dask_array, sig_dims)
        self._array = dask_array
        self._sig_dims = sig_dims
        self._sig_shape = self._array.shape[-self._sig_dims:]
        self._dtype = self._array.dtype
        self._preserve_dimension = preserve_dimensions
        self._min_size = min_size
        if self._min_size is None:
            # TODO add a method to determine a sensible partition byte-size
            self._min_size = self._default_min_size

    @property
    def array(self):
        return self._array

    def get_io_backend(self):
        return DaskBackend()

    def initialize(self, executor):
        self._array = self._adapt_chunking(self._array, self._sig_dims)
        self._nav_shape = self._array.shape[:-self._sig_dims]

        self._nav_shape_product = int(np.prod(self._nav_shape))
        self._image_count = self._nav_shape_product
        shape = Shape(self._nav_shape + self._sig_shape, sig_dims=self._sig_dims)
        self._meta = DataSetMeta(
            shape=shape,
            raw_dtype=np.dtype(self._dtype),
            sync_offset=0,
            image_count=self._nav_shape_product,
        )
        return self

    @property
    def dtype(self):
        return self._meta.raw_dtype

    @property
    def shape(self):
        return self._meta.shape

    @classmethod
    def get_msg_converter(cls):
        return DaskDatasetParams

    @property
    def _default_min_size(self):
        """
        Default minimum chunk size if not supplied at init
        """
        return 128 * (2**20)  # MB

    def _chunk_slices(self, array):
        chunks = array.chunks
        boundaries = tuple(tuple(self.chunks_to_slices(chunk_lengths)) for chunk_lengths in chunks)
        return tuple(itertools.product(*boundaries))

    def _adapt_chunking(self, array, sig_dims):
        n_dimension = array.ndim
        # Handle chunked signal dimensions by merging just in case
        sig_dim_idxs = [*range(n_dimension)[-sig_dims:]]
        if any([len(array.chunks[c]) > 1 for c in sig_dim_idxs]):
            original_n_chunks = [len(c) for c in array.chunks]
            array = array.rechunk({idx: -1 for idx in sig_dim_idxs})
            log.warning('Merging sig dim chunks as LiberTEM does not '
                        'support paritioning along the sig axes. '
                        f'Original n_blocks: {original_n_chunks}. '
                        f'New n_blocks: {[len(c) for c in array.chunks]}.')
        # Warn if there is no nav_dim chunking
        n_nav_chunks = [len(dim_chunking) for dim_chunking in array.chunks[:-sig_dims]]
        if set(n_nav_chunks) == {1}:
            log.warning('Dask array is not chunked in navigation dimensions, '
                        'cannot split into nav-partitions without loading the '
                        'whole dataset on each worker. '
                        f'Array shape: {array.shape}. '
                        f'Chunking: {array.chunks}. '
                        f'array size {array.nbytes / 1e6} MiB.')
            # If we are here there is nothing else to do.
            return array
        # Orient the nav dimensions so that the zeroth dimension is
        # the most chunked, this obviously changes the dataset nav_shape !
        if not self._preserve_dimension:
            n_nav_chunks = [len(dim_chunking) for dim_chunking in array.chunks[:-sig_dims]]
            nav_sort_order = np.argsort(n_nav_chunks)[::-1].tolist()
            sort_order = nav_sort_order + sig_dim_idxs
            if not np.equal(sort_order, np.arange(n_dimension)).all():
                original_shape = array.shape
                original_n_chunks = [len(c) for c in array.chunks]
                array = da.transpose(array, axes=sort_order)
                log.warning('Re-ordered nav_dimensions to improve partitioning, '
                            'create the dataset with preserve_dimensions=True '
                            'to suppress this behaviour. '
                            f'Original shape: {original_shape} with '
                            f'n_blocks: {original_n_chunks}. '
                            f'New shape: {array.shape} with '
                            f'n_blocks: {[len(c) for c in array.chunks]}.')
        # Handle chunked nav_dimensions
        # We can allow nav_dimensions to be fully chunked (one chunk per element)
        # up-to-but-not-including the first non-fully chunked dimension. After this point
        # we must merge/rechunk all subsequent nav dimensions to ensure continuity
        # of frame indexes in a flattened nav dimension. This should be removed
        # when if we allow non-contiguous flat_idx Partitions
        nav_rechunk_dict = {}
        for dim_idx, dim_chunking in enumerate(array.chunks[:-sig_dims]):
            if set(dim_chunking) == {1}:
                continue
            else:
                merge_dimensions = [*range(dim_idx + 1, n_dimension - sig_dims)]
                for merge_i in merge_dimensions:
                    if len(array.chunks[merge_i]) > 1:
                        nav_rechunk_dict[merge_i] = -1
        if nav_rechunk_dict:
            original_n_chunks = [len(c) for c in array.chunks]
            array = array.rechunk(nav_rechunk_dict)
            log.warning('Merging nav dimension chunks according to scheme '
                        f'{nav_rechunk_dict} as we cannot maintain continuity '
                        'of frame indexing in the flattened navigation dimension. '
                        f'Original n_blocks: {original_n_chunks}. '
                        f'New n_blocks: {[len(c) for c in array.chunks]}.')
        # Merge remaining chunks maintaining C-ordering until we reach a target chunk sizes
        # or a minmum number of partitions corresponding to the number of workers
        new_chunking, min_size, max_size = merge_until_target(array, self._min_size)
        if new_chunking != array.chunks:
            original_n_chunks = [len(c) for c in array.chunks]
            chunksizes = get_chunksizes(array)
            orig_min, orig_max = chunksizes.min(), chunksizes.max()
            array = array.rechunk(new_chunking)
            log.warning('Applying re-chunking to increase minimum partition size. '
                        f'n_blocks: {original_n_chunks} => {[len(c) for c in array.chunks]}. '
                        f'Min chunk size {orig_min / 1e6:.1f} => {min_size / 1e6:.1f} MiB , '
                        f'Max chunk size {orig_max / 1e6:.1f} => {max_size / 1e6:.1f} MiB.')
        return array

    def _check_array(self, array, sig_dims):
        if not isinstance(array, da.Array):
            raise DataSetException('Expected a Dask array as input, recieved '
                                   f'{type(array)}.')
        if not isinstance(sig_dims, int) and sig_dims >= 0:
            raise DataSetException('Expected non-negative integer sig_dims,'
                                   f'recieved {sig_dims}.')
        if any([np.isnan(c).any() for c in array.shape])\
           or any([np.isnan(c).any() for c in array.chunks]):
            raise DataSetException('Dask array has an unknown shape or chunk sizes '
                                   'so cannot be interpreted as a LiberTEM partitions. '
                                   'Run array.compute_compute_chunk_sizes() '
                                   'before passing to DaskDataSet, though this '
                                   'may be performance-intensive. Chunking: '
                                   f'{array.chunks}, Shape {array.shape}')
        if sig_dims >= array.ndim:
            raise DataSetException(f'Number of sig_dims {sig_dims} not compatible '
                                   f'with number of array dims {array.ndim}, '
                                   'must be able to create partitions along nav '
                                   'dimensions.')
        return True

    def check_valid(self):
        return self._check_array(self._array, self._sig_dims)

    def get_num_partitions(self):
        return len([*itertools.product(*self._array.chunks)])

    @staticmethod
    def chunks_to_slices(chunk_lengths):
        prior = 0
        for c in chunk_lengths:
            newc = c + prior
            yield slice(prior, newc)
            prior = newc

    @staticmethod
    def slices_to_shape(slices):
        return tuple(s.stop - s.start for s in slices)

    @staticmethod
    def slices_to_origin(slices):
        return tuple(s.start for s in slices)

    @staticmethod
    def flatten_nav(slices, nav_shape, sig_dims):
        """
        Because LiberTEM partitions are set up with a flat nav dimension
        we must flatten the Dask array slices. This is ensured to be possible
        by earlier calls to _adapt_chunking but should be removed if ever
        partitions are able to have >1D navigation axes.
        """
        nav_slices = slices[:-sig_dims]
        sig_slices = slices[-sig_dims:]
        start_frame = np.ravel_multi_index([s.start for s in nav_slices], nav_shape)
        end_frame = 1 + np.ravel_multi_index([s.stop - 1 for s in nav_slices], nav_shape)
        nav_slice = slice(start_frame, end_frame)
        return (nav_slice,) + sig_slices, start_frame, end_frame

    def get_slices(self):
        """
        Generates the LiberTEM slices which correspond to the chunks
        in the Dask array backing the dataset

        Generates both the flat_nav slice for creating the LiberTEM partition
        and also the full_slices used to index into the dask array
        """
        chunk_slices = self._chunk_slices(self._array)

        for full_slices in chunk_slices:
            flat_slices, start_frame, end_frame = self.flatten_nav(full_slices, self._nav_shape,
                                                                   self._sig_dims)
            flat_slice = Slice(origin=self.slices_to_origin(flat_slices),
                               shape=Shape(self.slices_to_shape(flat_slices),
                                           sig_dims=self._sig_dims))
            yield full_slices, flat_slice, start_frame, end_frame

    def _get_fileset(self):
        """
        The fileset is set up to have one 'file' per partition
        which corresponds to one 'file' per Dask chunk
        """
        partitions = []
        for full_slices, _, start, stop in self.get_slices():
            partitions.append(DaskFile(
                array_chunk=self._array[full_slices],
                path=None,
                start_idx=start,
                end_idx=stop,
                native_dtype=self._dtype,
                sig_shape=self.shape.sig
            ))
        return DaskFileSet(partitions)

    def get_partitions(self):
        """
        Partitions contain a reference to the whole array and the whole
        fileset, but the part_slice and start_frame/num_frames provided mean
        that the subsequent call to get_read_ranges() means only one 'file'
        is read/.compute(), and this corresponds to the partition *exactly*
        """
        fileset = self._get_fileset()
        for _, part_slice, start, stop in self.get_slices():
            yield DaskPartition(
                self._array,
                meta=self._meta,
                fileset=fileset,
                partition_slice=part_slice,
                start_frame=start,
                num_frames=stop - start,
                io_backend=self.get_io_backend(),
                decoder=self.get_decoder()
            )

    def __repr__(self):
        return (f"<DaskDataSet of {self.dtype} shape={self.shape}, "
                f"n_blocks={[len(c) for c in self._array.chunks]}>")


class DaskFile(File):
    def __init__(self, *args, array_chunk=None, **kwargs):
        """
        Upon creation, the dask array has been sliced to give
        only one chunk corresponding to a LiberTEM partition
        """
        self._array = array_chunk
        super().__init__(*args, **kwargs)


class DaskFileSet(FileSet):
    pass


class DaskPartition(BasePartition):
    def __init__(self, dask_array, *args, **kwargs):
        self._array = dask_array
        super().__init__(*args, **kwargs)


def array_mult(*arrays, dtype=np.float64):
    num_arrays = len(arrays)
    if num_arrays == 1:
        return np.asarray(arrays[0]).astype(dtype)
    elif num_arrays == 2:
        return np.multiply.outer(*arrays).astype(dtype)
    elif num_arrays > 2:
        return np.multiply.outer(arrays[0], array_mult(*arrays[1:]))
    else:
        raise RuntimeError('Unexpected number of arrays')


def get_last_chunked_dim(chunking):
    n_chunks = [len(c) for c in chunking]
    chunked_dims = [idx for idx, el in enumerate(n_chunks) if el > 1]
    try:
        return chunked_dims[-1]
    except IndexError:
        return -1


def get_chunksizes(array, chunking=None):
    if chunking is None:
        chunking = array.chunks
    shape = array.shape
    el_bytes = array.dtype.itemsize
    last_chunked = get_last_chunked_dim(chunking)
    if last_chunked < 0:
        return np.asarray(array.nbytes)
    static_size = np.prod(shape[last_chunked + 1:], dtype=np.float64) * el_bytes
    chunksizes = array_mult(*chunking[:last_chunked + 1]) * static_size
    return chunksizes


def modify_chunking(chunking, dim, merge_idxs):
    chunk_dim = chunking[dim]
    merge_idxs = tuple(sorted(merge_idxs))
    before = chunk_dim[:merge_idxs[0]]
    after = chunk_dim[merge_idxs[1] + 1:]
    merged_dim = (sum(chunk_dim[merge_idxs[0]:merge_idxs[1] + 1]),)
    new_chunk_dim = tuple(before) + merged_dim + tuple(after)
    chunking = chunking[:dim] + (new_chunk_dim,) + chunking[dim + 1:]
    return chunking


def findall(sequence, val):
    return [idx for idx, e in enumerate(sequence) if e == val]


def neighbour_idxs(sequence, idx):
    max_idx = len(sequence) - 1
    if idx > 0 and idx < max_idx:
        return (idx - 1, idx + 1)
    elif idx == 0:
        return (None, idx + 1)
    elif idx == max_idx:
        return (idx - 1, None)
    else:
        raise


def min_neighbour(sequence, idx):
    left, right = neighbour_idxs(sequence, idx)
    if left is None:
        return right
    elif right is None:
        return left
    else:
        return min([left, right], key=lambda x: sequence[x])


def min_with_min_neighbor(sequence):
    min_val = min(sequence)
    occurences = findall(sequence, min_val)
    min_idx_pairs = [(idx, min_neighbour(sequence, idx)) for idx in occurences]
    pair = [sum(get_values(sequence, idxs)) for idxs in min_idx_pairs]
    min_pair = min(pair)
    min_pair_occurences = findall(pair, min_pair)
    return min_idx_pairs[min_pair_occurences[-1]]  # breaking ties from right


def get_values(sequence, idxs):
    return [sequence[idx] for idx in idxs]


def merge_until_target(array, target, min_chunks=0):
    chunking = array.chunks
    if array.nbytes < target:
        # A really small dataset, better to treat as one partition
        chunking = tuple((s,) for s in array.shape)
    chunksizes = get_chunksizes(array)
    while chunksizes.size > min_chunks and chunksizes.min() < target:
        if (chunksizes < 0).any():
            log.warn('Overflow in chunksize calculation, will be clipped!')
        chunksizes = np.clip(chunksizes, 0., np.inf)
        last_chunked_dim = get_last_chunked_dim(chunking)
        if last_chunked_dim < 0:
            # No chunking, by definition complete
            break
        last_chunking = chunking[last_chunked_dim]
        to_merge = min_with_min_neighbor(last_chunking)
        chunking = modify_chunking(chunking, last_chunked_dim, to_merge)
        chunksizes = get_chunksizes(array, chunking=chunking)
    return chunking, chunksizes.min(), chunksizes.max()