LiberTEM/LiberTEM

View on GitHub
src/libertem/io/dataset/memory.py

Summary

Maintainability
C
1 day
Test Coverage
import time
import logging
from typing import Optional, TYPE_CHECKING
from collections.abc import Sequence

import psutil
import numpy as np
from sparseconverter import (
    BACKENDS, ArrayBackend, for_backend, get_backend, SPARSE_BACKENDS, get_device_class,
    conversion_cost
)

from libertem.common.math import prod, count_nonzero, flat_nonzero
from libertem.common.messageconverter import MessageConverter
from libertem.io.dataset.base import (
    FileSet, BasePartition, DataSet, DataSetMeta, TilingScheme,
    File, MMapBackend, DataSetException
)
from libertem.io.dataset.base.backend_mmap import MMapBackendImpl, MMapFile
from libertem.common import Shape, Slice
from libertem.io.dataset.base import DataTile

if TYPE_CHECKING:
    from libertem.common.executor import JobExecutor

log = logging.getLogger(__name__)


class FakeMMapFile(MMapFile):
    """
    Implementing the same interface as MMapFile, without filesystem backing
    """
    def open(self):
        self._arr = self.desc._data
        self._mmap = self.desc._data
        return self

    def close(self):
        self._arr = None
        self._mmap = None


class MemBackend(MMapBackend):
    def get_impl(self):
        return MemBackendImpl()


class MemBackendImpl(MMapBackendImpl):
    FILE_CLS = FakeMMapFile

    def _set_readahead_hints(self, roi, fileset):
        pass

    def _get_tiles_roi(self, tiling_scheme, open_files, read_ranges, roi, sync_offset):
        ds_sig_shape = tiling_scheme.dataset_shape.sig
        sig_dims = tiling_scheme.shape.sig.dims
        slices, ranges, scheme_indices = read_ranges

        fh = open_files[0]
        memmap = fh.array.reshape((fh.desc.num_frames,) + tuple(ds_sig_shape))
        flat_roi = np.full(roi.reshape((-1,)).shape, False)
        roi_nonzero = flat_nonzero(roi)
        offset_roi = np.clip(roi_nonzero + sync_offset, 0, flat_roi.size)
        flat_roi[offset_roi] = True
        data_w_roi = memmap[flat_roi]

        for idx in range(slices.shape[0]):
            origin, shape = slices[idx]
            scheme_idx = scheme_indices[idx]

            tile_slice = Slice(
                origin=origin,
                shape=Shape(shape, sig_dims=sig_dims)
            )
            if sync_offset >= 0:
                data_slice = tile_slice.get()
            else:
                frames_to_skip = count_nonzero(roi.reshape((-1,))[:abs(sync_offset)])
                data_slice = Slice(
                    origin=(origin[0] - frames_to_skip,) + tuple(origin[-sig_dims:]),
                    shape=Shape(shape, sig_dims=sig_dims)
                )
                data_slice = data_slice.get()
            data = data_w_roi[data_slice]
            yield DataTile(
                data,
                tile_slice=tile_slice,
                scheme_idx=scheme_idx,
            )

    def get_tiles(
        self, decoder, tiling_scheme, fileset, read_ranges, roi, native_dtype, read_dtype,
        sync_offset, corrections, array_backend: ArrayBackend,
    ):
        if roi is None:
            # support arbitrary tiling in case of no roi
            with self.open_files(fileset) as open_files:
                if sync_offset >= 0:
                    for tile in self._get_tiles_straight(
                        tiling_scheme, open_files, read_ranges, sync_offset
                    ):
                        if tile.dtype != read_dtype or tile.c_contiguous is False:
                            data = tile.data.astype(read_dtype)
                        else:
                            data = tile.data
                        self.preprocess(data, tile.tile_slice, corrections)
                        data = for_backend(data, array_backend)
                        yield DataTile(data, tile.tile_slice, tile.scheme_idx)
                else:
                    for tile in self._get_tiles_w_copy(
                        tiling_scheme=tiling_scheme,
                        open_files=open_files,
                        read_ranges=read_ranges,
                        read_dtype=read_dtype,
                        native_dtype=native_dtype,
                        decoder=decoder,
                        corrections=corrections,
                    ):
                        data = for_backend(tile.data, array_backend)
                        yield DataTile(data, tile.tile_slice, tile.scheme_idx)
        else:
            with self.open_files(fileset) as open_files:
                for tile in self._get_tiles_roi(
                    tiling_scheme=tiling_scheme,
                    open_files=open_files,
                    read_ranges=read_ranges,
                    roi=roi,
                    sync_offset=sync_offset,
                ):
                    data = tile.data.astype(read_dtype)
                    self.preprocess(data, tile.tile_slice, corrections)
                    data = for_backend(data, array_backend)
                    yield DataTile(data, tile.tile_slice, tile.scheme_idx)


class MemDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/MEMDatasetParams.schema.json",
        "title": "MEMDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "MEMORY"},
            "tileshape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
            },
            "datashape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
            },
            "num_partitions": {"type": "number", "minimum": 1},
            "sig_dims": {"type": "number", "minimum": 1},
            "check_cast": {"type": "boolean"},
            "crop_frames": {"type": "boolean"},
            "tiledelay": {"type": "number"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
            "array_backend": {"type": "string"},
        },
        "required": ["type", "tileshape", "num_partitions"],
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["tileshape", "num_partitions", "sig_dims", "check_cast",
                      "crop_frames", "tiledelay", "datashape", "array_backend"]
            if k in raw_data
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = raw_data["sync_offset"]
        return data


class MemoryFile(File):
    def __init__(self, data, check_cast=True, *args, **kwargs):
        self._data = data
        self._check_cast = check_cast
        super().__init__(*args, **kwargs)

    @property
    def data(self):
        return self._data


class MemoryDataSet(DataSet):
    '''
    This dataset is constructed from a NumPy array in memory for testing
    purposes. It is not recommended for production use since it performs poorly with a
    distributed executor.

    Examples
    --------

    >>> data = np.zeros((2, 2, 64, 64), dtype=np.float32)
    >>> ds = ctx.load('memory', data=data, sig_dims=2)
    '''
    def __init__(self, tileshape=None, num_partitions=None, data=None, sig_dims=None,
                 check_cast=True, tiledelay=None, datashape=None, base_shape=None,
                 force_need_decode=False, io_backend=None,
                 nav_shape=None, sig_shape=None, sync_offset=0, array_backends=None):
        super().__init__(io_backend=io_backend)
        if io_backend is not None:
            raise ValueError("MemoryDataSet currently doesn't support alternative I/O backends")
        # For HTTP API testing purposes: Allow to create empty dataset with given shape
        if data is None:
            if datashape is None:
                raise DataSetException('MemoryDataSet can be created from either data [np.ndarray],'
                                       ' or datashape [tuple | Shape], both arguments are None')
            data = np.zeros(datashape, dtype=np.float32)
        if num_partitions is None:
            num_partitions = psutil.cpu_count(logical=False)
        # if tileshape is None:
        #     sig_shape = data.shape[-sig_dims:]
        #     target = 2**20
        #     framesize = prod(sig_shape)
        #     framecount = max(1, min(prod(data.shape[:-sig_dims]), int(target / framesize)))
        #     tileshape = (framecount, ) + sig_shape
        self.data = data
        self._base_shape = base_shape
        self.num_partitions = num_partitions

        if sig_dims is None:
            if sig_shape is not None:
                sig_dims = len(sig_shape)
            elif nav_shape is not None:
                sig_dims = len(data.shape) - len(nav_shape)
            else:
                sig_dims = 2
        else:
            if sig_shape is not None and len(sig_shape) != sig_dims:
                raise ValueError(
                    f"Length of sig_shape {sig_shape} not matching sig_dims {sig_dims}."
                )

        self.sig_dims = sig_dims
        if nav_shape is None:
            nav_shape = data.shape[:-sig_dims]
        else:
            nav_shape = tuple(nav_shape)
        if sig_shape is None:
            sig_shape = data.shape[-sig_dims:]
        else:
            sig_shape = tuple(sig_shape)
        if self.data.size % prod(sig_shape) != 0:
            raise ValueError("Data size is not a multiple of sig shape")
        self._image_count = self.data.size // prod(sig_shape)
        self._nav_shape = nav_shape
        self._sig_shape = sig_shape
        self._sync_offset = sync_offset
        self._array_backends = array_backends
        self._check_cast = check_cast
        self._tiledelay = tiledelay
        self._force_need_decode = force_need_decode
        self._nav_shape_product = int(prod(nav_shape))
        self._shape = Shape(
            nav_shape + sig_shape, sig_dims=self.sig_dims
        )
        if tileshape is None:
            self.tileshape = None
        else:
            assert len(tileshape) == self.sig_dims + 1
            self.tileshape = Shape(tileshape, sig_dims=self.sig_dims)
        self._sync_offset_info = self.get_sync_offset_info()
        self._meta = DataSetMeta(
            shape=self._shape,
            array_backends=self.array_backends,
            raw_dtype=self.data.dtype,
            sync_offset=self._sync_offset,
            image_count=self._image_count,
        )

    def initialize(self, executor):
        return self

    @classmethod
    def get_msg_converter(cls):
        return MemDatasetParams

    @classmethod
    def detect_params(cls, data: np.ndarray, executor: "JobExecutor"):
        try:
            _ = data.shape
            return {
                "parameters": {
                    "data": data,
                },
                "info": {}
            }
        except AttributeError:
            return False

    @property
    def dtype(self):
        return self.data.dtype

    @property
    def shape(self):
        return self._shape

    @property
    def array_backends(self) -> Sequence[ArrayBackend]:
        """
        All backends can be returned on request

        .. versionadded:: 0.11.0
        """
        if self._array_backends is None:
            native = get_backend(self.data)
            is_sparse = native in SPARSE_BACKENDS
            native_device_class = get_device_class(native)
            cost_metric = {}
            # Sort by tuple (cost_override, conversion_cost),
            # meaning preference for native backend, same sparsity
            # and same device class take precedence over measured conversion cost
            for backend in BACKENDS:
                cost_metric[backend] = [5, conversion_cost(native, backend)]
                if backend == native:
                    cost_metric[backend][0] -= 2
                # sparse==sparse or dense==dense
                if (backend in SPARSE_BACKENDS) == is_sparse:
                    cost_metric[backend][0] -= 2
                # Same device class
                if get_device_class(backend) == native_device_class:
                    cost_metric[backend][0] -= 1
            return tuple(sorted(BACKENDS, key=lambda k: cost_metric[k]))
        else:
            return self._array_backends

    def check_valid(self):
        return True

    def get_cache_key(self):
        return TypeError("memory data set is not cacheable yet")

    def get_num_partitions(self):
        return self.num_partitions

    def get_base_shape(self, roi):
        if self.tileshape is not None:
            return self.tileshape
        if self._base_shape is not None:
            return self._base_shape
        return super().get_base_shape(roi)

    def adjust_tileshape(
        self, tileshape: tuple[int, ...], roi: Optional[np.ndarray],
    ) -> tuple[int, ...]:
        if self.tileshape is not None:
            return tuple(self.tileshape)
        return super().adjust_tileshape(tileshape, roi)

    def need_decode(self, read_dtype, roi, corrections):
        if self._force_need_decode:
            return True
        return super().need_decode(read_dtype, roi, corrections)

    def get_io_backend(self):
        return MemBackend()

    def get_partitions(self):
        fileset = FileSet([
            MemoryFile(
                path=None,
                start_idx=0,
                end_idx=self._image_count,
                native_dtype=self.data.dtype,
                sig_shape=self.shape.sig,
                data=self.data.reshape((-1, *self.shape.sig)),
                check_cast=self._check_cast,
            )
        ])

        for part_slice, start, stop in self.get_slices():
            log.debug(
                "creating partition slice %s start %s stop %s",
                part_slice, start, stop,
            )
            yield MemPartition(
                meta=self._meta,
                partition_slice=part_slice,
                fileset=fileset,
                start_frame=start,
                num_frames=stop - start,
                tiledelay=self._tiledelay,
                tileshape=self.tileshape,
                force_need_decode=self._force_need_decode,
                io_backend=self.get_io_backend(),
                decoder=self.get_decoder(),
            )


