LiberTEM/LiberTEM

View on GitHub
src/libertem/io/dataset/raw_csr.py

Summary

Maintainability
B
6 hrs
Test Coverage
import typing
import os

import scipy.sparse
import numpy as np
import numba
import tomli
from sparseconverter import SCIPY_CSR, ArrayBackend, for_backend, NUMPY

from libertem.common import Slice, Shape
from libertem.common.math import prod, count_nonzero
from libertem.io.corrections.corrset import CorrectionSet
from libertem.io.dataset.base import (
    DataTile, DataSet
)
from libertem.io.dataset.base.meta import DataSetMeta
from libertem.io.dataset.base.partition import Partition
from libertem.io.dataset.base.tiling_scheme import TilingScheme
from libertem.common.messageconverter import MessageConverter
from libertem.common.numba import numba_dtypes

if typing.TYPE_CHECKING:
    from libertem.io.dataset.base.backend import IOBackend
    from libertem.common.executor import JobExecutor
    import numpy.typing as nt


def load_toml(path: str):
    with open(path, "rb") as f:
        return tomli.load(f)


class RawCSRDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/RawCSRDatasetParams.schema.json",
        "title": "RawCSRDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "RAW_CSR"},
            "path": {"type": "string"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
        },
        "required": ["type", "path"]
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["path"]
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = raw_data["sync_offset"]
        return data


class CSRDescriptor(typing.NamedTuple):
    indptr_file: str
    indptr_dtype: np.dtype
    indices_file: str
    indices_dtype: np.dtype
    data_file: str
    data_dtype: np.dtype


class CSRTriple(typing.NamedTuple):
    indptr: np.ndarray
    indices: np.ndarray
    data: np.ndarray


