Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,38 @@ Types of changes:

### Added
- Added the `pulse` extra dependency to the `pyproject.toml` file, which includes the `openpulse` package. This allows users to install pulse-related functionality when needed. ([#195](https://github.com/qBraid/pyqasm/pull/195))
- Added support for unrolling `while` loops with compile time condition evaluation. Users can now use `unroll` on while loops which do not have conditions depending on quantum measurements. ([#206](https://github.com/qBraid/pyqasm/pull/206)) Eg. -

```python
import pyqasm

qasm_str = """
OPENQASM 3.0;
qubit[4] q;
int i = 0;
while (i < 3) {
h q[i];
cx q[i], q[i+1];
i += 1;
}

"""
result = pyqasm.loads(qasm_str)
result.unroll()
print(result)

# **Output**

# OPENQASM 3.0;
# qubit[4] q;
# h q[0];
# cx q[0], q[1];
# h q[1];
# cx q[1], q[2];
# h q[2];
# cx q[2], q[3];

```

### Improved / Modified

Expand Down
3 changes: 2 additions & 1 deletion src/pyqasm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@
__version__ = version("pyqasm")

from .entrypoint import dump, dumps, load, loads
from .exceptions import PyQasmError, QasmParsingError, ValidationError
from .exceptions import LoopLimitExceededError, PyQasmError, QasmParsingError, ValidationError
from .modules import Qasm2Module, Qasm3Module, QasmModule
from .printer import draw

__all__ = [
"PyQasmError",
"ValidationError",
"LoopLimitExceededError",
"QasmParsingError",
"load",
"loads",
Expand Down
25 changes: 25 additions & 0 deletions src/pyqasm/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np
from openqasm3.ast import (
BinaryExpression,
DiscreteSet,
Expression,
Identifier,
Expand All @@ -34,6 +35,7 @@
QuantumMeasurementStatement,
RangeDefinition,
Span,
UnaryExpression,
)

from pyqasm.exceptions import QasmParsingError, ValidationError, raise_qasm3_error
Expand Down Expand Up @@ -292,3 +294,26 @@ def verify_gate_qubits(gate: QuantumGate, span: Optional[Span] = None):
error_node=gate,
span=span,
)

@staticmethod
def condition_depends_on_measurement(condition: Expression, measurement_set: set[str]) -> bool:
"""Recursively check if the condition depends on a classical register set by measurement."""

def _depends(expr) -> bool:
if isinstance(expr, Identifier):
return expr.name in measurement_set

if isinstance(expr, IndexExpression):
# Check if the collection being indexed is in the measurement set
if isinstance(expr.collection, Identifier):
return expr.collection.name in measurement_set
return _depends(expr.collection) or _depends(expr.index)

if isinstance(expr, BinaryExpression):
return _depends(expr.lhs) or _depends(expr.rhs)

if isinstance(expr, UnaryExpression):
return _depends(expr.expression)
return False

return _depends(condition)
34 changes: 34 additions & 0 deletions src/pyqasm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,40 @@ class QasmParsingError(QASM3ParsingError):
where the given program could not be correctly parsed."""


class LoopLimitExceededError(PyQasmError):
"""Exception raised when a loop limit is exceeded during unrolling or other operations."""

def __init__(self, message: str = "Loop limit exceeded."):
super().__init__(message)


class LoopControlSignal(Exception):
"""Base class for loop control signals like break and continue.
This class is used to signal control flow changes within loops during AST traversal."""

def __init__(self, signal_type: str):
assert signal_type in ("break", "continue")
self.signal_type = signal_type


class BreakSignal(LoopControlSignal):
"""Signal to break out of a loop during AST traversal."""

def __init__(self, msg: Optional[str] = None):
if msg is None:
msg = "break"
super().__init__(msg)


class ContinueSignal(LoopControlSignal):
"""Signal to continue to the next iteration of a loop during AST traversal."""

def __init__(self, msg: Optional[str] = None):
if msg is None:
msg = "continue"
super().__init__("continue")


def raise_qasm3_error(
message: Optional[str] = None,
err_type: Type[Exception] = ValidationError,
Expand Down
1 change: 1 addition & 0 deletions src/pyqasm/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def unroll(self, **kwargs):
**kwargs: Additional arguments to pass to the QasmVisitor.
external_gates (list[str]): List of gates that should not be unrolled.
unroll_barriers (bool): If True, barriers will be unrolled. Defaults to True.
max_loop_iters (int): Max number of iterations for unrolling loops. Defaults to 1e9.
check_only (bool): If True, only check the program without executing it.
Defaults to False.

Expand Down
106 changes: 100 additions & 6 deletions src/pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,28 @@
from openqasm3.printer import dumps

from pyqasm.analyzer import Qasm3Analyzer
from pyqasm.elements import ClbitDepthNode, Context, InversionOp, QubitDepthNode, Variable
from pyqasm.exceptions import ValidationError, raise_qasm3_error
from pyqasm.elements import (
ClbitDepthNode,
Context,
InversionOp,
QubitDepthNode,
Variable,
)
from pyqasm.exceptions import (
BreakSignal,
ContinueSignal,
LoopControlSignal,
LoopLimitExceededError,
ValidationError,
raise_qasm3_error,
)
from pyqasm.expressions import Qasm3ExprEvaluator
from pyqasm.maps import SWITCH_BLACKLIST_STMTS
from pyqasm.maps.expressions import ARRAY_TYPE_MAP, CONSTANTS_MAP, MAX_ARRAY_DIMENSIONS
from pyqasm.maps.expressions import (
ARRAY_TYPE_MAP,
CONSTANTS_MAP,
MAX_ARRAY_DIMENSIONS,
)
from pyqasm.maps.gates import (
map_qasm_ctrl_op_to_callable,
map_qasm_inv_op_to_callable,
Expand Down Expand Up @@ -62,12 +79,13 @@ class QasmVisitor:
check_only (bool): If True, only check the program without executing it. Defaults to False.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
module,
check_only: bool = False,
external_gates: list[str] | None = None,
unroll_barriers: bool = True,
max_loop_iters: int = int(1e9),
):
self._module = module
self._scope: deque = deque([{}])
Expand All @@ -92,7 +110,9 @@ def __init__(
self._in_branching_statement: int = 0
self._is_branch_qubits: set[tuple[str, int]] = set()
self._is_branch_clbits: set[tuple[str, int]] = set()
self._measurement_set: set[str] = set()
self._init_utilities()
self._loop_limit = max_loop_iters

def _init_utilities(self):
"""Initialize the utilities for the visitor."""
Expand Down Expand Up @@ -555,6 +575,11 @@ def _visit_measurement( # pylint: disable=too-many-locals
qubit_node.depth = max(qubit_node.depth, clbit_node.depth)
clbit_node.depth = max(qubit_node.depth, clbit_node.depth)

if isinstance(target, qasm3_ast.Identifier):
self._measurement_set.add(target.name)
elif isinstance(target, qasm3_ast.IndexedIdentifier):
self._measurement_set.add(target.name.name)

unrolled_measurements.append(unrolled_measure)

if self._check_only:
Expand Down Expand Up @@ -878,6 +903,18 @@ def _visit_basic_gate_operation(

return result

def _visit_break(self, statement: qasm3_ast.BreakStatement) -> None:
raise_qasm3_error(
err_type=BreakSignal,
error_node=statement,
)

def _visit_continue(self, statement: qasm3_ast.ContinueStatement) -> None:
raise_qasm3_error(
err_type=ContinueSignal,
error_node=statement,
)

def _visit_custom_gate_operation(
self,
operation: qasm3_ast.QuantumGate,
Expand Down Expand Up @@ -1990,8 +2027,62 @@ def _visit_function_call(

return return_value, result

def _visit_while_loop(self, statement: qasm3_ast.WhileLoop) -> None:
pass
def _visit_while_loop(self, statement: qasm3_ast.WhileLoop) -> list[qasm3_ast.Statement]:
"""Visit a while-loop element.

Args:
statement (qasm3_ast.WhileLoop) - the while-loop AST node
Returns:
list[qasm3_ast.Statement] - flattened/unrolled statements
Raises:
ValidationError - if loop condition is non-classical or dynamic
LoopLimitExceededError - if the loop exceeds the maximum limit"""

result = []

loop_counter = 0
max_iterations = self._loop_limit

if Qasm3Analyzer.condition_depends_on_measurement(
statement.while_condition, self._measurement_set
):
raise_qasm3_error(
"Cannot unroll while-loop with condition depending on quantum measurement result.",
error_node=statement,
span=statement.span,
)

while True:
cond_value = Qasm3ExprEvaluator.evaluate_expression(statement.while_condition)[0]
if not cond_value:
break

self._push_context(Context.BLOCK)
self._push_scope({})

try:
result.extend(self.visit_basic_block(statement.block))
except LoopControlSignal as lcs:
self._pop_scope()
self._restore_context()
if lcs.signal_type == "break":
break
if lcs.signal_type == "continue":
continue

self._pop_scope()
self._restore_context()

loop_counter += 1
if loop_counter >= max_iterations:
raise_qasm3_error(
"Loop exceeded max allowed iterations",
err_type=LoopLimitExceededError,
error_node=statement,
span=statement.span,
)

return result

def _visit_alias_statement(self, statement: qasm3_ast.AliasStatement) -> list[None]:
"""Visit an alias statement element.
Expand Down Expand Up @@ -2232,11 +2323,14 @@ def visit_statement(self, statement: qasm3_ast.Statement) -> list[qasm3_ast.Stat
qasm3_ast.ConstantDeclaration: self._visit_constant_declaration,
qasm3_ast.BranchingStatement: self._visit_branching_statement,
qasm3_ast.ForInLoop: self._visit_forin_loop,
qasm3_ast.WhileLoop: self._visit_while_loop,
qasm3_ast.AliasStatement: self._visit_alias_statement,
qasm3_ast.SwitchStatement: self._visit_switch_statement,
qasm3_ast.SubroutineDefinition: self._visit_subroutine_definition,
qasm3_ast.ExpressionStatement: lambda x: self._visit_function_call(x.expression),
qasm3_ast.IODeclaration: lambda x: [],
qasm3_ast.BreakStatement: self._visit_break,
qasm3_ast.ContinueStatement: self._visit_continue,
}

visitor_function = visit_map.get(type(statement))
Expand Down
Loading