LucaCappelletti94/ddd_subplots

View on GitHub
ddd_subplots/rotate.py

Summary

Maintainability
A
3 hrs
Test Coverage
A
97%
"""Package to produce rotating 3d plots."""
import os
import warnings
from typing import Callable, List, Union

import imageio
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib.pyplot as plt
from matplotlib.axis import Axis
from matplotlib.axes import Axes
import numpy as np
import cv2
from sklearn.preprocessing import MinMaxScaler
from tqdm.auto import trange


def rotate_along_last_axis(
    x: np.ndarray, y: np.ndarray, *features: List[np.ndarray], theta: float
) -> List[np.ndarray]:
    """Return points rotate along z-axis.

    Parameters
    ---------------------
    x: np.ndarray,
        First axis of the points vector.
    y: np.ndarray,
        Second axis of the points vector.
    features: List[np.ndarray],
        Extra features to be rotated.
    theta: float,
        Theta for the current variation.

    Returns
    ----------------------
    Tuple with rotated values.
    """
    w = x + 1j * y
    return [
        np.real(np.exp(1j * theta) * w),
        np.imag(np.exp(1j * theta) * w),
        *[feature for feature in features],
    ]


def rotating_spiral(*features: List[np.ndarray], theta: float) -> np.ndarray:
    """Return rotated points following a spiral path.

    Parameters
    ---------------------
    features: List[np.ndarray],
        Extra features to be rotated.
    theta: float,
        Theta for the current variation.

    Returns
    ----------------------
    Numpy array with rotated values.
    """
    features = list(features)
    for i in range(len(features)):
        new_features = rotate_along_last_axis( # pylint: disable=no-value-for-parameter
            *features, theta=theta * min(2**i, 2)
        )
        features[-1] = new_features[0]
        features[:-1] = new_features[1:]
    return np.vstack([feature / np.sqrt(2) for feature in features])


def render_frame(
    func: Callable,
    points: Union[np.ndarray, List[np.ndarray]],
    theta: float,
    args: List,
    **kwargs,
) -> np.ndarray:
    """Returns rendered frame.

    Parameters
    -----------------------
    func: Callable,
        Function to call to renderize the frame.
    points: Union[np.ndarray, List[np.ndarray]],
        The points to be rotated and renderized.
    theta: float,
        The amount of rotation.
    args: List,
        The list of positional arguments.
    kwargs: Dict,
        The dictionary of keywargs arguments.
    """
    points = [rotating_spiral(*matrix.T, theta=theta).T for matrix in points]

    points = [matrix[:, :3] if matrix.shape[1] > 3 else matrix for matrix in points]

    returned_value = func(points[0] if len(points) == 1 else points, *args, **kwargs)

    if not isinstance(returned_value, tuple):
        raise ValueError(
            "The provided rendering function does not return "
            "a tuple with figure and axes!"
        )

    fig, axis = returned_value[:2]
    canvas = FigureCanvas(fig)

    window = 1.0
    if any([matrix.shape[1] > 2 for matrix in points]):
        window = 0.6

    if isinstance(axis, (Axes, Axis)):
        axis = np.array([axis])

    for ax in axis.flatten():
        ax.set_axis_off()
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xlim(-window, window)
        ax.set_ylim(-window, window)
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.axis("off")
        try:
            ax.set_zlim(-window, window)
            ax.set_zticklabels([])
        except AttributeError:
            pass

    canvas.draw()  # draw the canvas, cache the renderer

    width, height = fig.get_size_inches() * fig.get_dpi()

    data = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8).reshape(
        int(height), int(width), 3
    )

    plt.close(fig)
    plt.close()

    return data


def rotate(
    func: Callable,
    points: Union[np.ndarray, List[np.ndarray]],
    path: str,
    *args,
    fps: int = 24,
    duration: int = 1,
    verbose: bool = False,
    **kwargs,
):
    """Create rotating gif of given image.

    Parameters
    -----------------------
    func: Callable
        function return the figure.
    points: Union[np.ndarray, List[np.ndarray]]
        The 3D or 4D array to rotate or roto-translate.
    path: str
        path where to save the GIF.
    *args
        positional arguments to be passed to the `func` callable.
    fps: int = 24
        number of FPS to create.
    duration: int = 1
        Duration of the rotation in seconds.
    verbose: bool = False
        whetever to be verbose about frame creation.
    **kwargs
        keyword argument to be passed to the `func` callable

    Raises
    -----------------------
    ValueError
        If the provided points cloud is None.
    ValueError
        If the provided points cloud is not a numpy array.
    """
    if os.path.dirname(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)

    if os.path.exists(path):
        os.remove(path)

    if not isinstance(points, list):
        points = [points]

    for i, points_cloud in enumerate(points):
        if points_cloud is None:
            raise ValueError(f"The provided points cloud at index {i} is None!")
        if not isinstance(points_cloud, np.ndarray):
            raise ValueError(
                f"The provided points cloud at index {i} is not a numpy array! "
                f"Instead it is a {type(points_cloud)}!"
            )

    scaled_points = [
        MinMaxScaler(feature_range=(-1, 1)).fit_transform(matrix) for matrix in points
    ]

    total_frames = duration * fps

    is_gif = path.endswith(".gif")
    is_video = path.split(".")[-1] in ("webm", "mp4", "avi")

    if is_gif:
        gif_writer = imageio.get_writer(path, mode="I", fps=fps)
    elif is_video:
        encoding = {"mp4": "MP4V", "avi": "FMP4", "webm": "vp80"}[path.split(".")[-1]]
        fourcc = cv2.VideoWriter_fourcc(*encoding)  # pylint: disable=no-member
    else:
        raise ValueError(
            "The provided format, as detected from the provided "
            "path extension, is not supported! "
            f"The path you have provided is `{path}`."
        )

    for frame in trange(
        total_frames,
        desc="Rendering",
        disable=not verbose,
        dynamic_ncols=True,
        leave=False,
    ):
        rendered_frame = render_frame(
            func=func,
            points=scaled_points,
            theta=2 * np.pi * frame / total_frames,
            args=args,
            **kwargs,
        )

        if is_gif:
            gif_writer.append_data(rendered_frame)
        else:
            # If this is the first frame
            if frame == 0:
                height, width, _ = rendered_frame.shape
                video_writer = cv2.VideoWriter(  # pylint: disable=no-member
                    path, fourcc, fps, (width, height)
                )
            video_writer.write(rendered_frame)

    if is_gif:
        try:
            from pygifsicle import optimize  # pylint: disable=import-outside-toplevel

            optimize(path)
        except ImportError:
            warnings.warn(
                "The `pygifsicle` package is not installed. "
                "It is not possible to optimize the GIF "
                "file size, which might be very large. "
                "Considering installing it with `pip install pygifsicle`, "
                "which will require `gifsicle` to be installed "
                "in your system."
            )
    else:
        cv2.destroyAllWindows()  # pylint: disable=no-member
        video_writer.release()

    if not os.path.exists(path):
        raise ValueError(
            (
                f"The expected target path file `{path}` was "
                "not created. Tipically this is caused by some "
                "errors in the encoding of the file that has "
                "been chosen. Please take a look at the log that "
                "has be printed in either the console or the jupyter "
                "kernel."
            )
        )