qutip/qutip-qip

View on GitHub
tests/test_qir.py

Summary

Maintainability
B
5 hrs
Test Coverage
import typing
import pytest
import numpy as np
import random
from numpy.testing import assert_allclose
from qutip_qip.circuit import QubitCircuit

# Will skip tests in this entire file
# if PyQIR is not available.
pqg = pytest.importorskip("pyqir.generator")

pqp = pytest.importorskip("pyqir.parser")
import pyqir.parser

from qutip_qip import qir

T = typing.TypeVar("T")


def _assert_is_single(collection: typing.List[T]) -> T:
    assert len(collection) == 1
    return collection[0]


def _assert_arg_is_qubit(
    arg: pyqir.parser.QirOperand, idx: typing.Optional[int] = None
):
    assert isinstance(arg, pyqir.parser.QirQubitConstant)
    if idx is not None:
        assert arg.id == idx


def _assert_arg_is_result(
    arg: pyqir.parser.QirOperand, idx: typing.Optional[int] = None
):
    assert isinstance(arg, pyqir.parser.QirResultConstant)
    if idx is not None:
        assert arg.id == idx


def _assert_arg_is_double(
    arg: pyqir.parser.QirOperand, angle: typing.Optional[float] = None
):
    assert isinstance(arg, pyqir.parser.QirDoubleConstant)
    if angle is not None:
        np.testing.assert_allclose(arg.value, angle)


def _assert_is_simple_qis_call(
    inst: pyqir.parser.QirInstr, gate_name: str, targets: typing.List[int]
):
    assert isinstance(inst, pyqir.parser.QirQisCallInstr)
    assert inst.func_name == f"__quantum__qis__{gate_name}__body"
    assert len(inst.func_args) == len(targets)
    for target, arg in zip(targets, inst.func_args):
        _assert_arg_is_qubit(arg, target)


def _assert_is_rotation_qis_call(
    inst: pyqir.parser.QirInstr, gate_name: str, angle: float, target: int
):
    assert isinstance(inst, pyqir.parser.QirQisCallInstr)
    assert inst.func_name == f"__quantum__qis__{gate_name}__body"
    assert len(inst.func_args) == 2
    angle_arg, target_arg = inst.func_args
    _assert_arg_is_double(angle_arg, angle)
    _assert_arg_is_qubit(target_arg, target)


def _assert_is_measurement_qis_call(
    inst: pyqir.parser.QirInstr, gate_name: str, target: int, result: int
):
    assert isinstance(inst, pyqir.parser.QirQisCallInstr)
    assert inst.func_name == f"__quantum__qis__{gate_name}__body"
    assert len(inst.func_args) == 2
    target_arg, result_arg = inst.func_args
    _assert_arg_is_qubit(target_arg, target)
    _assert_arg_is_result(result_arg, result)


