Skip to content
174 changes: 142 additions & 32 deletions pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import copy
import logging
from collections import deque
from typing import Any, Optional, Union
from functools import partial
from typing import Any, Callable, Optional, Union

import numpy as np
import openqasm3.ast as qasm3_ast
Expand Down Expand Up @@ -50,9 +51,10 @@ class QasmVisitor:
Args:
initialize_runtime (bool): If True, quantum runtime will be initialized. Defaults to True.
record_output (bool): If True, output of the circuit will be recorded. Defaults to True.
external_gates (list[str]): List of gates that should not be unrolled.
"""

def __init__(self, module, check_only: bool = False):
def __init__(self, module, check_only: bool = False, external_gates: list[str] | None = None):
self._module = module
self._scope: deque = deque([{}])
self._context: deque = deque([Context.GLOBAL])
Expand All @@ -65,6 +67,7 @@ def __init__(self, module, check_only: bool = False):
self._function_qreg_transform_map: deque = deque([]) # for nested functions
self._global_creg_size_map: dict[str, int] = {}
self._custom_gates: dict[str, qasm3_ast.QuantumGateDefinition] = {}
self._external_gates: list[str] = [] if external_gates is None else external_gates
self._subroutine_defns: dict[str, qasm3_ast.SubroutineDefinition] = {}
self._check_only: bool = check_only
self._curr_scope: int = 0
Expand Down Expand Up @@ -608,6 +611,73 @@ def _visit_gate_definition(self, definition: qasm3_ast.QuantumGateDefinition) ->

return []

def _unroll_multiple_target_qubits(
self, operation: qasm3_ast.QuantumGate, gate_qubit_count: int
) -> list[list[qasm3_ast.IndexedIdentifier]]:
"""Unroll the complete list of all qubits that the given operation is applied to.
E.g. this maps 'cx q[0], q[1], q[2], q[3]' to [[q[0], q[1]], [q[2], q[3]]]

Args:
operation (qasm3_ast.QuantumGate): The gate to be applied.
gate_qubit_count (list[int]): The number of qubits that a single gate acts on.

Returns:
The list of all targets that the unrolled gate should act on.
"""
op_qubits = self._get_op_bits(operation, self._global_qreg_size_map)
if len(op_qubits) % gate_qubit_count != 0:
raise_qasm3_error(
f"Invalid number of qubits {len(op_qubits)} for operation {operation.name.name}",
span=operation.span,
)
qubit_subsets = []
for i in range(0, len(op_qubits), gate_qubit_count):
# we apply the gate on the qubit subset linearly
qubit_subsets.append(op_qubits[i : i + gate_qubit_count])
return qubit_subsets

def _broadcast_gate_operation(
self, gate_function: Callable, all_targets: list[list[qasm3_ast.IndexedIdentifier]]
) -> list[qasm3_ast.QuantumGate]:
"""Broadcasts the application of a gate onto multiple sets of target qubits.

Args:
gate_function (callable): The gate that should be applied to multiple target qubits.
(All arguments of the callable should be qubits, i.e. all non-qubit arguments of the
gate should already be evaluated, e.g. using functools.partial).
all_targets (list[list[qasm3_ast.IndexedIdentifier]]):
The list of target of target qubits.
The length of this list indicates the number of time the gate is invoked.
Returns:
List of all executed gates.
"""
result = []
for targets in all_targets:
result.extend(gate_function(*targets))
return result

def _update_qubit_depth_for_gate(self, all_targets: list[list[qasm3_ast.IndexedIdentifier]]):
"""Updates the depth of the circuit after applying a broadcasted gate.

Args:
all_targes: The list of qubits on which a gate was just added.

Returns:
None
"""
for qubit_subset in all_targets:
max_involved_depth = 0
for qubit in qubit_subset:
qubit_name, qubit_id = qubit.name.name, qubit.indices[0][0].value # type: ignore
qubit_node = self._module._qubit_depths[(qubit_name, qubit_id)]
qubit_node.num_gates += 1
max_involved_depth = max(max_involved_depth, qubit_node.depth + 1)

for qubit in qubit_subset:
qubit_name, qubit_id = qubit.name.name, qubit.indices[0][0].value # type: ignore
qubit_node = self._module._qubit_depths[(qubit_name, qubit_id)]
qubit_node.depth = max_involved_depth

def _visit_basic_gate_operation( # pylint: disable=too-many-locals
self, operation: qasm3_ast.QuantumGate, inverse: bool = False
) -> list[qasm3_ast.QuantumGate]:
Expand All @@ -633,9 +703,7 @@ def _visit_basic_gate_operation( # pylint: disable=too-many-locals
ValidationError: If the number of qubits is invalid.

"""

