src/qutip_qip/circuit/mat_renderer.py
"""
Module for rendering a quantum circuit using matplotlib library.
"""
from typing import Union, Optional, List, Dict
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.patches import (
FancyBboxPatch,
Circle,
Arc,
FancyArrow,
)
from .base_renderer import BaseRenderer, StyleConfig
from ..operations import Gate, Measurement
from ..circuit import QubitCircuit
__all__ = [
"MatRenderer",
]
class MatRenderer(BaseRenderer):
"""
Class to render a quantum circuit using matplotlib.
Parameters
----------
qc : QuantumCircuit Object
The quantum circuit to be rendered.
ax : Axes Object, optional
The axes object to plot the circuit. The default is None.
**style
Additional style arguments to be passed to the `StyleConfig` dataclass.
"""
def __init__(
self,
qc: QubitCircuit,
ax: Axes = None,
**style,
) -> None:
# user defined style
style = {} if style is None else style
self.style = StyleConfig(**style)
super().__init__(self.style)
self._qc = qc
self._ax = ax
self._qwires = qc.N
self._cwires = qc.num_cbits
# default values
self._cwire_sep = 0.02
self._min_gate_height = 0.2
self._min_gate_width = 0.2
self._default_layers = 2
self._arrow_lenght = 0.06
self._connector_r = 0.01
self._target_node_r = 0.12
self._control_node_r = 0.05
self._display_layer_len = 0
self._start_pad = 0.1
self._layer_list = {i: [self._start_pad] for i in range(self._qwires)}
# fig config
self._zorder = {
"wire": 1,
"wire_label": 1,
"gate": 2,
"node": 2,
"bridge": 2,
"connector": 3,
"gate_label": 3,
"node_label": 3,
}
self.style.fig_height = (
((self._qwires + self._cwires) * self.style.wire_sep)
if self.style.fig_height is None
else self.style.fig_height
)
if self._ax is None:
self.fig = plt.figure()
self._ax = self.fig.add_subplot(111)
self.fig.set_dpi(self.style.dpi)
else:
self.fig = self._ax.get_figure()
def _get_text_width(
self,
text: str,
fontsize: float,
fontweight: Union[float, str],
fontfamily: str,
fontstyle: str,
) -> float:
"""
Get the width of the text to be plotted.
Parameters
----------
text : str
The text to be plotted.
fontsize : float
The fontsize of the text.
fontweight : str or float
The fontweight of the text.
fontfamily : str
The fontfamily of the text.
fontstyle : str
The fontstyle of the text.
Returns
-------
float
The width of the text inches.
"""
text_obj = plt.Text(
0,
0,
text,
fontsize=fontsize,
fontweight=fontweight,
fontfamily=fontfamily,
fontstyle=fontstyle,
)
self._ax.add_artist(text_obj)
bbox = text_obj.get_window_extent(
renderer=self._ax.figure.canvas.get_renderer()
)
text_obj.remove()
return bbox.width / self.style.dpi
def _add_wire(self) -> None:
"""
Adds the wires to the circuit.
"""
max_len = (
max([sum(self._layer_list[i]) for i in range(self._qwires)])
+ self.style.end_wire_ext * self.style.layer_sep
)
for i in range(self._qwires):
wire = plt.Line2D(
[0, max_len],
[
(i + self._cwires) * self.style.wire_sep,
(i + self._cwires) * self.style.wire_sep,
],
lw=1,
color=self.style.wire_color,
zorder=self._zorder["wire"],
)
self._ax.add_line(wire)
for i in range(self._cwires):
wire_up = plt.Line2D(
[0, max_len],
[
(i * self.style.wire_sep) + self._cwire_sep,
(i * self.style.wire_sep) + self._cwire_sep,
],
lw=1,
color=self.style.wire_color,
zorder=self._zorder["wire"],
)
wire_down = plt.Line2D(
[0, max_len],
[
(i * self.style.wire_sep) - self._cwire_sep,
(i * self.style.wire_sep) - self._cwire_sep,
],
lw=1,
color=self.style.wire_color,
zorder=self._zorder["wire"],
)
self._ax.add_line(wire_up)
self._ax.add_line(wire_down)
def _add_wire_labels(self) -> None:
"""
Adds the wire labels to the circuit.
"""
if self.style.wire_label is None:
default_labels = [f"$c_{{{i}}}$" for i in range(self._cwires)] + [
f"$q_{{{i}}}$" for i in range(self._qwires)
]
self.style.wire_label = default_labels
self.max_label_width = max(
[
self._get_text_width(
label,
self.style.fontsize,
"normal",
"monospace",
"normal",
)
for label in self.style.wire_label
]
)
for i, label in enumerate(self.style.wire_label):
wire_label = plt.Text(
-self.style.label_pad,
i * self.style.wire_sep,
label,
fontsize=self.style.fontsize,
fontweight="normal",
fontfamily="monospace",
fontstyle="normal",
verticalalignment="center",
horizontalalignment="right",
zorder=self._zorder["wire_label"],
color=self.style.wire_color,
)
self._ax.add_artist(wire_label)
def _draw_control_node(self, pos: int, xskip: float, color: str) -> None:
"""
Draw the control node for the multi-qubit gate.
Parameters
----------
pos : int
The position of the control node, in terms of the wire number.
xskip : float
The horizontal value for getting to requested layer.
color : str
The color of the control node. HEX code or color name supported by matplotlib are valid.
"""
pos = pos + self._cwires
control_node = Circle(
(
xskip + self.style.gate_margin + self.style.gate_pad,
pos * self.style.wire_sep,
),
self._control_node_r,
color=color,
zorder=self._zorder["node"],
)
self._ax.add_artist(control_node)
def _draw_target_node(self, pos: int, xskip: float, color: str) -> None:
"""
Draw the target node for the multi-qubit gate.
Parameters
----------
pos : int
The position of the target node, in terms of the wire number.
xskip : float
The horizontal value for getting to requested layer.
color : str
The color of the control node. HEX code or color name supported by matplotlib are valid.
"""
pos = pos + self._cwires
target_node = Circle(
(
xskip + self.style.gate_margin + self.style.gate_pad,
pos * self.style.wire_sep,
),
self._target_node_r,
color=color,
zorder=self._zorder["node"],
)
vertical_line = plt.Line2D(
(
xskip + self.style.gate_margin + self.style.gate_pad,
xskip + self.style.gate_margin + self.style.gate_pad,
),
(
pos * self.style.wire_sep - self._target_node_r / 2,
pos * self.style.wire_sep + self._target_node_r / 2,
),
lw=1.5,
color=self.style.color,
zorder=self._zorder["node_label"],
)
horizontal_line = plt.Line2D(
(
xskip
+ self.style.gate_margin
+ self.style.gate_pad
- self._target_node_r / 2,
xskip
+ self.style.gate_margin
+ self.style.gate_pad
+ self._target_node_r / 2,
),
(pos * self.style.wire_sep, pos * self.style.wire_sep),
lw=1.5,
color=self.style.color,
zorder=self._zorder["node_label"],
)
self._ax.add_artist(target_node)
self._ax.add_line(vertical_line)
self._ax.add_line(horizontal_line)
def _draw_qbridge(
self, pos1: int, pos2: int, xskip: float, color: str
) -> None:
"""
Draw the bridge between the control and target nodes for the multi-qubit gate.
Parameters
----------
pos1 : int
The position of the first node for the bridge, in terms of the wire number.
pos2 : int
The position of the second node for the bridge, in terms of the wire number.
xskip : float
The horizontal value for getting to requested layer.
color : str
The color of the control node. HEX code or color name supported by matplotlib are valid.
"""
pos2 = pos2 + self._cwires
pos1 = pos1 + self._cwires
bridge = plt.Line2D(
[
xskip + self.style.gate_margin + self.style.gate_pad,
xskip + self.style.gate_margin + self.style.gate_pad,
],
[pos1 * self.style.wire_sep, pos2 * self.style.wire_sep],
color=color,
zorder=self._zorder["bridge"],
)
self._ax.add_line(bridge)
def _draw_cbridge(
self, c_pos: int, q_pos: int, xskip: float, color: str
) -> None:
"""
Draw the bridge between the classical and quantum wires for the measurement gate.
Parameters
----------
c_pos : int
The position of the classical wire, in terms of the wire number.
q_pos : int
The position of the quantum wire, in terms of the wire number.
xskip : float
The horizontal value for getting to requested layer.
color : str
The color of the bridge.
"""
q_pos = q_pos + self._cwires
cbridge_l = plt.Line2D(
(
xskip
+ self.style.gate_margin
+ self._min_gate_width / 2
- self._cwire_sep,
xskip
+ self.style.gate_margin
+ self._min_gate_width / 2
- self._cwire_sep,
),
(
c_pos * self.style.wire_sep + self._arrow_lenght,
q_pos * self.style.wire_sep,
),
color=color,
zorder=self._zorder["bridge"],
)
cbridge_r = plt.Line2D(
(
xskip
+ self.style.gate_margin
+ self._min_gate_width / 2
+ self._cwire_sep,
xskip
+ self.style.gate_margin
+ self._min_gate_width / 2
+ self._cwire_sep,
),
(
c_pos * self.style.wire_sep + self._arrow_lenght,
q_pos * self.style.wire_sep,
),
color=color,
zorder=self._zorder["bridge"],
)
end_arrow = FancyArrow(
xskip + self.style.gate_margin + self._min_gate_width / 2,
c_pos * self.style.wire_sep + self._arrow_lenght,
0,
-self._cwire_sep * 3,
width=0,
head_width=self._cwire_sep * 5,
head_length=self._cwire_sep * 3,
length_includes_head=True,
color=color,
zorder=self._zorder["bridge"],
)
self._ax.add_line(cbridge_l)
self._ax.add_line(cbridge_r)
self._ax.add_artist(end_arrow)
def _draw_swap_mark(self, pos: int, xskip: int, color: str) -> None:
"""
Draw the swap mark for the SWAP gate.
Parameters
----------
pos : int
The position of the swap mark, in terms of the wire number.
xskip : float
The horizontal value for getting to requested layer.
color : str
The color of the swap mark.
"""
pos = pos + self._cwires
dia_left = plt.Line2D(
[
xskip
+ self.style.gate_margin
+ self.style.gate_pad
+ self._min_gate_width / 3,
xskip
+ self.style.gate_margin
+ self.style.gate_pad
- self._min_gate_width / 3,
],
[
pos * self.style.wire_sep + self._min_gate_height / 2,
pos * self.style.wire_sep - self._min_gate_height / 2,
],
color=color,
linewidth=2,
zorder=self._zorder["gate"],
)
dia_right = plt.Line2D(
[
xskip
+ self.style.gate_margin
+ self.style.gate_pad
- self._min_gate_width / 3,
xskip
+ self.style.gate_margin
+ self.style.gate_pad
+ self._min_gate_width / 3,
],
[
pos * self.style.wire_sep + self._min_gate_height / 2,
pos * self.style.wire_sep - self._min_gate_height / 2,
],
color=color,
linewidth=2,
zorder=self._zorder["gate"],
)
self._ax.add_line(dia_left)
self._ax.add_line(dia_right)
def to_pi_fraction(self, value: float, tolerance: float = 0.01) -> str:
"""
Convert a value to a string fraction of pi.
Parameters
----------
value : float
The value to be converted.
tolerance : float, optional
The tolerance for the fraction. The default is 0.01.
Returns
-------
str
The value in terms of pi.
"""
pi_value = value / np.pi
if abs(pi_value - round(pi_value)) < tolerance:
num = round(pi_value)
return f"[{num}\\pi]" if num != 1 else "[\\pi]"
for denom in [2, 3, 4, 6, 8, 12]:
fraction_value = pi_value * denom
if abs(fraction_value - round(fraction_value)) < tolerance:
num = round(fraction_value)
return (
f"[{num}\\pi/{denom}]" if num != 1 else f"[\\pi/{denom}]"
)
return f"[{round(value, 2)}]"
def _draw_singleq_gate(self, gate: Gate, layer: int) -> None:
"""
Draw the single qubit gate.
Parameters
----------
gate : Gate Object
The gate to be plotted.
layer : int
The layer the gate is acting on.
"""
gate_wire = gate.targets[0]
if gate.arg_value is not None and self.showarg is True:
pi_frac = self.to_pi_fraction(gate.arg_value)
text = f"${{{self.text}}}_{{{pi_frac}}}$"
else:
text = self.text
text_width = self._get_text_width(
text,
self.fontsize,
self.fontweight,
self.fontfamily,
self.fontstyle,
)
gate_width = max(
text_width + self.style.gate_pad * 2, self._min_gate_width
)
gate_text = plt.Text(
self._get_xskip([gate_wire], layer)
+ self.style.gate_margin
+ gate_width / 2,
(gate_wire + self._cwires) * self.style.wire_sep,
text,
color=self.fontcolor,
fontsize=self.fontsize,
fontweight=self.fontweight,
fontfamily=self.fontfamily,
fontstyle=self.fontstyle,
verticalalignment="center",
horizontalalignment="center",
zorder=self._zorder["gate_label"],
)
gate_patch = FancyBboxPatch(
(
self._get_xskip([gate_wire], layer) + self.style.gate_margin,
(gate_wire + self._cwires) * self.style.wire_sep
- self._min_gate_height / 2,
),
gate_width,
self._min_gate_height,
boxstyle=self.style.bulge,
mutation_scale=0.3,
facecolor=self.color,
edgecolor=self.color,
zorder=self._zorder["gate"],
)
self._ax.add_artist(gate_text)
self._ax.add_patch(gate_patch)
self._manage_layers(gate_width, [gate_wire], layer)
def _draw_multiq_gate(self, gate: Gate, layer: int) -> None:
"""
Draw the multi-qubit gate.
Parameters
----------
gate : Gate Object
The gate to be plotted.
layer : int
The layer the gate is acting on.
"""
wire_list = list(
range(self.merged_wires[0], self.merged_wires[-1] + 1)
)
com_xskip = self._get_xskip(wire_list, layer)
if gate.name == "CNOT" or gate.name == "CX":
self._draw_control_node(gate.controls[0], com_xskip, self.color)
self._draw_target_node(gate.targets[0], com_xskip, self.color)
self._draw_qbridge(
gate.targets[0], gate.controls[0], com_xskip, self.color
)
self._manage_layers(
2 * self.style.gate_pad + self._target_node_r / 3,
wire_list,
layer,
com_xskip,
)
elif gate.name == "SWAP":
self._draw_swap_mark(gate.targets[0], com_xskip, self.color)
self._draw_swap_mark(gate.targets[1], com_xskip, self.color)
self._draw_qbridge(
gate.targets[0], gate.targets[1], com_xskip, self.color
)
self._manage_layers(
2 * (self.style.gate_pad + self._min_gate_width / 3),
wire_list,
layer,
com_xskip,
)
elif gate.name == "TOFFOLI":
self._draw_control_node(gate.controls[0], com_xskip, self.color)
self._draw_control_node(gate.controls[1], com_xskip, self.color)
self._draw_target_node(gate.targets[0], com_xskip, self.color)
self._draw_qbridge(
gate.targets[0], gate.controls[0], com_xskip, self.color
)
self._draw_qbridge(
gate.targets[0], gate.controls[1], com_xskip, self.color
)
self._manage_layers(
2 * self.style.gate_pad + self._target_node_r / 3,
wire_list,
layer,
com_xskip,
)
else:
adj_targets = [i + self._cwires for i in sorted(gate.targets)]
text_width = self._get_text_width(
self.text,
self.fontsize,
self.fontweight,
self.fontfamily,
self.fontstyle,
)
gate_width = max(
text_width + self.style.gate_pad * 2, self._min_gate_width
)
xskip = self._get_xskip(wire_list, layer)
gate_text = plt.Text(
xskip + self.style.gate_margin + gate_width / 2,
(adj_targets[0] + adj_targets[-1]) / 2 * self.style.wire_sep,
self.text,
color=self.fontcolor,
fontsize=self.fontsize,
fontweight=self.fontweight,
fontfamily=self.fontfamily,
fontstyle=self.fontstyle,
verticalalignment="center",
horizontalalignment="center",
zorder=self._zorder["gate_label"],
)
gate_patch = FancyBboxPatch(
(
xskip + self.style.gate_margin,
adj_targets[0] * self.style.wire_sep
- self._min_gate_height / 2,
),
gate_width,
self._min_gate_height
+ self.style.wire_sep * (adj_targets[-1] - adj_targets[0]),
boxstyle=self.style.bulge,
mutation_scale=0.3,
facecolor=self.color,
edgecolor=self.color,
zorder=self._zorder["gate"],
)
if len(gate.targets) > 1:
for i in range(len(gate.targets)):
connector_l = Circle(
(
xskip + self.style.gate_margin + self._connector_r,
(adj_targets[i]) * self.style.wire_sep,
),
self._connector_r,
color=self.fontcolor,
zorder=self._zorder["connector"],
)
connector_r = Circle(
(
xskip
+ self.style.gate_margin
+ gate_width
- self._connector_r,
(adj_targets[i]) * self.style.wire_sep,
),
self._connector_r,
color=self.fontcolor,
zorder=self._zorder["connector"],
)
self._ax.add_artist(connector_l)
self._ax.add_artist(connector_r)
# add cbridge if control qubits are present
if gate.controls is not None:
for control in gate.controls:
self._draw_control_node(
control, xskip + text_width / 2, self.color
)
self._draw_qbridge(
control,
gate.targets[0],
xskip + text_width / 2,
self.color,
)
self._ax.add_artist(gate_text)
self._ax.add_patch(gate_patch)
self._manage_layers(gate_width, wire_list, layer, xskip)
def _draw_measure(self, c_pos: int, q_pos: int, layer: int) -> None:
"""
Draw the measurement gate.
Parameters
----------
c_pos : int
The position of the classical wire.
q_pos : int
The position of the quantum wire.
layer : int
The layer the gate is acting on.
"""
xskip = self._get_xskip(
list(range(0, self.merged_wires[-1] + 1)), layer
)
measure_box = FancyBboxPatch(
(
xskip + self.style.gate_margin,
(q_pos + self._cwires) * self.style.wire_sep
- self._min_gate_height / 2,
),
self._min_gate_width,
self._min_gate_height,
boxstyle=self.style.bulge,
mutation_scale=0.3,
facecolor=self.style.bgcolor,
edgecolor=self.style.measure_color,
linewidth=1.25,
zorder=self._zorder["gate"],
)
arc = Arc(
(
xskip + self.style.gate_margin + self._min_gate_width / 2,
(q_pos + self._cwires) * self.style.wire_sep
- self._min_gate_height / 2,
),
self._min_gate_width * 1.5,
self._min_gate_height * 1,
angle=0,
theta1=0,
theta2=180,
color=self.style.measure_color,
linewidth=1.25,
zorder=self._zorder["gate_label"],
)
arrow = FancyArrow(
xskip + self.style.gate_margin + self._min_gate_width / 2,
(q_pos + self._cwires) * self.style.wire_sep
- self._min_gate_height / 2,
self._min_gate_width * 0.7,
self._min_gate_height * 0.7,
length_includes_head=True,
head_width=0,
linewidth=1.25,
color=self.style.measure_color,
zorder=self._zorder["gate_label"],
)
self._draw_cbridge(c_pos, q_pos, xskip, color=self.style.measure_color)
self._manage_layers(
self._min_gate_width,
list(range(0, self.merged_wires[-1] + 1)),
layer,
xskip,
)
self._ax.add_patch(measure_box)
self._ax.add_artist(arc)
self._ax.add_artist(arrow)
def canvas_plot(self) -> None:
"""
Plot the quantum circuit.
"""
self._add_wire_labels()
for gate in self._qc.gates:
if isinstance(gate, Measurement):
self.merged_wires = gate.targets.copy()
self.merged_wires.sort()
self._draw_measure(
gate.classical_store,
gate.targets[0],
max(
len(self._layer_list[i])
for i in range(0, self.merged_wires[-1] + 1)
),
)
if isinstance(gate, Gate):
style = gate.style if gate.style is not None else {}
self.text = (
gate.arg_label if gate.arg_label is not None else gate.name
)
self.color = style.get(
"color",
self.style.theme.get(
gate.name, self.style.theme["default_gate"]
),
)
self.fontsize = style.get("fontsize", self.style.fontsize)
self.fontcolor = style.get("fontcolor", self.style.color)
self.fontweight = style.get("fontweight", "normal")
self.fontstyle = style.get("fontstyle", "normal")
self.fontfamily = style.get("fontfamily", "monospace")
self.showarg = style.get("showarg", False)
# multi-qubit gate
if (
len(gate.targets) > 1
or getattr(gate, "controls", False) is not None
):
self.merged_wires = gate.targets.copy()
if gate.controls is not None:
self.merged_wires += gate.controls.copy()
self.merged_wires.sort()
find_layer = [
len(self._layer_list[i])
for i in range(
self.merged_wires[0], self.merged_wires[-1] + 1
)
]
self._draw_multiq_gate(gate, max(find_layer))
else:
self._draw_singleq_gate(
gate, len(self._layer_list[gate.targets[0]])
)
self._add_wire()
self._fig_config()
plt.tight_layout()
plt.show()
def _fig_config(self) -> None:
"""
Configure the figure settings.
"""
self.fig.set_facecolor(self.style.bgcolor)
self.fig.set_size_inches(
self.style.fig_width, self.style.fig_height, forward=True
)
self._ax.set_ylim(
-self.style.padding,
self.style.padding
+ (self._qwires + self._cwires - 1) * self.style.wire_sep,
)
self._ax.set_xlim(
-self.style.padding - self.max_label_width - self.style.label_pad,
self.style.padding
+ self.style.end_wire_ext * self.style.layer_sep
+ max([sum(self._layer_list[i]) for i in range(self._qwires)]),
)
if self.style.title is not None:
self._ax.set_title(
self.style.title, pad=10, color=self.style.wire_color
)
self._ax.set_aspect("equal")
self._ax.axis("off")
def save(self, filename: str, **kwargs) -> None:
"""
Save the circuit to a file.
Parameters
----------
filename : str
The name of the file to save the circuit to.
**kwargs
Additional arguments to be passed to `plt.savefig`.
"""
self.fig.savefig(filename, bbox_inches="tight", **kwargs)