Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@ Types of changes:
### Fixed

- Fixed the way how depth is calculated when external gates are defined with unrolling a QASM module. ([#198](https://github.com/qBraid/pyqasm/pull/198))
- Added separate depth calculation for gates inside branching statements. ([#200](https://github.com/qBraid/pyqasm/pull/200))
- **Example:**
```python
OPENQASM 3.0;
include "stdgates.inc";
qubit[4] q;
bit[4] c;
bit[4] c0;
if (c[0]){
x q[0];
h q[0]
}
else {
h q[1];
}
```
```text
Depth = 1
```
- Previously, each gate inside an `if`/`else` block would advance only its own wire depth. Now, when any branching statement is encountered, all qubit‐ and clbit‐depths used inside that block are first incremented by one, then set to the maximum of those new values. This ensures the entire conditional block counts as single “depth” increment, rather than letting individual gates within the same branch float ahead independently.
- In the above snippet, c[0], q[0], and q[1] all jump together to a single new depth for that branch.

### Dependencies

Expand Down
95 changes: 73 additions & 22 deletions src/pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def __init__(
self._curr_scope: int = 0
self._label_scope_level: dict[int, set] = {self._curr_scope: set()}
self._recording_ext_gate_depth = False

self._in_branching_statement: int = 0
self._is_branch_qubits: set[tuple[str, int]] = set()
self._is_branch_clbits: set[tuple[str, int]] = set()
self._init_utilities()

def _init_utilities(self):
Expand Down Expand Up @@ -500,10 +502,12 @@ def _visit_measurement( # pylint: disable=too-many-locals
measure=qasm3_ast.QuantumMeasurement(qubit=src_id), target=None
)
)
src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore
qubit_node = self._module._qubit_depths[(src_name, src_id)]
qubit_node.depth += 1
qubit_node.num_measurements += 1
# if measurement gate is not in branching statement
if not self._in_branching_statement:
src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore
qubit_node = self._module._qubit_depths[(src_name, src_id)]
qubit_node.depth += 1
qubit_node.num_measurements += 1
else:
target_name: str = (
target.name if isinstance(target, qasm3_ast.Identifier) else target.name.name
Expand Down Expand Up @@ -533,21 +537,23 @@ def _visit_measurement( # pylint: disable=too-many-locals
measure=qasm3_ast.QuantumMeasurement(qubit=src_id),
target=tgt_id if target else None,
)
src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore
tgt_name, tgt_id = tgt_id.name.name, tgt_id.indices[0][0].value # type: ignore

qubit_node, clbit_node = (
self._module._qubit_depths[(src_name, src_id)],
self._module._clbit_depths[(tgt_name, tgt_id)],
)
qubit_node.depth += 1
qubit_node.num_measurements += 1
# if measurement gate is not in branching statement
if not self._in_branching_statement:
src_name, src_id = src_id.name.name, src_id.indices[0][0].value # type: ignore
tgt_name, tgt_id = tgt_id.name.name, tgt_id.indices[0][0].value # type: ignore

qubit_node, clbit_node = (
self._module._qubit_depths[(src_name, src_id)],
self._module._clbit_depths[(tgt_name, tgt_id)],
)
qubit_node.depth += 1
qubit_node.num_measurements += 1

clbit_node.depth += 1
clbit_node.num_measurements += 1
clbit_node.depth += 1
clbit_node.num_measurements += 1

qubit_node.depth = max(qubit_node.depth, clbit_node.depth)
clbit_node.depth = max(qubit_node.depth, clbit_node.depth)
qubit_node.depth = max(qubit_node.depth, clbit_node.depth)
clbit_node.depth = max(qubit_node.depth, clbit_node.depth)

unrolled_measurements.append(unrolled_measure)

Expand Down Expand Up @@ -769,7 +775,8 @@ def _update_qubit_depth_for_gate(
qubit_node = self._module._qubit_depths[(qubit.name.name, qubit_id)]
qubit_node.depth = max_involved_depth

def _visit_basic_gate_operation( # pylint: disable=too-many-locals
# pylint: disable=too-many-branches, too-many-locals
def _visit_basic_gate_operation(
self,
operation: qasm3_ast.QuantumGate,
inverse: bool = False,
Expand Down Expand Up @@ -851,8 +858,16 @@ def _visit_basic_gate_operation( # pylint: disable=too-many-locals
result.extend(
self._broadcast_gate_operation(unrolled_gate_function, unrolled_targets, ctrls)
)

self._update_qubit_depth_for_gate(unrolled_targets, ctrls)
# if gate is not in branching statement
if not self._in_branching_statement:
self._update_qubit_depth_for_gate(unrolled_targets, ctrls)
else:
for qubit_subset in unrolled_targets + [ctrls]: # get qreg in branching operations
for qubit in qubit_subset:
assert isinstance(qubit.indices, list) and len(qubit.indices) > 0
assert isinstance(qubit.indices[0], list) and len(qubit.indices[0]) > 0
qubit_idx = Qasm3ExprEvaluator.evaluate_expression(qubit.indices[0][0])[0]
self._is_branch_qubits.add((qubit.name.name, qubit_idx))

# check for duplicate bits
for final_gate in result:
Expand Down Expand Up @@ -952,7 +967,16 @@ def _visit_custom_gate_operation(
# Update the depth only once for the entire custom gate
if self._recording_ext_gate_depth:
self._recording_ext_gate_depth = False
self._update_qubit_depth_for_gate([op_qubits], ctrls)
if not self._in_branching_statement: # if custom gate is not in branching statement
self._update_qubit_depth_for_gate([op_qubits], ctrls)
else:
# get qubit registers in branching operations
for qubit_subset in [op_qubits] + [ctrls]:
for qubit in qubit_subset:
assert isinstance(qubit.indices, list) and len(qubit.indices) > 0
assert isinstance(qubit.indices[0], list) and len(qubit.indices[0]) > 0
qubit_idx = Qasm3ExprEvaluator.evaluate_expression(qubit.indices[0][0])[0]
self._is_branch_qubits.add((qubit.name.name, qubit_idx))

self._restore_context()

Expand Down Expand Up @@ -1611,6 +1635,24 @@ def _evaluate_array_initialization(

return np.array(init_values, dtype=ARRAY_TYPE_MAP[base_type.__class__])

# update branching operators depth
def _update_branching_gate_depths(self) -> None:
"""Updates the depth of the circuit after applying branching statements."""
all_nodes = [
self._module._qubit_depths[(name, idx)] for name, idx in self._is_branch_qubits
] + [self._module._clbit_depths[(name, idx)] for name, idx in self._is_branch_clbits]

try:
max_depth = max(node.depth + 1 for node in all_nodes)
except ValueError:
max_depth = 0

for node in all_nodes:
node.depth = max_depth

self._is_branch_clbits.clear()
self._is_branch_qubits.clear()

def _visit_branching_statement(
self, statement: qasm3_ast.BranchingStatement
) -> list[qasm3_ast.Statement]:
Expand All @@ -1626,6 +1668,7 @@ def _visit_branching_statement(
self._push_scope({})
self._curr_scope += 1
self._label_scope_level[self._curr_scope] = set()
self._in_branching_statement += 1

result = []
condition = statement.condition
Expand Down Expand Up @@ -1659,6 +1702,9 @@ def _visit_branching_statement(
reg_idx, self._global_creg_size_map[reg_name], qubit=False, op_node=condition
)

# getting creg for depth counting
self._is_branch_clbits.add((reg_name, reg_idx))

new_if_block = qasm3_ast.BranchingStatement(
condition=qasm3_ast.BinaryExpression(
op=qasm3_ast.BinaryOperator["=="],
Expand Down Expand Up @@ -1690,6 +1736,8 @@ def _visit_branching_statement(
rhs_value -= 1

size = self._global_creg_size_map[reg_name]
# getting cregs for depth counting
self._is_branch_clbits.update((reg_name, i) for i in range(size))
rhs_value_str = bin(int(rhs_value))[2:].zfill(size)
else_block = self.visit_basic_block(statement.else_block)

Expand Down Expand Up @@ -1734,6 +1782,9 @@ def ravel(bit_ind):
self._curr_scope -= 1
self._pop_scope()
self._restore_context()
self._in_branching_statement -= 1
if not self._in_branching_statement:
self._update_branching_gate_depths()

if self._check_only:
return []
Expand Down
133 changes: 131 additions & 2 deletions tests/qasm3/test_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def test_qasm3_depth_no_branching(program, expected_depth):
assert result.depth() == expected_depth


@pytest.mark.skip(reason="Not implemented branching conditions depth")
@pytest.mark.parametrize(
"program, expected_depth",
[
Expand Down Expand Up @@ -460,8 +459,9 @@ def test_qasm3_depth_no_branching(program, expected_depth):
measure q[0] -> c[0];

if (c==1) measure q[1] -> c[1];
if (c==3) measure q[1] -> c[1];
""",
4,
5,
),
(
"""
Expand All @@ -488,10 +488,139 @@ def test_qasm3_depth_no_branching(program, expected_depth):
""",
8,
),
(
"""
OPENQASM 3.0;
include "stdgates.inc";
gate custom a, b{
cx a, b;
h a;
}
qubit[4] q;
bit[4] c;
bit[4] c0;
h q;
measure q -> c0;
if(c0[0]){
x q[0];
cx q[0], q[1];
if (c0[1]){
cx q[1], q[2];
}
}
if (c[0]){
custom q[2], q[3];
}
array[int[32], 8] arr;
arr[0] = 1;
if(arr[0] >= 1){
h q[0];
h q[1];
}
""",
4,
),
(
"""
OPENQASM 3.0;
include "stdgates.inc";
qubit[1] q;
bit[4] c;
if(c == 3){
h q[0];
}
if(c >= 3){
h q[0];
} else {
x q[0];
}
if(c <= 3){
h q[0];
} else {
x q[0];
}
if(c[0] < 4){
h q[0];
} else {
x q[0];
}
""",
4,
),
(
"""
OPENQASM 3.0;
include "stdgates.inc";
qubit[2] q;
bit[2] c;
h q[0];
cx q[0], q[1];
c[0] = measure q[0];
c[1] = measure q[1];
if (c[0] == false) {
if (c[1] == true) {
x q[0];
}
else {
if (c[1] == false){
x q[1];
}
else {
z q[0];
}
}
}

if (c == 0) {
x q[0];
}
else {
y q[1];
}
x q[0];
""",
6,
),
],
)
def test_qasm3_depth_branching(program, expected_depth):
"""Test calculating depth of qasm3 circuit with branching conditions"""
result = loads(program)
result.unroll()
result.remove_barriers()
assert result.depth() == expected_depth


def test_qasm3_depth_branching_for_external_gates():
"""Test calculating depth of qasm3 circuit with external gates inside branching conditions"""
qasm3_string = """
OPENQASM 3.0;
include "stdgates.inc";
bit[2] c;
gate my_gate q1, q2 {
h q1;
cx q1, q2;
h q2;
}
gate my_gate_two q1, q2 {
cx q1, q2;
}

qubit[2] q;
if (c == 0){
measure q -> c;
my_gate q[0], q[1];
}
else {
if (c[0] == false) {
my_gate q[1], q[0];
}
else{
measure q -> c;
}
}
my_gate_two q[0], q[1];
"""
result = loads(qasm3_string)
result._external_gates = ["my_gate", "my_gate_two"]
assert result.depth() == 2