Ptrskay3/PySprint

View on GitHub
pysprint/core/bases/_apply.py

Summary

Maintainability
A
0 mins
Test Coverage
import types
import warnings
from collections.abc import Iterable
from inspect import getfullargspec

import numpy as np


class _DatasetApply:
    """
    Helper class to apply function to
    `pysprint.core.bases.dataset.Dataset` objects.
    """
    def __init__(
            self,
            obj,
            func,
            axis=None,
            args=None,
            kwargs=None
    ):

        self.obj = obj
        self.args = args or ()
        self.kwargs = kwargs or {}

        self.f = func
        self.axis = axis

        if self.axis == "x" or self.axis == 0:
            self.target = "x"
        elif self.axis == "y" or self.axis == 1:
            self.target = "y"
        else:
            raise ValueError("Axis must be 'x', 'y', '0' or '1'.")
        self.shape = len(getattr(self.obj, self.target))

    def perform(self):
        """
        Apply the specified function.
        """
        if isinstance(self.f, str):
            func = getattr(self.obj, self.f)
            sig = getfullargspec(func)
            if "axis" in sig.args:
                self.kwargs["axis"] = self.axis
            # Let's assume we don't mess up the shape internally
            func(*self.args, **self.kwargs)
            return self.obj  # we need to return this because of `inplacify` deco.

        elif isinstance(self.f, np.ufunc):
            target = getattr(self.obj, self.target)

            retval = self.f(target, *self.args, **self.kwargs)
            value = self._validate(retval)

            setattr(self.obj, self.target, value)
            if self.target == "y":
                setattr(self.obj, "y_norm", value)
            return value

        elif isinstance(self.f, types.FunctionType):
            sig = getfullargspec(self.f)
            if "axis" in sig.args:
                self.kwargs["axis"] = self.axis
            # we can safely vectorize it here
            self.f = np.vectorize(self.f)
            target = getattr(self.obj, self.target)
            retval = self.f(target, *self.args, **self.kwargs)
            value = self._validate(retval)
            setattr(self.obj, self.target, value)
            if self.target == "y":
                setattr(self.obj, "y_norm", value)
            return value

    def _validate(self, val):

        if isinstance(val, (Iterable, np.ndarray)):
            val = np.asarray(val, dtype=np.float64)

            if val.ndim != 1:
                val = np.concatenate(val).ravel()
                warnings.warn("Function return value was flattened.")

            if len(val) != len(np.unique(val)):
                if len(np.unique(val)) == self.shape:
                    return val
                else:
                    if self.target == "x":
                        raise ValueError(
                            f"Function returned duplicated values which is not allowed when"
                            " modifying the x axis. After filtering to unique values "
                            f"a {len(np.unique(val))}-length array was produced, "
                            f"but {self.shape} was expected."
                        )
                    return val

            if len(val) != self.shape:
                retval = self._broadcast(val)
                return retval
            return val
        else:
            raise TypeError("Function should return a number or Iterable type.")

    def _broadcast(self, val):
        if len(val) > self.shape:
            return val[:self.shape]
        elif len(val) < self.shape:
            if not self.shape % len(val) == 0:
                raise ValueError("Cannot broadcast safely to the desired shape.")
            else:
                return np.repeat(val, (self.shape % len(val)))