Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@ Types of changes:

### Fixed
- Fixed bug in release workflow(s) that caused discrepancy between `pyqasm.__version__` and `importlib.metadata.version` ([#147](https://github.com/qBraid/pyqasm/pull/147))
- Fixed a bug in broadcast operation for duplicate qubits so that the following -

```qasm
OPENQASM 3.0;
include "stdgates.inc";
qubit[3] q;
qubit[2] q2;
cx q[0], q[1], q[1], q[2];
cx q2, q2;
```

will unroll correctly to -

```qasm
OPENQASM 3.0;
include "stdgates.inc";
qubit[3] q;
qubit[2] q2;
// cx q[0], q[1], q[1], q[2];
cx q[0], q[1];
cx q[1], q[2];

// cx q2, q2;
cx q2[0], q2[1];
cx q2[0], q2[1];
```

The logic for duplicate qubit detection is moved out of the `QasmVisitor._get_op_bits` into `Qasm3Analyzer` class and is executed post gate broadcast operation ([#155](https://github.com/qBraid/pyqasm/pull/155)).

### Dependencies

Expand Down
48 changes: 48 additions & 0 deletions src/pyqasm/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
DiscreteSet,
Expression,
Identifier,
IndexedIdentifier,
IndexExpression,
IntegerLiteral,
IntType,
QuantumGate,
QuantumMeasurementStatement,
RangeDefinition,
Span,
)

from pyqasm.exceptions import QasmParsingError, ValidationError, raise_qasm3_error
Expand Down Expand Up @@ -234,3 +237,48 @@ def extract_qasm_version(qasm: str) -> float: # type: ignore[return]
return float(f"{major}.{minor}")

raise_qasm3_error("Could not determine the OpenQASM version.", err_type=QasmParsingError)

@staticmethod
def extract_duplicate_qubit(qubit_list: list[IndexedIdentifier]):
"""
Extracts the duplicate qubit from a list of qubits.

Args:
qubit_list (list[IndexedIdentifier]): The list of qubits.

Returns:
tuple(string, int): The duplicate qubit name and id.
"""
qubit_set = set()
for qubit in qubit_list:
assert isinstance(qubit, IndexedIdentifier)
qubit_name = qubit.name.name
qubit_id = qubit.indices[0][0].value # type: ignore
if (qubit_name, qubit_id) in qubit_set:
return (qubit_name, qubit_id)
qubit_set.add((qubit_name, qubit_id))
return None

@staticmethod
def verify_gate_qubits(gate: QuantumGate, span: Optional[Span] = None):
"""
Verify the qubits for a quantum gate.

Args:
gate (QuantumGate): The quantum gate.
span (Span, optional): The span of the gate.

Raises:
ValidationError: If qubits are duplicated.

Returns:
None
"""
# 1. check for duplicate bits
duplicate_qubit = Qasm3Analyzer.extract_duplicate_qubit(gate.qubits) # type: ignore
if duplicate_qubit:
qubit_name, qubit_id = duplicate_qubit
raise_qasm3_error(
f"Duplicate qubit {qubit_name}[{qubit_id}] in gate {gate.name.name}",
span=span,
)
20 changes: 9 additions & 11 deletions src/pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def _get_op_bits(
list[qasm3_ast.IndexedIdentifier] : The bits for the operation.
"""
openqasm_bits = []
visited_bits = set()
bit_list = []
original_size_map = reg_size_map

Expand Down Expand Up @@ -413,16 +412,6 @@ def _get_op_bits(
)
for bit_id in bit_ids
]
# check for duplicate bits
for bit_id in new_bits:
bit_name, bit_value = bit_id.name.name, bit_id.indices[0][0].value
if tuple((bit_name, bit_value)) in visited_bits:
raise_qasm3_error(
f"Duplicate {'qubit' if qubits else 'clbit'} "
f"{bit_name}[{bit_value}] argument",
span=operation.span,
)
visited_bits.add((bit_name, bit_value))

openqasm_bits.extend(new_bits)

Expand Down Expand Up @@ -794,6 +783,11 @@ def _visit_basic_gate_operation( # pylint: disable=too-many-locals
)

self._update_qubit_depth_for_gate(unrolled_targets, ctrls)

# check for duplicate bits
for final_gate in result:
Qasm3Analyzer.verify_gate_qubits(final_gate, operation.span)

if self._check_only:
return []

Expand Down Expand Up @@ -950,6 +944,10 @@ def gate_function(*qubits):
all_targets = self._unroll_multiple_target_qubits(operation, gate_qubit_count)
result = self._broadcast_gate_operation(gate_function, all_targets)

# check for any duplicates
for final_gate in result:
Qasm3Analyzer.verify_gate_qubits(final_gate, operation.span)

self._restore_context()
if self._check_only:
return []
Expand Down
10 changes: 10 additions & 0 deletions tests/qasm3/resources/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ def test_fixture():
""",
"Undefined identifier a in.*",
),
"duplicate_qubits": (
"""
OPENQASM 3;
include "stdgates.inc";

qubit[2] q1;
cx q1[0] , q1[0]; // duplicate qubit
""",
r"Duplicate qubit q1\[0\] in gate cx",
),
}

# qasm_input, expected_error
Expand Down
11 changes: 0 additions & 11 deletions tests/qasm3/test_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,3 @@ def test_incorrect_barrier():
ValidationError, match="Index 3 out of range for register of size 2 in qubit"
):
loads(out_of_bounds).validate()

duplicate = """
OPENQASM 3.0;

qubit[2] q1;

barrier q1, q1;
"""

with pytest.raises(ValidationError, match=r"Duplicate qubit .*argument"):
loads(duplicate).validate()
17 changes: 17 additions & 0 deletions tests/qasm3/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,23 @@ def test_inverse_global_phase():
check_unrolled_qasm(dumps(module), qasm3_expected)


def test_duplicate_qubit_broadcast():
qasm3_string = """
OPENQASM 3.0;
include "stdgates.inc";
qubit[3] q;

cx q[0], q[1], q[1], q[2];"""

module = loads(qasm3_string)
module.unroll()

assert module.num_qubits == 3
assert module.num_clbits == 0

check_two_qubit_gate_op(module.unrolled_ast, 2, [[0, 1], [1, 2]], "cx")


@pytest.mark.parametrize("test_name", custom_op_tests)
def test_custom_ops_with_external_gates(test_name, request):
qasm3_string = request.getfixturevalue(test_name)
Expand Down