Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Types of changes:
### Improved / Modified
- Added `slots=True` parameter to the data classes in `elements.py` to improve memory efficiency ([#218](https://github.com/qBraid/pyqasm/pull/218))
- Updated the documentation to include core features in the `README` ([#219](https://github.com/qBraid/pyqasm/pull/219))
- Added support to `device qubit` resgister consolidation.([#222](https://github.com/qBraid/pyqasm/pull/222))

### Deprecated

Expand Down
2 changes: 2 additions & 0 deletions src/pyqasm/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class Variable: # pylint: disable=too-many-instance-attributes
base_size (int): Base size of the variable.
dims (Optional[List[int]]): Dimensions of the variable.
value (Optional[int | float | np.ndarray]): Value of the variable.
span (Any): Span of the variable.
is_constant (bool): Flag indicating if the variable is constant.
is_register (bool): Flag indicating if the variable is a register.
readonly (bool): Flag indicating if the variable is readonly.
Expand All @@ -99,6 +100,7 @@ class Variable: # pylint: disable=too-many-instance-attributes
base_size: int
dims: Optional[list[int]] = None
value: Optional[int | float | np.ndarray] = None
span: Any = None
is_constant: bool = False
is_register: bool = False
readonly: bool = False
Expand Down
13 changes: 9 additions & 4 deletions src/pyqasm/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import openqasm3.ast


def load(filename: str) -> QasmModule:
def load(filename: str, **kwargs) -> QasmModule:
"""Loads an OpenQASM program into a `QasmModule` object.

Args:
Expand All @@ -44,15 +44,18 @@ def load(filename: str) -> QasmModule:
raise TypeError("Input 'filename' must be of type 'str'.")
with open(filename, "r", encoding="utf-8") as file:
program = file.read()
return loads(program)
return loads(program, **kwargs)


def loads(program: openqasm3.ast.Program | str) -> QasmModule:
def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule:
"""Loads an OpenQASM program into a `QasmModule` object.

Args:
program (openqasm3.ast.Program or str): The OpenQASM program to validate.

**kwargs: Additional arguments to pass to the loads function.
device_qubits (int): Number of physical qubits available on the target device.

Raises:
TypeError: If the input is not a string or an `openqasm3.ast.Program` instance.
ValidationError: If the program fails parsing or semantic validation.
Expand All @@ -79,7 +82,9 @@ def loads(program: openqasm3.ast.Program | str) -> QasmModule:

qasm_module = Qasm3Module if program.version.startswith("3") else Qasm2Module
module = qasm_module("main", program)

# Store device_qubits on the module for later use
if dev_qbts := kwargs.get("device_qubits"):
module._device_qubits = dev_qbts
return module


Expand Down
1 change: 1 addition & 0 deletions src/pyqasm/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def _check_type_size(expression, var_name, var_format, base_type):
dims=[],
value=var_value,
is_constant=const_expr,
span=expression.span,
)
cast_var_value = Qasm3Validator.validate_variable_assignment_value(
variable, var_value, expression
Expand Down
15 changes: 15 additions & 0 deletions src/pyqasm/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(self, name: str, program: Program):
self._unrolled_ast = Program(statements=[])
self._external_gates: list[str] = []
self._decompose_native_gates: Optional[bool] = None
self._device_qubits: Optional[int] = None
self._consolidate_qubits: Optional[bool] = False

@property
def name(self) -> str:
Expand Down Expand Up @@ -519,6 +521,13 @@ def validate(self):
self.num_qubits, self.num_clbits = 0, 0
visitor = QasmVisitor(self, check_only=True)
self.accept(visitor)
# Implicit validation: check total qubits if device_qubits is set and not consolidating
if self._device_qubits:
if self.num_qubits > self._device_qubits:
raise ValidationError(
# pylint: disable-next=line-too-long
f"Total qubits '{self.num_qubits}' exceed device qubits '{self._device_qubits}'."
)
except (ValidationError, NotImplementedError) as err:
self.num_qubits, self.num_clbits = -1, -1
raise err
Expand All @@ -534,6 +543,9 @@ def unroll(self, **kwargs):
max_loop_iters (int): Max number of iterations for unrolling loops. Defaults to 1e9.
check_only (bool): If True, only check the program without executing it.
Defaults to False.
device_qubits (int): Number of physical qubits available on the target device.
consolidate_qubits (bool): If True, consolidate all quantum registers into
single register.

Raises:
ValidationError: If the module fails validation during unrolling.
Expand All @@ -545,12 +557,15 @@ def unroll(self, **kwargs):
"""
if not kwargs:
kwargs = {}

try:
self.num_qubits, self.num_clbits = 0, 0
if ext_gates := kwargs.get("external_gates"):
self._external_gates = ext_gates
else:
self._external_gates = []
if consolidate_qbts := kwargs.get("consolidate_qubits"):
self._consolidate_qubits = consolidate_qbts
visitor = QasmVisitor(module=self, **kwargs)
self.accept(visitor)
except (ValidationError, UnrollError) as err:
Expand Down
3 changes: 3 additions & 0 deletions src/pyqasm/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _process_classical_arg_by_value(
dims=None,
value=actual_arg_value,
is_constant=False,
span=fn_call.span,
)

@classmethod # pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches
Expand Down Expand Up @@ -346,6 +347,7 @@ def _process_classical_arg_by_reference(
dims=formal_dimensions,
value=actual_array_view, # this is the VIEW of the actual array
readonly=readonly_arr,
span=fn_call.span,
)

@classmethod # pylint: disable-next=too-many-arguments
Expand Down Expand Up @@ -454,4 +456,5 @@ def process_quantum_arg( # pylint: disable=too-many-locals
dims=None,
value=None,
is_constant=False,
span=fn_call.span,
)
133 changes: 132 additions & 1 deletion src/pyqasm/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

"""
from copy import deepcopy
from typing import Any, NamedTuple, Optional
from typing import Any, NamedTuple, Optional, Sequence, cast

import numpy as np
from openqasm3.ast import (
Expand All @@ -37,9 +37,11 @@
QASMNode,
QuantumBarrier,
QuantumGate,
QuantumMeasurementStatement,
QuantumPhase,
QuantumReset,
RangeDefinition,
Statement,
UintType,
UnaryExpression,
UnaryOperator,
Expand Down Expand Up @@ -438,3 +440,132 @@ def get_type_string(variable: Variable) -> str:
if is_array:
type_str += f", {', '.join([str(dim) for dim in dims])}]"
return type_str

@staticmethod
def consolidate_qubit_registers( # pylint: disable=too-many-branches, too-many-locals, too-many-statements
unrolled_stmts: Sequence[Statement] | Statement,
qubit_register_offsets: dict[str, int],
global_qreg_size_map: dict[str, int],
device_qubits: int | None,
) -> Sequence[Statement] | Statement:
"""Transform statements by mapping qubit registers to device qubit register indices

Args:
unrolled_stmts : The statements or single statement to transform.
qubit_register_offsets (dict): Mapping from register name to its
offset in the global qubit array.
global_qreg_size_map (dict): original global qubit register mapping.
device_qubits (int): Total number of device qubits

Returns:
The transformed statements or statement with qubit registers mapped to device indices.
"""
if device_qubits is None:
device_qubits = sum(global_qreg_size_map.values())

def _get_pyqasm_device_qubit_index(
reg: str, idx: int, qubit_reg_offsets: dict[str, int], global_qreg: dict[str, int]
):
_offsets = qubit_reg_offsets
_n_qubits = global_qreg[reg]
if not 0 <= idx < _n_qubits:
raise IndexError(f"{reg}[{idx}] out of range (0..{_n_qubits-1})")
return _offsets[reg] + idx

if isinstance(unrolled_stmts, QuantumBarrier):
_qubit_id = cast(Identifier, unrolled_stmts.qubits[0]) # type: ignore[union-attr]
if not isinstance(_qubit_id, IndexedIdentifier):
_start = _get_pyqasm_device_qubit_index(
_qubit_id.name, 0, qubit_register_offsets, global_qreg_size_map
)
_end = _get_pyqasm_device_qubit_index(
_qubit_id.name,
global_qreg_size_map[_qubit_id.name] - 1,
qubit_register_offsets,
global_qreg_size_map,
)
if _start == 0:
_qubit_id.name = f"__PYQASM_QUBITS__[:{_end+1}]"
elif _end == device_qubits - 1:
_qubit_id.name = f"__PYQASM_QUBITS__[{_start}:]"
else:
_qubit_id.name = f"__PYQASM_QUBITS__[{_start}:{_end+1}]"
else:
_qubit_str = cast(str, unrolled_stmts.qubits[0].name) # type: ignore[union-attr]
_qubit_ind = cast(
list, unrolled_stmts.qubits[0].indices
) # type: ignore[union-attr]
for multi_ind in _qubit_ind:
for ind in multi_ind:
pyqasm_ind = _get_pyqasm_device_qubit_index(
_qubit_str.name, ind.value, qubit_register_offsets, global_qreg_size_map
)
ind.value = pyqasm_ind
_qubit_str.name = "__PYQASM_QUBITS__"

if isinstance(unrolled_stmts, list): # pylint: disable=too-many-nested-blocks
if isinstance(unrolled_stmts[0], QuantumMeasurementStatement):
for stmt in unrolled_stmts:
_qubit_id = cast(
Identifier, stmt.measure.qubit.name
) # type: ignore[union-attr]
_qubit_ind = cast(list, stmt.measure.qubit.indices) # type: ignore[union-attr]
for multiple_ind in _qubit_ind:
for ind in multiple_ind:
_pyqasm_val = _get_pyqasm_device_qubit_index(
_qubit_id.name,
ind.value,
qubit_register_offsets,
global_qreg_size_map,
)
ind.value = _pyqasm_val
_qubit_id.name = "__PYQASM_QUBITS__"

if isinstance(unrolled_stmts[0], QuantumReset):
for stmt in unrolled_stmts:
_qubit_str = cast(str, stmt.qubits.name.name) # type: ignore[union-attr]
_qubit_ind = cast(list, stmt.qubits.indices) # type: ignore[union-attr]
for multiple_ind in _qubit_ind:
for ind in multiple_ind:
_pyqasm_val = _get_pyqasm_device_qubit_index(
_qubit_str, ind.value, qubit_register_offsets, global_qreg_size_map
)
ind.value = _pyqasm_val
stmt.qubits.name.name = "__PYQASM_QUBITS__" # type: ignore[union-attr]

if isinstance(unrolled_stmts[0], QuantumBarrier):
for stmt in unrolled_stmts:
_qubit_ind_id = cast(
IndexedIdentifier, stmt.qubits[0]
) # type: ignore[union-attr]
_original_qubit_name = _qubit_ind_id.name.name
for multiple_ind in _qubit_ind_id.indices:
for ind in multiple_ind: # type: ignore[union-attr]
ind_val = cast(IntegerLiteral, ind) # type: ignore[union-attr]
pyqasm_val = _get_pyqasm_device_qubit_index(
_original_qubit_name,
ind_val.value,
qubit_register_offsets,
global_qreg_size_map,
)
ind_val.value = pyqasm_val
_qubit_ind_id.name.name = "__PYQASM_QUBITS__"

if isinstance(unrolled_stmts[0], QuantumGate):
for stmt in unrolled_stmts:
stmt_qubits: list[IndexedIdentifier] = []
for qubit in stmt.qubits:
pyqasm_val = _get_pyqasm_device_qubit_index(
qubit.name.name,
qubit.indices[0][0].value,
qubit_register_offsets,
global_qreg_size_map,
)
stmt_qubits.append(
IndexedIdentifier(
Identifier("__PYQASM_QUBITS__"), [[IntegerLiteral(pyqasm_val)]]
)
)
stmt.qubits = stmt_qubits

return unrolled_stmts
1 change: 1 addition & 0 deletions src/pyqasm/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def validate_return_statement( # pylint: disable=inconsistent-return-statements
base_size,
None,
None,
span=return_statement.span,
),
return_value,
op_node=return_statement,
Expand Down
Loading