LiberTEM/LiberTEM

View on GitHub
src/libertem/io/corrections/detector.py

Summary

Maintainability
D
1 day
Test Coverage
import numpy as np
import numba
import sparse

from libertem.common.math import prod
from libertem.common.sparse import is_sparse
from libertem.common.numba import (
    numba_ravel_multi_index_multi, numba_unravel_index_multi
)


@numba.njit(cache=True, nogil=True)
def _correct_numba_inplace(buffer, dark_image, gain_map, exclude_pixels, repair_environments,
        repair_counts):
    '''
    Numerical work horse to perform detector corrections

    This function uses blocked processing for cache efficiency, hence the nested loops. It is
    about 4x faster than a naive numpy implementation.

    Parameters
    ----------

    buffer:
        (n, m) with data, modified in-place.

    dark_image:
        (m) with dark frame to be subtracted first

    gain_map:
        (m) with gain map to multiply the data with after subtraction

    exclude_pixels:
        int(k) array of indices in the flattened signal dimension to patch

    repair_environments:
        Array with environments for each pixel to use as a reference.
        Array of shape (k, max_repair_counts)

    repair_counts:
        Array(int) of length k with number of valid entries in each entry in repair_environments.


    Returns
    -------
    buffer (modified in-place)
    '''
    nav_blocksize = 4
    sig_blocksize = 2**20 // (512 * 4 * nav_blocksize)

    nav_blocks = buffer.shape[0] // nav_blocksize
    nav_remainder = buffer.shape[0] % nav_blocksize

    sig_blocks = buffer.shape[1] // sig_blocksize
    sig_remainder = buffer.shape[1] % sig_blocksize

    def _get_dark_px(sig):
        if dark_image is None:
            return 0
        return dark_image[sig]

    def _get_gain_px(sig):
        if gain_map is None:
            return 1
        return gain_map[sig]

    for nav_block in range(nav_blocks):
        # Dark and gain blocked in sig
        for sig_block in range(sig_blocks):
            for nav in range(nav_block * nav_blocksize, (nav_block + 1) * nav_blocksize):
                for sig in range(sig_block * sig_blocksize, (sig_block + 1) * sig_blocksize):
                    buffer[nav, sig] = (buffer[nav, sig] - _get_dark_px(sig)) * _get_gain_px(sig)

        # Dark and gain remainder of sig blocks
        for nav in range(nav_block * nav_blocksize, (nav_block + 1) * nav_blocksize):
            for sig in range(sig_blocks*sig_blocksize, sig_blocks*sig_blocksize + sig_remainder):
                buffer[nav, sig] = (buffer[nav, sig] - _get_dark_px(sig)) * _get_gain_px(sig)

        # Hole repair blocked in nav
        for i, p in enumerate(exclude_pixels):
            if repair_counts[i] > 0:  # Avoid div0
                for nav in range(nav_block * nav_blocksize, (nav_block + 1) * nav_blocksize):
                    acc = 0
                    for index in repair_environments[i, :repair_counts[i]]:
                        acc += buffer[nav, index]
                    buffer[nav, p] = acc / repair_counts[i]

    # Processing unblocked nav remainder
    for nav in range(nav_blocks * nav_blocksize, nav_blocks * nav_blocksize + nav_remainder):
        # Dark and gain unblocked in sig
        for sig in range(buffer.shape[1]):
            buffer[nav, sig] = (buffer[nav, sig] - _get_dark_px(sig)) * _get_gain_px(sig)

        # Hole repair
        for i, p in enumerate(exclude_pixels):
            if repair_counts[i] > 0:  # Avoid div0
                acc = 0
                for index in repair_environments[i, :repair_counts[i]]:
                    acc += buffer[nav, index]
                buffer[nav, p] = acc / repair_counts[i]

    return buffer


@numba.njit(cache=True, nogil=True)
def environments(excluded_pixels, sigshape):
    '''
    Calculate a hypercube surface around a pixel, excluding frame boundaries

    Returns
    -------
    repairs, repair_counts
        repairs : numpy.ndarray
            Array with shape (exclude_pixels, sig_dims, indices)
        repair_counts : numpy.ndarray
            Array with length exclude_pixels, containing the number of pixels
            in the repair environment
    '''

    max_repair_count = 3**len(sigshape) - 1
    num_pixels = len(excluded_pixels[0])
    repairs = np.zeros((num_pixels, len(sigshape), max_repair_count), dtype=np.intp)
    repair_counts = np.zeros(num_pixels, dtype=np.intp)
    all_indices = np.arange(3**len(sigshape), dtype=np.intp)
    coord_shape = np.full(len(sigshape), 3, dtype=np.intp)
    coord_offsets = numba_unravel_index_multi(all_indices, coord_shape) - 1
    for i in range(num_pixels):
        repair_count = 0
        for position in range(coord_offsets.shape[1]):
            select = False
            for dim in range(coord_offsets.shape[0]):
                coord = coord_offsets[dim, position] + excluded_pixels[dim, i]
                # Any of the coordinates is different
                select += (coord != excluded_pixels[dim, i])
            for dim in range(coord_offsets.shape[0]):
                coord = coord_offsets[dim, position] + excluded_pixels[dim, i]
                # All of the coordinates are within bounds
                select *= (coord >= 0)
                select *= (coord < sigshape[dim])
            if select:
                for dim in range(coord_offsets.shape[0]):
                    coord = coord_offsets[dim, position] + excluded_pixels[dim, i]
                    repairs[i, dim, repair_count] = coord
                repair_count += 1
        repair_counts[i] = repair_count

    return repairs, repair_counts


class RepairValueError(ValueError):
    pass


