nstarman/utilipy

View on GitHub
utilipy/data_utils/xfm/graph.py

Summary

Maintainability
C
1 day
Test Coverage
# -*- coding: utf-8 -*-

"""Data Transformation Graph."""

__author__ = "Nathaniel Starkman"
__credits__ = ["Astropy"]


__all__ = [
    "TransformGraph",
]


##############################################################################
# IMPORTS

# BUILT-IN
import heapq
import typing as T
from collections import defaultdict

# PROJECT-SPECIFIC
from .transformations import CompositeTransform, _default_xfm_set
from utilipy.utils import functools, inspect

##############################################################################
# PARAMETERS


##############################################################################
# CODE
##############################################################################


class TransformGraph:
    """Graph representing the paths between data types.

    Notes
    -----
    Note that the _graph key-value order is "totype": "fromtype".
    This is the opposite type as Astropy's since this TransformGraph
    needs to support:

    - catch-all conversions, such as converting from any sequence ``data``
      to type ``list`` by applying ``list(data)``
    - sub-type conversions, where the conversion a -> b (with types A, B)
      works on any subtype of A (like A1(A)).


    .. todo::

        - catch-all conversions to a specific type
        - multiple input option conversions to a specific type
        - multiple output option conversions by choosing one with shortest path
        - pre-register basic conversions like None->None, list->tuple

    """

    def __init__(self, seed_basic: bool = True):
        """Data Transformation Graph.

        Parameters
        ----------
        seed_basic : bool
            whether to start with a basic set of transformations

            .. todo::

                Generate documentation / colored graph like Astropy's

        """
        # graph, in reverse order
        self._graph = _default_xfm_set if seed_basic else defaultdict(dict)
        self.invalidate_cache()  # generates cache entries

    # /def

    @property
    def _cached_names(self):
        if self._cached_names_dct is None:
            self._cached_names_dct = dct = {}
            for c in self.type_set:
                nm = getattr(c, "name", None)
                if nm is not None:
                    if not isinstance(nm, list):
                        nm = [nm]
                    for name in nm:
                        dct[name] = c

        return self._cached_names_dct

    # /def

    @property
    def type_set(self):
        """A `set` of all data types present in this `TransformGraph`."""
        if self._cached_type_set is None:
            self._cached_type_set = set()
            for a in self._graph:
                self._cached_type_set.add(a)
                for b in self._graph[a]:
                    self._cached_type_set.add(b)

        return self._cached_type_set.copy()

    # /def

    def invalidate_cache(self):
        """Clears all caching attributes.

        Invalidates the cache that stores optimizations for traversing the
        transform graph.  This is called automatically when transforms
        are added or removed, but will need to be called manually if
        weights on transforms are modified inplace.

        """
        self._cached_names_dct: T.Optional[dict] = None
        self._cached_type_set: T.Optional[set] = None
        # self._cached_frame_attributes = None
        # self._cached_component_names = None
        self._shortestpaths = {}
        self._composite_cache = {}

    # /def

    def add_transform(self, fromtype, totype, transform):
        """Add a new data transformation to the graph.

        .. todo::

           - support an "Any" option in fromtype
           - support adding a tuple of types as the "fromtype"
           - support subtypes in "fromtype"

        Parameters
        ----------
        fromtype : class
            The data class to start from.
        totype : class
            The data class to transform into.
        transform : DataTransform or similar callable
            The transformation object. Typically a `DataTransform` object,
            although it may be some other callable that is called with the same
            signature.

        Raises
        ------
        TypeError
            If ``fromtype`` or ``totype`` are not classes or ``transform`` is
            not callable.

        """
        if not inspect.isclass(fromtype) and fromtype is not None:
            raise TypeError("fromtype must be a class")
        if not inspect.isclass(totype) and totype is not None:
            raise TypeError("totype must be a class")
        if not callable(transform):
            raise TypeError("transform must be callable")

        type_set = self.type_set.copy()
        type_set.add(fromtype)
        type_set.add(totype)

        self._graph[totype][fromtype] = transform
        self.invalidate_cache()

    # /def

    def remove_transform(self, fromtype, totype, transform):
        """Removes a data transform from the graph.

        .. todo::

            - support removing catch-all transformations

        Parameters
        ----------
        fromtype : class or `None`
            The coordinate frame *class* to start from. If `None`,
            ``transform`` will be searched for and removed (``totype`` must
            also be `None`).
        totype : class or `None`
            The coordinate frame *class* to transform into. If `None`,
            ``transform`` will be searched for and removed (``fromtype`` must
            also be `None`).
        transform : callable or `None`
            The transformation object to be removed or `None`.  If `None`
            and ``totype`` and ``fromtype`` are supplied, there will be no
            check to ensure the correct object is removed.

        """
        if fromtype is None or totype is None:
            if not (totype is None and fromtype is None):
                raise ValueError(
                    "fromtype and totype must both be None if either are"
                )
            if transform is None:
                raise ValueError("cannot give all Nones to remove_transform")

            # search for the requested transform by brute force and remove it
            for a in self._graph:
                agraph = self._graph[a]
                for b in agraph:
                    if agraph[b] is transform:
                        del agraph[b]
                        fromtype = a
                        break

                # If transform was found, need to break out of outer loop
                if fromtype:
                    break
            else:
                raise ValueError(
                    "Could not find transform {} in the "
                    "graph".format(transform)
                )

        else:
            if transform is None:
                self._graph[totype].pop(fromtype, None)
            else:
                curr = self._graph[totype].get(fromtype, None)
                if curr is transform:
                    self._graph[totype].pop(fromtype)
                else:
                    raise ValueError(
                        "Current transform from {} to {} is not "
                        "{}".format(fromtype, totype, transform)
                    )

        # Remove the subgraph if it is now empty
        if self._graph[totype] == {}:
            self._graph.pop(totype)

        self.invalidate_cache()

    # /def

    def _construct_path(self, fromtype, totype):
        """Construct path using Dijkstra's algorithm.

        Parameters
        ----------
        fromtype
        totype

        Returns
        -------
        path : list
        priority : int

        """
        inf = float("inf")

        nodes = []
        # first make the list of nodes
        for a in self._graph:
            if a not in nodes:
                nodes.append(a)
            for b in self._graph[a]:
                if b not in nodes:
                    nodes.append(b)

        if fromtype not in nodes or totype not in nodes:
            # fromtype or totype are isolated or not registered, so there's
            # certainly no way to get from one to the other
            return None, inf

        edgeweights = {}
        # construct another graph that is a dict of dicts of priorities
        # (used as edge weights in Dijkstra's algorithm)
        for a in self._graph:
            edgeweights[a] = aew = {}
            agraph = self._graph[a]
            for b in agraph:
                aew[b] = float(
                    agraph[b].priority if hasattr(agraph[b], "priority") else 1
                )

        # entries in q are [distance, count, nodeobj, pathlist]
        # count is needed because in py 3.x, tie-breaking fails on the nodes.
        # this way, insertion order is preserved if the weights are the same
        # q = [[inf,i,n,[]] for i, n in enumerate(nodes) if n is not fromtype]
        # q.insert(0, [0, -1, fromtype, []])
        q = [[inf, i, n, []] for i, n in enumerate(nodes) if n is not totype]
        q.insert(0, [0, -1, totype, []])

        # this dict stores the distance to node from ``fromtype`` and the path
        result = {}

        # definitely starts as a valid heap because of the insert line;
        # from the node to itself is always the shortest distance
        while len(q) > 0:
            d, orderi, n, path = heapq.heappop(q)

            if d == inf:
                # everything left is unreachable from fromtype,
                # just copy them to the results and jump out of the loop
                result[n] = (None, d)
                for d, orderi, n, path in q:
                    result[n] = (None, d)
                break
            else:
                result[n] = (path, d)
                path.append(n)
                if n not in edgeweights:
                    # a system that can be transformed to, but not from.
                    continue
                for n2 in edgeweights[n]:
                    if n2 not in result:  # already visited
                        # find where n2 is in the heap
                        for i in range(len(q)):
                            if q[i][2] == n2:
                                break
                        else:
                            raise ValueError(
                                "n2 not in heap - this should be impossible!"
                            )

                        newd = d + edgeweights[n][n2]
                        if newd < q[i][0]:
                            q[i][0] = newd
                            q[i][3] = list(path)
                            heapq.heapify(q)

        # cache for later use
        # self._shortestpaths[totype] = result # FIXME
        # return result[fromtype]
        # FIXME
        path, d = result[fromtype]
        return path[::-1], d

    # /def

    def find_shortest_path(self, fromtype, totype):
        """Compute shortest path along graph from one system to another.

        Parameters
        ----------
        fromtype : class
            The coordinate frame class to start from.
        totype : class
            The coordinate frame class to transform into.

        Returns
        -------
        path : list of classes or `None`
            The path from ``fromtype`` to ``totype`` as an in-order sequence
            of classes.  This list includes *both* ``fromtype`` and
            ``totype``. Is `None` if there is no possible path.
        distance : number
            The total distance/priority from ``fromtype`` to ``totype``.  If
            priorities are not set this is the number of transforms
            needed. Is ``inf`` if there is no possible path.

        """
        # ----------------------------------
        # special-case the 0 or 1-path

        if totype is fromtype:
            if fromtype not in self._graph[totype]:
                # Means there's no transform necessary to go from it to itself.
                return [totype], 0

        if fromtype in self._graph[totype]:
            # this will also catch the case where totype is fromtype, but has
            # a defined transform.
            t = self._graph[totype][fromtype]
            return (
                [fromtype, totype],
                float(t.priority if hasattr(t, "priority") else 1),
            )

        # ----------------------------------
        # otherwise, need to construct the path:

        inf = float("inf")

        # TODO verify this works for catch-alls
        if totype in self._shortestpaths:
            # already have a cached result
            fpaths = self._shortestpaths[totype]
            if fromtype in fpaths:
                return fpaths[fromtype]
            else:
                path, priority = None, inf

        path, priority = self._construct_path(fromtype, totype)

        return path, priority

    # /def

    def get_transform(self, fromtype, totype):
        """Generate `CompositeTransform` for a datatype transformation.

        Parameters
        ----------
        fromtype : class
            The coordinate frame class to start from.
        totype : class
            The coordinate frame class to transform into.

        Returns
        -------
        trans : `CompositeTransform` or `None`
            If there is a path from ``fromtype`` to ``totype``, this is a
            transform object for that path. If no path could be found, this is
            `None`.

        Notes
        -----
        This function always returns a `CompositeTransform`, because
        `CompositeTransform` is slightly more adaptable in the way it can be
        called than other transform classes. Specifically, it takes care of
        intermediate steps of transformations in a way that is consistent with
        1-hop transformations.

        """
        if not inspect.isclass(fromtype) and fromtype is not None:
            raise TypeError("fromtype is not a class")
        if not inspect.isclass(totype) and totype is not None:
            raise TypeError("totype is not a class")

        path, distance = self.find_shortest_path(fromtype, totype)

        if path is None:
            return None

        transforms = []
        currtype = path[0]
        for p in path[1:]:  # first element is fromtype so we skip it
            transforms.append(self._graph[p][currtype])
            currtype = p

        fttuple = (fromtype, totype)
        if fttuple not in self._composite_cache:
            comptrans = CompositeTransform(
                transforms, fromtype, totype, register_graph=False
            )
            self._composite_cache[fttuple] = comptrans

        return self._composite_cache[fttuple]

    # /def

    def lookup_name(self, name: str):
        """Tries to locate the class with the provided alias.

        Parameters
        ----------
        name : str
            The alias to look up.

        Returns
        -------
        datacls
            The data class corresponding to the ``name`` or `None` if
            no such class exists.

        """
        return self._cached_names.get(name, None)

    # /def

    def get_names(self):
        """Returns all available transform names.

        Returns
        -------
        nms : list
            The aliases for coordinate systems.
            They will all be valid arguments to `lookup_name`.

        """
        return list(self._cached_names.keys())

    # /def

    def register(
        self, transcls, fromtype, totype, priority: int = 1, **kwargs
    ):
        """A function decorator for defining a transformation.

        .. note::
            If decorating a static method of a class, ``@staticmethod``
            should be  added *above* this decorator.

        Parameters
        ----------
        transcls : class
            The class of the transformation object to create.
        fromtype : class
            The data class to start from.
        totype : class
            The data class to transform into.
        priority : number
            The priority if this transform when finding the shortest
            coordinate transform path - large numbers are lower priorities.

        Additional keyword arguments are passed into the ``transcls``
        constructor.

        Returns
        -------
        deco : function
            A function that can be called on another function as a decorator
            (see example).

        Notes
        -----
        This decorator assumes the first argument of the ``transcls``
        initializer accepts a callable, and that the second and third
        are ``fromtype`` and ``totype``. If this is not true, you should just
        initialize the class manually and use `add_transform` instead of
        using this decorator.

        Examples
        --------
        ::

            graph = TransformGraph()

            @graph.transform(DataTransform, list, tuple)
            def list_to_tuple(data):
                return tuple(data)

        """

        # create decorator
        def register_xfm_decorator(func: T.Callable):
            # this doesn't do anything directly with the transform because
            # ``register_graph=self`` stores it in the transform graph
            # automatically
            transcls(
                func,
                fromtype,
                totype,
                priority=priority,
                register_graph=self,
                **kwargs,
            )
            return func

        # /def

        return register_xfm_decorator

    # /def

    transform = register  # for similarity to Astropy's
    # /def

    def function_decorator(
        self,
        function: T.Optional[T.Callable] = None,
        *,
        _doc_style: str = "numpy",
        _doc_fmt: T.Dict[str, T.Any] = {},
        **arguments,
    ):
        """Apply data transformations to function arguments.

        Parameters
        ----------
        function : T.Callable or None, optional
            the function to be decoratored
            if None, then returns decorator to apply.

        **arguments: dict
            argument information, where keyword is the argument parameter
            name in `function`. The values are either the desired output type
            or a 3-element *tuple* in the following order
            (outtype, (args), dict(kwargs)). The args and kwargs are passed
            into the transformation.

        Returns
        -------
        wrapper : T.Callable
            wrapper for `function` that manage input catalog tables.
            includes the original function in a method `.__wrapped__`

        Other Parameters
        ----------------
        _doc_style: str or formatter, optional
            `function` docstring style. Parameter to `wraps`.
        _doc_fmt: dict, optional
            `function` docstring format arguments. Parameter to `wraps`.

        Notes
        -----
        .. todo::

            - scrape output type from function argument annotation
            - support a multiple possible output types (Union[etc]), choosing
              the one with the shortest path

        """
        if function is None:  # allowing for optional arguments
            return functools.partial(
                self.function_decorator,
                _doc_style=_doc_style,
                _doc_fmt=_doc_fmt,
                **arguments,
            )

        sig = inspect.fuller_signature(function)
        _doc_fmt.update({"argkeys": ", ".join(arguments.keys())})

        @functools.wraps(function, _doc_style=_doc_style, _doc_fmt=_doc_fmt)
        def wrapper(*args, _skip_decorator=False, **kwargs):
            """Wrapper docstring.

            Other Parameters
            ----------------
            _skip_decorator : bool, optional
                Whether to skip the decorator.
                default {_skip_decorator}

            Notes
            -----
            This function is wrapped with a data `~TransformGraph` decorator.
            See `~TransformGraph.function_decorator` for details.
            The transformation arguments are also attached to this function
            as the attribute ``._transforms``.
            The affected arguments are: {argkeys}

            """
            if _skip_decorator:  # whether to skip decorator or keep going
                return function(*args, **kwargs)
            # else:

            ba = sig.bind_partial_with_defaults(*args, **kwargs)

            for name, outtype in arguments.items():
                data = ba.arguments[name]  # get the data to be transformed

                # The values are either the desired output type
                # or a 3-element *tuple* in the following order
                # (outtype, (args), dict(kwargs)). The args and kwargs
                # are passed into the transformation.
                t_args, t_kw = (), {}  # assume no input (kw)args
                if isinstance(outtype, tuple):  # output type or info tuple
                    if len(outtype) == 3:  # it's an info tuple
                        outtype, t_args, t_kw = outtype

                # get transformation
                fromtype = type(data) if data is not None else None
                t = self.get_transform(fromtype, outtype)
                # print(t, type(data), outtype)
                # apply transformation
                ba.arguments[name] = t(data, *t_args, **t_kw)

            # /def

            return_ = function(*ba.args, **ba.kwargs)

            return return_

        # /def

        wrapper._transforms = arguments

        return wrapper

        # /def

    decorate = function_decorator
    # /def

    # def copy(self):
    #     """Deep-copy self"""
    #     return copy.deepcopy(self)

    # # /def


# /class


##############################################################################
# END