alice4space/qutree

View on GitHub
qutree/__init__.py

Summary

Maintainability
A
3 hrs
Test Coverage
"""Plot sets of multiqubit quantum pure states as a binary tree of Bloch spheres."""

__version__ = "0.1.1"
__author__ = "Alice Barthe"
__email__ = "alice.barthe@cern.ch"

from pathlib import Path
from typing import Optional, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import qutip
from matplotlib import cm


def nphi_psi(psi: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    r"""Decompose an array of complex into absolute value and angles.

    Args:
        psi: array of complex values

    Returns:
        r: array of absolute values
        phi: array of angles from :math:`[0, 2\pi]`
    """
    r = np.abs(psi)
    phi = np.angle(psi) % (2 * np.pi)

    return r, phi


def nthe_n0n1(n0: np.ndarray, n1: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    r"""Cartesian to Polar.

    Compute :math:`n` and :math:`\theta` such that :math:`[n0,n1] = n[cos(\theta),sin(\theta)]`

    Args:
        n0: array of absolute values
        n1: array of absolute values

    Returns:
        n: array of norm
        the: array of angles from :math:`[0, \pi]`
    """
    n = np.sqrt(n0**2 + n1**2)
    the = 2 * np.arctan2(n1, n0)

    return n, the


def phimp_phi01(phi0, phi1):
    """Go from two subspace phases to local and global phase.

    Args:
        phi0: array of subspace 0 phases
        phi1: array of subspace 1 phases

    Returns:
        phig: array of global phases
        phil: array of local phases
    """
    phig = phi0
    phil = phi1 - phi0 + (phi1 < phi0) * 2 * np.pi
    return phil, phig


def thephi_to_xyz(the: np.ndarray, phi: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    r"""Spherical to cartesian.

    Args:
        the: array of the angles :math:`[0,\pi]`
        phi: array of :math:`\phi` angles :math:`[0,2 \pi]`

    Returns:
        x: array of the x axis coordinate
        y: array of the y axis coordinate
        z: array of the z axis coordinate
    """
    z = np.cos(the)
    x = np.sin(the) * np.cos(phi)
    y = np.sin(the) * np.sin(phi)

    return x, y, z


def fun_recursive(
    psi: np.ndarray,
    coord: str = "",
    tree_vals: list = [],
    tree_idxs: list = [],
    tol: Optional[float] = 1e-4,
) -> Tuple[np.ndarray, np.ndarray, list, list]:
    r"""Recursive function going down the hilbert schmidt decomposition.

    Args:
        psi: array representing the state in the current subspace :math:`(2^N,M)`
        coord: string representing the current coordinate in the binary tree
        tree_vals: register storing the :math:`\theta` and :math:`\phi` of the Bloch sphere at each coordinate
        tree_idxs: register storing the coordinates
        tol: tolerance to delete a subspace.

    Returns:
        phip: the global phases of the local subspace
        n: the amplitude of the local subspace
        tree_vals: register storing the :math:`\theta` and :math:`\phi` of the Bloch sphere at each coordinate
        tree_idxs: register storing the coordinates
    """
    N = psi.shape[0]

    if N == 2:
        ns, phis = nphi_psi(psi)
        n0 = ns[0, :]
        n1 = ns[1, :]
        phi0 = phis[0, :]
        phi1 = phis[1, :]

    else:
        psi0 = psi[: N // 2, :].copy()
        psi1 = psi[N // 2 :, :].copy()

        if tol is not None:
            n0 = np.sqrt(np.sum(np.abs(psi0) ** 2, axis=0))
            n1 = np.sqrt(np.sum(np.abs(psi1) ** 2, axis=0))

            i0 = np.where(n0 < tol)
            i1 = np.where(n1 < tol)

            psi0[:, i0] = psi1[:, i0]
            psi1[:, i1] = psi0[:, i1]

            phi0, _, tree_vals, tree_idxs = fun_recursive(
                psi0, coord + "0", tree_vals, tree_idxs, tol
            )
            phi1, _, tree_vals, tree_idxs = fun_recursive(
                psi1, coord + "1", tree_vals, tree_idxs, tol
            )

        else:
            phi0, n0, tree_vals, tree_idxs = fun_recursive(
                psi0, coord + "0", tree_vals, tree_idxs, tol
            )
            phi1, n1, tree_vals, tree_idxs = fun_recursive(
                psi1, coord + "1", tree_vals, tree_idxs, tol
            )

    n, the = nthe_n0n1(n0, n1)
    phim, phip = phimp_phi01(phi0, phi1)

    tree_vals.append(np.stack([the, phim]))
    tree_idxs.append(coord)

    return phip, n, tree_vals, tree_idxs


def bloch_points(
    xp: np.ndarray,
    yp: np.ndarray,
    zp: np.ndarray,
    colors,
    ax: Optional[mpl.axes.Axes] = None,
    azim: float = -40.0,
    elev: float = 30.0,
) -> None:
    """Plot points on a bloch spere.

    Args:
        xp: array of the x axis coordinate
        yp: array of the y axis coordinate
        zp: array of the z axis coordinate
        colors: array of colors for each point
        ax: ax where the Bloch Sphere will be plotted, if None one is created
        azim: azimuth to plot the sphere
        elev: azimuth to plot the sphere
    """
    if ax is None:
        fig = plt.figure(constrained_layout=True)
        ax = fig.add_subplot(1, 1, 1, projection="3d", azim=azim, elev=elev)

    b = qutip.Bloch(axes=ax)
    b.make_sphere()
    pnts = [xp, yp, zp]
    b.add_points(pnts, "m")
    b.point_color = list(colors)
    b.render()


class BBT:
    def __init__(self, num_qubits: int) -> None:
        """Class containing the data to plot the Bloch Binary Tree.

        Args:
            num_qubits: number of qubits
        """
        self.num_qubits = num_qubits

    def add_data(
        self, states: np.ndarray, colors: Optional[np.ndarray] = None, cmap: str = "jet"
    ) -> None:
        """Add data of states to the tree computing the Bloch spheres parameters.

        Args:
            states: complex array of M samples
            colors: complex array of either Mx4 for already generated colors, M for values between 0 and 1 that will follow the colormap, by default it will be the indexes.
            cmap: string name of a matplotlib colormap (default : 'jet')
        """
        assert (
            states.shape[0] == 2**self.num_qubits
        ), "The data doesn't fit the expected number of qubits"
        assert len(states.shape) == 2, "the data must be a 2D tensor"

        self.num_samples = states.shape[1]

        tree_vals: list = []
        tree_idxs: list = []

        fun_recursive(states, tree_vals=tree_vals, tree_idxs=tree_idxs, tol=None)

        self.tree_vals = tree_vals
        self.tree_idxs = tree_idxs

        if colors is not None:
            assert (colors.max() <= 1) and (
                colors.min() >= 0
            ), "the colors values must be between 0 and 1"
            if (
                len(colors.shape) == 2
            ):  # case in which the color is a a 2D tensor, colors matrix directly
                assert (
                    colors.shape[0] == 4
                ), "the colors matrix must be a 2D tensor of 4 (RGBA) by the number of samples"
                assert (
                    colors.shape[1] == self.num_samples
                ), "the colors matrix must be a 2D tensor of 4 (RGBA) by the number of samples"
                self.samples_colors = colors
            elif (
                len(colors.shape) == 1
            ):  # case in which the color is a a 1D tensor, create colors matrix with cmap
                assert (
                    colors.shape[0] == self.num_samples
                ), "the colors must have the same length as the number of samples"
                cm_custom = cm.get_cmap(cmap, 256)
                self.samples_colors = cm_custom(colors)
        else:
            cm_custom = cm.get_cmap(cmap, self.num_samples)
            self.samples_colors = cm_custom(np.linspace(0, 1, self.num_samples))

    def plot_tree(
        self,
        azim: float = -40.0,
        elev: float = 30.0,
        size_sphere: float = 2.0,
        dst_file: Union[Path, str] = "",
    ) -> None:
        """Display the binary tree.

        Args:
            azim: viewing angle -> azimuth (default:-40)
            elev: viewing angle -> elevation (default:30)
            size_sphere: size of each single sphere.
            dst_file: name of the destination file, leave blank to not save the plot.
        """
        dw = 1 / (2 ** (self.num_qubits - 1))
        dh = 1 / self.num_qubits
        fig = plt.figure(figsize=(size_sphere / dw, size_sphere / dh))

        _ = fig.add_axes([0, 0, 1, 1])
        _.axis("off")

        for i, d in enumerate(self.tree_idxs):
            nlevel = self.num_qubits - len(d)
            h = (nlevel - 1) * dh
            if len(d) == 0:
                w = 0.5 - dw / 2
            else:
                nbit = int(d, 2)
                if nlevel == 1:
                    w = dw * nbit
                else:
                    w = (1 + 2 * nbit) / 2 ** (len(d) + 1) - dw / 2
            ax = fig.add_axes([w, h, dw, dh], projection="3d", azim=azim, elev=elev)
            ax.set_facecolor((0, 0, 0, 0))  # remove background
            the = self.tree_vals[i][0, :]
            phi = self.tree_vals[i][1, :]
            x, y, z = thephi_to_xyz(the, phi)
            bloch_points(x, y, z, colors=self.samples_colors, ax=ax)

            if dst_file != "":
                fig.savefig(dst_file, transparent=True)
                print(f"file successfully saved to {dst_file}")