logger.debug("Visiting basic gate operation '%s'", str(operation))
op_qubits = self._get_op_bits(operation, self._global_qreg_size_map)
inverse_action = None
if not inverse:
qasm_func, op_qubit_count = map_qasm_op_to_callable(operation.name.name)
Expand All @@ -645,45 +713,24 @@ def _visit_basic_gate_operation( # pylint: disable=too-many-locals
operation.name.name
)

op_parameters = None

if len(op_qubits) % op_qubit_count != 0:
raise_qasm3_error(
f"Invalid number of qubits {len(op_qubits)} for operation {operation.name.name}",
span=operation.span,
)
op_parameters = []

if len(operation.arguments) > 0: # parametric gate
op_parameters = self._get_op_parameters(operation)
if inverse_action == InversionOp.INVERT_ROTATION:
op_parameters = [-1 * param for param in op_parameters]

result = []
for i in range(0, len(op_qubits), op_qubit_count):
# we apply the gate on the qubit subset linearly
qubit_subset = op_qubits[i : i + op_qubit_count]
unrolled_gate = []
if op_parameters is not None:
unrolled_gate = qasm_func(*op_parameters, *qubit_subset)
else:
unrolled_gate = qasm_func(*qubit_subset)
result.extend(unrolled_gate)

# update qubit depths
max_involved_depth = 0
for qubit in qubit_subset:
qubit_name, qubit_id = qubit.name.name, qubit.indices[0][0].value # type: ignore
qubit_node = self._module._qubit_depths[(qubit_name, qubit_id)]
qubit_node.num_gates += 1
max_involved_depth = max(max_involved_depth, qubit_node.depth + 1)
unrolled_targets = self._unroll_multiple_target_qubits(operation, op_qubit_count)
unrolled_gate_function = partial(qasm_func, *op_parameters)
result.extend(self._broadcast_gate_operation(unrolled_gate_function, unrolled_targets))

for qubit in qubit_subset:
qubit_name, qubit_id = qubit.name.name, qubit.indices[0][0].value # type: ignore
qubit_node = self._module._qubit_depths[(qubit_name, qubit_id)]
qubit_node.depth = max_involved_depth
self._update_qubit_depth_for_gate(unrolled_targets)

if self._check_only:
return []

return result