class RawCSRDataSet(DataSet):
    """
    Read sparse data in compressed sparse row (CSR) format from a triple of files
    that contain the index pointers, the coordinates and the values. See
    `Wikipedia article on the CSR format
    <https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)>`_
    for more information on the format.

    The necessary parameters are specified in a TOML file like this:

    .. code-block::

        [params]

        filetype = "raw_csr"
        nav_shape = [512, 512]
        sig_shape = [516, 516]

        [raw_csr]

        indptr_file = "rowind.dat"
        indptr_dtype = "<i4"

        indices_file = "coords.dat"
        indices_dtype = "<i4"

        data_file = "values.dat"
        data_dtype = "<i4"`

    Both the navigation and signal axis are flattened in the file, so that existing
    CSR libraries like scipy.sparse can be used directly by memory-mapping or
    reading the file contents.

    Parameters
    ----------

    path : str
        Path to the TOML file with file names and other parameters for the sparse dataset.
    nav_shape : Tuple[int, int], optional
        A nav_shape to apply to the dataset overriding the shape
        value read from the TOML file, by default None. This can
        be used to read a subset of the data, or reshape the
        contained data.
    sig_shape : Tuple[int, int], optional
        A sig_shape to apply to the dataset overriding the shape
        value read from the TOML file, by default None.
    sync_offset : int, optional, by default 0
        If positive, number of frames to skip from start
        If negative, number of blank frames to insert at start
    io_backend : IOBackend, optional
        The I/O backend to use, see :ref:`io backends`, by default None.

    Examples
    --------

    >>> ds = ctx.load("raw_csr", path='./path_to.toml')  # doctest: +SKIP
    """

    def __init__(
        self,
        path: str,
        nav_shape: typing.Optional[tuple[int, ...]] = None,
        sig_shape: typing.Optional[tuple[int, ...]] = None,
        sync_offset: int = 0,
        io_backend: typing.Optional["IOBackend"] = None
    ):
        if io_backend is not None:
            raise NotImplementedError()
        super().__init__(io_backend=io_backend)
        self._path = path
        if nav_shape is not None:
            nav_shape = tuple(nav_shape)
        self._nav_shape = nav_shape
        if sig_shape is not None:
            sig_shape = tuple(sig_shape)
        self._sig_shape = sig_shape
        self._sync_offset = sync_offset
        self._conf = None
        self._descriptor = None

    def initialize(self, executor: "JobExecutor") -> "DataSet":
        self._conf = conf = executor.run_function(load_toml, self._path)
        assert conf is not None
        if conf['params']['filetype'].lower() != 'raw_csr':
            raise ValueError(f"Filetype is not CSR, found {conf['params']['filetype']}")
        nav_shape = tuple(conf['params']['nav_shape'])
        sig_shape = tuple(conf['params']['sig_shape'])
        if self._nav_shape is None:
            self._nav_shape = nav_shape
        if self._sig_shape is None:
            self._sig_shape = sig_shape
        else:
            if prod(self._sig_shape) != prod(sig_shape):
                raise ValueError(f"Sig size mismatch between {self._sig_shape} and {sig_shape}.")

        shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape))
        self._descriptor = descriptor = executor.run_function(get_descriptor, self._path)
        executor.run_function(
            check,
            descriptor=descriptor,
            nav_shape=self._nav_shape,
            sig_shape=self._sig_shape
        )
        image_count = executor.run_function(get_nav_size, descriptor=descriptor)
        self._image_count = image_count
        self._nav_shape_product = int(prod(self._nav_shape))
        self._sync_offset_info = self.get_sync_offset_info()
        self._meta = DataSetMeta(
            shape=shape,
            array_backends=[SCIPY_CSR],
            image_count=image_count,
            raw_dtype=descriptor.data_dtype,
            dtype=None,
            metadata=None,
            sync_offset=self._sync_offset,
        )
        return self

    @property
    def dtype(self) -> "nt.DTypeLike":
        assert self._meta is not None
        return self._meta.raw_dtype

    @property
    def shape(self) -> Shape:
        assert self._meta is not None
        return self._meta.shape

    @property
    def array_backends(self) -> typing.Sequence[ArrayBackend]:
        assert self._meta is not None
        return self._meta.array_backends

    def get_base_shape(self, roi):
        return (1, ) + tuple(self.shape.sig)

    def get_max_io_size(self):
        # High value since referring to dense for the time being
        # Compromise between memory use during densification and
        # performance with native sparse
        return int(1024*1024*20)

    def check_valid(self) -> bool:
        return True  # TODO

    @staticmethod
    def _get_filesize(path):
        return os.stat(path).st_size

    def supports_correction(self):
        return False

    @classmethod
    def detect_params(cls, path: str, executor: "JobExecutor"):
        try:
            _, extension = os.path.splitext(path)
            has_extension = extension.lstrip('.') in cls.get_supported_extensions()
            under_size_lim = executor.run_function(cls._get_filesize, path) < 2**20  # 1 MB
            if not (has_extension or under_size_lim):
                return False
            conf = executor.run_function(load_toml, path)
            if "params" not in conf:
                return False

            if "filetype" not in conf["params"]:
                return False
            if conf["params"]["filetype"].lower() != "raw_csr":
                return False
            descriptor = executor.run_function(get_descriptor, path)
            image_count = executor.run_function(get_nav_size, descriptor=descriptor)
            return {
                "parameters": {
                    'path': path,
                    "nav_shape": conf["params"]["nav_shape"],
                    "sig_shape": conf["params"]["sig_shape"],
                    "sync_offset": 0,
                },
                "info": {
                    "image_count": image_count,
                }
            }
        except (TypeError, UnicodeDecodeError, tomli.TOMLDecodeError, OSError):
            return False

    @classmethod
    def get_msg_converter(cls) -> type["MessageConverter"]:
        return RawCSRDatasetParams

    def get_diagnostics(self):
        return [
            {"name": "data dtype", "value": str(self._descriptor.data_dtype)},
            {"name": "indptr dtype", "value": str(self._descriptor.indptr_dtype)},
            {"name": "indices dtype", "value": str(self._descriptor.indices_dtype)},
        ]  # TODO: nonzero elements?

    @classmethod
    def get_supported_extensions(cls) -> set[str]:
        return {"toml"}

    def get_cache_key(self) -> str:
        raise NotImplementedError()  # TODO

    @classmethod
    def get_supported_io_backends(cls) -> list[str]:
        return []  # FIXME: we may want to read using a backend in the future

    def adjust_tileshape(
        self,
        tileshape: tuple[int, ...],
        roi: typing.Optional[np.ndarray]
    ) -> tuple[int, ...]:
        return (tileshape[0],) + tuple(self._sig_shape)

    def need_decode(
        self,
        read_dtype: "nt.DTypeLike",
        roi: typing.Optional[np.ndarray],
        corrections: typing.Optional[CorrectionSet]
    ) -> bool:
        return super().need_decode(read_dtype, roi, corrections)

    def get_partitions(self) -> typing.Generator[Partition, None, None]:
        assert self._meta is not None
        for part_slice, start, stop in self.get_slices():
            yield RawCSRPartition(
                descriptor=self._descriptor,
                meta=self._meta,
                partition_slice=part_slice,
                start_frame=start,
                num_frames=stop - start,
                io_backend=None,
                decoder=None,
            )


