diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fa9716d..5bdecd93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,27 @@ Types of changes: ### Fixed - Fixed the way how depth is calculated when external gates are defined with unrolling a QASM module. ([#198](https://github.com/qBraid/pyqasm/pull/198)) +- Added separate depth calculation for gates inside branching statements. ([#200](https://github.com/qBraid/pyqasm/pull/200)) + - **Example:** + ```python + OPENQASM 3.0; + include "stdgates.inc"; + qubit[4] q; + bit[4] c; + bit[4] c0; + if (c[0]){ + x q[0]; + h q[0] + } + else { + h q[1]; + } + ``` + ```text + Depth = 1 + ``` + - Previously, each gate inside an `if`/`else` block would advance only its own wire depth. Now, when any branching statement is encountered, all qubit‐ and clbit‐depths used inside that block are first incremented by one, then set to the maximum of those new values. This ensures the entire conditional block counts as single “depth” increment, rather than letting individual gates within the same branch float ahead independently. + - In the above snippet, c[0], q[0], and q[1] all jump together to a single new depth for that branch. ### Dependencies diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index a454b3a3..0fd6d1ec 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -89,7 +89,9 @@ def __init__( self._curr_scope: int = 0 self._label_scope_level: dict[int, set] = {self._curr_scope: set()} self._recording_ext_gate_depth = False - + self._in_branching_statement: int = 0 + self._is_branch_qubits: set[tuple[str, int]] = set() + self._is_branch_clbits: set[tuple[str, int]] = set() self._init_utilities() def _init_utilities(self): @@ -500,10 +502,12 @@ def _visit_measurement( # pylint: disable=too-many-locals measure=qasm3_ast.QuantumMeasurement(qubit=src_id), target=None ) ) - src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore - qubit_node = self._module._qubit_depths[(src_name, src_id)] - qubit_node.depth += 1 - qubit_node.num_measurements += 1 + # if measurement gate is not in branching statement + if not self._in_branching_statement: + src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore + qubit_node = self._module._qubit_depths[(src_name, src_id)] + qubit_node.depth += 1 + qubit_node.num_measurements += 1 else: target_name: str = ( target.name if isinstance(target, qasm3_ast.Identifier) else target.name.name @@ -533,21 +537,23 @@ def _visit_measurement( # pylint: disable=too-many-locals measure=qasm3_ast.QuantumMeasurement(qubit=src_id), target=tgt_id if target else None, ) - src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore - tgt_name, tgt_id = tgt_id.name.name, tgt_id.indices[0][0].value # type: ignore - - qubit_node, clbit_node = ( - self._module._qubit_depths[(src_name, src_id)], - self._module._clbit_depths[(tgt_name, tgt_id)], - ) - qubit_node.depth += 1 - qubit_node.num_measurements += 1 + # if measurement gate is not in branching statement + if not self._in_branching_statement: + src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore + tgt_name, tgt_id = tgt_id.name.name, tgt_id.indices[0][0].value # type: ignore + + qubit_node, clbit_node = ( + self._module._qubit_depths[(src_name, src_id)], + self._module._clbit_depths[(tgt_name, tgt_id)], + ) + qubit_node.depth += 1 + qubit_node.num_measurements += 1 - clbit_node.depth += 1 - clbit_node.num_measurements += 1 + clbit_node.depth += 1 + clbit_node.num_measurements += 1 - qubit_node.depth = max(qubit_node.depth, clbit_node.depth) - clbit_node.depth = max(qubit_node.depth, clbit_node.depth) + qubit_node.depth = max(qubit_node.depth, clbit_node.depth) + clbit_node.depth = max(qubit_node.depth, clbit_node.depth) unrolled_measurements.append(unrolled_measure) @@ -769,7 +775,8 @@ def _update_qubit_depth_for_gate( qubit_node = self._module._qubit_depths[(qubit.name.name, qubit_id)] qubit_node.depth = max_involved_depth - def _visit_basic_gate_operation( # pylint: disable=too-many-locals + # pylint: disable=too-many-branches, too-many-locals + def _visit_basic_gate_operation( self, operation: qasm3_ast.QuantumGate, inverse: bool = False, @@ -851,8 +858,16 @@ def _visit_basic_gate_operation( # pylint: disable=too-many-locals result.extend( self._broadcast_gate_operation(unrolled_gate_function, unrolled_targets, ctrls) ) - - self._update_qubit_depth_for_gate(unrolled_targets, ctrls) + # if gate is not in branching statement + if not self._in_branching_statement: + self._update_qubit_depth_for_gate(unrolled_targets, ctrls) + else: + for qubit_subset in unrolled_targets + [ctrls]: # get qreg in branching operations + for qubit in qubit_subset: + assert isinstance(qubit.indices, list) and len(qubit.indices) > 0 + assert isinstance(qubit.indices[0], list) and len(qubit.indices[0]) > 0 + qubit_idx = Qasm3ExprEvaluator.evaluate_expression(qubit.indices[0][0])[0] + self._is_branch_qubits.add((qubit.name.name, qubit_idx)) # check for duplicate bits for final_gate in result: @@ -952,7 +967,16 @@ def _visit_custom_gate_operation( # Update the depth only once for the entire custom gate if self._recording_ext_gate_depth: self._recording_ext_gate_depth = False - self._update_qubit_depth_for_gate([op_qubits], ctrls) + if not self._in_branching_statement: # if custom gate is not in branching statement + self._update_qubit_depth_for_gate([op_qubits], ctrls) + else: + # get qubit registers in branching operations + for qubit_subset in [op_qubits] + [ctrls]: + for qubit in qubit_subset: + assert isinstance(qubit.indices, list) and len(qubit.indices) > 0 + assert isinstance(qubit.indices[0], list) and len(qubit.indices[0]) > 0 + qubit_idx = Qasm3ExprEvaluator.evaluate_expression(qubit.indices[0][0])[0] + self._is_branch_qubits.add((qubit.name.name, qubit_idx)) self._restore_context() @@ -1611,6 +1635,24 @@ def _evaluate_array_initialization( return np.array(init_values, dtype=ARRAY_TYPE_MAP[base_type.__class__]) + # update branching operators depth + def _update_branching_gate_depths(self) -> None: + """Updates the depth of the circuit after applying branching statements.""" + all_nodes = [ + self._module._qubit_depths[(name, idx)] for name, idx in self._is_branch_qubits + ] + [self._module._clbit_depths[(name, idx)] for name, idx in self._is_branch_clbits] + + try: + max_depth = max(node.depth + 1 for node in all_nodes) + except ValueError: + max_depth = 0 + + for node in all_nodes: + node.depth = max_depth + + self._is_branch_clbits.clear() + self._is_branch_qubits.clear() + def _visit_branching_statement( self, statement: qasm3_ast.BranchingStatement ) -> list[qasm3_ast.Statement]: @@ -1626,6 +1668,7 @@ def _visit_branching_statement( self._push_scope({}) self._curr_scope += 1 self._label_scope_level[self._curr_scope] = set() + self._in_branching_statement += 1 result = [] condition = statement.condition @@ -1659,6 +1702,9 @@ def _visit_branching_statement( reg_idx, self._global_creg_size_map[reg_name], qubit=False, op_node=condition ) + # getting creg for depth counting + self._is_branch_clbits.add((reg_name, reg_idx)) + new_if_block = qasm3_ast.BranchingStatement( condition=qasm3_ast.BinaryExpression( op=qasm3_ast.BinaryOperator["=="], @@ -1690,6 +1736,8 @@ def _visit_branching_statement( rhs_value -= 1 size = self._global_creg_size_map[reg_name] + # getting cregs for depth counting + self._is_branch_clbits.update((reg_name, i) for i in range(size)) rhs_value_str = bin(int(rhs_value))[2:].zfill(size) else_block = self.visit_basic_block(statement.else_block) @@ -1734,6 +1782,9 @@ def ravel(bit_ind): self._curr_scope -= 1 self._pop_scope() self._restore_context() + self._in_branching_statement -= 1 + if not self._in_branching_statement: + self._update_branching_gate_depths() if self._check_only: return [] diff --git a/tests/qasm3/test_depth.py b/tests/qasm3/test_depth.py index 487c1f2c..68dca449 100644 --- a/tests/qasm3/test_depth.py +++ b/tests/qasm3/test_depth.py @@ -426,7 +426,6 @@ def test_qasm3_depth_no_branching(program, expected_depth): assert result.depth() == expected_depth -@pytest.mark.skip(reason="Not implemented branching conditions depth") @pytest.mark.parametrize( "program, expected_depth", [ @@ -460,8 +459,9 @@ def test_qasm3_depth_no_branching(program, expected_depth): measure q[0] -> c[0]; if (c==1) measure q[1] -> c[1]; +if (c==3) measure q[1] -> c[1]; """, - 4, + 5, ), ( """ @@ -488,10 +488,139 @@ def test_qasm3_depth_no_branching(program, expected_depth): """, 8, ), + ( + """ +OPENQASM 3.0; +include "stdgates.inc"; +gate custom a, b{ + cx a, b; + h a; +} +qubit[4] q; +bit[4] c; +bit[4] c0; +h q; +measure q -> c0; +if(c0[0]){ + x q[0]; + cx q[0], q[1]; + if (c0[1]){ + cx q[1], q[2]; + } +} +if (c[0]){ + custom q[2], q[3]; +} +array[int[32], 8] arr; +arr[0] = 1; +if(arr[0] >= 1){ + h q[0]; + h q[1]; +} +""", + 4, + ), + ( + """ +OPENQASM 3.0; +include "stdgates.inc"; +qubit[1] q; +bit[4] c; +if(c == 3){ + h q[0]; +} +if(c >= 3){ + h q[0]; +} else { + x q[0]; +} +if(c <= 3){ + h q[0]; +} else { + x q[0]; +} +if(c[0] < 4){ + h q[0]; +} else { + x q[0]; +} +""", + 4, + ), + ( + """ +OPENQASM 3.0; +include "stdgates.inc"; +qubit[2] q; +bit[2] c; +h q[0]; +cx q[0], q[1]; +c[0] = measure q[0]; +c[1] = measure q[1]; +if (c[0] == false) { + if (c[1] == true) { + x q[0]; + } + else { + if (c[1] == false){ + x q[1]; + } + else { + z q[0]; + } + } +} + +if (c == 0) { + x q[0]; +} +else { + y q[1]; +} +x q[0]; +""", + 6, + ), ], ) def test_qasm3_depth_branching(program, expected_depth): """Test calculating depth of qasm3 circuit with branching conditions""" result = loads(program) result.unroll() + result.remove_barriers() assert result.depth() == expected_depth + + +def test_qasm3_depth_branching_for_external_gates(): + """Test calculating depth of qasm3 circuit with external gates inside branching conditions""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + bit[2] c; + gate my_gate q1, q2 { + h q1; + cx q1, q2; + h q2; + } + gate my_gate_two q1, q2 { + cx q1, q2; + } + + qubit[2] q; + if (c == 0){ + measure q -> c; + my_gate q[0], q[1]; + } + else { + if (c[0] == false) { + my_gate q[1], q[0]; + } + else{ + measure q -> c; + } + } + my_gate_two q[0], q[1]; + """ + result = loads(qasm3_string) + result._external_gates = ["my_gate", "my_gate_two"] + assert result.depth() == 2