diff --git a/CHANGELOG.md b/CHANGELOG.md index 369f94ff..0fa9716d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ Types of changes: ### Fixed +- Fixed the way how depth is calculated when external gates are defined with unrolling a QASM module. ([#198](https://github.com/qBraid/pyqasm/pull/198)) + ### Dependencies ### Other diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index 5b6890cb..e33df77f 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -55,6 +55,7 @@ def __init__(self, name: str, program: Program): self._has_barriers: Optional[bool] = None self._validated_program = False self._unrolled_ast = Program(statements=[]) + self._external_gates: list[str] = [] @property def name(self) -> str: @@ -278,7 +279,10 @@ def depth(self): qasm_module = self.copy() qasm_module._qubit_depths = {} qasm_module._clbit_depths = {} - qasm_module.unroll() + + # Unroll using any external gates that have been recorded for this + # module + qasm_module.unroll(external_gates=self._external_gates) max_depth = 0 max_qubit_depth, max_clbit_depth = 0, 0 @@ -539,6 +543,10 @@ def unroll(self, **kwargs): kwargs = {} try: self.num_qubits, self.num_clbits = 0, 0 + if ext_gates := kwargs.get("external_gates"): + self._external_gates = ext_gates + else: + self._external_gates = [] visitor = QasmVisitor(module=self, **kwargs) self.accept(visitor) except (ValidationError, UnrollError) as err: diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index bb24ef47..a454b3a3 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -88,6 +88,7 @@ def __init__( self._unroll_barriers: bool = unroll_barriers self._curr_scope: int = 0 self._label_scope_level: dict[int, set] = {self._curr_scope: set()} + self._recording_ext_gate_depth = False self._init_utilities() @@ -750,22 +751,23 @@ def _update_qubit_depth_for_gate( Returns: None """ - for qubit_subset in all_targets: - max_involved_depth = 0 - for qubit in qubit_subset + ctrls: - assert isinstance(qubit.indices[0], list) - _qid_ = qubit.indices[0][0] - qubit_id = Qasm3ExprEvaluator.evaluate_expression(_qid_)[0] # type: ignore - qubit_node = self._module._qubit_depths[(qubit.name.name, qubit_id)] - qubit_node.num_gates += 1 - max_involved_depth = max(max_involved_depth, qubit_node.depth + 1) - - for qubit in qubit_subset + ctrls: - assert isinstance(qubit.indices[0], list) - _qid_ = qubit.indices[0][0] - qubit_id = Qasm3ExprEvaluator.evaluate_expression(_qid_)[0] # type: ignore - qubit_node = self._module._qubit_depths[(qubit.name.name, qubit_id)] - qubit_node.depth = max_involved_depth + if not self._recording_ext_gate_depth: + for qubit_subset in all_targets: + max_involved_depth = 0 + for qubit in qubit_subset + ctrls: + assert isinstance(qubit.indices[0], list) + _qid_ = qubit.indices[0][0] + qubit_id = Qasm3ExprEvaluator.evaluate_expression(_qid_)[0] # type: ignore + qubit_node = self._module._qubit_depths[(qubit.name.name, qubit_id)] + qubit_node.num_gates += 1 + max_involved_depth = max(max_involved_depth, qubit_node.depth + 1) + + for qubit in qubit_subset + ctrls: + assert isinstance(qubit.indices[0], list) + _qid_ = qubit.indices[0][0] + qubit_id = Qasm3ExprEvaluator.evaluate_expression(_qid_)[0] # type: ignore + qubit_node = self._module._qubit_depths[(qubit.name.name, qubit_id)] + qubit_node.depth = max_involved_depth def _visit_basic_gate_operation( # pylint: disable=too-many-locals self, @@ -913,6 +915,11 @@ def _visit_custom_gate_operation( gate_definition_ops.reverse() self._push_context(Context.GATE) + + # Pause recording the depth of new gates because we are processing the + # definition of a custom gate here - handle the depth separately afterwards + self._recording_ext_gate_depth = gate_name in self._external_gates + result = [] for gate_op in gate_definition_ops: if isinstance(gate_op, (qasm3_ast.QuantumGate, qasm3_ast.QuantumPhase)): @@ -942,6 +949,11 @@ def _visit_custom_gate_operation( span=gate_op.span, ) + # Update the depth only once for the entire custom gate + if self._recording_ext_gate_depth: + self._recording_ext_gate_depth = False + self._update_qubit_depth_for_gate([op_qubits], ctrls) + self._restore_context() if self._check_only: @@ -969,7 +981,6 @@ def _visit_external_gate_operation( Returns: list[qasm3_ast.QuantumGate]: The quantum gate that was collected. """ - logger.debug("Visiting external gate operation '%s'", str(operation)) gate_name: str = operation.name.name if ctrls is None: diff --git a/tests/qasm3/test_depth.py b/tests/qasm3/test_depth.py index aac75b0e..487c1f2c 100644 --- a/tests/qasm3/test_depth.py +++ b/tests/qasm3/test_depth.py @@ -51,25 +51,63 @@ def test_gate_depth(): assert result.depth() == 5 -@pytest.mark.skip(reason="Not implemented computing depth of external gates") -def test_gate_depth_external_function(): - qasm3_string = """ - OPENQASM 3; - include "stdgates.inc"; +QASM3_STRING_1 = """ +OPENQASM 3; +include "stdgates.inc"; - gate my_gate() q { - h q; - x q; - } +gate my_gate() q { + h q; + x q; +} - qubit q; - my_gate() q; - """ - result = loads(qasm3_string) +qubit q; +my_gate() q; +""" + +QASM3_STRING_2 = """ +OPENQASM 3.0; +include "stdgates.inc"; +gate my_gate q1, q2 { + h q1; + cx q1, q2; + h q2; +} +qubit[2] q; +my_gate q[0], q[1]; +""" + +QASM3_STRING_3 = """ +OPENQASM 3.0; +include "stdgates.inc"; +gate my_gate q1, q2 { } +qubit[2] q; +my_gate q[0], q[1]; +""" + + +@pytest.mark.parametrize( + ["input_qasm_str", "first_depth", "second_depth", "num_qubits"], + [ + (QASM3_STRING_1, 1, 2, 1), + (QASM3_STRING_2, 1, 3, 2), + (QASM3_STRING_3, 1, 0, 2), + ], +) +def test_gate_depth_external_function(input_qasm_str, first_depth, second_depth, num_qubits): + result = loads(input_qasm_str) result.unroll(external_gates=["my_gate"]) - assert result.num_qubits == 1 + assert result.num_qubits == num_qubits + + for i in range(num_qubits): + assert result._qubit_depths[("q", i)].num_gates == 1 + assert result.num_clbits == 0 - assert result.depth() == 1 + assert result.depth() == first_depth + + # Check that unrolling with no external_gates flushes the internally stored + # external gates and influences the depth calculation + result.unroll() + assert result.depth() == second_depth def test_pow_gate_depth():