class RawCSRPartition(Partition):
    def __init__(
        self,
        descriptor: CSRDescriptor,
        start_frame: int,
        num_frames: int,
        *args,
        **kwargs
    ):
        self._descriptor = descriptor
        self._start_frame = start_frame
        self._num_frames = num_frames
        self._corrections = CorrectionSet()
        self._worker_context = None
        super().__init__(*args, **kwargs)

    def set_corrections(self, corrections: typing.Optional[CorrectionSet]):
        if corrections is not None and corrections.have_corrections():
            raise NotImplementedError("corrections not implemented for raw CSR data set")

    def validate_tiling_scheme(self, tiling_scheme: TilingScheme):
        if len(tiling_scheme) != 1:
            raise ValueError("Cannot slice CSR data in sig dimensions")

    def get_locations(self):
        # Allow using any worker by default
        return None

    def get_tiles(
        self,
        tiling_scheme: TilingScheme,
        dest_dtype="float32",
        roi=None,
        array_backend: typing.Optional[ArrayBackend] = None
    ):
        assert array_backend == SCIPY_CSR or array_backend is None
        tiling_scheme = tiling_scheme.adjust_for_partition(self)
        self.validate_tiling_scheme(tiling_scheme)
        triple = get_triple(self._descriptor)
        if self._corrections is not None and self._corrections.have_corrections():
            raise NotImplementedError(
                "corrections are not yet supported for raw CSR"
            )
        if roi is None:
            yield from read_tiles_straight(
                triple, self.slice, self.meta.sync_offset, tiling_scheme, dest_dtype
            )
        else:
            yield from read_tiles_with_roi(
                triple, self.slice, self.meta.sync_offset, tiling_scheme, roi, dest_dtype
            )


def sliced_indptr(triple: CSRTriple, partition_slice: Slice, sync_offset: int):
    assert len(partition_slice.shape.nav) == 1
    skip = min(0, partition_slice.origin[0] + sync_offset)
    indptr_start = max(0, partition_slice.origin[0] + sync_offset)
    indptr_stop = max(0, partition_slice.origin[0] + partition_slice.shape.nav[0] + 1 + sync_offset)
    return skip, triple.indptr[indptr_start:indptr_stop]


def get_triple(descriptor: CSRDescriptor) -> CSRTriple:
    data: np.ndarray = np.memmap(
        descriptor.data_file,
        dtype=descriptor.data_dtype,
        mode='r'
    )
    indices: np.ndarray = np.memmap(
        descriptor.indices_file,
        dtype=descriptor.indices_dtype,
        mode='r'
    )
    indptr: np.ndarray = np.memmap(
        descriptor.indptr_file,
        dtype=descriptor.indptr_dtype,
        mode='r'
    )

    return CSRTriple(
        indptr=indptr,
        indices=indices,
        data=data,
    )


