Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Types of changes:
```
- Previously, each gate inside an `if`/`else` block would advance only its own wire depth. Now, when any branching statement is encountered, all qubit‐ and clbit‐depths used inside that block are first incremented by one, then set to the maximum of those new values. This ensures the entire conditional block counts as single “depth” increment, rather than letting individual gates within the same branch float ahead independently.
- In the above snippet, c[0], q[0], and q[1] all jump together to a single new depth for that branch.

- Added initial support to explicit casting by converting the declarations into implicit casting logic. ([#205](https://github.com/qBraid/pyqasm/pull/205))
### Dependencies

### Other
Expand Down
1 change: 1 addition & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Source code for OpenQASM 3 program validator and semantic analyzer
| ForLoops | ✅ | Completed |
| RangeDefinition | ✅ | Completed |
| QuantumGate | ✅ | Completed |
| Cast | ✅ | Completed |
| QuantumGateModifier (ctrl) | 📋 | Planned |
| WhileLoop | 📋 | Planned |
| IODeclaration | 📋 | Planned |
Expand Down
82 changes: 82 additions & 0 deletions src/pyqasm/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
"""
from openqasm3.ast import (
BinaryExpression,
BitType,
BooleanLiteral,
BoolType,
Cast,
DurationLiteral,
Expression,
FloatLiteral,
Expand All @@ -40,6 +42,7 @@
)

from pyqasm.analyzer import Qasm3Analyzer
from pyqasm.elements import Variable
from pyqasm.exceptions import ValidationError, raise_qasm3_error
from pyqasm.maps.expressions import CONSTANTS_MAP, qasm3_expression_op_map
from pyqasm.validator import Qasm3Validator
Expand Down Expand Up @@ -205,6 +208,33 @@ def _process_variable(var_name: str, indices=None):
Qasm3ExprEvaluator._check_var_initialized(var_name, var_value, expression)
return _check_and_return_value(var_value)

def _check_type_size(expression, var_name, var_format, base_type):
base_size = 1
if not isinstance(base_type, BoolType):
initial_size = 1 if isinstance(base_type, BitType) else 32
try:
base_size = (
initial_size
if not hasattr(base_type, "size") or base_type.size is None
else Qasm3ExprEvaluator.evaluate_expression(
base_type.size, const_expr=True
)[0]
)
except ValidationError as err:
raise_qasm3_error(
f"Invalid base size for {var_format} '{var_name}'",
error_node=expression,
span=expression.span,
raised_from=err,
)
if not isinstance(base_size, int) or base_size <= 0:
raise_qasm3_error(
f"Invalid base size '{base_size}' for {var_format} '{var_name}'",
error_node=expression,
span=expression.span,
)
return base_size

if isinstance(expression, Identifier):
var_name = expression.name
if var_name in CONSTANTS_MAP:
Expand Down Expand Up @@ -283,6 +313,13 @@ def _process_variable(var_name: str, indices=None):
return _check_and_return_value(expression.value)

if isinstance(expression, UnaryExpression):
if validate_only:
if isinstance(expression.expression, Cast):
return cls.evaluate_expression(
expression.expression, const_expr, reqd_type, validate_only
)
return (None, [])

operand, returned_stats = cls.evaluate_expression(
expression.expression, const_expr, reqd_type
)
Expand All @@ -298,6 +335,19 @@ def _process_variable(var_name: str, indices=None):
return _check_and_return_value(qasm3_expression_op_map(op_name, operand))

if isinstance(expression, BinaryExpression):
if validate_only:
if isinstance(expression.lhs, Cast) and isinstance(expression.rhs, Cast):
return (None, statements)
if isinstance(expression.lhs, Cast):
return cls.evaluate_expression(
expression.lhs, const_expr, reqd_type, validate_only
)
if isinstance(expression.rhs, Cast):
return cls.evaluate_expression(
expression.rhs, const_expr, reqd_type, validate_only
)
return (None, statements)

lhs_value, lhs_statements = cls.evaluate_expression(
expression.lhs, const_expr, reqd_type
)
Expand All @@ -317,6 +367,38 @@ def _process_variable(var_name: str, indices=None):
statements.extend(ret_stmts)
return _check_and_return_value(ret_value)

if isinstance(expression, Cast):
if validate_only:
return (expression.type, statements)

var_name = ""
if isinstance(expression.argument, Identifier):
var_name = expression.argument.name

var_value, cast_stmts = cls.evaluate_expression(
expression=expression.argument, const_expr=const_expr
)

var_format = "variable"
if var_name == "":
var_name = f"{var_value}"
var_format = "value"

cast_type_size = _check_type_size(expression, var_name, var_format, expression.type)
variable = Variable(
name=var_name,
base_type=expression.type,
base_size=cast_type_size,
dims=[],
value=var_value,
is_constant=const_expr,
)
cast_var_value = Qasm3Validator.validate_variable_assignment_value(
variable, var_value, expression
)
statements.extend(cast_stmts)
return _check_and_return_value(cast_var_value)

raise_qasm3_error(
f"Unsupported expression type {type(expression)}",
err_type=ValidationError,
Expand Down
137 changes: 98 additions & 39 deletions src/pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,83 @@ def _get_op_bits(

return openqasm_bits

def _check_variable_type_size(
self, statement: qasm3_ast.Statement, var_name: str, var_format: str, base_type: Any
) -> int:
"""Get the size of the given variable type.

Args:
statement: current statement to get span.
var_name(str): variable name of the current operation.
base_type (Any): Base type of the variable.
is_const (bool): whether the statement is constant declaration or not.
Returns:
Int: size of the variable base type.
"""
base_size = 1
if not isinstance(base_type, qasm3_ast.BoolType):
initial_size = 1 if isinstance(base_type, qasm3_ast.BitType) else 32
try:
base_size = (
initial_size
if not hasattr(base_type, "size") or base_type.size is None
else Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)[0]
)
except ValidationError as err:
raise_qasm3_error(
f"Invalid base size for {var_format} '{var_name}'",
error_node=statement,
span=statement.span,
raised_from=err,
)
if not isinstance(base_size, int) or base_size <= 0:
raise_qasm3_error(
f"Invalid base size '{base_size}' for {var_format} '{var_name}'",
error_node=statement,
span=statement.span,
)
return base_size

# pylint: disable-next=too-many-arguments
def _check_variable_cast_type(
self,
statement: qasm3_ast.Statement,
val_type: Any,
var_name: str,
base_type: Any,
base_size: Any,
is_const: bool,
) -> None:
"""Checks the declaration type and cast type of current variable.

Args:
statement: current statement to get span.
val_type(Any): type of cast to apply on variable.
var_name(str): declaration variable name.
base_type (Any): Base type of the declaration variable.
base_size(Any): literal to get the base size of the declaration variable.
is_const (bool): whether the statement is constant declaration or not.
Returns:
None
"""
if not val_type:
val_type = base_type

var_format = "variable"
if is_const:
var_format = "constant"

val_type_size = self._check_variable_type_size(statement, var_name, var_format, val_type)
if not isinstance(val_type, type(base_type)) or val_type_size != base_size:
raise_qasm3_error(
f"Declaration type: "
f"'{(type(base_type).__name__).replace('Type', '')}[{base_size}]' and "
f"Cast type: '{(type(val_type).__name__).replace('Type', '')}[{val_type_size}]',"
f" should be same for '{var_name}'",
error_node=statement,
span=statement.span,
)

def _visit_measurement( # pylint: disable=too-many-locals
self, statement: qasm3_ast.QuantumMeasurementStatement
) -> list[qasm3_ast.QuantumMeasurementStatement]:
Expand Down Expand Up @@ -1289,28 +1366,11 @@ def _visit_constant_declaration(
statements.extend(stmts)

base_type = statement.type
if isinstance(base_type, qasm3_ast.BoolType):
base_size = 1
elif hasattr(base_type, "size"):
if base_type.size is None:
base_size = 32 # default for now
else:
try:
base_size = Qasm3ExprEvaluator.evaluate_expression(
base_type.size, const_expr=True
)[0]
if not isinstance(base_size, int) or base_size <= 0:
raise ValidationError(
f"Invalid base size {base_size} for variable '{var_name}'"
)
except ValidationError as err:
raise_qasm3_error(
f"Invalid base size for constant '{var_name}'",
error_node=statement,
span=statement.span,
raised_from=err,
)

base_size = self._check_variable_type_size(statement, var_name, "constant", base_type)
val_type, _ = Qasm3ExprEvaluator.evaluate_expression(
statement.init_expression, validate_only=True
)
self._check_variable_cast_type(statement, val_type, var_name, base_type, base_size, True)
variable = Variable(var_name, base_type, base_size, [], init_value, is_constant=True)

# cast + validation
Expand Down Expand Up @@ -1368,22 +1428,7 @@ def _visit_classical_declaration(
dimensions = base_type.dimensions
base_type = base_type.base_type

base_size = 1
if not isinstance(base_type, qasm3_ast.BoolType):
initial_size = 1 if isinstance(base_type, qasm3_ast.BitType) else 32
try:
base_size = (
initial_size
if not hasattr(base_type, "size") or base_type.size is None
else Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)[0]
)
except ValidationError as err:
raise_qasm3_error(
f"Invalid base size for variable '{var_name}'",
error_node=statement,
span=statement.span,
raised_from=err,
)
base_size = self._check_variable_type_size(statement, var_name, "variable", base_type)
Qasm3Validator.validate_classical_type(base_type, base_size, var_name, statement)

# initialize the bit register
Expand Down Expand Up @@ -1436,6 +1481,12 @@ def _visit_classical_declaration(
statement.init_expression
)
statements.extend(stmts)
val_type, _ = Qasm3ExprEvaluator.evaluate_expression(
statement.init_expression, validate_only=True
)
self._check_variable_cast_type(
statement, val_type, var_name, base_type, base_size, False
)
except ValidationError as err:
raise_qasm3_error(
f"Invalid initialization value for variable '{var_name}'",
Expand Down Expand Up @@ -1553,7 +1604,15 @@ def _visit_classical_assignment(
rvalue
) # consists of scope check and index validation
statements.extend(rhs_stmts)

val_type, _ = Qasm3ExprEvaluator.evaluate_expression(rvalue, validate_only=True)
self._check_variable_cast_type(
statement,
val_type,
lvar_name,
lvar.base_type, # type: ignore[union-attr]
lvar.base_size, # type: ignore[union-attr]
False,
)
# cast + validation
rvalue_eval = None
if not isinstance(rvalue_raw, np.ndarray):
Expand Down
27 changes: 26 additions & 1 deletion tests/qasm3/declarations/test_classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

from pyqasm.entrypoint import loads
from pyqasm.exceptions import ValidationError
from tests.qasm3.resources.variables import ASSIGNMENT_TESTS, DECLARATION_TESTS
from tests.qasm3.resources.variables import (
ASSIGNMENT_TESTS,
CASTING_TESTS,
DECLARATION_TESTS,
FAIL_CASTING_TESTS,
)
from tests.utils import check_single_qubit_rotation_op


Expand Down Expand Up @@ -389,3 +394,23 @@ def test_incorrect_assignments(test_name, caplog):

assert f"Error at line {line_num}, column {col_num}" in caplog.text
assert err_line in caplog.text


@pytest.mark.parametrize("test_name", CASTING_TESTS.keys())
def test_explicit_casting(test_name):
qasm_input = CASTING_TESTS[test_name]
loads(qasm_input).validate()


@pytest.mark.parametrize("test_name", FAIL_CASTING_TESTS.keys())
def test_incorrect_casting(test_name, caplog):
qasm_input, error_message, line_num, col_num, err_line = FAIL_CASTING_TESTS[test_name]
with pytest.raises(ValidationError) as excinfo:
loads(qasm_input).validate()

first = excinfo.value.__cause__ or excinfo.value.__context__
assert first is not None, "Expected a chained ValidationError"
msg = str(first)
assert error_message in msg
assert f"Error at line {line_num}, column {col_num}" in caplog.text
assert err_line in caplog.text
Loading