diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bdecd93..23678e8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,7 +50,7 @@ Types of changes: ``` - Previously, each gate inside an `if`/`else` block would advance only its own wire depth. Now, when any branching statement is encountered, all qubit‐ and clbit‐depths used inside that block are first incremented by one, then set to the maximum of those new values. This ensures the entire conditional block counts as single “depth” increment, rather than letting individual gates within the same branch float ahead independently. - In the above snippet, c[0], q[0], and q[1] all jump together to a single new depth for that branch. - +- Added initial support to explicit casting by converting the declarations into implicit casting logic. ([#205](https://github.com/qBraid/pyqasm/pull/205)) ### Dependencies ### Other diff --git a/src/README.md b/src/README.md index d781dab6..8950e0fe 100644 --- a/src/README.md +++ b/src/README.md @@ -26,6 +26,7 @@ Source code for OpenQASM 3 program validator and semantic analyzer | ForLoops | ✅ | Completed | | RangeDefinition | ✅ | Completed | | QuantumGate | ✅ | Completed | +| Cast | ✅ | Completed | | QuantumGateModifier (ctrl) | 📋 | Planned | | WhileLoop | 📋 | Planned | | IODeclaration | 📋 | Planned | diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index f89ad620..62926e13 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -18,8 +18,10 @@ """ from openqasm3.ast import ( BinaryExpression, + BitType, BooleanLiteral, BoolType, + Cast, DurationLiteral, Expression, FloatLiteral, @@ -40,6 +42,7 @@ ) from pyqasm.analyzer import Qasm3Analyzer +from pyqasm.elements import Variable from pyqasm.exceptions import ValidationError, raise_qasm3_error from pyqasm.maps.expressions import CONSTANTS_MAP, qasm3_expression_op_map from pyqasm.validator import Qasm3Validator @@ -205,6 +208,33 @@ def _process_variable(var_name: str, indices=None): Qasm3ExprEvaluator._check_var_initialized(var_name, var_value, expression) return _check_and_return_value(var_value) + def _check_type_size(expression, var_name, var_format, base_type): + base_size = 1 + if not isinstance(base_type, BoolType): + initial_size = 1 if isinstance(base_type, BitType) else 32 + try: + base_size = ( + initial_size + if not hasattr(base_type, "size") or base_type.size is None + else Qasm3ExprEvaluator.evaluate_expression( + base_type.size, const_expr=True + )[0] + ) + except ValidationError as err: + raise_qasm3_error( + f"Invalid base size for {var_format} '{var_name}'", + error_node=expression, + span=expression.span, + raised_from=err, + ) + if not isinstance(base_size, int) or base_size <= 0: + raise_qasm3_error( + f"Invalid base size '{base_size}' for {var_format} '{var_name}'", + error_node=expression, + span=expression.span, + ) + return base_size + if isinstance(expression, Identifier): var_name = expression.name if var_name in CONSTANTS_MAP: @@ -283,6 +313,13 @@ def _process_variable(var_name: str, indices=None): return _check_and_return_value(expression.value) if isinstance(expression, UnaryExpression): + if validate_only: + if isinstance(expression.expression, Cast): + return cls.evaluate_expression( + expression.expression, const_expr, reqd_type, validate_only + ) + return (None, []) + operand, returned_stats = cls.evaluate_expression( expression.expression, const_expr, reqd_type ) @@ -298,6 +335,19 @@ def _process_variable(var_name: str, indices=None): return _check_and_return_value(qasm3_expression_op_map(op_name, operand)) if isinstance(expression, BinaryExpression): + if validate_only: + if isinstance(expression.lhs, Cast) and isinstance(expression.rhs, Cast): + return (None, statements) + if isinstance(expression.lhs, Cast): + return cls.evaluate_expression( + expression.lhs, const_expr, reqd_type, validate_only + ) + if isinstance(expression.rhs, Cast): + return cls.evaluate_expression( + expression.rhs, const_expr, reqd_type, validate_only + ) + return (None, statements) + lhs_value, lhs_statements = cls.evaluate_expression( expression.lhs, const_expr, reqd_type ) @@ -317,6 +367,38 @@ def _process_variable(var_name: str, indices=None): statements.extend(ret_stmts) return _check_and_return_value(ret_value) + if isinstance(expression, Cast): + if validate_only: + return (expression.type, statements) + + var_name = "" + if isinstance(expression.argument, Identifier): + var_name = expression.argument.name + + var_value, cast_stmts = cls.evaluate_expression( + expression=expression.argument, const_expr=const_expr + ) + + var_format = "variable" + if var_name == "": + var_name = f"{var_value}" + var_format = "value" + + cast_type_size = _check_type_size(expression, var_name, var_format, expression.type) + variable = Variable( + name=var_name, + base_type=expression.type, + base_size=cast_type_size, + dims=[], + value=var_value, + is_constant=const_expr, + ) + cast_var_value = Qasm3Validator.validate_variable_assignment_value( + variable, var_value, expression + ) + statements.extend(cast_stmts) + return _check_and_return_value(cast_var_value) + raise_qasm3_error( f"Unsupported expression type {type(expression)}", err_type=ValidationError, diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 0fd6d1ec..be563d7d 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -462,6 +462,83 @@ def _get_op_bits( return openqasm_bits + def _check_variable_type_size( + self, statement: qasm3_ast.Statement, var_name: str, var_format: str, base_type: Any + ) -> int: + """Get the size of the given variable type. + + Args: + statement: current statement to get span. + var_name(str): variable name of the current operation. + base_type (Any): Base type of the variable. + is_const (bool): whether the statement is constant declaration or not. + Returns: + Int: size of the variable base type. + """ + base_size = 1 + if not isinstance(base_type, qasm3_ast.BoolType): + initial_size = 1 if isinstance(base_type, qasm3_ast.BitType) else 32 + try: + base_size = ( + initial_size + if not hasattr(base_type, "size") or base_type.size is None + else Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)[0] + ) + except ValidationError as err: + raise_qasm3_error( + f"Invalid base size for {var_format} '{var_name}'", + error_node=statement, + span=statement.span, + raised_from=err, + ) + if not isinstance(base_size, int) or base_size <= 0: + raise_qasm3_error( + f"Invalid base size '{base_size}' for {var_format} '{var_name}'", + error_node=statement, + span=statement.span, + ) + return base_size + + # pylint: disable-next=too-many-arguments + def _check_variable_cast_type( + self, + statement: qasm3_ast.Statement, + val_type: Any, + var_name: str, + base_type: Any, + base_size: Any, + is_const: bool, + ) -> None: + """Checks the declaration type and cast type of current variable. + + Args: + statement: current statement to get span. + val_type(Any): type of cast to apply on variable. + var_name(str): declaration variable name. + base_type (Any): Base type of the declaration variable. + base_size(Any): literal to get the base size of the declaration variable. + is_const (bool): whether the statement is constant declaration or not. + Returns: + None + """ + if not val_type: + val_type = base_type + + var_format = "variable" + if is_const: + var_format = "constant" + + val_type_size = self._check_variable_type_size(statement, var_name, var_format, val_type) + if not isinstance(val_type, type(base_type)) or val_type_size != base_size: + raise_qasm3_error( + f"Declaration type: " + f"'{(type(base_type).__name__).replace('Type', '')}[{base_size}]' and " + f"Cast type: '{(type(val_type).__name__).replace('Type', '')}[{val_type_size}]'," + f" should be same for '{var_name}'", + error_node=statement, + span=statement.span, + ) + def _visit_measurement( # pylint: disable=too-many-locals self, statement: qasm3_ast.QuantumMeasurementStatement ) -> list[qasm3_ast.QuantumMeasurementStatement]: @@ -1289,28 +1366,11 @@ def _visit_constant_declaration( statements.extend(stmts) base_type = statement.type - if isinstance(base_type, qasm3_ast.BoolType): - base_size = 1 - elif hasattr(base_type, "size"): - if base_type.size is None: - base_size = 32 # default for now - else: - try: - base_size = Qasm3ExprEvaluator.evaluate_expression( - base_type.size, const_expr=True - )[0] - if not isinstance(base_size, int) or base_size <= 0: - raise ValidationError( - f"Invalid base size {base_size} for variable '{var_name}'" - ) - except ValidationError as err: - raise_qasm3_error( - f"Invalid base size for constant '{var_name}'", - error_node=statement, - span=statement.span, - raised_from=err, - ) - + base_size = self._check_variable_type_size(statement, var_name, "constant", base_type) + val_type, _ = Qasm3ExprEvaluator.evaluate_expression( + statement.init_expression, validate_only=True + ) + self._check_variable_cast_type(statement, val_type, var_name, base_type, base_size, True) variable = Variable(var_name, base_type, base_size, [], init_value, is_constant=True) # cast + validation @@ -1368,22 +1428,7 @@ def _visit_classical_declaration( dimensions = base_type.dimensions base_type = base_type.base_type - base_size = 1 - if not isinstance(base_type, qasm3_ast.BoolType): - initial_size = 1 if isinstance(base_type, qasm3_ast.BitType) else 32 - try: - base_size = ( - initial_size - if not hasattr(base_type, "size") or base_type.size is None - else Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)[0] - ) - except ValidationError as err: - raise_qasm3_error( - f"Invalid base size for variable '{var_name}'", - error_node=statement, - span=statement.span, - raised_from=err, - ) + base_size = self._check_variable_type_size(statement, var_name, "variable", base_type) Qasm3Validator.validate_classical_type(base_type, base_size, var_name, statement) # initialize the bit register @@ -1436,6 +1481,12 @@ def _visit_classical_declaration( statement.init_expression ) statements.extend(stmts) + val_type, _ = Qasm3ExprEvaluator.evaluate_expression( + statement.init_expression, validate_only=True + ) + self._check_variable_cast_type( + statement, val_type, var_name, base_type, base_size, False + ) except ValidationError as err: raise_qasm3_error( f"Invalid initialization value for variable '{var_name}'", @@ -1553,7 +1604,15 @@ def _visit_classical_assignment( rvalue ) # consists of scope check and index validation statements.extend(rhs_stmts) - + val_type, _ = Qasm3ExprEvaluator.evaluate_expression(rvalue, validate_only=True) + self._check_variable_cast_type( + statement, + val_type, + lvar_name, + lvar.base_type, # type: ignore[union-attr] + lvar.base_size, # type: ignore[union-attr] + False, + ) # cast + validation rvalue_eval = None if not isinstance(rvalue_raw, np.ndarray): diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index b600f6a2..b4424019 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -20,7 +20,12 @@ from pyqasm.entrypoint import loads from pyqasm.exceptions import ValidationError -from tests.qasm3.resources.variables import ASSIGNMENT_TESTS, DECLARATION_TESTS +from tests.qasm3.resources.variables import ( + ASSIGNMENT_TESTS, + CASTING_TESTS, + DECLARATION_TESTS, + FAIL_CASTING_TESTS, +) from tests.utils import check_single_qubit_rotation_op @@ -389,3 +394,23 @@ def test_incorrect_assignments(test_name, caplog): assert f"Error at line {line_num}, column {col_num}" in caplog.text assert err_line in caplog.text + + +@pytest.mark.parametrize("test_name", CASTING_TESTS.keys()) +def test_explicit_casting(test_name): + qasm_input = CASTING_TESTS[test_name] + loads(qasm_input).validate() + + +@pytest.mark.parametrize("test_name", FAIL_CASTING_TESTS.keys()) +def test_incorrect_casting(test_name, caplog): + qasm_input, error_message, line_num, col_num, err_line = FAIL_CASTING_TESTS[test_name] + with pytest.raises(ValidationError) as excinfo: + loads(qasm_input).validate() + + first = excinfo.value.__cause__ or excinfo.value.__context__ + assert first is not None, "Expected a chained ValidationError" + msg = str(first) + assert error_message in msg + assert f"Error at line {line_num}, column {col_num}" in caplog.text + assert err_line in caplog.text diff --git a/tests/qasm3/resources/variables.py b/tests/qasm3/resources/variables.py index 2f1c7494..263ead09 100644 --- a/tests/qasm3/resources/variables.py +++ b/tests/qasm3/resources/variables.py @@ -94,7 +94,7 @@ include "stdgates.inc"; int[32.1] x; """, - "Invalid base size 32.1 for variable 'x'", + "Invalid base size '32.1' for variable 'x'", 4, 8, "int[32.1] x;", @@ -105,7 +105,7 @@ include "stdgates.inc"; const int[32.1] x = 3; """, - "Invalid base size for constant 'x'", + "Invalid base size '32.1' for constant 'x'", 4, 8, "const int[32.1] x = 3;", @@ -361,3 +361,165 @@ "x[3] = 3;", ), } + +CASTING_TESTS = { + "General_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + const float[64] f1 = 2.5; + uint[8] runtime_u = 7; + int[32] i2 = 2 * int[32](float[64](int[16](f1))); + const int[8] i1 = int[8](f1); + const uint u1 = 2 * uint(f1); + int ccf1 = float(runtime_u) * int(f1); + uint ul1 = uint(float[64](int[16](f1))) * 2; + const int un = -int(u1); + """ + ), + "Bool_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + + bool b_false = false; + bool b_true = true; + + int i1 = int(b_false); + uint[16] u1 = uint[16](b_true); + float[32] f0 = float[32](b_false); + + bit b; + b = b_true; + + bit[4] bits_from_true = bit[4](b_true); + + bool b_nested = bool(float[32](uint[8](int[8](bit[8](bool(true)))))); + """ + ), + "Int_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + + int[4] x = -3; + bool b = bool(x); + uint[8] ux = uint[8](x); + float[32] f = float[32](x); + bit[4] bits = bit[4](x); + """ + ), + "Unsigned_Int_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + + uint[8] x = 3; + bool b = bool(x); + int[8] i = int[8](x); + float[32] f = float[32](x); + bit[4] bits = bit[4](x); + """ + ), + "Float_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + + const float[64] two_pi = 6.283185307179586; + float[64] f = two_pi * (127. / 512.); + bool b = bool(f); + int i = int(f); + uint u = uint(f); + // angle[8] a = angle[8](f); + """ + ), + "Bit_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + + int v = 15; + bit[4] x = v; + bool b = bool(x); + int[32] i = int[32](x); + uint[32] u = uint[32](x); + // angle[4] a = angle[4](x); + """ + ), +} + +FAIL_CASTING_TESTS = { + "Float_to_Bit_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + const float[64] f1 = 2.5; + const bit[2] b1 = bit[2](f1); + """, + "Cannot cast to . Invalid assignment " + "of type to variable f1 of type ", + 5, + 8, + "const bit[2] b1 = bit[2](f1);", + ), + "Const_to_non-Const_test": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + uint[8] runtime_u = 7; + const int[16] i2 = int[16](runtime_u); + """, + "Expected variable 'runtime_u' to be constant in given expression", + 5, + 35, + "const int[16] i2 = int[16](runtime_u);", + ), + "Declaration_vs_Cast": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + int v = 15; + int[32] i = uint[32](v); + """, + "Declaration type: 'Int[32]' and Cast type: 'Uint[32]', should be same for 'i'", + 5, + 8, + "int[32] i = uint[32](v);", + ), + "Incorrect_base_size_for_cast_variable": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + const float[64] f1 = 2.5; + const int[32] i1 = int[32.5](f1); + """, + "Invalid base size '32.5' for variable 'f1'", + 5, + 27, + "int[32.5](f1);", + ), + "Unsupported_expression": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + + duration d1 = 1ns; + """, + "Unsupported expression type ''", + 5, + 22, + "1.0ns", + ), + "Incorrect_base_size_for_direct_value_in_cast": ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + const uint[32] iu = uint[12.2](24); + """, + "Invalid base size '12.2' for value '24'", + 4, + 28, + "uint[12.2](24);", + ), +}