@numba.njit(cache=True, nogil=True)
def flatten_filter(excluded_pixels, repairs, repair_counts, sig_shape):
    '''
    Flatten excluded pixels and repair environments and filter for collisions

    Ravel indices to flattened signal dimension and
    removed damaged pixels from all repair environments, i.e. only use
    "good" pixels.
    '''
    excluded_flat = numba_ravel_multi_index_multi(excluded_pixels, sig_shape)

    max_repair_count = 3**len(sig_shape) - 1
    new_repair_counts = np.zeros_like(repair_counts)
    repair_flat = np.zeros((len(excluded_flat), max_repair_count), dtype=np.intp)

    excluded_dict = {}
    for i in excluded_flat:
        excluded_dict[i] = True

    for i in range(len(excluded_flat)):
        a = numba_ravel_multi_index_multi(repairs[i, ..., :repair_counts[i]], sig_shape)
        nonzero_index = 0
        for j in range(repair_counts[i]):
            if a[j] not in excluded_dict:
                repair_flat[i, nonzero_index] = a[j]
                nonzero_index += 1
        new_repair_counts[i] = nonzero_index
        if new_repair_counts[i] == 0:
            pass
            # TODO fix for Numba
            # raise RepairValueError("Repair environment for pixel %i is empty" % i)

    return (excluded_flat, repair_flat, new_repair_counts)


def correct(
        buffer, dark_image=None, gain_map=None, excluded_pixels=None, repair_descriptor=None,
        inplace=False, sig_shape=None, allow_empty=False):
    '''
    Function to perform detector corrections

    This function delegates the processing to a function written with numba that is
    about 4x faster than a naive numpy implementation.

    Parameters
    ----------

    buffer:
        shape (*nav, *sig) with data. It is modified in-place if inplace==True.

    dark_image:
        shape (*sig) with dark frame to be subtracted first

    gain_map:
        shape (*sig) with gain map to multiply the data with after subtraction

    exclude_pixels:
        int(sigs, k) array of indices in the signal dimension to patch.
        The first dimension is the number of signal dimensions, the second the number of pixels

    repair_descriptor : RepairDescriptor
        This allows to re-use the calculation and filtering of repair environments when
        specified instead of exclude_pixels. This is particularly advantageous for tiled processing.

    inplace:
        If True, modify the input buffer in-place.
        If False (default), copy the input buffer before correcting.

    Returns
    -------
    shape (*nav, *sig) If inplace==True, this is :code:`buffer` modified in-place.
    '''
    s = buffer.shape

    if dark_image is not None:
        sig_shape = dark_image.shape
        dark_image = dark_image.flatten()
    if gain_map is not None:
        sig_shape = gain_map.shape
        gain_map = gain_map.flatten()
    if sig_shape is None:
        raise ValueError("need either `dark_image`, `gain_map`, or `sig_shape`")
    nav_shape = s[0:-len(sig_shape)]

    if inplace:
        if buffer.dtype.kind not in ('f', 'c'):
            raise TypeError("In-place correction only supported for floating point data.")
        out = buffer
    else:
        # astype() is always a copy even if it is the same dtype
        out = buffer.astype(np.result_type(np.float32, buffer))

    if repair_descriptor is None:
        repair_descriptor = RepairDescriptor(
            sig_shape=sig_shape,
            excluded_pixels=excluded_pixels,
            allow_empty=allow_empty
        )
    else:
        repair_descriptor.check_empty_repairs(allow_empty=allow_empty)
        if excluded_pixels is not None:
            raise ValueError("Invalid arguments: Bot repair_descriptor and excluded_pixels set")

    _correct_numba_inplace(
        buffer=out.reshape((prod(nav_shape), prod(sig_shape))),
        dark_image=dark_image,
        gain_map=gain_map,
        exclude_pixels=repair_descriptor.exclude_flat,
        repair_environments=repair_descriptor.repair_flat,
        repair_counts=repair_descriptor.repair_counts,
    )
    return out


class RepairDescriptor:
    def __init__(self, sig_shape, excluded_pixels=None, allow_empty=False):
        if excluded_pixels is None:
            excluded_pixels = np.zeros((len(sig_shape), 0), dtype=np.intp)
        else:
            excluded_pixels = np.array(excluded_pixels)

        repairs, repair_counts = environments(excluded_pixels, np.array(sig_shape))

        self.exclude_flat, self.repair_flat, self.repair_counts = flatten_filter(
            excluded_pixels, repairs, repair_counts, sig_shape
        )
        self.check_empty_repairs(allow_empty=allow_empty)

    def empty_repairs(self):
        return np.argwhere(self.repair_counts == 0)

    def check_empty_repairs(self, allow_empty):
        if not allow_empty:
            empty = self.empty_repairs()
            if len(empty) > 0:
                raise RepairValueError(
                    f"Empty repair environments for pixel(s) number {empty}."
                )


def correct_dot_masks(masks, gain_map, excluded_pixels=None, allow_empty=False):
    mask_shape = masks.shape
    sig_shape = gain_map.shape
    masks = masks.reshape((-1, prod(sig_shape)))

    if excluded_pixels is not None:
        if is_sparse(masks):
            result = sparse.DOK(masks)
        else:
            result = masks.copy()
        desc = RepairDescriptor(sig_shape, excluded_pixels=excluded_pixels, allow_empty=allow_empty)
        for e, r, c in zip(desc.exclude_flat, desc.repair_flat, desc.repair_counts):
            result[:, e] = 0
            rep = masks[:, e] / c
            # We have to loop because of sparse.pydata limitations
            for m in range(result.shape[0]):
                for rr in r[:c]:
                    result[m, rr] = result[m, rr] + rep[m]
        if is_sparse(result):
            result = sparse.COO(result)
    else:
        result = masks
    result = result * gain_map.flatten()
    return result.reshape(mask_shape)