def check(descriptor: CSRDescriptor, nav_shape, sig_shape, debug=False):
    triple = get_triple(descriptor)
    if triple.indices.shape != triple.data.shape:
        raise RuntimeError('Shape mismatch between data and indices.')
    if debug:
        assert np.min(triple.indices) >= 0
        assert np.max(triple.indices) < prod(sig_shape)
        assert np.min(triple.indptr) >= 0
        assert np.max(triple.indptr) == len(triple.indices)


def get_descriptor(path: str) -> CSRDescriptor:
    """
    Get a CSRDescriptor from the path to a toml sidecar file
    """
    conf = load_toml(path)
    assert conf is not None
    if conf['params']['filetype'].lower() != 'raw_csr':
        raise ValueError(f"Filetype is not CSR, found {conf['params']['filetype']}")

    base_path = os.path.dirname(path)
    # make sure the key is not case sensitive to follow the convention of
    # the Context.load() function.
    csr_key = conf['params']['filetype']
    csr_conf = conf[csr_key]
    return CSRDescriptor(
        indptr_file=os.path.join(base_path, csr_conf['indptr_file']),
        indptr_dtype=csr_conf['indptr_dtype'],
        indices_file=os.path.join(base_path, csr_conf['indices_file']),
        indices_dtype=csr_conf['indices_dtype'],
        data_file=os.path.join(base_path, csr_conf['data_file']),
        data_dtype=csr_conf['data_dtype'],
    )


def get_nav_size(descriptor: CSRDescriptor) -> int:
    '''
    To run efficiently on a remote worker for dataset initialization
    '''
    indptr = np.memmap(
        descriptor.indptr_file,
        dtype=descriptor.indptr_dtype,
        mode='r',
    )
    return len(indptr) - 1


def read_tiles_straight(
    triple: CSRTriple,
    partition_slice: Slice,
    sync_offset: int,
    tiling_scheme: TilingScheme,
    dest_dtype: np.dtype,
):
    assert len(tiling_scheme) == 1

    skip, indptr = sliced_indptr(
        triple,
        partition_slice=partition_slice,
        sync_offset=sync_offset
    )

    sig_shape = tuple(partition_slice.shape.sig)
    sig_size = partition_slice.shape.sig.size
    sig_dims = len(sig_shape)

    # Technically, one could use the slicing implementation of csr_matrix here.
    # However, it is slower, presumably because it takes a copy
    # Furthermore it provides a template to use an actual I/O backend here
    # instead of memory mapping.
    for indptr_start in range(0, len(indptr) - 1, tiling_scheme.depth):
        tile_start = indptr_start - skip  # skip is a negative value or 0
        indptr_stop = min(indptr_start + tiling_scheme.depth, len(indptr) - 1)
        if indptr_stop - indptr_start <= 0:
            continue

        indptr_slice = indptr[indptr_start:indptr_stop + 1]

        start = indptr[indptr_start]
        stop = indptr[indptr_stop]
        data = triple.data[start:stop]
        if dest_dtype != data.dtype:
            data = data.astype(dest_dtype)
        indices = triple.indices[start:stop]

        indptr_slice = indptr_slice - indptr_slice[0]
        arr = scipy.sparse.csr_matrix(
            (data, indices, indptr_slice),
            shape=(indptr_stop - indptr_start, sig_size)
        )
        tile_slice = Slice(
            origin=(partition_slice.origin[0] + tile_start, ) + (0, ) * sig_dims,
            shape=Shape((arr.shape[0], ) + sig_shape, sig_dims=sig_dims),
        )
        yield DataTile(
            data=arr,
            tile_slice=tile_slice,
            scheme_idx=0,
        )


def populate_tile(
    indptr_tile_start: "np.ndarray",
    indptr_tile_stop: "np.ndarray",
    orig_data: "np.ndarray",
    orig_indices: "np.ndarray",
    data_out: "np.ndarray",
    indices_out: "np.ndarray",
    indptr_out: "np.ndarray",
):
    offset = 0
    indptr_out[0] = 0
    for i, (start, stop) in enumerate(zip(indptr_tile_start, indptr_tile_stop)):
        chunk_size = stop - start
        data_out[offset:offset + chunk_size] = orig_data[start:stop]
        indices_out[offset:offset + chunk_size] = orig_indices[start:stop]
        offset += chunk_size
        indptr_out[i + 1] = offset


