benvial/gyptis

View on GitHub
src/gyptis/plot.py

Summary

Maintainability
B
4 hrs
Test Coverage
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Benjamin Vial
# This file is part of gyptis
# Version: 1.0.2
# License: MIT
# See the documentation at gyptis.gitlab.io


import copy

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from dolfin.common.plotting import mesh2triang
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.transforms import Affine2D
from matplotlib.tri import Triangulation

from . import dolfin
from .complex import *
from .utils.helpers import project_iterative

colors = dict(
    red=(210 / 255, 95 / 255, 95 / 255), green=(69 / 255, 149 / 255, 125 / 255)
)

if "gyptis" not in plt.colormaps():
    matplotlib.colormaps.register(
        cmap=LinearSegmentedColormap.from_list(
            "gyptis", [colors["green"], (1, 1, 1), colors["red"]], N=100
        )
    )
if "gyptis_r" not in plt.colormaps():
    matplotlib.colormaps.register(
        cmap=LinearSegmentedColormap.from_list(
            "gyptis_r", [colors["red"], (1, 1, 1), colors["green"]], N=100
        )
    )

if "gyptis_white" not in plt.colormaps():
    matplotlib.colormaps.register(
        cmap=LinearSegmentedColormap.from_list(
            "gyptis_white", [(1, 1, 1), (1, 1, 1), (1, 1, 1)], N=100
        )
    )

if "gyptis_black" not in plt.colormaps():
    matplotlib.colormaps.register(
        cmap=LinearSegmentedColormap.from_list(
            "gyptis_black", [(0, 0, 0), (0, 0, 0), (0, 0, 0)], N=100
        )
    )


def get_boundaries(markers, domain=None, shift=(0, 0)):
    data = markers.array()
    triang = mesh2triang(markers.mesh())
    ids = np.unique(data) if domain is None else [domain]
    triangulations = []
    for id in ids:
        triang_ = copy.deepcopy(triang)
        triang_.set_mask(data != id)
        triangulations.append(triang_)

    sub_bnds = []
    for triangtest in triangulations:
        maskedTris = triangtest.get_masked_triangles()
        verts = np.stack((triangtest.x[maskedTris], triangtest.y[maskedTris]), axis=-1)
        all_vert = np.vstack(verts).T
        sub_triang = Triangulation(*all_vert)

        boundaries = []
        for i in range(len(sub_triang.triangles)):
            boundaries.extend(
                (
                    sub_triang.triangles[i, j],
                    sub_triang.triangles[i, (j + 1) % 3],
                )
                for j in range(3)
                if sub_triang.neighbors[i, j] < 0
            )
        boundaries = np.asarray(boundaries)

        bndpnts = (
            shift[0] + sub_triang.x[boundaries].T,
            shift[1] + sub_triang.y[boundaries].T,
        )

        sub_bnds.append(bndpnts)

    return sub_bnds


def plot_boundaries(markers, domain=None, shift=(0, 0), ax=None, **kwargs):
    sub_bnds = get_boundaries(markers, domain=domain, shift=shift)
    if "c" not in kwargs and "color" not in kwargs:
        kwargs["color"] = "k"
    if ax is None:
        ax = plt.gca()
    lines = []
    for bndpnts in sub_bnds:
        line = ax.plot(*bndpnts, **kwargs)
        lines.append(line)
    return lines


# def plot_subdomains(markers, alpha=0.3):
#     a = dolfin.plot(markers, cmap="binary", alpha=alpha, lw=0.00, edgecolor="face")
#     return a
#     # a.set_edgecolors((0.1, 0.2, 0.5, 0.))


def plot_markers(markers, subdomains, ax=None, geometry=None, colorbar=True, **kwargs):
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    if "cmap" not in kwargs:
        kwargs["cmap"] = "inferno"
    plt.sca(ax)
    n = len(subdomains)
    cmap_disc = plt.cm.get_cmap(kwargs["cmap"])
    bounds = np.array(list(subdomains.values()))
    srt = np.argsort(bounds)
    ids = np.array(list(subdomains.keys()))[srt]
    bounds = bounds[srt]
    bounds = np.hstack([bounds, [bounds[-1] + 1]])

    tt = [(bounds[i + 1] + bounds[i]) / 2 for i in range(n)]

    norm = matplotlib.colors.BoundaryNorm(bounds, cmap_disc.N)
    kwargs["cmap"] = cmap_disc
    p = dolfin.plot(markers, norm=norm, **kwargs)
    if colorbar:
        cbar = plt.colorbar(
            matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_disc), ticks=tt, ax=ax
        )

        cbar.ax.set_yticklabels(ids)
        cbar.ax.tick_params(length=0)
    else:
        cbar = None
    kwargs.pop("cmap")
    # cbar = plt.colorbar(p,ticks=list(subdomains.keys()),cmap=cmap_disc) if colorbar else None

    if geometry:
        geometry.plot_subdomains(**kwargs)
    return p, cbar


def plot_subdomains(markers, **kwargs):
    return plot_boundaries(markers, **kwargs)


def plot_real(
    fplot,
    ax=None,
    geometry=None,
    proj_space=None,
    colorbar=True,
    orientation="vertical",
    **kwargs,
):
    proj = proj_space is not None
    if "cmap" not in kwargs:
        kwargs["cmap"] = "RdBu_r"
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    if proj:
        fplot = project_iterative(fplot, proj_space)
    plt.sca(ax)
    p = dolfin.plot(fplot, **kwargs)
    cbar = plt.colorbar(p, orientation=orientation) if colorbar else None
    kwargs.pop("cmap")
    if geometry:
        geometry.plot_subdomains(**kwargs)
    return p, cbar


def plot_complex(
    fplot,
    ax=None,
    geometry=None,
    proj_space=None,
    colorbar=True,
    orientation="vertical",
    **kwargs,
):
    if ax is None:
        fig, ax = plt.subplots(1, 2)
    P, C = [], []
    for a, f in zip(ax, [fplot.real, fplot.imag]):
        p, cbar = plot_real(
            f,
            ax=a,
            geometry=geometry,
            proj_space=proj_space,
            colorbar=colorbar,
            orientation=orientation,
            **kwargs,
        )
        P.append(p)
        C.append(cbar)
    return P, C


def plot(fplot, ax=None, geometry=None, proj_space=None, colorbar=True, **kwargs):
    returnfunc = plot_complex if iscomplex(fplot) else plot_real
    return returnfunc(fplot, ax, geometry, proj_space, colorbar, **kwargs)


def pause(interval):
    backend = plt.rcParams["backend"]
    if backend in matplotlib.rcsetup.interactive_bk:
        figManager = matplotlib._pylab_helpers.Gcf.get_active()
        if figManager is not None:
            canvas = figManager.canvas
            if canvas.figure.stale:
                canvas.draw()
            canvas.start_event_loop(interval)
            return


def check_plot_type(plottype, f):
    if plottype == "real":
        fplot = f.real
    elif plottype == "imag":
        fplot = f.imag
    elif plottype == "module":
        fplot = f.module
    elif plottype == "phase":
        fplot = f.phase
    else:
        raise (
            ValueError(
                f"wrong plot type {plottype}, choose between real, imag, module or phase"
            )
        )
    return fplot