def _visit_custom_gate_operation(
Expand Down Expand Up @@ -765,6 +812,67 @@ def _visit_custom_gate_operation(

return result

def _visit_external_gate_operation(
self, operation: qasm3_ast.QuantumGate, inverse: bool = False
) -> list[qasm3_ast.QuantumGate]:
"""Visit an external gate operation element.

Args:
operation (qasm3_ast.QuantumGate): The external gate operation to visit.
inverse (bool): Whether the operation is an inverse operation. Defaults to False.

If True, the gate operation is applied in reverse order and the
inverse modifier is appended to each gate call.
See https://openqasm.com/language/gates.html#inverse-modifier
for more clarity.

Returns:
list[qasm3_ast.QuantumGate]: The quantum gate that was collected.
"""

logger.debug("Visiting external gate operation '%s'", str(operation))
gate_name: str = operation.name.name

if gate_name in self._custom_gates:
# Ignore result, this is just for validation
self._visit_custom_gate_operation(operation, inverse=inverse)
# Don't need to check if custom gate exists, since we just validated the call
gate_qubit_count = len(self._custom_gates[gate_name].qubits)
else:
# Ignore result, this is just for validation
self._visit_basic_gate_operation(operation, inverse=inverse)
# Don't need to check if basic gate exists, since we just validated the call
_, gate_qubit_count = map_qasm_op_to_callable(operation.name.name)

op_parameters = [
qasm3_ast.FloatLiteral(param) for param in self._get_op_parameters(operation)
]

self._push_context(Context.GATE)

modifiers = []
if inverse:
modifiers = [qasm3_ast.QuantumGateModifier(qasm3_ast.GateModifierName.inv, None)]

def gate_function(*qubits):
Comment thread
TheGupta2012 marked this conversation as resolved.
return [
qasm3_ast.QuantumGate(
modifiers=modifiers,
name=qasm3_ast.Identifier(gate_name),
qubits=list(qubits),
arguments=list(op_parameters),
)
]

all_targets = self._unroll_multiple_target_qubits(operation, gate_qubit_count)
result = self._broadcast_gate_operation(gate_function, all_targets)

self._restore_context()
if self._check_only:
return []

return result

def _collapse_gate_modifiers(self, operation: qasm3_ast.QuantumGate) -> tuple:
"""Collapse the gate modifiers of a gate operation.
Some analysis is required to get this result.
Expand Down Expand Up @@ -828,7 +936,9 @@ def _visit_generic_gate_operation(
# apply the power first and then inverting the result
result = []
for _ in range(power_value):
if operation.name.name in self._custom_gates:
if operation.name.name in self._external_gates:
result.extend(self._visit_external_gate_operation(operation, inverse_value))
elif operation.name.name in self._custom_gates:
result.extend(self._visit_custom_gate_operation(operation, inverse_value))
else:
result.extend(self._visit_basic_gate_operation(operation, inverse_value))
Expand Down
2 changes: 1 addition & 1 deletion tests/qasm3/resources/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_fixture():
qubit[3] q1;
cx q1; // invalid application of gate, as we apply it to 3 qubits in blocks of 2
""",
"Invalid number of qubits 3 for operation .*",
"Invalid number of qubits 3 for operation cx",
),
"unsupported_parameter_type": (
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/qasm3/test_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@ def test_gate_depth():
assert result.depth() == 5


@pytest.mark.skip(reason="Not implemented computing depth of external gates")
def test_gate_depth_external_function():
qasm3_string = """
OPENQASM 3;
include "stdgates.inc";

gate my_gate() q {
h q;
x q;
}

qubit q;
my_gate() q;
"""
result = load(qasm3_string)
result.unroll(external_gates=["my_gate"])
assert result.num_qubits == 1
assert result.num_clbits == 0
assert result.depth() == 1


def test_pow_gate_depth():
qasm3_string = """
OPENQASM 3;
Expand Down
45 changes: 45 additions & 0 deletions tests/qasm3/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from tests.utils import (
check_custom_qasm_gate_op,
check_custom_qasm_gate_op_with_external_gates,
check_single_qubit_gate_op,
check_single_qubit_rotation_op,
check_three_qubit_gate_op,
Expand Down Expand Up @@ -138,6 +139,36 @@ def test_qasm_u3_gates():
check_single_qubit_rotation_op(result.unrolled_ast, 1, [0], [0.5, 0.5, 0.5], "u3")


def test_qasm_u3_gates_external():
qasm3_string = """
OPENQASM 3;
include "stdgates.inc";

qubit[2] q1;
u3(0.5, 0.5, 0.5) q1[0];
"""
result = load(qasm3_string)
result.unroll(external_gates=["u3"])
assert result.num_qubits == 2
assert result.num_clbits == 0
check_single_qubit_gate_op(result.unrolled_ast, 1, [0], "u3")


def test_qasm_u3_gates_external_with_multiple_qubits():
qasm3_string = """
OPENQASM 3;
include "stdgates.inc";

qubit[2] q1;
u3(0.5, 0.5, 0.5) q1;
"""
result = load(qasm3_string)
result.unroll(external_gates=["u3"])
assert result.num_qubits == 2
assert result.num_clbits == 0
check_single_qubit_gate_op(result.unrolled_ast, 2, [0, 1], "u3")


def test_qasm_u2_gates():
qasm3_string = """
OPENQASM 3;
Expand Down Expand Up @@ -174,6 +205,20 @@ def test_custom_ops(test_name, request):
check_custom_qasm_gate_op(result.unrolled_ast, gate_type)


@pytest.mark.parametrize("test_name", custom_op_tests)
def test_custom_ops_with_external_gates(test_name, request):
qasm3_string = request.getfixturevalue(test_name)
gate_type = test_name.removeprefix("Fixture_")
result = load(qasm3_string)
result.unroll(external_gates=["custom", "custom1"])

assert result.num_qubits == 2
assert result.num_clbits == 0

# Check for custom gate definition
check_custom_qasm_gate_op_with_external_gates(result.unrolled_ast, gate_type)


def test_pow_gate_modifier():
qasm3_string = """
OPENQASM 3;
Expand Down
15 changes: 15 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,18 @@ def check_custom_qasm_gate_op(unrolled_ast, test_type):
if test_type not in test_function_map:
raise ValueError(f"Unknown test type {test_type}")
test_function_map[test_type](unrolled_ast)


def check_custom_qasm_gate_op_with_external_gates(unrolled_ast, test_type):
if test_type == "simple":
check_two_qubit_gate_op(unrolled_ast, 1, [(0, 1)], "custom")
elif test_type == "nested":
check_two_qubit_gate_op(unrolled_ast, 1, [(0, 1)], "custom")
elif test_type == "complex":
# Only custom1 is external, custom2 and custom3 should be unrolled
check_single_qubit_gate_op(unrolled_ast, 1, [0], "custom1")
check_single_qubit_gate_op(unrolled_ast, 1, [0], "ry")
check_single_qubit_gate_op(unrolled_ast, 1, [0], "rz")
check_two_qubit_gate_op(unrolled_ast, 1, [[0, 1]], "cx")
else:
raise ValueError(f"Unknown test type {test_type}")