populate_tile_numba = numba.njit(populate_tile)


def can_use_numba(triple: CSRTriple) -> bool:
    return all(d in numba_dtypes
        for d in (triple.data.dtype, triple.indices.dtype, triple.indptr.dtype))


def read_tiles_with_roi(
    triple: CSRTriple,
    partition_slice: Slice,
    sync_offset: int,
    tiling_scheme: TilingScheme,
    roi: np.ndarray,
    dest_dtype: np.dtype,
):
    assert len(tiling_scheme) == 1
    roi = roi.reshape((-1, ))
    part_start = max(0, partition_slice.origin[0])
    tile_offset = count_nonzero(roi[:part_start])
    part_roi = partition_slice.get(roi, nav_only=True)

    skip, indptr = sliced_indptr(triple, partition_slice=partition_slice, sync_offset=sync_offset)

    if skip < 0:
        skipped_part_roi = part_roi[-skip:]
    else:
        skipped_part_roi = part_roi

    roi_overhang = max(0, len(skipped_part_roi) - len(indptr) + 1)
    if roi_overhang:
        real_part_roi = skipped_part_roi[:-roi_overhang]
    else:
        real_part_roi = skipped_part_roi

    real_part_roi = for_backend(real_part_roi, NUMPY)

    sig_shape = tuple(partition_slice.shape.sig)
    sig_size = partition_slice.shape.sig.size
    sig_dims = len(sig_shape)

    start_values = indptr[:-1][real_part_roi]
    stop_values = indptr[1:][real_part_roi]

    # Implementing this "by hand" instead of fancy indexing to provide a template to use an
    # actual I/O backend here instead of memory mapping.
    # The native scipy.sparse.csr_matrix implementation of fancy indexing
    # with a boolean mask for nav is very fast.

    if can_use_numba(triple):
        my_populate_tile = populate_tile_numba
    else:
        my_populate_tile = populate_tile

    for indptr_start in range(0, len(part_roi), tiling_scheme.depth):
        indptr_stop = min(indptr_start + tiling_scheme.depth, len(start_values))
        indptr_start = min(indptr_start, indptr_stop)
        # Don't read empty slices
        if indptr_stop - indptr_start <= 0:
            continue
        # Cast to int64 to avoid later upcasting to float64 in case of uint64
        # We can safely assume that files have less than 2**63 entries so that casting
        # from uint64 to int64 should be safe
        indptr_tile_start = start_values[indptr_start:indptr_stop].astype(np.int64)
        indptr_tile_stop = stop_values[indptr_start:indptr_stop].astype(np.int64)
        size = sum(indptr_tile_stop - indptr_tile_start)

        data = np.zeros(dtype=dest_dtype, shape=size)
        indices = np.zeros(dtype=triple.indices.dtype, shape=size)
        indptr_slice = np.zeros(
            dtype=indptr.dtype, shape=indptr_stop - indptr_start + 1
        )
        my_populate_tile(
            indptr_tile_start=indptr_tile_start,
            indptr_tile_stop=indptr_tile_stop,
            orig_data=triple.data,
            orig_indices=triple.indices,
            data_out=data,
            indices_out=indices,
            indptr_out=indptr_slice,
        )

        arr = scipy.sparse.csr_matrix(
            (data, indices, indptr_slice),
            shape=(indptr_stop - indptr_start, sig_size)
        )
        tile_slice = Slice(
            origin=(tile_offset + indptr_start, ) + (0, ) * sig_dims,
            shape=Shape((indptr_stop - indptr_start, ) + sig_shape, sig_dims=sig_dims),
        )
        yield DataTile(
            data=arr,
            tile_slice=tile_slice,
            scheme_idx=0,
        )