Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Types of changes:

### Fixed
- Fixed Complex value initialization error. ([#253](https://github.com/qBraid/pyqasm/pull/253))
- Fixed duplicate qubit argument check in function calls and Error in function call with aliased qubit. ([#260](https://github.com/qBraid/pyqasm/pull/260))


### Dependencies
- Bumps `@actions/checkout` from 4 to 5 ([#250](https://github.com/qBraid/pyqasm/pull/250))
Expand Down
67 changes: 59 additions & 8 deletions src/pyqasm/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from pyqasm.exceptions import ValidationError, raise_qasm3_error
from pyqasm.expressions import Qasm3ExprEvaluator
from pyqasm.transformer import Qasm3Transformer
from pyqasm.validator import Qasm3Validator


class Qasm3SubroutineProcessor:
Expand All @@ -65,6 +64,41 @@ def set_visitor_obj(cls, visitor_obj) -> None:
"""
cls.visitor_obj = visitor_obj

@staticmethod
def validate_unique_qubits(qubit_map: dict, reg_name: str, indices: list) -> bool:
"""
Validate that qubits passed for a given actual register are unique across
all quantum arguments in a function call and within the same argument itself.

This function mutates the provided `qubit_map` by tracking which indices of
each register have already been used while validating earlier arguments.

Args:
qubit_map (dict): Map used for duplicate detection; keys are register names,
values are sets of previously seen indices for that register.
reg_name (str): Actual register name appearing in the call (e.g., 'q').
indices (list): Concrete qubit indices being bound for this argument.

Returns:
bool: False if any duplicate is detected (within this argument or across
previously processed arguments); True otherwise. On success, the
map is updated with the new indices for subsequent checks.
"""
seen = qubit_map.setdefault(reg_name, set())

# Reject duplicates within the same argument (e.g., q[0], q[0]).
if len(set(indices)) != len(indices):
return False

# Reject duplicates against indices already seen for this register.
for idx in indices:
if idx in seen:
return False

# Update the seen set so subsequent arguments are validated against it.
seen.update(indices)
return True

@staticmethod
def get_fn_actual_arg_name(actual_arg: Identifier | IndexExpression) -> Optional[str]:
"""Get the name of the actual argument passed to a function.
Expand Down Expand Up @@ -525,8 +559,15 @@ def process_quantum_arg( # pylint: disable=too-many-locals
span=fn_call.span,
)

# Include alias register sizes when resolving actual target qubits
# so that aliased identifiers like `let a = q[i]; dummy(a);` are valid.
merged_size_map = {
**actual_qreg_size_map,
**getattr(cls.visitor_obj, "_global_alias_size_map", {}),
}

actual_qids, actual_qubits_size = Qasm3Transformer.get_target_qubits(
actual_arg, actual_qreg_size_map, actual_arg_name
actual_arg, merged_size_map, actual_arg_name
)

if formal_qubit_size != actual_qubits_size:
Expand All @@ -540,18 +581,28 @@ def process_quantum_arg( # pylint: disable=too-many-locals
span=fn_call.span,
)

if not Qasm3Validator.validate_unique_qubits(
duplicate_qubit_map, actual_arg_name, actual_qids
):
# If the actual argument is an alias, resolve to the underlying
# register name and indices for duplicate detection and mapping.
resolved_reg_name = actual_arg_name
resolved_qids = list(actual_qids)
if getattr(actual_arg_var, "is_alias", False):
resolved_pairs = [
cls.visitor_obj._alias_qubit_labels[(actual_arg_name, qid)] for qid in actual_qids
]
# All alias pairs point to the same underlying register
resolved_reg_name = resolved_pairs[0][0] if resolved_pairs else actual_arg_name
resolved_qids = [pair[1] for pair in resolved_pairs]

if not cls.validate_unique_qubits(duplicate_qubit_map, resolved_reg_name, resolved_qids):
raise_qasm3_error(
f"Duplicate qubit argument for register '{actual_arg_name}' "
f"Duplicate qubit argument for register '{resolved_reg_name}' "
f"in function call for '{fn_name}'",
error_node=fn_call,
span=fn_call.span,
)

for idx, qid in enumerate(actual_qids):
qubit_transform_map[(formal_reg_name, idx)] = (actual_arg_name, qid)
for idx, qid in enumerate(resolved_qids):
qubit_transform_map[(formal_reg_name, idx)] = (resolved_reg_name, qid)

return Variable(
name=formal_reg_name,
Expand Down
4 changes: 3 additions & 1 deletion src/pyqasm/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def get_target_qubits(
qid, qreg_size_map[target_name], qubit=True, op_node=target
)
target_qubits_size = len(target_qids)
elif isinstance(target.index[0], (IntegerLiteral, Identifier)): # "(q[0]); OR (q[i]);"
elif isinstance(
target.index[0], (IntegerLiteral, Identifier, BinaryExpression)
): # "(q[0]); OR (q[i]); OR (q[i+1]);"
target_qids = [Qasm3ExprEvaluator.evaluate_expression(target.index[0])[0]]
Qasm3Validator.validate_register_index(
target_qids[0], qreg_size_map[target_name], qubit=True, op_node=target
Expand Down
21 changes: 0 additions & 21 deletions src/pyqasm/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,24 +340,3 @@ def validate_return_statement( # pylint: disable=inconsistent-return-statements
return_value,
op_node=return_statement,
)

@staticmethod
def validate_unique_qubits(qubit_map: dict, reg_name: str, indices: list) -> bool:
"""
Validates that the qubits in the given register are unique.

Args:
qubit_map (dict): Dictionary of qubits.
reg_name (str): The name of the register.
indices (list): A list of indices representing the qubits.

Returns:
bool: True if the qubits are unique, False otherwise.
"""
if reg_name not in qubit_map:
qubit_map[reg_name] = set(indices)
else:
for idx in indices:
if idx in qubit_map[reg_name]:
return False
return True
56 changes: 23 additions & 33 deletions src/pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2516,48 +2516,35 @@ def _visit_alias_statement(self, statement: qasm3_ast.AliasStatement) -> list[No
)
aliased_reg_size = self._global_qreg_size_map[aliased_reg_name]
if isinstance(value, qasm3_ast.Identifier): # "let alias = q;"
for i in range(aliased_reg_size):
self._alias_qubit_labels[(alias_reg_name, i)] = (aliased_reg_name, i)
target_qids = list(range(aliased_reg_size))
alias_reg_size = aliased_reg_size
elif isinstance(value, qasm3_ast.IndexExpression):
if isinstance(value.index, qasm3_ast.DiscreteSet): # "let alias = q[{0,1}];"
qids = Qasm3Transformer.extract_values_from_discrete_set(value.index, statement)
for i, qid in enumerate(qids):
if isinstance(value.index, qasm3_ast.DiscreteSet):
target_qids = Qasm3Transformer.extract_values_from_discrete_set(
value.index, statement
)
for qid in target_qids:
Qasm3Validator.validate_register_index(
qid,
self._global_qreg_size_map[aliased_reg_name],
qubit=True,
op_node=statement,
)
self._alias_qubit_labels[(alias_reg_name, i)] = (aliased_reg_name, qid)
alias_reg_size = len(qids)
elif len(value.index) != 1: # like "let alias = q[0,1];"?
raise_qasm3_error(
"An index set can be specified by a single integer (signed or unsigned), "
"a comma-separated list of integers contained in braces {a,b,c,…}, "
"or a range",
error_node=statement,
span=statement.span,
)
elif isinstance(value.index[0], qasm3_ast.IntegerLiteral): # "let alias = q[0];"
qid = value.index[0].value
Qasm3Validator.validate_register_index(
qid, self._global_qreg_size_map[aliased_reg_name], qubit=True, op_node=statement
)
self._alias_qubit_labels[(alias_reg_name, 0)] = (
aliased_reg_name,
value.index[0].value,
)
alias_reg_size = 1
elif isinstance(value.index[0], qasm3_ast.RangeDefinition): # "let alias = q[0:1:2];"
qids = Qasm3Transformer.get_qubits_from_range_definition(
value.index[0],
aliased_reg_size,
is_qubit_reg=True,
alias_reg_size = len(target_qids)
else:
if len(value.index) != 1:
raise_qasm3_error(
"An index set can be specified by a single integer (signed or unsigned), "
"a comma-separated list of integers contained in braces {a,b,c,…}, "
"or a range",
error_node=statement,
span=statement.span,
)
target_qids, alias_reg_size = Qasm3Transformer.get_target_qubits(
value, {aliased_reg_name: aliased_reg_size}, aliased_reg_name
)
for i, qid in enumerate(qids):
self._alias_qubit_labels[(alias_reg_name, i)] = (aliased_reg_name, qid)
alias_reg_size = len(qids)
for i, qid in enumerate(target_qids):
self._alias_qubit_labels[(alias_reg_name, i)] = (aliased_reg_name, qid)

# we are updating as the alias can be redefined as well
alias_var = Variable(
Expand All @@ -2569,6 +2556,9 @@ def _visit_alias_statement(self, statement: qasm3_ast.AliasStatement) -> list[No
is_alias=True,
span=statement.span,
)
# Mark alias variables that reference qubits as qubit variables so they
# can be passed as quantum arguments to subroutines.
alias_var.is_qubit = True

if self._scope_manager.check_in_scope(alias_reg_name):
# means, the alias is present in current scope
Expand Down
34 changes: 34 additions & 0 deletions tests/qasm3/resources/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,40 @@
8, # Column number
"my_function(1)", # Complete line
),
"test_duplicate_qubit_args_singletons": (
"""
OPENQASM 3;
include "stdgates.inc";

def my_function(qubit a, qubit b) {
h b;
return;
}
qubit[2] q;
my_function(q[0], q[0]);
""",
r"Duplicate qubit argument for register 'q' in function call for 'my_function'",
10,
8,
"my_function(q[0], q[0])",
),
"test_duplicate_qubit_args_singletons_2": (
"""
OPENQASM 3;
include "stdgates.inc";

def my_function(qubit[2] p) {
h p;
return;
}
qubit[3] q;
my_function(q[{0, 0}]);
""",
r"Duplicate qubit argument for register 'q' in function call for 'my_function'",
10,
8,
"my_function(q[{0, 0}])",
),
"redefinition_raises_error": (
"""
OPENQASM 3;
Expand Down
30 changes: 30 additions & 0 deletions tests/qasm3/subroutines/test_subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,36 @@ def my_function(int[32] a, qubit q_arg) {
check_single_qubit_rotation_op(result.unrolled_ast, 3, [2, 1, 0], [2, 1, 0], "rx")


def test_alias_arg_from_loop_validates():
"""Alias of a dynamic indexed qubit used as function argument should validate."""
qasm_str = """
OPENQASM 3.0;
include "stdgates.inc";

qubit[4] q;

def dummy(qubit[1] q_arg) -> bool {
h q_arg;
return true;
}

for int i in [0:2]
{
let new_q = q[i];
dummy(new_q);
}
for int i in [0:2]
{
let new_q = q[i+1];
dummy(new_q);
}
"""

result = loads(qasm_str)
# Should not raise ValidationError
result.unroll()


def test_function_call_with_return():
"""Test that a function call with a return value is correctly parsed."""
qasm_str = """OPENQASM 3.0;
Expand Down
Loading