From a19b05612343d15d894cdd5ceb5b1cf4392b5e68 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Fri, 27 Jun 2025 15:27:17 -0500 Subject: [PATCH 01/10] update visitor.py, base.py and added test_device_qubits.py - Added a new parameter `device_qubits` to `QasmVisitor` and `QasmModule` to manage qubit consolidation. - Introduced methods for qubit index mapping and register consolidation to ensure compatibility with device constraints. - Updated measurement, barrier, and gate operations to utilize the new device qubit logic. - Added unit tests to validate the functionality of qubit consolidation and error handling for exceeding device qubit limits. --- src/pyqasm/modules/base.py | 3 + src/pyqasm/visitor.py | 155 +++++++++++++++++++- tests/qasm3/test_device_qubits.py | 233 ++++++++++++++++++++++++++++++ 3 files changed, 386 insertions(+), 5 deletions(-) create mode 100644 tests/qasm3/test_device_qubits.py diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index 3b18db71..ada90739 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -57,6 +57,7 @@ def __init__(self, name: str, program: Program): self._unrolled_ast = Program(statements=[]) self._external_gates: list[str] = [] self._decompose_native_gates: Optional[bool] = None + self._device_qubits: Optional[int] = 0 @property def name(self) -> str: @@ -551,6 +552,8 @@ def unroll(self, **kwargs): self._external_gates = ext_gates else: self._external_gates = [] + if device_qbts := kwargs.get("device_qubits"): + self._device_qubits = device_qbts visitor = QasmVisitor(module=self, **kwargs) self.accept(visitor) except (ValidationError, UnrollError) as err: diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index ef71d358..1095d1d1 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -22,7 +22,7 @@ import logging from collections import deque from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, cast import numpy as np import openqasm3.ast as qasm3_ast @@ -86,6 +86,7 @@ def __init__( # pylint: disable=too-many-arguments external_gates: list[str] | None = None, unroll_barriers: bool = True, max_loop_iters: int = int(1e9), + device_qubits: int = 0, ): self._module = module self._scope: deque = deque([{}]) @@ -113,6 +114,8 @@ def __init__( # pylint: disable=too-many-arguments self._measurement_set: set[str] = set() self._init_utilities() self._loop_limit = max_loop_iters + self._device_qubits: int = device_qubits + self._in_generic_gate_op_scope: int = 0 def _init_utilities(self): """Initialize the utilities for the visitor.""" @@ -559,7 +562,66 @@ def _check_variable_cast_type( span=statement.span, ) - def _visit_measurement( # pylint: disable=too-many-locals + def _get_pyqasm_device_qubit_index(self, reg: str, idx: int) -> int: + """ + Returns qubit index in __PYQASM_QUBITS__ ordering for given quantum register and index. + + Args: + reg (str): The name of the quantum register. + idx (int): The index within the quantum register. + + Returns: + int: The global __PYQASM_QUBITS__ index corresponding to reg[idx]. + + Raises: + IndexError: If the index `idx` is out of range for the register. + """ + _offsets: dict[str, int] = {} + _offset = 0 + for name, n_qubits in self._global_qreg_size_map.items(): + _offsets[name] = _offset + _offset += n_qubits + _n_qubits = self._global_qreg_size_map[reg] + if not 0 <= idx < _n_qubits: + raise IndexError(f"{reg}[{idx}] out of range (0..{_n_qubits-1})") + return _offsets[reg] + idx + + def _qubit_register_consolidation(self, unrolled_stmts: list) -> list[qasm3_ast.Statement]: + """ + Consolidate all quantum registers into a single register '__PYQASM_QUBITS__'. + + Args: + unrolled_stmts (list): The list of statements to process and modify in-place. + + Raises: + ValidationError: If the total number of qubits exceeds the available device qubits, + or if the reserved register '__PYQASM_QUBITS__' is already declared + in the original QASM program. + """ + total_qubits = sum(self._global_qreg_size_map.values()) + if total_qubits > self._device_qubits: + raise_qasm3_error( + f"Total qubits '({total_qubits})' exceed device qubits '({self._device_qubits})'." + ) + + if "__PYQASM_QUBITS__" in self._global_qreg_size_map: + raise_qasm3_error( + "Original QASM program already declares reserved register '__PYQASM_QUBITS__'." + ) + + removable_statements = [] + for stmt in unrolled_stmts: + if isinstance(stmt, qasm3_ast.QubitDeclaration): + removable_statements.append(stmt) + for r_stmt in removable_statements: + unrolled_stmts.remove(r_stmt) + pyqasm_reg_id = qasm3_ast.Identifier("__PYQASM_QUBITS__") + pyqasm_reg_size = qasm3_ast.IntegerLiteral(self._device_qubits) + pyqasm_reg_stmt = qasm3_ast.QubitDeclaration(pyqasm_reg_id, pyqasm_reg_size) + unrolled_stmts.insert(1, pyqasm_reg_stmt) + return unrolled_stmts + + def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches self, statement: qasm3_ast.QuantumMeasurementStatement ) -> list[qasm3_ast.QuantumMeasurementStatement]: """Visit a measurement statement element. @@ -659,6 +721,18 @@ def _visit_measurement( # pylint: disable=too-many-locals unrolled_measurements.append(unrolled_measure) + if self._device_qubits: + for stmt in unrolled_measurements: + _qubit_id = cast( + qasm3_ast.Identifier, stmt.measure.qubit.name + ) # type: ignore[union-attr] + _qubit_ind = cast(list, stmt.measure.qubit.indices) # type: ignore[union-attr] + for multiple_ind in _qubit_ind: + for ind in multiple_ind: + _pyqasm_val = self._get_pyqasm_device_qubit_index(_qubit_id.name, ind.value) + ind.value = _pyqasm_val + _qubit_id.name = "__PYQASM_QUBITS__" + if self._check_only: return [] @@ -699,12 +773,24 @@ def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> list[qasm3_ast.Quan unrolled_resets.append(unrolled_reset) + if self._device_qubits: + for stmt in unrolled_resets: + _qubit_str = cast(str, stmt.qubits.name.name) # type: ignore[union-attr] + _qubit_ind = cast(list, stmt.qubits.indices) # type: ignore[union-attr] + for multiple_ind in _qubit_ind: + for ind in multiple_ind: + _pyqasm_val = self._get_pyqasm_device_qubit_index(_qubit_str, ind.value) + ind.value = _pyqasm_val + stmt.qubits.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] + if self._check_only: return [] return unrolled_resets - def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> list[qasm3_ast.QuantumBarrier]: + def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches + self, barrier: qasm3_ast.QuantumBarrier + ) -> list[qasm3_ast.QuantumBarrier]: """Visit a barrier statement element. Args: @@ -749,8 +835,48 @@ def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> list[qasm3_ast.Qu return [] if not self._unroll_barriers: + if self._device_qubits: + _qubit_id = cast( + qasm3_ast.Identifier, barrier.qubits[0] + ) # type: ignore[union-attr] + if not isinstance(_qubit_id, qasm3_ast.IndexedIdentifier): + _start = self._get_pyqasm_device_qubit_index(_qubit_id.name, 0) + _end = self._get_pyqasm_device_qubit_index( + _qubit_id.name, self._global_qreg_size_map[_qubit_id.name] - 1 + ) + if _start == 0: + _qubit_id.name = f"__PYQASM_QUBITS__[:{_end+1}]" + elif _end == self._device_qubits - 1: + _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:]" + else: + _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:{_end+1}]" + else: + _qubit_str = cast(str, barrier.qubits[0].name) # type: ignore[union-attr] + _qubit_ind = cast(list, barrier.qubits[0].indices) # type: ignore[union-attr] + for multi_ind in _qubit_ind: + for ind in multi_ind: + pyqasm_ind = self._get_pyqasm_device_qubit_index( + _qubit_str.name, ind.value + ) + ind.value = pyqasm_ind + _qubit_str.name = "__PYQASM_QUBITS__" return [barrier] + if self._device_qubits: + for stmt in unrolled_barriers: + _qubit_ind_id = cast( + qasm3_ast.IndexedIdentifier, stmt.qubits[0] + ) # type: ignore[union-attr] + _original_qubit_name = _qubit_ind_id.name.name + for multiple_ind in _qubit_ind_id.indices: + for ind in multiple_ind: # type: ignore[union-attr] + ind_val = cast(qasm3_ast.IntegerLiteral, ind) # type: ignore[union-attr] + pyqasm_val = self._get_pyqasm_device_qubit_index( + _original_qubit_name, ind_val.value + ) + ind_val.value = pyqasm_val + _qubit_ind_id.name.name = "__PYQASM_QUBITS__" + return unrolled_barriers def _get_op_parameters(self, operation: qasm3_ast.QuantumGate) -> list[float]: @@ -1244,7 +1370,7 @@ def _visit_phase_operation( return [operation] - def _visit_generic_gate_operation( # pylint: disable=too-many-branches + def _visit_generic_gate_operation( # pylint: disable=too-many-branches, too-many-statements self, operation: qasm3_ast.QuantumGate | qasm3_ast.QuantumPhase, ctrls: Optional[list[qasm3_ast.IndexedIdentifier]] = None, @@ -1261,6 +1387,7 @@ def _visit_generic_gate_operation( # pylint: disable=too-many-branches negctrls = [] if ctrls is None: ctrls = [] + self._in_generic_gate_op_scope += 1 # only needs to be done once for a gate operation if ( @@ -1364,7 +1491,23 @@ def _visit_generic_gate_operation( # pylint: disable=too-many-branches qasm3_ast.QuantumGate([], qasm3_ast.Identifier("x"), [], [ctrl]) for ctrl in negctrls ] result = negs + result + negs # type: ignore - + self._in_generic_gate_op_scope -= 1 + if self._device_qubits and not self._in_generic_gate_op_scope: + result_copy = copy.deepcopy(result) + for stmt, c_stmt in zip(result, result_copy): + for qubit, c_qubit in zip(stmt.qubits, c_stmt.qubits): + _original_qubit_name = cast(qasm3_ast.Identifier, c_qubit.name) + for multi_ind, c_multi_ind in zip( + qubit.indices, c_qubit.indices # type: ignore[union-attr] + ): + for ind, c_ind in zip(multi_ind, c_multi_ind): + pyqasm_val = self._get_pyqasm_device_qubit_index( + _original_qubit_name.name, c_ind.value # type: ignore[union-attr] + ) + ind.value = pyqasm_val + for stmt in result: + for qubit in stmt.qubits: + qubit.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] if self._check_only: return [] @@ -2460,6 +2603,8 @@ def finalize(self, unrolled_stmts): """ # remove the gphase qubits if they use ALL qubits + if self._device_qubits: + unrolled_stmts = self._qubit_register_consolidation(unrolled_stmts) for stmt in unrolled_stmts: # Rule 1 if isinstance(stmt, qasm3_ast.QuantumPhase): diff --git a/tests/qasm3/test_device_qubits.py b/tests/qasm3/test_device_qubits.py new file mode 100644 index 00000000..33c53001 --- /dev/null +++ b/tests/qasm3/test_device_qubits.py @@ -0,0 +1,233 @@ +# Copyright 2025 qBraid +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module containing unit tests for the qubit register consolidation. + +""" + +import pytest + +from pyqasm.entrypoint import dumps, loads +from pyqasm.exceptions import ValidationError +from tests.utils import check_unrolled_qasm + + +def test_reset(): + qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[2] q; + qreg q2[3]; + reset q2; + reset q[1]; + """ + expected_qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[5] __PYQASM_QUBITS__; + reset __PYQASM_QUBITS__[2]; + reset __PYQASM_QUBITS__[3]; + reset __PYQASM_QUBITS__[4]; + reset __PYQASM_QUBITS__[1]; + """ + + result = loads(qasm) + result.unroll(device_qubits=5) + check_unrolled_qasm(dumps(result), expected_qasm) + + +def test_barrier(): + qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[2] q; + qreg q2[3]; + barrier q2; + barrier q[1]; + """ + expected_qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[5] __PYQASM_QUBITS__; + barrier __PYQASM_QUBITS__[2]; + barrier __PYQASM_QUBITS__[3]; + barrier __PYQASM_QUBITS__[4]; + barrier __PYQASM_QUBITS__[1]; + """ + result = loads(qasm) + result.unroll(device_qubits=5) + check_unrolled_qasm(dumps(result), expected_qasm) + + +def test_unrolled_barrier(): + qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[2] q; + qreg q2[3]; + barrier q[0]; + barrier q2; + barrier q; + """ + expected_qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[5] __PYQASM_QUBITS__; + barrier __PYQASM_QUBITS__[0]; + barrier __PYQASM_QUBITS__[2:]; + barrier __PYQASM_QUBITS__[:2]; + """ + result = loads(qasm) + result.unroll(unroll_barriers=False, device_qubits=5) + check_unrolled_qasm(dumps(result), expected_qasm) + + +def test_measurement(): + qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[4] q; + qreg q2[3]; + bit[3] c; + measure q2 -> c; + c[0] = measure q[0]; + c = measure q[:3]; + c = measure q2; + measure q2[1] -> c[2]; + """ + expected_qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[7] __PYQASM_QUBITS__; + bit[3] c; + c[0] = measure __PYQASM_QUBITS__[4]; + c[1] = measure __PYQASM_QUBITS__[5]; + c[2] = measure __PYQASM_QUBITS__[6]; + c[0] = measure __PYQASM_QUBITS__[0]; + c[0] = measure __PYQASM_QUBITS__[0]; + c[1] = measure __PYQASM_QUBITS__[1]; + c[2] = measure __PYQASM_QUBITS__[2]; + c[0] = measure __PYQASM_QUBITS__[4]; + c[1] = measure __PYQASM_QUBITS__[5]; + c[2] = measure __PYQASM_QUBITS__[6]; + c[2] = measure __PYQASM_QUBITS__[5]; + """ + result = loads(qasm) + result.unroll(device_qubits=7) + check_unrolled_qasm(dumps(result), expected_qasm) + + +def test_gates(): + qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[4] data; + qubit[2] ancilla; + bit[3] c; + x data[3]; + cx data[0], ancilla[1]; + crx (0.1) ancilla[0], data[2]; + gate custom_rccx a, b, c{ + rccx a, b, c; + } + custom_rccx ancilla[0], data[1], data[0]; + if(c[0]){ + x data[0]; + cx data[1], ancilla[1]; + } + if(c[1] == 1){ + cx ancilla[0], data[2]; + } + """ + expected_qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[6] __PYQASM_QUBITS__; + bit[3] c; + x __PYQASM_QUBITS__[3]; + cx __PYQASM_QUBITS__[0], __PYQASM_QUBITS__[5]; + rz(1.5707963267948966) __PYQASM_QUBITS__[2]; + rx(1.5707963267948966) __PYQASM_QUBITS__[2]; + rz(3.141592653589793) __PYQASM_QUBITS__[2]; + rx(1.5707963267948966) __PYQASM_QUBITS__[2]; + rz(3.141592653589793) __PYQASM_QUBITS__[2]; + cx __PYQASM_QUBITS__[4], __PYQASM_QUBITS__[2]; + rz(0) __PYQASM_QUBITS__[2]; + rx(1.5707963267948966) __PYQASM_QUBITS__[2]; + rz(3.0915926535897933) __PYQASM_QUBITS__[2]; + rx(1.5707963267948966) __PYQASM_QUBITS__[2]; + rz(3.141592653589793) __PYQASM_QUBITS__[2]; + cx __PYQASM_QUBITS__[4], __PYQASM_QUBITS__[2]; + rz(0) __PYQASM_QUBITS__[2]; + rx(1.5707963267948966) __PYQASM_QUBITS__[2]; + rz(3.191592653589793) __PYQASM_QUBITS__[2]; + rx(1.5707963267948966) __PYQASM_QUBITS__[2]; + rz(1.5707963267948966) __PYQASM_QUBITS__[2]; + rz(3.141592653589793) __PYQASM_QUBITS__[0]; + rx(1.5707963267948966) __PYQASM_QUBITS__[0]; + rz(4.71238898038469) __PYQASM_QUBITS__[0]; + rx(1.5707963267948966) __PYQASM_QUBITS__[0]; + rz(3.141592653589793) __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + rx(0.7853981633974483) __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + cx __PYQASM_QUBITS__[1], __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + rx(-0.7853981633974483) __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + cx __PYQASM_QUBITS__[4], __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + rx(0.7853981633974483) __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + cx __PYQASM_QUBITS__[1], __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + rx(-0.7853981633974483) __PYQASM_QUBITS__[0]; + h __PYQASM_QUBITS__[0]; + rz(3.141592653589793) __PYQASM_QUBITS__[0]; + rx(1.5707963267948966) __PYQASM_QUBITS__[0]; + rz(4.71238898038469) __PYQASM_QUBITS__[0]; + rx(1.5707963267948966) __PYQASM_QUBITS__[0]; + rz(3.141592653589793) __PYQASM_QUBITS__[0]; + if (c[0] == true) { + x __PYQASM_QUBITS__[0]; + cx __PYQASM_QUBITS__[1], __PYQASM_QUBITS__[5]; + } + if (c[1] == true) { + cx __PYQASM_QUBITS__[4], __PYQASM_QUBITS__[2]; + } + """ + result = loads(qasm) + result.unroll(device_qubits=6) + check_unrolled_qasm(dumps(result), expected_qasm) + + +@pytest.mark.parametrize( + "qasm_code,error_message", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[4] data; + qubit[3] ancilla; + """, + r"Total qubits '(7)' exceed device qubits '(6)'.", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[4] data; + qubit[2] __PYQASM_QUBITS__; + """, + r"Original QASM program already declares reserved register '__PYQASM_QUBITS__'.", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_incorrect_qubit_reg_consolidation(qasm_code, error_message, caplog): + with pytest.raises(ValidationError) as err: + with caplog.at_level("ERROR"): + loads(qasm_code).unroll(device_qubits=6) + assert error_message in str(err.value) From 38289fb8120e78b5291975c9a18406beda00e55d Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Fri, 27 Jun 2025 16:04:31 -0500 Subject: [PATCH 02/10] update test_device_qubits.py update test case --- tests/qasm3/test_device_qubits.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/qasm3/test_device_qubits.py b/tests/qasm3/test_device_qubits.py index 33c53001..bc27121b 100644 --- a/tests/qasm3/test_device_qubits.py +++ b/tests/qasm3/test_device_qubits.py @@ -72,16 +72,20 @@ def test_unrolled_barrier(): include "stdgates.inc"; qubit[2] q; qreg q2[3]; + qubit[2] q3; barrier q[0]; barrier q2; barrier q; + barrier q4; + """ expected_qasm = """OPENQASM 3.0; include "stdgates.inc"; qubit[5] __PYQASM_QUBITS__; barrier __PYQASM_QUBITS__[0]; - barrier __PYQASM_QUBITS__[2:]; + barrier __PYQASM_QUBITS__[2:5]; barrier __PYQASM_QUBITS__[:2]; + barrier __PYQASM_QUBITS__[5:]; """ result = loads(qasm) result.unroll(unroll_barriers=False, device_qubits=5) From 66359bde19d98ab673e1de631686f565f5448b06 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Fri, 27 Jun 2025 16:07:56 -0500 Subject: [PATCH 03/10] update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae2d47eb..9e43a5bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Types of changes: ### Improved / Modified - Added `slots=True` parameter to the data classes in `elements.py` to improve memory efficiency ([#218](https://github.com/qBraid/pyqasm/pull/218)) - Updated the documentation to include core features in the `README` ([#219](https://github.com/qBraid/pyqasm/pull/219)) +- Added support to `device qubit` resgister consolidation.([#222](https://github.com/qBraid/pyqasm/pull/222)) ### Deprecated From 14890ef45a92617cc12d576adaee82e5182e77c4 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Fri, 27 Jun 2025 21:27:29 -0500 Subject: [PATCH 04/10] update test cases --- tests/qasm3/test_device_qubits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/qasm3/test_device_qubits.py b/tests/qasm3/test_device_qubits.py index bc27121b..69993fdd 100644 --- a/tests/qasm3/test_device_qubits.py +++ b/tests/qasm3/test_device_qubits.py @@ -76,7 +76,7 @@ def test_unrolled_barrier(): barrier q[0]; barrier q2; barrier q; - barrier q4; + barrier q3; """ expected_qasm = """OPENQASM 3.0; @@ -88,7 +88,7 @@ def test_unrolled_barrier(): barrier __PYQASM_QUBITS__[5:]; """ result = loads(qasm) - result.unroll(unroll_barriers=False, device_qubits=5) + result.unroll(unroll_barriers=False, device_qubits=7) check_unrolled_qasm(dumps(result), expected_qasm) From 25aae83c55a17a63b0756383affc5d4d9346780b Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Fri, 27 Jun 2025 21:39:45 -0500 Subject: [PATCH 05/10] update test_device_qubits.py --- tests/qasm3/test_device_qubits.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/qasm3/test_device_qubits.py b/tests/qasm3/test_device_qubits.py index 69993fdd..3bab18be 100644 --- a/tests/qasm3/test_device_qubits.py +++ b/tests/qasm3/test_device_qubits.py @@ -77,11 +77,10 @@ def test_unrolled_barrier(): barrier q2; barrier q; barrier q3; - """ expected_qasm = """OPENQASM 3.0; include "stdgates.inc"; - qubit[5] __PYQASM_QUBITS__; + qubit[7] __PYQASM_QUBITS__; barrier __PYQASM_QUBITS__[0]; barrier __PYQASM_QUBITS__[2:5]; barrier __PYQASM_QUBITS__[:2]; From d668f4f8f93a3f925b77ddd77029368924c2336d Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Wed, 2 Jul 2025 01:13:44 -0500 Subject: [PATCH 06/10] Code refactor - Added `consolidate_qubits` parameter to `QasmVisitor`, `QasmModule`, and related methods to support qubit consolidation and device constraints. - Updated the `load` and `loads` functions to accept new parameters for better integration with device specifications. - Implemented transformation logic for qubit register mapping in statements to align with device qubit indices. - Enhanced unit tests to validate the new functionality, including error handling for exceeding device qubit limits and ensuring correct unrolling of QASM programs. --- src/pyqasm/entrypoint.py | 21 ++- src/pyqasm/modules/base.py | 27 +++- src/pyqasm/transformer.py | 143 +++++++++++++++++++- src/pyqasm/visitor.py | 215 ++++++++++++++---------------- tests/qasm3/test_device_qubits.py | 62 +++++++-- 5 files changed, 330 insertions(+), 138 deletions(-) diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index af083021..92bd331e 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -30,11 +30,15 @@ import openqasm3.ast -def load(filename: str) -> QasmModule: +def load( + filename: str, *, device_qubits: int | None = None, consolidate_qubits: bool = False +) -> QasmModule: """Loads an OpenQASM program into a `QasmModule` object. Args: filename (str): The filename of the OpenQASM program to validate. + device_qubits (int): Number of physical qubits available on the target device. + consolidate_qubits (bool): If True, consolidate all quantum registers into single register. Returns: QasmModule: An object containing the parsed qasm representation along with @@ -44,14 +48,21 @@ def load(filename: str) -> QasmModule: raise TypeError("Input 'filename' must be of type 'str'.") with open(filename, "r", encoding="utf-8") as file: program = file.read() - return loads(program) + return loads(program, device_qubits=device_qubits, consolidate_qubits=consolidate_qubits) -def loads(program: openqasm3.ast.Program | str) -> QasmModule: +def loads( + program: "openqasm3.ast.Program | str", + *, + device_qubits: int | None = None, + consolidate_qubits: bool = False, +) -> QasmModule: """Loads an OpenQASM program into a `QasmModule` object. Args: program (openqasm3.ast.Program or str): The OpenQASM program to validate. + device_qubits (int): Number of physical qubits available on the target device. + consolidate_qubits (bool): If True, consolidate all quantum registers into single register. Raises: TypeError: If the input is not a string or an `openqasm3.ast.Program` instance. @@ -79,7 +90,9 @@ def loads(program: openqasm3.ast.Program | str) -> QasmModule: qasm_module = Qasm3Module if program.version.startswith("3") else Qasm2Module module = qasm_module("main", program) - + # Store device_qubits and consolidate_qubits on the module for later use + module._device_qubits = device_qubits + module._consolidate_qubits = consolidate_qubits return module diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index ada90739..b8a995c9 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -57,7 +57,8 @@ def __init__(self, name: str, program: Program): self._unrolled_ast = Program(statements=[]) self._external_gates: list[str] = [] self._decompose_native_gates: Optional[bool] = None - self._device_qubits: Optional[int] = 0 + self._device_qubits: Optional[int] = None + self._consolidate_qubits: Optional[bool] = False @property def name(self) -> str: @@ -518,8 +519,20 @@ def validate(self): return try: self.num_qubits, self.num_clbits = 0, 0 - visitor = QasmVisitor(self, check_only=True) + visitor = QasmVisitor( + self, + check_only=True, + device_qubits=self._device_qubits, + ) self.accept(visitor) + # Implicit validation: check total qubits if device_qubits is set and not consolidating + if self._device_qubits: + total_qubits = sum(self._qubit_registers.values()) + if total_qubits > self._device_qubits: + raise ValidationError( + # pylint: disable-next=line-too-long + f"Total qubits '{total_qubits}' exceed device qubits '{self._device_qubits}'." + ) except (ValidationError, NotImplementedError) as err: self.num_qubits, self.num_clbits = -1, -1 raise err @@ -535,6 +548,9 @@ def unroll(self, **kwargs): max_loop_iters (int): Max number of iterations for unrolling loops. Defaults to 1e9. check_only (bool): If True, only check the program without executing it. Defaults to False. + device_qubits (int): Number of physical qubits available on the target device. + consolidate_qubits (bool): If True, consolidate all quantum registers into + single register. Raises: ValidationError: If the module fails validation during unrolling. @@ -546,6 +562,11 @@ def unroll(self, **kwargs): """ if not kwargs: kwargs = {} + # Use module attributes if not overridden by kwargs + if "device_qubits" not in kwargs: + kwargs["device_qubits"] = self._device_qubits + if "consolidate_qubits" not in kwargs: + kwargs["consolidate_qubits"] = self._consolidate_qubits try: self.num_qubits, self.num_clbits = 0, 0 if ext_gates := kwargs.get("external_gates"): @@ -554,6 +575,8 @@ def unroll(self, **kwargs): self._external_gates = [] if device_qbts := kwargs.get("device_qubits"): self._device_qubits = device_qbts + if consolidate_qbts := kwargs.get("consolidate_qubits"): + self._consolidate_qubits = consolidate_qbts visitor = QasmVisitor(module=self, **kwargs) self.accept(visitor) except (ValidationError, UnrollError) as err: diff --git a/src/pyqasm/transformer.py b/src/pyqasm/transformer.py index 119ac358..1de9905f 100644 --- a/src/pyqasm/transformer.py +++ b/src/pyqasm/transformer.py @@ -17,7 +17,7 @@ """ from copy import deepcopy -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple, Optional, Sequence, cast import numpy as np from openqasm3.ast import ( @@ -37,9 +37,11 @@ QASMNode, QuantumBarrier, QuantumGate, + QuantumMeasurementStatement, QuantumPhase, QuantumReset, RangeDefinition, + Statement, UintType, UnaryExpression, UnaryOperator, @@ -438,3 +440,142 @@ def get_type_string(variable: Variable) -> str: if is_array: type_str += f", {', '.join([str(dim) for dim in dims])}]" return type_str + + @staticmethod + def transform_qubit_reg_in_statemets( # pylint: disable=too-many-branches, too-many-locals, too-many-statements + unrolled_stmts: Sequence[Statement] | Statement, + qubit_register_offsets: dict[str, int], + global_qreg_size_map: dict[str, int], + device_qubits: int | None, + ) -> Sequence[Statement] | Statement: + """Transform statements by mapping qubit registers to device qubit register indices + + Args: + unrolled_stmts : The statements or single statement to transform. + qubit_register_offsets (dict): Mapping from register name to its + offset in the global qubit array. + global_qreg_size_map (dict): original global qubit register mapping. + device_qubits (int): Total number of device qubits + + Returns: + The transformed statements or statement with qubit registers mapped to device indices. + """ + if device_qubits is None: + device_qubits = sum(global_qreg_size_map.values()) + + def _get_pyqasm_device_qubit_index( + reg: str, idx: int, qubit_reg_offsets: dict[str, int], global_qreg: dict[str, int] + ): + _offsets = qubit_reg_offsets + _n_qubits = global_qreg[reg] + if not 0 <= idx < _n_qubits: + raise IndexError(f"{reg}[{idx}] out of range (0..{_n_qubits-1})") + return _offsets[reg] + idx + + if isinstance(unrolled_stmts, QuantumBarrier): + _qubit_id = cast(Identifier, unrolled_stmts.qubits[0]) # type: ignore[union-attr] + if not isinstance(_qubit_id, IndexedIdentifier): + _start = _get_pyqasm_device_qubit_index( + _qubit_id.name, 0, qubit_register_offsets, global_qreg_size_map + ) + _end = _get_pyqasm_device_qubit_index( + _qubit_id.name, + global_qreg_size_map[_qubit_id.name] - 1, + qubit_register_offsets, + global_qreg_size_map, + ) + if _start == 0: + _qubit_id.name = f"__PYQASM_QUBITS__[:{_end+1}]" + elif _end == device_qubits - 1: + _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:]" + else: + _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:{_end+1}]" + else: + _qubit_str = cast(str, unrolled_stmts.qubits[0].name) # type: ignore[union-attr] + _qubit_ind = cast( + list, unrolled_stmts.qubits[0].indices + ) # type: ignore[union-attr] + for multi_ind in _qubit_ind: + for ind in multi_ind: + pyqasm_ind = _get_pyqasm_device_qubit_index( + _qubit_str.name, ind.value, qubit_register_offsets, global_qreg_size_map + ) + ind.value = pyqasm_ind + _qubit_str.name = "__PYQASM_QUBITS__" + return unrolled_stmts + + if isinstance(unrolled_stmts, list): # pylint: disable=too-many-nested-blocks + if isinstance(unrolled_stmts[0], QuantumMeasurementStatement): + for stmt in unrolled_stmts: + _qubit_id = cast( + Identifier, stmt.measure.qubit.name + ) # type: ignore[union-attr] + _qubit_ind = cast(list, stmt.measure.qubit.indices) # type: ignore[union-attr] + for multiple_ind in _qubit_ind: + for ind in multiple_ind: + _pyqasm_val = _get_pyqasm_device_qubit_index( + _qubit_id.name, + ind.value, + qubit_register_offsets, + global_qreg_size_map, + ) + ind.value = _pyqasm_val + _qubit_id.name = "__PYQASM_QUBITS__" + return unrolled_stmts + + if isinstance(unrolled_stmts[0], QuantumReset): + for stmt in unrolled_stmts: + _qubit_str = cast(str, stmt.qubits.name.name) # type: ignore[union-attr] + _qubit_ind = cast(list, stmt.qubits.indices) # type: ignore[union-attr] + for multiple_ind in _qubit_ind: + for ind in multiple_ind: + _pyqasm_val = _get_pyqasm_device_qubit_index( + _qubit_str, ind.value, qubit_register_offsets, global_qreg_size_map + ) + ind.value = _pyqasm_val + stmt.qubits.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] + return unrolled_stmts + + if isinstance(unrolled_stmts[0], QuantumBarrier): + for stmt in unrolled_stmts: + _qubit_ind_id = cast( + IndexedIdentifier, stmt.qubits[0] + ) # type: ignore[union-attr] + _original_qubit_name = _qubit_ind_id.name.name + for multiple_ind in _qubit_ind_id.indices: + for ind in multiple_ind: # type: ignore[union-attr] + ind_val = cast(IntegerLiteral, ind) # type: ignore[union-attr] + pyqasm_val = _get_pyqasm_device_qubit_index( + _original_qubit_name, + ind_val.value, + qubit_register_offsets, + global_qreg_size_map, + ) + ind_val.value = pyqasm_val + _qubit_ind_id.name.name = "__PYQASM_QUBITS__" + return unrolled_stmts + + if isinstance(unrolled_stmts[0], QuantumGate): + unrolled_copy = deepcopy(unrolled_stmts) + for stmt, c_stmt in zip(unrolled_stmts, unrolled_copy): + for qubit, c_qubit in zip(stmt.qubits, c_stmt.qubits): + _original_qubit_name = cast( + Identifier, c_qubit.name + ) # type: ignore[assignment] + for multi_ind, c_multi_ind in zip( + qubit.indices, c_qubit.indices # type: ignore[union-attr] + ): + for ind, c_ind in zip(multi_ind, c_multi_ind): + pyqasm_val = _get_pyqasm_device_qubit_index( + _original_qubit_name.name, + c_ind.value, # type: ignore[union-attr] + qubit_register_offsets, + global_qreg_size_map, + ) + ind.value = pyqasm_val + for stmt in unrolled_stmts: + for qubit in stmt.qubits: + qubit.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] + return unrolled_stmts + + raise ValueError("Unexpected input to transform") diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 1095d1d1..174dd28b 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -77,6 +77,8 @@ class QasmVisitor: external_gates (list[str]): List of gates that should not be unrolled. unroll_barriers (bool): If True, barriers will be unrolled. Defaults to True. check_only (bool): If True, only check the program without executing it. Defaults to False. + device_qubits (int): Number of physical qubits available on the target device. + consolidate_qubits (bool): If True, consolidate all quantum registers into single register. """ def __init__( # pylint: disable=too-many-arguments @@ -86,7 +88,8 @@ def __init__( # pylint: disable=too-many-arguments external_gates: list[str] | None = None, unroll_barriers: bool = True, max_loop_iters: int = int(1e9), - device_qubits: int = 0, + device_qubits: int | None = None, + consolidate_qubits: bool = False, ): self._module = module self._scope: deque = deque([{}]) @@ -114,8 +117,10 @@ def __init__( # pylint: disable=too-many-arguments self._measurement_set: set[str] = set() self._init_utilities() self._loop_limit = max_loop_iters - self._device_qubits: int = device_qubits + self._device_qubits: int | None = device_qubits + self._consolidate_qubits: bool = consolidate_qubits self._in_generic_gate_op_scope: int = 0 + self._qubit_register_offsets: dict[str, int] = {} def _init_utilities(self): """Initialize the utilities for the visitor.""" @@ -355,6 +360,15 @@ def _visit_quantum_register( self._module._add_qubit_register(register_name, register_size) + # Inline: Update offsets after adding a new register if device_qubits is set + if self._consolidate_qubits: + offsets = {} + offset = 0 + for name, n_qubits in self._global_qreg_size_map.items(): + offsets[name] = offset + offset += n_qubits + self._qubit_register_offsets = offsets + logger.debug("Added labels for register '%s'", str(register)) if self._check_only: @@ -562,31 +576,9 @@ def _check_variable_cast_type( span=statement.span, ) - def _get_pyqasm_device_qubit_index(self, reg: str, idx: int) -> int: - """ - Returns qubit index in __PYQASM_QUBITS__ ordering for given quantum register and index. - - Args: - reg (str): The name of the quantum register. - idx (int): The index within the quantum register. - - Returns: - int: The global __PYQASM_QUBITS__ index corresponding to reg[idx]. - - Raises: - IndexError: If the index `idx` is out of range for the register. - """ - _offsets: dict[str, int] = {} - _offset = 0 - for name, n_qubits in self._global_qreg_size_map.items(): - _offsets[name] = _offset - _offset += n_qubits - _n_qubits = self._global_qreg_size_map[reg] - if not 0 <= idx < _n_qubits: - raise IndexError(f"{reg}[{idx}] out of range (0..{_n_qubits-1})") - return _offsets[reg] + idx - - def _qubit_register_consolidation(self, unrolled_stmts: list) -> list[qasm3_ast.Statement]: + def _qubit_register_consolidation( + self, unrolled_stmts: list, total_qubits: int + ) -> list[qasm3_ast.Statement]: """ Consolidate all quantum registers into a single register '__PYQASM_QUBITS__'. @@ -598,28 +590,38 @@ def _qubit_register_consolidation(self, unrolled_stmts: list) -> list[qasm3_ast. or if the reserved register '__PYQASM_QUBITS__' is already declared in the original QASM program. """ - total_qubits = sum(self._global_qreg_size_map.values()) - if total_qubits > self._device_qubits: + if total_qubits > self._device_qubits: # type: ignore raise_qasm3_error( - f"Total qubits '({total_qubits})' exceed device qubits '({self._device_qubits})'." + f"Total qubits '({total_qubits})' exceed device qubits '({self._device_qubits})'.", ) if "__PYQASM_QUBITS__" in self._global_qreg_size_map: raise_qasm3_error( - "Original QASM program already declares reserved register '__PYQASM_QUBITS__'." + "Original QASM program already declares quantum register '__PYQASM_QUBITS__'.", + ) + + if "__PYQASM_QUBITS__" in self._global_creg_size_map: + raise_qasm3_error( + "Original QASM program already declares classical register '__PYQASM_QUBITS__'.", + ) + + global_scope = self._get_global_scope() + if "__PYQASM_QUBITS__" in global_scope: + raise_qasm3_error( + "Variable '__PYQASM_QUBITS__' is already exists", ) - removable_statements = [] - for stmt in unrolled_stmts: - if isinstance(stmt, qasm3_ast.QubitDeclaration): - removable_statements.append(stmt) - for r_stmt in removable_statements: - unrolled_stmts.remove(r_stmt) pyqasm_reg_id = qasm3_ast.Identifier("__PYQASM_QUBITS__") - pyqasm_reg_size = qasm3_ast.IntegerLiteral(self._device_qubits) + pyqasm_reg_size = qasm3_ast.IntegerLiteral(self._device_qubits) # type: ignore pyqasm_reg_stmt = qasm3_ast.QubitDeclaration(pyqasm_reg_id, pyqasm_reg_size) - unrolled_stmts.insert(1, pyqasm_reg_stmt) - return unrolled_stmts + + _valid_statements: list[qasm3_ast.Statement] = [] + _valid_statements.append(pyqasm_reg_stmt) + for stmt in unrolled_stmts: + if not isinstance(stmt, qasm3_ast.QubitDeclaration): + _valid_statements.append(stmt) + + return _valid_statements def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches self, statement: qasm3_ast.QuantumMeasurementStatement @@ -721,17 +723,16 @@ def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches unrolled_measurements.append(unrolled_measure) - if self._device_qubits: - for stmt in unrolled_measurements: - _qubit_id = cast( - qasm3_ast.Identifier, stmt.measure.qubit.name - ) # type: ignore[union-attr] - _qubit_ind = cast(list, stmt.measure.qubit.indices) # type: ignore[union-attr] - for multiple_ind in _qubit_ind: - for ind in multiple_ind: - _pyqasm_val = self._get_pyqasm_device_qubit_index(_qubit_id.name, ind.value) - ind.value = _pyqasm_val - _qubit_id.name = "__PYQASM_QUBITS__" + if self._consolidate_qubits: + unrolled_measurements = cast( + list[qasm3_ast.QuantumMeasurementStatement], + Qasm3Transformer.transform_qubit_reg_in_statemets( + unrolled_measurements, + self._qubit_register_offsets, + self._global_qreg_size_map, + self._device_qubits, + ), + ) if self._check_only: return [] @@ -773,15 +774,16 @@ def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> list[qasm3_ast.Quan unrolled_resets.append(unrolled_reset) - if self._device_qubits: - for stmt in unrolled_resets: - _qubit_str = cast(str, stmt.qubits.name.name) # type: ignore[union-attr] - _qubit_ind = cast(list, stmt.qubits.indices) # type: ignore[union-attr] - for multiple_ind in _qubit_ind: - for ind in multiple_ind: - _pyqasm_val = self._get_pyqasm_device_qubit_index(_qubit_str, ind.value) - ind.value = _pyqasm_val - stmt.qubits.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] + if self._consolidate_qubits: + unrolled_resets = cast( + list[qasm3_ast.QuantumReset], + Qasm3Transformer.transform_qubit_reg_in_statemets( + unrolled_resets, + self._qubit_register_offsets, + self._global_qreg_size_map, + self._device_qubits, + ), + ) if self._check_only: return [] @@ -835,47 +837,28 @@ def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches return [] if not self._unroll_barriers: - if self._device_qubits: - _qubit_id = cast( - qasm3_ast.Identifier, barrier.qubits[0] - ) # type: ignore[union-attr] - if not isinstance(_qubit_id, qasm3_ast.IndexedIdentifier): - _start = self._get_pyqasm_device_qubit_index(_qubit_id.name, 0) - _end = self._get_pyqasm_device_qubit_index( - _qubit_id.name, self._global_qreg_size_map[_qubit_id.name] - 1 - ) - if _start == 0: - _qubit_id.name = f"__PYQASM_QUBITS__[:{_end+1}]" - elif _end == self._device_qubits - 1: - _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:]" - else: - _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:{_end+1}]" - else: - _qubit_str = cast(str, barrier.qubits[0].name) # type: ignore[union-attr] - _qubit_ind = cast(list, barrier.qubits[0].indices) # type: ignore[union-attr] - for multi_ind in _qubit_ind: - for ind in multi_ind: - pyqasm_ind = self._get_pyqasm_device_qubit_index( - _qubit_str.name, ind.value - ) - ind.value = pyqasm_ind - _qubit_str.name = "__PYQASM_QUBITS__" + if self._consolidate_qubits: + barrier = cast( + qasm3_ast.QuantumBarrier, + Qasm3Transformer.transform_qubit_reg_in_statemets( + barrier, + self._qubit_register_offsets, + self._global_qreg_size_map, + self._device_qubits, + ), + ) return [barrier] - if self._device_qubits: - for stmt in unrolled_barriers: - _qubit_ind_id = cast( - qasm3_ast.IndexedIdentifier, stmt.qubits[0] - ) # type: ignore[union-attr] - _original_qubit_name = _qubit_ind_id.name.name - for multiple_ind in _qubit_ind_id.indices: - for ind in multiple_ind: # type: ignore[union-attr] - ind_val = cast(qasm3_ast.IntegerLiteral, ind) # type: ignore[union-attr] - pyqasm_val = self._get_pyqasm_device_qubit_index( - _original_qubit_name, ind_val.value - ) - ind_val.value = pyqasm_val - _qubit_ind_id.name.name = "__PYQASM_QUBITS__" + if self._consolidate_qubits: + unrolled_barriers = cast( + list[qasm3_ast.QuantumBarrier], + Qasm3Transformer.transform_qubit_reg_in_statemets( + unrolled_barriers, + self._qubit_register_offsets, + self._global_qreg_size_map, + self._device_qubits, + ), + ) return unrolled_barriers @@ -1492,22 +1475,17 @@ def _visit_generic_gate_operation( # pylint: disable=too-many-branches, too-man ] result = negs + result + negs # type: ignore self._in_generic_gate_op_scope -= 1 - if self._device_qubits and not self._in_generic_gate_op_scope: - result_copy = copy.deepcopy(result) - for stmt, c_stmt in zip(result, result_copy): - for qubit, c_qubit in zip(stmt.qubits, c_stmt.qubits): - _original_qubit_name = cast(qasm3_ast.Identifier, c_qubit.name) - for multi_ind, c_multi_ind in zip( - qubit.indices, c_qubit.indices # type: ignore[union-attr] - ): - for ind, c_ind in zip(multi_ind, c_multi_ind): - pyqasm_val = self._get_pyqasm_device_qubit_index( - _original_qubit_name.name, c_ind.value # type: ignore[union-attr] - ) - ind.value = pyqasm_val - for stmt in result: - for qubit in stmt.qubits: - qubit.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] + if self._consolidate_qubits and not self._in_generic_gate_op_scope: + result = cast( + list[qasm3_ast.QuantumGate | qasm3_ast.QuantumPhase], + Qasm3Transformer.transform_qubit_reg_in_statemets( + result, + self._qubit_register_offsets, + self._global_qreg_size_map, + self._device_qubits, + ), + ) + if self._check_only: return [] @@ -2603,8 +2581,11 @@ def finalize(self, unrolled_stmts): """ # remove the gphase qubits if they use ALL qubits - if self._device_qubits: - unrolled_stmts = self._qubit_register_consolidation(unrolled_stmts) + if self._consolidate_qubits: + total_qubits = sum(self._global_qreg_size_map.values()) + if self._device_qubits is None: + self._device_qubits = total_qubits + unrolled_stmts = self._qubit_register_consolidation(unrolled_stmts, total_qubits) for stmt in unrolled_stmts: # Rule 1 if isinstance(stmt, qasm3_ast.QuantumPhase): diff --git a/tests/qasm3/test_device_qubits.py b/tests/qasm3/test_device_qubits.py index 3bab18be..167267cb 100644 --- a/tests/qasm3/test_device_qubits.py +++ b/tests/qasm3/test_device_qubits.py @@ -33,8 +33,8 @@ def test_reset(): reset q[1]; """ expected_qasm = """OPENQASM 3.0; - include "stdgates.inc"; qubit[5] __PYQASM_QUBITS__; + include "stdgates.inc"; reset __PYQASM_QUBITS__[2]; reset __PYQASM_QUBITS__[3]; reset __PYQASM_QUBITS__[4]; @@ -42,7 +42,7 @@ def test_reset(): """ result = loads(qasm) - result.unroll(device_qubits=5) + result.unroll(device_qubits=5, consolidate_qubits=True) check_unrolled_qasm(dumps(result), expected_qasm) @@ -55,15 +55,15 @@ def test_barrier(): barrier q[1]; """ expected_qasm = """OPENQASM 3.0; - include "stdgates.inc"; qubit[5] __PYQASM_QUBITS__; + include "stdgates.inc"; barrier __PYQASM_QUBITS__[2]; barrier __PYQASM_QUBITS__[3]; barrier __PYQASM_QUBITS__[4]; barrier __PYQASM_QUBITS__[1]; """ - result = loads(qasm) - result.unroll(device_qubits=5) + result = loads(qasm, device_qubits=5, consolidate_qubits=True) + result.unroll() check_unrolled_qasm(dumps(result), expected_qasm) @@ -79,15 +79,15 @@ def test_unrolled_barrier(): barrier q3; """ expected_qasm = """OPENQASM 3.0; - include "stdgates.inc"; qubit[7] __PYQASM_QUBITS__; + include "stdgates.inc"; barrier __PYQASM_QUBITS__[0]; barrier __PYQASM_QUBITS__[2:5]; barrier __PYQASM_QUBITS__[:2]; barrier __PYQASM_QUBITS__[5:]; """ - result = loads(qasm) - result.unroll(unroll_barriers=False, device_qubits=7) + result = loads(qasm, device_qubits=7, consolidate_qubits=True) + result.unroll(unroll_barriers=False) check_unrolled_qasm(dumps(result), expected_qasm) @@ -104,8 +104,8 @@ def test_measurement(): measure q2[1] -> c[2]; """ expected_qasm = """OPENQASM 3.0; - include "stdgates.inc"; qubit[7] __PYQASM_QUBITS__; + include "stdgates.inc"; bit[3] c; c[0] = measure __PYQASM_QUBITS__[4]; c[1] = measure __PYQASM_QUBITS__[5]; @@ -119,7 +119,7 @@ def test_measurement(): c[2] = measure __PYQASM_QUBITS__[6]; c[2] = measure __PYQASM_QUBITS__[5]; """ - result = loads(qasm) + result = loads(qasm, consolidate_qubits=True) result.unroll(device_qubits=7) check_unrolled_qasm(dumps(result), expected_qasm) @@ -146,8 +146,8 @@ def test_gates(): } """ expected_qasm = """OPENQASM 3.0; - include "stdgates.inc"; qubit[6] __PYQASM_QUBITS__; + include "stdgates.inc"; bit[3] c; x __PYQASM_QUBITS__[3]; cx __PYQASM_QUBITS__[0], __PYQASM_QUBITS__[5]; @@ -202,10 +202,25 @@ def test_gates(): } """ result = loads(qasm) - result.unroll(device_qubits=6) + result.unroll(device_qubits=6, consolidate_qubits=True) check_unrolled_qasm(dumps(result), expected_qasm) +def test_validate(caplog): + with pytest.raises(ValidationError, match=r"Total qubits '4' exceed device qubits '3'."): + with caplog.at_level("ERROR"): + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[4] q; + bit[4] c; + for int i in [0:2] { + h q[0]; + } + """ + loads(qasm3_string, device_qubits=3).validate() + + @pytest.mark.parametrize( "qasm_code,error_message", [ @@ -225,12 +240,31 @@ def test_gates(): qubit[4] data; qubit[2] __PYQASM_QUBITS__; """, - r"Original QASM program already declares reserved register '__PYQASM_QUBITS__'.", + r"Original QASM program already declares quantum register '__PYQASM_QUBITS__'.", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[6] data; + bit[2] __PYQASM_QUBITS__; + """, + r"Original QASM program already declares classical register '__PYQASM_QUBITS__'.", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[6] data; + bit[2] class_data; + int __PYQASM_QUBITS__; + """, + r"Variable '__PYQASM_QUBITS__' is already exists", ), ], ) # pylint: disable-next= too-many-arguments def test_incorrect_qubit_reg_consolidation(qasm_code, error_message, caplog): with pytest.raises(ValidationError) as err: with caplog.at_level("ERROR"): - loads(qasm_code).unroll(device_qubits=6) + loads(qasm_code).unroll(device_qubits=6, consolidate_qubits=True) assert error_message in str(err.value) From a455495058fe97b75edae2d26aa9be147809a502 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Wed, 2 Jul 2025 01:28:11 -0500 Subject: [PATCH 07/10] linting --- src/pyqasm/entrypoint.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index 92bd331e..8105c5b7 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -31,7 +31,7 @@ def load( - filename: str, *, device_qubits: int | None = None, consolidate_qubits: bool = False + filename: str, device_qubits: int | None = None, consolidate_qubits: bool = False ) -> QasmModule: """Loads an OpenQASM program into a `QasmModule` object. @@ -52,8 +52,7 @@ def load( def loads( - program: "openqasm3.ast.Program | str", - *, + program: openqasm3.ast.Program | str, device_qubits: int | None = None, consolidate_qubits: bool = False, ) -> QasmModule: From 0c6dc9cfdc3aca5cfa91443f305d20a0b552d73d Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Wed, 2 Jul 2025 16:21:10 -0500 Subject: [PATCH 08/10] code refactor - Introduced a new `span` attribute in the `Variable` class to capture the span of variables. - Refactored `load` and `loads` functions to accept additional keyword arguments for improved flexibility. - Updated various parts of the codebase to utilize the new `span` attribute during variable creation and validation. - Enhanced unit tests to ensure proper handling of the new functionality and validate error messages related to variable definitions. --- src/pyqasm/elements.py | 2 + src/pyqasm/entrypoint.py | 23 +++------ src/pyqasm/expressions.py | 1 + src/pyqasm/modules/base.py | 19 ++----- src/pyqasm/subroutines.py | 3 ++ src/pyqasm/transformer.py | 42 +++++++-------- src/pyqasm/validator.py | 1 + src/pyqasm/visitor.py | 86 ++++++++++++++++--------------- tests/qasm3/test_device_qubits.py | 48 +++++++++++------ 9 files changed, 115 insertions(+), 110 deletions(-) diff --git a/src/pyqasm/elements.py b/src/pyqasm/elements.py index 3e4936c7..ac084c75 100644 --- a/src/pyqasm/elements.py +++ b/src/pyqasm/elements.py @@ -89,6 +89,7 @@ class Variable: # pylint: disable=too-many-instance-attributes base_size (int): Base size of the variable. dims (Optional[List[int]]): Dimensions of the variable. value (Optional[int | float | np.ndarray]): Value of the variable. + span (Any): Span of the variable. is_constant (bool): Flag indicating if the variable is constant. is_register (bool): Flag indicating if the variable is a register. readonly (bool): Flag indicating if the variable is readonly. @@ -99,6 +100,7 @@ class Variable: # pylint: disable=too-many-instance-attributes base_size: int dims: Optional[list[int]] = None value: Optional[int | float | np.ndarray] = None + span: Any = None is_constant: bool = False is_register: bool = False readonly: bool = False diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index 8105c5b7..b977e937 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -30,15 +30,11 @@ import openqasm3.ast -def load( - filename: str, device_qubits: int | None = None, consolidate_qubits: bool = False -) -> QasmModule: +def load(filename: str, **kwargs) -> QasmModule: """Loads an OpenQASM program into a `QasmModule` object. Args: filename (str): The filename of the OpenQASM program to validate. - device_qubits (int): Number of physical qubits available on the target device. - consolidate_qubits (bool): If True, consolidate all quantum registers into single register. Returns: QasmModule: An object containing the parsed qasm representation along with @@ -48,20 +44,17 @@ def load( raise TypeError("Input 'filename' must be of type 'str'.") with open(filename, "r", encoding="utf-8") as file: program = file.read() - return loads(program, device_qubits=device_qubits, consolidate_qubits=consolidate_qubits) + return loads(program, **kwargs) -def loads( - program: openqasm3.ast.Program | str, - device_qubits: int | None = None, - consolidate_qubits: bool = False, -) -> QasmModule: +def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: """Loads an OpenQASM program into a `QasmModule` object. Args: program (openqasm3.ast.Program or str): The OpenQASM program to validate. + + **kwargs: Additional arguments to pass to the loads function. device_qubits (int): Number of physical qubits available on the target device. - consolidate_qubits (bool): If True, consolidate all quantum registers into single register. Raises: TypeError: If the input is not a string or an `openqasm3.ast.Program` instance. @@ -89,9 +82,9 @@ def loads( qasm_module = Qasm3Module if program.version.startswith("3") else Qasm2Module module = qasm_module("main", program) - # Store device_qubits and consolidate_qubits on the module for later use - module._device_qubits = device_qubits - module._consolidate_qubits = consolidate_qubits + # Store device_qubits on the module for later use + if dev_qbts := kwargs.get("device_qubits"): + module._device_qubits = dev_qbts return module diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index 62926e13..d5642ac4 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -392,6 +392,7 @@ def _check_type_size(expression, var_name, var_format, base_type): dims=[], value=var_value, is_constant=const_expr, + span=expression.span, ) cast_var_value = Qasm3Validator.validate_variable_assignment_value( variable, var_value, expression diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index b8a995c9..137975d6 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -519,19 +519,14 @@ def validate(self): return try: self.num_qubits, self.num_clbits = 0, 0 - visitor = QasmVisitor( - self, - check_only=True, - device_qubits=self._device_qubits, - ) + visitor = QasmVisitor(self, check_only=True) self.accept(visitor) # Implicit validation: check total qubits if device_qubits is set and not consolidating if self._device_qubits: - total_qubits = sum(self._qubit_registers.values()) - if total_qubits > self._device_qubits: + if self.num_qubits > self._device_qubits: raise ValidationError( # pylint: disable-next=line-too-long - f"Total qubits '{total_qubits}' exceed device qubits '{self._device_qubits}'." + f"Total qubits '{self.num_qubits}' exceed device qubits '{self._device_qubits}'." ) except (ValidationError, NotImplementedError) as err: self.num_qubits, self.num_clbits = -1, -1 @@ -562,19 +557,13 @@ def unroll(self, **kwargs): """ if not kwargs: kwargs = {} - # Use module attributes if not overridden by kwargs - if "device_qubits" not in kwargs: - kwargs["device_qubits"] = self._device_qubits - if "consolidate_qubits" not in kwargs: - kwargs["consolidate_qubits"] = self._consolidate_qubits + try: self.num_qubits, self.num_clbits = 0, 0 if ext_gates := kwargs.get("external_gates"): self._external_gates = ext_gates else: self._external_gates = [] - if device_qbts := kwargs.get("device_qubits"): - self._device_qubits = device_qbts if consolidate_qbts := kwargs.get("consolidate_qubits"): self._consolidate_qubits = consolidate_qbts visitor = QasmVisitor(module=self, **kwargs) diff --git a/src/pyqasm/subroutines.py b/src/pyqasm/subroutines.py index 4f4096e3..894a2b4f 100644 --- a/src/pyqasm/subroutines.py +++ b/src/pyqasm/subroutines.py @@ -156,6 +156,7 @@ def _process_classical_arg_by_value( dims=None, value=actual_arg_value, is_constant=False, + span=fn_call.span, ) @classmethod # pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches @@ -346,6 +347,7 @@ def _process_classical_arg_by_reference( dims=formal_dimensions, value=actual_array_view, # this is the VIEW of the actual array readonly=readonly_arr, + span=fn_call.span, ) @classmethod # pylint: disable-next=too-many-arguments @@ -454,4 +456,5 @@ def process_quantum_arg( # pylint: disable=too-many-locals dims=None, value=None, is_constant=False, + span=fn_call.span, ) diff --git a/src/pyqasm/transformer.py b/src/pyqasm/transformer.py index 1de9905f..b6fdaec6 100644 --- a/src/pyqasm/transformer.py +++ b/src/pyqasm/transformer.py @@ -29,6 +29,7 @@ FloatLiteral, Identifier, IndexedIdentifier, + IndexElement, IndexExpression, IntegerLiteral, ) @@ -442,7 +443,7 @@ def get_type_string(variable: Variable) -> str: return type_str @staticmethod - def transform_qubit_reg_in_statemets( # pylint: disable=too-many-branches, too-many-locals, too-many-statements + def consolidate_qubit_registers( # pylint: disable=too-many-branches, too-many-locals, too-many-statements unrolled_stmts: Sequence[Statement] | Statement, qubit_register_offsets: dict[str, int], global_qreg_size_map: dict[str, int], @@ -502,7 +503,6 @@ def _get_pyqasm_device_qubit_index( ) ind.value = pyqasm_ind _qubit_str.name = "__PYQASM_QUBITS__" - return unrolled_stmts if isinstance(unrolled_stmts, list): # pylint: disable=too-many-nested-blocks if isinstance(unrolled_stmts[0], QuantumMeasurementStatement): @@ -521,7 +521,6 @@ def _get_pyqasm_device_qubit_index( ) ind.value = _pyqasm_val _qubit_id.name = "__PYQASM_QUBITS__" - return unrolled_stmts if isinstance(unrolled_stmts[0], QuantumReset): for stmt in unrolled_stmts: @@ -534,7 +533,6 @@ def _get_pyqasm_device_qubit_index( ) ind.value = _pyqasm_val stmt.qubits.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] - return unrolled_stmts if isinstance(unrolled_stmts[0], QuantumBarrier): for stmt in unrolled_stmts: @@ -553,29 +551,27 @@ def _get_pyqasm_device_qubit_index( ) ind_val.value = pyqasm_val _qubit_ind_id.name.name = "__PYQASM_QUBITS__" - return unrolled_stmts if isinstance(unrolled_stmts[0], QuantumGate): - unrolled_copy = deepcopy(unrolled_stmts) - for stmt, c_stmt in zip(unrolled_stmts, unrolled_copy): - for qubit, c_qubit in zip(stmt.qubits, c_stmt.qubits): - _original_qubit_name = cast( - Identifier, c_qubit.name - ) # type: ignore[assignment] - for multi_ind, c_multi_ind in zip( - qubit.indices, c_qubit.indices # type: ignore[union-attr] - ): - for ind, c_ind in zip(multi_ind, c_multi_ind): + for stmt in unrolled_stmts: + stmt_qubits: list[IndexedIdentifier] = [] + for qubit in stmt.qubits: + _original_qbt_name = cast(Identifier, qubit.name) + qubit_indices: list[IndexElement] = [] + for multi_ind in qubit.indices: # type: ignore[union-attr] + qubit_sub_ind: IndexElement = [] + for ind in multi_ind: pyqasm_val = _get_pyqasm_device_qubit_index( - _original_qubit_name.name, - c_ind.value, # type: ignore[union-attr] + _original_qbt_name.name, + ind.value, # type: ignore[union-attr] qubit_register_offsets, global_qreg_size_map, ) - ind.value = pyqasm_val - for stmt in unrolled_stmts: - for qubit in stmt.qubits: - qubit.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr] - return unrolled_stmts + qubit_sub_ind.append(IntegerLiteral(pyqasm_val)) + qubit_indices.append(qubit_sub_ind) + stmt_qubits.append( + IndexedIdentifier(Identifier("__PYQASM_QUBITS__"), qubit_indices) + ) + stmt.qubits = stmt_qubits - raise ValueError("Unexpected input to transform") + return unrolled_stmts diff --git a/src/pyqasm/validator.py b/src/pyqasm/validator.py index 90d0ac44..14c42e55 100644 --- a/src/pyqasm/validator.py +++ b/src/pyqasm/validator.py @@ -335,6 +335,7 @@ def validate_return_statement( # pylint: disable=inconsistent-return-statements base_size, None, None, + span=return_statement.span, ), return_value, op_node=return_statement, diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 174dd28b..026f6be1 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -20,7 +20,7 @@ """ import copy import logging -from collections import deque +from collections import OrderedDict, deque from functools import partial from typing import Any, Callable, Optional, cast @@ -88,7 +88,6 @@ def __init__( # pylint: disable=too-many-arguments external_gates: list[str] | None = None, unroll_barriers: bool = True, max_loop_iters: int = int(1e9), - device_qubits: int | None = None, consolidate_qubits: bool = False, ): self._module = module @@ -117,10 +116,10 @@ def __init__( # pylint: disable=too-many-arguments self._measurement_set: set[str] = set() self._init_utilities() self._loop_limit = max_loop_iters - self._device_qubits: int | None = device_qubits self._consolidate_qubits: bool = consolidate_qubits self._in_generic_gate_op_scope: int = 0 - self._qubit_register_offsets: dict[str, int] = {} + self._qubit_register_offsets: OrderedDict = OrderedDict() + self._qubit_register_max_offset = 0 def _init_utilities(self): """Initialize the utilities for the visitor.""" @@ -346,7 +345,14 @@ def _visit_quantum_register( self._add_var_in_scope( Variable( - register_name, qasm3_ast.QubitDeclaration, register_size, None, None, False, True + register_name, + qasm3_ast.QubitDeclaration, + register_size, + None, + None, + register.span, + False, + True, ) ) size_map[f"{register_name}"] = register_size @@ -362,12 +368,14 @@ def _visit_quantum_register( # Inline: Update offsets after adding a new register if device_qubits is set if self._consolidate_qubits: - offsets = {} - offset = 0 - for name, n_qubits in self._global_qreg_size_map.items(): - offsets[name] = offset - offset += n_qubits - self._qubit_register_offsets = offsets + self._qubit_register_offsets[register_name] = self._qubit_register_max_offset + self._qubit_register_max_offset += register_size + # offsets = {} + # offset = 0 + # for name, n_qubits in self._global_qreg_size_map.items(): + # offsets[name] = offset + # offset += n_qubits + # self._qubit_register_offsets = offsets logger.debug("Added labels for register '%s'", str(register)) @@ -590,29 +598,22 @@ def _qubit_register_consolidation( or if the reserved register '__PYQASM_QUBITS__' is already declared in the original QASM program. """ - if total_qubits > self._device_qubits: # type: ignore + if total_qubits > self._module._device_qubits: # type: ignore raise_qasm3_error( - f"Total qubits '({total_qubits})' exceed device qubits '({self._device_qubits})'.", - ) - - if "__PYQASM_QUBITS__" in self._global_qreg_size_map: - raise_qasm3_error( - "Original QASM program already declares quantum register '__PYQASM_QUBITS__'.", - ) - - if "__PYQASM_QUBITS__" in self._global_creg_size_map: - raise_qasm3_error( - "Original QASM program already declares classical register '__PYQASM_QUBITS__'.", + # pylint: disable-next=line-too-long + f"Total qubits '({total_qubits})' exceed device qubits '({self._module._device_qubits})'.", ) global_scope = self._get_global_scope() - if "__PYQASM_QUBITS__" in global_scope: - raise_qasm3_error( - "Variable '__PYQASM_QUBITS__' is already exists", - ) + for var, val in global_scope.items(): + if var == "__PYQASM_QUBITS__": + raise_qasm3_error( + "Variable '__PYQASM_QUBITS__' is already defined", + span=val.span, + ) pyqasm_reg_id = qasm3_ast.Identifier("__PYQASM_QUBITS__") - pyqasm_reg_size = qasm3_ast.IntegerLiteral(self._device_qubits) # type: ignore + pyqasm_reg_size = qasm3_ast.IntegerLiteral(self._module._device_qubits) # type: ignore pyqasm_reg_stmt = qasm3_ast.QubitDeclaration(pyqasm_reg_id, pyqasm_reg_size) _valid_statements: list[qasm3_ast.Statement] = [] @@ -726,11 +727,11 @@ def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches if self._consolidate_qubits: unrolled_measurements = cast( list[qasm3_ast.QuantumMeasurementStatement], - Qasm3Transformer.transform_qubit_reg_in_statemets( + Qasm3Transformer.consolidate_qubit_registers( unrolled_measurements, self._qubit_register_offsets, self._global_qreg_size_map, - self._device_qubits, + self._module._device_qubits, ), ) @@ -777,11 +778,11 @@ def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> list[qasm3_ast.Quan if self._consolidate_qubits: unrolled_resets = cast( list[qasm3_ast.QuantumReset], - Qasm3Transformer.transform_qubit_reg_in_statemets( + Qasm3Transformer.consolidate_qubit_registers( unrolled_resets, self._qubit_register_offsets, self._global_qreg_size_map, - self._device_qubits, + self._module._device_qubits, ), ) @@ -840,11 +841,11 @@ def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches if self._consolidate_qubits: barrier = cast( qasm3_ast.QuantumBarrier, - Qasm3Transformer.transform_qubit_reg_in_statemets( + Qasm3Transformer.consolidate_qubit_registers( barrier, self._qubit_register_offsets, self._global_qreg_size_map, - self._device_qubits, + self._module._device_qubits, ), ) return [barrier] @@ -852,11 +853,11 @@ def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches if self._consolidate_qubits: unrolled_barriers = cast( list[qasm3_ast.QuantumBarrier], - Qasm3Transformer.transform_qubit_reg_in_statemets( + Qasm3Transformer.consolidate_qubit_registers( unrolled_barriers, self._qubit_register_offsets, self._global_qreg_size_map, - self._device_qubits, + self._module._device_qubits, ), ) @@ -1478,11 +1479,11 @@ def _visit_generic_gate_operation( # pylint: disable=too-many-branches, too-man if self._consolidate_qubits and not self._in_generic_gate_op_scope: result = cast( list[qasm3_ast.QuantumGate | qasm3_ast.QuantumPhase], - Qasm3Transformer.transform_qubit_reg_in_statemets( + Qasm3Transformer.consolidate_qubit_registers( result, self._qubit_register_offsets, self._global_qreg_size_map, - self._device_qubits, + self._module._device_qubits, ), ) @@ -1539,7 +1540,9 @@ def _visit_constant_declaration( statement.init_expression, validate_only=True ) self._check_variable_cast_type(statement, val_type, var_name, base_type, base_size, True) - variable = Variable(var_name, base_type, base_size, [], init_value, is_constant=True) + variable = Variable( + var_name, base_type, base_size, [], init_value, is_constant=True, span=statement.span + ) # cast + validation variable.value = Qasm3Validator.validate_variable_assignment_value( @@ -1670,6 +1673,7 @@ def _visit_classical_declaration( final_dimensions, init_value, is_register=isinstance(base_type, qasm3_ast.BitType), + span=statement.span, ) # validate the assignment @@ -2583,8 +2587,8 @@ def finalize(self, unrolled_stmts): # remove the gphase qubits if they use ALL qubits if self._consolidate_qubits: total_qubits = sum(self._global_qreg_size_map.values()) - if self._device_qubits is None: - self._device_qubits = total_qubits + if self._module._device_qubits is None: + self._module._device_qubits = total_qubits unrolled_stmts = self._qubit_register_consolidation(unrolled_stmts, total_qubits) for stmt in unrolled_stmts: # Rule 1 diff --git a/tests/qasm3/test_device_qubits.py b/tests/qasm3/test_device_qubits.py index 167267cb..9261e00b 100644 --- a/tests/qasm3/test_device_qubits.py +++ b/tests/qasm3/test_device_qubits.py @@ -41,8 +41,8 @@ def test_reset(): reset __PYQASM_QUBITS__[1]; """ - result = loads(qasm) - result.unroll(device_qubits=5, consolidate_qubits=True) + result = loads(qasm, device_qubits=5) + result.unroll(consolidate_qubits=True) check_unrolled_qasm(dumps(result), expected_qasm) @@ -62,8 +62,8 @@ def test_barrier(): barrier __PYQASM_QUBITS__[4]; barrier __PYQASM_QUBITS__[1]; """ - result = loads(qasm, device_qubits=5, consolidate_qubits=True) - result.unroll() + result = loads(qasm, device_qubits=5) + result.unroll(consolidate_qubits=True) check_unrolled_qasm(dumps(result), expected_qasm) @@ -86,8 +86,8 @@ def test_unrolled_barrier(): barrier __PYQASM_QUBITS__[:2]; barrier __PYQASM_QUBITS__[5:]; """ - result = loads(qasm, device_qubits=7, consolidate_qubits=True) - result.unroll(unroll_barriers=False) + result = loads(qasm, device_qubits=7) + result.unroll(unroll_barriers=False, consolidate_qubits=True) check_unrolled_qasm(dumps(result), expected_qasm) @@ -119,8 +119,8 @@ def test_measurement(): c[2] = measure __PYQASM_QUBITS__[6]; c[2] = measure __PYQASM_QUBITS__[5]; """ - result = loads(qasm, consolidate_qubits=True) - result.unroll(device_qubits=7) + result = loads(qasm, device_qubits=7) + result.unroll(consolidate_qubits=True) check_unrolled_qasm(dumps(result), expected_qasm) @@ -201,8 +201,8 @@ def test_gates(): cx __PYQASM_QUBITS__[4], __PYQASM_QUBITS__[2]; } """ - result = loads(qasm) - result.unroll(device_qubits=6, consolidate_qubits=True) + result = loads(qasm, device_qubits=6) + result.unroll(consolidate_qubits=True) check_unrolled_qasm(dumps(result), expected_qasm) @@ -222,7 +222,7 @@ def test_validate(caplog): @pytest.mark.parametrize( - "qasm_code,error_message", + "qasm_code, error_message", [ ( """ @@ -233,6 +233,18 @@ def test_validate(caplog): """, r"Total qubits '(7)' exceed device qubits '(6)'.", ), + ], +) # pylint: disable-next= too-many-arguments +def test_incorrect_device_qubits(qasm_code, error_message, caplog): + with pytest.raises(ValidationError) as err: + with caplog.at_level("ERROR"): + loads(qasm_code, device_qubits=6).unroll(consolidate_qubits=True) + assert error_message in str(err.value) + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ ( """ OPENQASM 3.0; @@ -240,7 +252,8 @@ def test_validate(caplog): qubit[4] data; qubit[2] __PYQASM_QUBITS__; """, - r"Original QASM program already declares quantum register '__PYQASM_QUBITS__'.", + r"Variable '__PYQASM_QUBITS__' is already defined", + r"Error at line 5, column 12", ), ( """ @@ -249,7 +262,8 @@ def test_validate(caplog): qubit[6] data; bit[2] __PYQASM_QUBITS__; """, - r"Original QASM program already declares classical register '__PYQASM_QUBITS__'.", + r"Variable '__PYQASM_QUBITS__' is already defined", + r"Error at line 5, column 12", ), ( """ @@ -259,12 +273,14 @@ def test_validate(caplog): bit[2] class_data; int __PYQASM_QUBITS__; """, - r"Variable '__PYQASM_QUBITS__' is already exists", + r"Variable '__PYQASM_QUBITS__' is already defined", + r"Error at line 6, column 12", ), ], ) # pylint: disable-next= too-many-arguments -def test_incorrect_qubit_reg_consolidation(qasm_code, error_message, caplog): +def test_incorrect_qubit_reg(qasm_code, error_message, error_span, caplog): with pytest.raises(ValidationError) as err: with caplog.at_level("ERROR"): - loads(qasm_code).unroll(device_qubits=6, consolidate_qubits=True) + loads(qasm_code, device_qubits=6).unroll(consolidate_qubits=True) assert error_message in str(err.value) + assert error_span in caplog.text From 5894c3ced6ae795cfe07e8c4b894f1fd4443fe22 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Wed, 2 Jul 2025 16:35:58 -0500 Subject: [PATCH 09/10] linting --- src/pyqasm/entrypoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index b977e937..5facd8f9 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -53,8 +53,8 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: Args: program (openqasm3.ast.Program or str): The OpenQASM program to validate. - **kwargs: Additional arguments to pass to the loads function. - device_qubits (int): Number of physical qubits available on the target device. + **kwargs: Additional arguments to pass to the loads function. + device_qubits (int): Number of physical qubits available on the target device. Raises: TypeError: If the input is not a string or an `openqasm3.ast.Program` instance. From 06a921d3c1ff58cd7f3de8df472e39142a1fcbeb Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Thu, 3 Jul 2025 02:32:04 -0500 Subject: [PATCH 10/10] code refactor --- src/pyqasm/transformer.py | 24 +++++++++--------------- src/pyqasm/visitor.py | 10 +++------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/pyqasm/transformer.py b/src/pyqasm/transformer.py index b6fdaec6..2c9087c7 100644 --- a/src/pyqasm/transformer.py +++ b/src/pyqasm/transformer.py @@ -29,7 +29,6 @@ FloatLiteral, Identifier, IndexedIdentifier, - IndexElement, IndexExpression, IntegerLiteral, ) @@ -556,21 +555,16 @@ def _get_pyqasm_device_qubit_index( for stmt in unrolled_stmts: stmt_qubits: list[IndexedIdentifier] = [] for qubit in stmt.qubits: - _original_qbt_name = cast(Identifier, qubit.name) - qubit_indices: list[IndexElement] = [] - for multi_ind in qubit.indices: # type: ignore[union-attr] - qubit_sub_ind: IndexElement = [] - for ind in multi_ind: - pyqasm_val = _get_pyqasm_device_qubit_index( - _original_qbt_name.name, - ind.value, # type: ignore[union-attr] - qubit_register_offsets, - global_qreg_size_map, - ) - qubit_sub_ind.append(IntegerLiteral(pyqasm_val)) - qubit_indices.append(qubit_sub_ind) + pyqasm_val = _get_pyqasm_device_qubit_index( + qubit.name.name, + qubit.indices[0][0].value, + qubit_register_offsets, + global_qreg_size_map, + ) stmt_qubits.append( - IndexedIdentifier(Identifier("__PYQASM_QUBITS__"), qubit_indices) + IndexedIdentifier( + Identifier("__PYQASM_QUBITS__"), [[IntegerLiteral(pyqasm_val)]] + ) ) stmt.qubits = stmt_qubits diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 026f6be1..6fb16d83 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -366,16 +366,12 @@ def _visit_quantum_register( self._module._add_qubit_register(register_name, register_size) - # Inline: Update offsets after adding a new register if device_qubits is set + # _qubit_register_offsets maps each original quantum register to its + # starting index in the consolidated register, enabling correct + # translation of qubit indices after consolidation. if self._consolidate_qubits: self._qubit_register_offsets[register_name] = self._qubit_register_max_offset self._qubit_register_max_offset += register_size - # offsets = {} - # offset = 0 - # for name, n_qubits in self._global_qreg_size_map.items(): - # offsets[name] = offset - # offset += n_qubits - # self._qubit_register_offsets = offsets logger.debug("Added labels for register '%s'", str(register))