class TestConverter:
    """
    Test suite that checks that conversions from circuits to QIR produce
    correct QIR modules.

    Note that since literal byte equivalence is not guaranteed, our testing
    strategy will be to use the PyQIR parser package to read back the QIR that
    we export and check for semantic equivalence.
    """

    def test_simple_x_circuit(self):
        """
        Test to check conversion of a circuit
        containing a single qubit gate.
        """
        circuit = QubitCircuit(1)
        circuit.add_gate("X", targets=[0])
        parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
            circuit, format=qir.QirFormat.MODULE
        )
        parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
        assert parsed_func.required_qubits == 1
        assert parsed_func.required_results == 0
        parsed_block = _assert_is_single(parsed_func.blocks)
        inst = _assert_is_single(parsed_block.instructions)
        _assert_is_simple_qis_call(inst, "x", [0])

    def test_simple_cnot_circuit(self):
        """
        Test to check conversion of a circuit
        containing a single qubit gate.
        """
        circuit = QubitCircuit(2)
        circuit.add_gate("CX", targets=[1], controls=[0])
        parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
            circuit, format=qir.QirFormat.MODULE
        )
        parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
        assert parsed_func.required_qubits == 2
        assert parsed_func.required_results == 0
        parsed_block = _assert_is_single(parsed_func.blocks)
        inst = _assert_is_single(parsed_block.instructions)
        _assert_is_simple_qis_call(inst, "cnot", [0, 1])

    def test_simple_rz_circuit(self):
        """
        Test to check conversion of a circuit
        containing a single qubit gate.
        """
        circuit = QubitCircuit(1)
        circuit.add_gate("RZ", targets=[0], arg_value=0.123)
        parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
            circuit, format=qir.QirFormat.MODULE
        )
        parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
        assert parsed_func.required_qubits == 1
        assert parsed_func.required_results == 0
        parsed_block = _assert_is_single(parsed_func.blocks)
        inst = _assert_is_single(parsed_block.instructions)
        _assert_is_rotation_qis_call(inst, "rz", 0.123, 0)

    def test_teleport_circuit(self):
        # NB: this test is a bit detailed, as it checks metadata throughout
        #     control flow in a teleportation circuit.
        circuit = QubitCircuit(3, num_cbits=2)
        msg, here, there = range(3)
        circuit.add_gate("RZ", targets=[msg], arg_value=0.123)
        circuit.add_gate("SNOT", targets=[here])
        circuit.add_gate("CNOT", targets=[there], controls=[here])
        circuit.add_gate("CNOT", targets=[here], controls=[msg])
        circuit.add_gate("SNOT", targets=[msg])
        circuit.add_measurement("Z", targets=[msg], classical_store=0)
        circuit.add_measurement("Z", targets=[here], classical_store=1)
        circuit.add_gate("X", targets=[there], classical_controls=[0])
        circuit.add_gate("Z", targets=[there], classical_controls=[1])

        parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
            circuit, format=qir.QirFormat.MODULE
        )
        parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
        assert parsed_func.required_qubits == 3
        assert parsed_func.required_results == 2
        assert len(parsed_func.blocks) == 7

        def assert_readresult(inst, result: int):
            assert isinstance(inst, pyqir.parser.QirQisCallInstr)
            assert inst.func_name == "__quantum__qis__read_result__body"
            arg = _assert_is_single(inst.func_args)
            _assert_arg_is_result(arg, result)
            return inst.output_name

        entry = parsed_func.blocks[0]
        then = parsed_func.blocks[1]
        else_ = parsed_func.blocks[2]
        continue_ = parsed_func.blocks[3]
        then2 = parsed_func.blocks[4]
        else3 = parsed_func.blocks[5]
        continue4 = parsed_func.blocks[6]

        # Entry block
        # NB: We only check the name of the entry point block, none of the
        #     others names are semantically relevant, and thus can change
        #     without that being a breaking change.
        assert entry.name == "entry"
        _assert_is_rotation_qis_call(entry.instructions[0], "rz", 0.123, msg)
        _assert_is_simple_qis_call(entry.instructions[1], "h", [here])
        _assert_is_simple_qis_call(
            entry.instructions[2], "cnot", [here, there]
        )
        _assert_is_simple_qis_call(entry.instructions[3], "cnot", [msg, here])
        _assert_is_simple_qis_call(entry.instructions[4], "h", [msg])
        _assert_is_measurement_qis_call(entry.instructions[5], "mz", msg, 0)
        _assert_is_measurement_qis_call(entry.instructions[6], "mz", here, 1)
        cond_label = assert_readresult(entry.instructions[7], 0)
        term = entry.terminator
        assert isinstance(term, pyqir.parser.QirCondBrTerminator)
        cond = term.condition
        assert isinstance(cond, pyqir.parser.QirLocalOperand)
        assert cond.name == cond_label
        assert term.true_dest == then.name
        assert term.false_dest == else_.name

        # Then block
        inst = _assert_is_single(then.instructions)
        _assert_is_simple_qis_call(inst, "x", [there])
        term = then.terminator
        assert isinstance(term, pyqir.parser.QirBrTerminator)
        assert term.dest == continue_.name

        # else block
        assert len(else_.instructions) == 0
        term = else_.terminator
        assert isinstance(term, pyqir.parser.QirBrTerminator)
        assert term.dest == continue_.name

        # continue block
        inst = _assert_is_single(continue_.instructions)
        cond_label = assert_readresult(inst, 1)
        term = continue_.terminator
        assert isinstance(term, pyqir.parser.QirCondBrTerminator)
        cond = term.condition
        assert isinstance(cond, pyqir.parser.QirLocalOperand)
        assert cond.name == cond_label
        assert term.true_dest == then2.name
        assert term.false_dest == else3.name

        # then2 block
        inst = _assert_is_single(then2.instructions)
        _assert_is_simple_qis_call(inst, "z", [there])
        term = then2.terminator
        assert isinstance(term, pyqir.parser.QirBrTerminator)
        assert term.dest == continue4.name

        # else3 block
        assert len(else3.instructions) == 0
        term = else3.terminator
        assert isinstance(term, pyqir.parser.QirBrTerminator)
        assert term.dest == continue4.name

        # continue4 block
        assert len(continue4.instructions) == 0
        term = continue4.terminator
        assert isinstance(term, pyqir.parser.QirRetTerminator)
        assert term.operand is None