class MemPartition(BasePartition):
    def __init__(self, tiledelay, tileshape, force_need_decode=False,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._tiledelay = tiledelay
        self._tileshape = tileshape
        self._force_tileshape = True
        self._force_need_decode = force_need_decode

    def get_io_backend(self):
        return MemBackend()

    def get_macrotile(self, *args, **kwargs):
        self._force_tileshape = False
        mt = super().get_macrotile(*args, **kwargs)
        self._force_tileshape = True
        return mt

    def get_tiles(self, *args, **kwargs):
        if args and isinstance(args[0], TilingScheme):
            tiling_scheme = args[0]
            args = args[1:]
            intent = tiling_scheme.intent
        elif 'tiling_scheme' in kwargs:
            tiling_scheme = kwargs.pop('tiling_scheme')
            intent = tiling_scheme.intent
        else:
            # In this case we require the next if-block to execute
            intent = None
        # force our own tiling_scheme, if a tileshape is given:
        if self._tileshape is not None and self._force_tileshape:
            tiling_scheme = TilingScheme.make_for_shape(
                tileshape=self._tileshape,
                dataset_shape=self.meta.shape,
                intent=intent,
            )
        tiles = super().get_tiles(tiling_scheme, *args, **kwargs)
        if self._tiledelay:
            log.debug("delayed get_tiles, tiledelay=%.3f" % self._tiledelay)
            for tile in tiles:
                yield tile
                time.sleep(self._tiledelay)
        else:
            yield from tiles