From 61344c507ee7ae9448b5d3ca8506eed459594a13 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Fri, 1 Aug 2025 18:29:43 -0500 Subject: [PATCH 01/12] fix the `.gitattributes` file to exclude binary files from text processing --- .gitattributes | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.gitattributes b/.gitattributes index 4c78e9ae..d5602369 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,7 @@ # Ref : https://docs.github.com/en/get-started/getting-started-with-git/configuring-git-to-handle-line-endings#per-repository-settings -# Set line endings to lf -* text eol=lf \ No newline at end of file +# Set line endings to lf for text files +* text eol=lf + +# Binary files should not be processed as text +*.png binary \ No newline at end of file From f39652eb283a62557c34af9937bd398045461c3c Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Sat, 2 Aug 2025 19:38:57 -0500 Subject: [PATCH 02/12] Add angle type support in QASM --- src/pyqasm/elements.py | 2 + src/pyqasm/entrypoint.py | 3 ++ src/pyqasm/expressions.py | 43 ++++++++++++---- src/pyqasm/maps/expressions.py | 6 ++- src/pyqasm/modules/base.py | 1 + src/pyqasm/pulse/validator.py | 50 ++++++++++++++++++ src/pyqasm/visitor.py | 51 +++++++++++++++++-- tests/qasm3/declarations/test_classical.py | 59 ++++++++++++++++++++++ tests/qasm3/resources/variables.py | 14 +---- 9 files changed, 203 insertions(+), 26 deletions(-) diff --git a/src/pyqasm/elements.py b/src/pyqasm/elements.py index 736007b1..ee0a3e10 100644 --- a/src/pyqasm/elements.py +++ b/src/pyqasm/elements.py @@ -92,6 +92,7 @@ class Variable: # pylint: disable=too-many-instance-attributes value (Optional[int | float | np.ndarray]): Value of the variable. time_unit (Optional[str]): Time unit associated with the duration variable. span (Any): Span of the variable. + angle_bit_string (Optional[str]): Bit string representation of the angle value. shadow (bool): Flag indicating if the current variable is shadowed from its parent scope. is_constant (bool): Flag indicating if the variable is constant. is_register (bool): Flag indicating if the variable is a register. @@ -106,6 +107,7 @@ class Variable: # pylint: disable=too-many-instance-attributes value: Optional[int | float | np.ndarray] = None time_unit: Optional[str] = None span: Any = None + angle_bit_string: Optional[str] = None shadow: bool = False is_constant: bool = False is_qubit: bool = False diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index 2fe06fd5..59435916 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -59,6 +59,7 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: **kwargs: Additional arguments to pass to the loads function. device_qubits (int): Number of physical qubits available on the target device. device_cycle_time (float): The duration of a hardware device cycle, in seconds. + compiler_angle_type_size (int): The width of the angle type in the compiler. Raises: TypeError: If the input is not a string or an `openqasm3.ast.Program` instance. @@ -91,6 +92,8 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: module._device_qubits = dev_qbts if dev_cycle_time := kwargs.get("device_cycle_time"): module._device_cycle_time = dev_cycle_time + if compiler_angle_type_size := kwargs.get("compiler_angle_type_size"): + module._compiler_angle_type_size = compiler_angle_type_size return module diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index c2b12a36..337b71ec 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -17,7 +17,9 @@ """ from openqasm3.ast import ( + AngleType, BinaryExpression, + BitstringLiteral, BitType, BooleanLiteral, BoolType, @@ -53,6 +55,7 @@ class Qasm3ExprEvaluator: """Class for evaluating QASM3 expressions.""" visitor_obj = None + angle_vars_in_expr: list[Variable] = [] @classmethod def set_visitor_obj(cls, visitor_obj) -> None: @@ -70,8 +73,8 @@ def _check_var_in_scope(cls, var_name, expression): """ scope_manager = cls.visitor_obj._scope_manager + var = scope_manager.get_from_global_scope(var_name) if not scope_manager.check_in_scope(var_name): - var = scope_manager.get_from_global_scope(var_name) if var is not None and not var.is_constant: raise_qasm3_error( f"Global variable '{var_name}' must be a constant to use it in a local scope.", @@ -84,6 +87,17 @@ def _check_var_in_scope(cls, var_name, expression): error_node=expression, span=expression.span, ) + if var and isinstance(var.base_type, AngleType) and var not in cls.angle_vars_in_expr: + base_size = var.base_type.size + if len(cls.angle_vars_in_expr) > 0: + for var in cls.angle_vars_in_expr: + if var.base_type.size != base_size: + raise_qasm3_error( + "All 'Angle' variables in binary expression must have the same size", + error_node=expression, + span=expression.span, + ) + cls.angle_vars_in_expr.append(var) @classmethod def _check_var_constant(cls, var_name, const_expr, expression): @@ -254,7 +268,7 @@ def _check_type_size(expression, var_name, var_format, base_type): if isinstance(expression, Identifier): var_name = expression.name if var_name in CONSTANTS_MAP: - if not reqd_type or reqd_type == Qasm3FloatType: + if not reqd_type or reqd_type in (Qasm3FloatType, AngleType): return _check_and_return_value(CONSTANTS_MAP[var_name]) raise_qasm3_error( f"Constant '{var_name}' not allowed in non-float expression", @@ -310,7 +324,7 @@ def _check_type_size(expression, var_name, var_format, base_type): ) return _check_and_return_value(dimensions[index]) - if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral)): + if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral, BitstringLiteral)): if reqd_type: if reqd_type == BoolType and isinstance(expression, BooleanLiteral): return _check_and_return_value(expression.value) @@ -318,6 +332,8 @@ def _check_type_size(expression, var_name, var_format, base_type): return _check_and_return_value(expression.value) if reqd_type == Qasm3FloatType and isinstance(expression, FloatLiteral): return _check_and_return_value(expression.value) + if reqd_type == AngleType: + return _check_and_return_value(expression.value) raise_qasm3_error( f"Invalid value {expression.value} with type {type(expression)} " f"for required type {reqd_type}", @@ -365,14 +381,23 @@ def _check_type_size(expression, var_name, var_format, base_type): if validate_only: if isinstance(expression.lhs, Cast) and isinstance(expression.rhs, Cast): return (None, statements) + + _lhs, _lhs_stmts = cls.evaluate_expression( + expression.lhs, const_expr, reqd_type, validate_only + ) + _rhs, _rhs_stmts = cls.evaluate_expression( + expression.rhs, const_expr, reqd_type, validate_only + ) + if isinstance(expression.lhs, Cast): - return cls.evaluate_expression( - expression.lhs, const_expr, reqd_type, validate_only - ) + return (_lhs, _lhs_stmts) if isinstance(expression.rhs, Cast): - return cls.evaluate_expression( - expression.rhs, const_expr, reqd_type, validate_only - ) + return (_rhs, _rhs_stmts) + + if type(reqd_type) is type(AngleType) and len(cls.angle_vars_in_expr) > 0: + var_type = cls.angle_vars_in_expr[0].base_type + cls.angle_vars_in_expr.clear() + return (var_type, statements) return (None, statements) lhs_value, lhs_statements = cls.evaluate_expression( diff --git a/src/pyqasm/maps/expressions.py b/src/pyqasm/maps/expressions.py index aac09ca5..77bbb2f4 100644 --- a/src/pyqasm/maps/expressions.py +++ b/src/pyqasm/maps/expressions.py @@ -125,6 +125,8 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): if openqasm_type == BitType: return rhs_value != 0 if openqasm_type == AngleType: + if isinstance(rhs_value, bool): + return ((2 * CONSTANTS_MAP["pi"]) * (1 / 2)) if rhs_value else 0.0 return rhs_value # not sure @@ -150,7 +152,7 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): ComplexType: complex, DurationType: float, StretchType: float, - # AngleType: None, # not sure + AngleType: float, } # Reference: https://openqasm.com/language/types.html#allowed-casts @@ -160,7 +162,7 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): BitType: (bool, int, np.int64, np.bool_), UintType: (bool, int, float, np.int64, np.uint64, np.float64, np.bool_), FloatType: (bool, int, float, np.int64, np.float64, np.bool_), - AngleType: (float, np.float64), + AngleType: (float, np.float64, bool, np.bool_), } ARRAY_TYPE_MAP = { diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index ee353b49..d2c30121 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -60,6 +60,7 @@ def __init__(self, name: str, program: Program): self._device_qubits: Optional[int] = None self._consolidate_qubits: Optional[bool] = False self._device_cycle_time: Optional[int] = None + self._compiler_angle_type_size: Optional[int] = None @property def name(self) -> str: diff --git a/src/pyqasm/pulse/validator.py b/src/pyqasm/pulse/validator.py index e1e82538..78d9e21f 100644 --- a/src/pyqasm/pulse/validator.py +++ b/src/pyqasm/pulse/validator.py @@ -20,6 +20,7 @@ from openqasm3.ast import ( BinaryExpression, + BitstringLiteral, Box, Cast, ConstantDeclaration, @@ -34,11 +35,60 @@ ) from pyqasm.exceptions import raise_qasm3_error +from pyqasm.maps.expressions import CONSTANTS_MAP class PulseValidator: """Class with validation functions for Pulse visitor""" + @staticmethod + def validate_angle_type_value( + statement: Any, + init_value: int | float, + base_size: int, + compiler_angle_width: Optional[int] = None, + ) -> tuple: + """ + Validates and processes angle type value. + + Args: + statement: The AST statement node + init_value: The evaluated initialization value + base_size: The base size of the angle type + compiler_angle_width: Optional compiler angle width override + + Returns: + tuple: The processed angle value and bit string representation + + Raises: + ValidationError: If the angle initialization is invalid + """ + # Optimize: check both possible fields for BitstringLiteral in one go + init_exp = getattr(statement, "init_expression", None) + rval = getattr(statement, "rvalue", None) + is_bitstring = isinstance(init_exp, BitstringLiteral) or isinstance(rval, BitstringLiteral) + expression = init_exp or rval + if is_bitstring and expression is not None: + angle_type_size = expression.width + if compiler_angle_width: + if angle_type_size != compiler_angle_width: + raise_qasm3_error( + f"BitString angle width '{angle_type_size}' does not match " + f"with compiler angle width '{compiler_angle_width}'", + error_node=statement, + span=statement.span, + ) + angle_type_size = compiler_angle_width + angle_bit_string = format(expression.value, f"0{angle_type_size}b") + angle_val = (2 * CONSTANTS_MAP["pi"]) * (expression.value / (2**angle_type_size)) + else: + angle_val = init_value % (2 * CONSTANTS_MAP["pi"]) + angle_type_size = compiler_angle_width or base_size + bit_string_value = int((2**angle_type_size) * (angle_val / (2 * CONSTANTS_MAP["pi"]))) + angle_bit_string = format(bit_string_value, f"0{angle_type_size}b") + + return angle_val, angle_bit_string + @staticmethod def validate_duration_or_stretch_statements( statement: Statement, diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index a1b374ca..905a048a 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -330,6 +330,11 @@ def _check_variable_type_size( if not hasattr(base_type, "size") or base_type.size is None else Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)[0] ) + if ( + isinstance(base_type, qasm3_ast.AngleType) + and self._module._compiler_angle_type_size + ): + base_size = self._module._compiler_angle_type_size except ValidationError as err: raise_qasm3_error( f"Invalid base size for {var_format} '{var_name}'", @@ -1337,12 +1342,27 @@ def _visit_constant_declaration( base_type = statement.type base_size = self._check_variable_type_size(statement, var_name, "constant", base_type) + angle_val_bit_string = None + if isinstance(base_type, qasm3_ast.AngleType): + init_value, angle_val_bit_string = PulseValidator.validate_angle_type_value( + statement, + init_value=init_value, + base_size=base_size, + compiler_angle_width=self._module._compiler_angle_type_size, + ) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( statement.init_expression, validate_only=True ) self._check_variable_cast_type(statement, val_type, var_name, base_type, base_size, True) variable = Variable( - var_name, base_type, base_size, [], init_value, is_constant=True, span=statement.span + var_name, + base_type, + len(angle_val_bit_string) if angle_val_bit_string else base_size, + [], + init_value, + is_constant=True, + span=statement.span, + angle_bit_string=angle_val_bit_string, ) if isinstance(base_type, (qasm3_ast.DurationType, qasm3_ast.StretchType)): @@ -1405,6 +1425,7 @@ def _visit_classical_declaration( base_type = statement.type dimensions = [] final_dimensions = [] + angle_val_bit_string = None if isinstance(base_type, qasm3_ast.StretchType): if statement.init_expression: @@ -1483,9 +1504,21 @@ def _visit_classical_declaration( statement.init_expression, dt=self._module._device_cycle_time ) statements.extend(stmts) + _req_type = ( + type(qasm3_ast.AngleType()) + if isinstance(base_type, qasm3_ast.AngleType) + else None + ) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, validate_only=True + statement.init_expression, validate_only=True, reqd_type=_req_type ) + if isinstance(base_type, qasm3_ast.AngleType): + init_value, angle_val_bit_string = PulseValidator.validate_angle_type_value( + statement, + init_value=init_value, + base_size=base_size, + compiler_angle_width=self._module._compiler_angle_type_size, + ) self._check_variable_cast_type( statement, val_type, var_name, base_type, base_size, False ) @@ -1500,11 +1533,12 @@ def _visit_classical_declaration( variable = Variable( var_name, base_type, - base_size, + len(angle_val_bit_string) if angle_val_bit_string else base_size, final_dimensions, init_value, is_qubit=False, span=statement.span, + angle_bit_string=angle_val_bit_string, ) if isinstance(base_type, qasm3_ast.DurationType): @@ -1637,6 +1671,17 @@ def _visit_classical_assignment( lvar.base_size, # type: ignore[union-attr] False, ) + angle_val_bit_string = None + if isinstance(lvar_base_type, qasm3_ast.AngleType): + rvalue_raw, angle_val_bit_string = PulseValidator.validate_angle_type_value( + statement, + init_value=rvalue_raw, + base_size=lvar.base_size, # type: ignore[union-attr] + compiler_angle_width=self._module._compiler_angle_type_size, + ) + lvar.angle_bit_string = angle_val_bit_string # type: ignore[union-attr] + if angle_val_bit_string: + lvar.base_size = len(angle_val_bit_string) # type: ignore[union-attr] # cast + validation rvalue_eval = None if not isinstance(rvalue_raw, np.ndarray): diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index 72b33f93..b01acd10 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -45,6 +45,7 @@ def test_scalar_declarations(): bool i; duration j; stretch st; + angle[8] ang1; """ loads(qasm3_string).validate() @@ -75,6 +76,9 @@ def test_const_declarations(): const duration t8 = t2/t3; const stretch st = 300ns; const stretch st2 = t2/t3; + const angle[8] ang1 = 7 * (pi / 8); + const angle[8] ang2 = 9 * (pi / 8); + const angle[8] ang3 = ang1 + ang2; """ loads(qasm3_string).validate() @@ -98,6 +102,8 @@ def test_scalar_assignments(): duration du = 200us; duration du2; du2 = 300s; + angle[8] ang1; + ang1 = 9 * (pi / 8); """ loads(qasm3_string).validate() @@ -123,6 +129,9 @@ def test_scalar_value_assignment(): duration t9 = t2 - t7; duration t10 = t2 + t7; duration t11 = t2 * t7; + angle[8] ang1 = 7 * (pi / 8); + angle[8] ang2 = 9 * (pi / 8); + angle[8] ang3 = ang1 + ang2; """ b = 5.0 @@ -603,6 +612,43 @@ def test_duration_casting_error(qasm_code, error_message, error_span, caplog): assert error_span in caplog.text +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + angle[8] ang1 = 7 * (pi / 8); + angle[7] ang2 = 9 * (pi / 8); + angle[8] ang3 = ang1 + ang2; + """, + r"All 'Angle' variables in binary expression must have the same size", + r"Error at line 6, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + angle ang1 = "1000111111"; + """, + r"BitString angle width '10' does not match with compiler angle width '8'", + r"Error at line 4, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_angle_type_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + loads(qasm_code, compiler_angle_type_size=8).validate() + + first = excinfo.value.__cause__ or excinfo.value.__context__ + assert first is not None, "Expected a chained ValidationError" + msg = str(first) + assert error_message in msg + assert error_span in caplog.text + + def test_device_time_duration_(): """Test device cycle time duration""" qasm3_string = """ @@ -613,3 +659,16 @@ def test_device_time_duration_(): const duration t3 =300us; """ loads(qasm3_string, device_cycle_time=1e-9).validate() + + +def test_compiler_angle_type_size(): + """Test compiler angle type size""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + angle[8] ang1 = 7 * (pi / 8); + const angle[8] ang2 = 9 * (pi / 8); + angle[8] ang3; + ang3 = 100.0; + """ + loads(qasm3_string, compiler_angle_type_size=8).validate() diff --git a/tests/qasm3/resources/variables.py b/tests/qasm3/resources/variables.py index 2c1e233e..2ef514f2 100644 --- a/tests/qasm3/resources/variables.py +++ b/tests/qasm3/resources/variables.py @@ -146,18 +146,6 @@ 8, "float[23] x;", ), - "unsupported_types": ( - """ - OPENQASM 3.0; - include "stdgates.inc"; - - angle x = 3.4; - """, - "Invalid initialization value for variable 'x'", - 5, - 8, - "angle x = 3.4;", - ), "imaginary_variable": ( """ OPENQASM 3.0; @@ -375,6 +363,8 @@ int ccf1 = float(runtime_u) * int(f1); uint ul1 = uint(float[64](int[16](f1))) * 2; const int un = -int(u1); + angle ang1 = angle(true); + angle ang2 = angle(false); """ ), "Bool_test": ( From 21d1d733fa35e380c198d8aee5613ec4c407d9ab Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Sat, 2 Aug 2025 19:43:22 -0500 Subject: [PATCH 03/12] update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 976034fc..1ebfe326 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Types of changes: - A github workflow for validating `CHANGELOG` updates in a PR ([#214](https://github.com/qBraid/pyqasm/pull/214)) - Added `unroll` command support in PYQASM CLI with options skipping files, overwriting originals files, and specifying output paths.([#224](https://github.com/qBraid/pyqasm/pull/224)) - Added `.github/copilot-instructions.md` to the repository to document coding standards and design principles for pyqasm. This file provides detailed guidance on documentation, static typing, formatting, error handling, and adherence to the QASM specification for all code contributions. ([#234](https://github.com/qBraid/pyqasm/pull/234)) +- Added support for `Angle` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239)) ### Improved / Modified - Added `slots=True` parameter to the data classes in `elements.py` to improve memory efficiency ([#218](https://github.com/qBraid/pyqasm/pull/218)) From cb4462407f0e266ee1fdb5d75f3a72336cf524cd Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Sat, 2 Aug 2025 20:36:01 -0500 Subject: [PATCH 04/12] update test case --- tests/qasm3/declarations/test_classical.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index b01acd10..c2266812 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -668,7 +668,6 @@ def test_compiler_angle_type_size(): include "stdgates.inc"; angle[8] ang1 = 7 * (pi / 8); const angle[8] ang2 = 9 * (pi / 8); - angle[8] ang3; - ang3 = 100.0; + angle[4] ang3 = "1010"; """ - loads(qasm3_string, compiler_angle_type_size=8).validate() + loads(qasm3_string, compiler_angle_type_size=4).validate() From b28950b7a6a42675d4bc11430ff48b0fa06ff57a Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Sun, 3 Aug 2025 17:42:37 -0500 Subject: [PATCH 05/12] Add support for complex number operations in QASM --- src/pyqasm/expressions.py | 30 ++++++--- src/pyqasm/maps/expressions.py | 15 +++++ src/pyqasm/pulse/validator.py | 14 ++++ src/pyqasm/validator.py | 2 +- src/pyqasm/visitor.py | 49 ++++++++++++++ tests/qasm3/declarations/test_classical.py | 74 ++++++++++++++++++++++ tests/qasm3/test_expressions.py | 8 --- 7 files changed, 174 insertions(+), 18 deletions(-) diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index 337b71ec..9eb73c7d 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -47,7 +47,12 @@ from pyqasm.analyzer import Qasm3Analyzer from pyqasm.elements import Variable from pyqasm.exceptions import ValidationError, raise_qasm3_error -from pyqasm.maps.expressions import CONSTANTS_MAP, TIME_UNITS_MAP, qasm3_expression_op_map +from pyqasm.maps.expressions import ( + CONSTANTS_MAP, + FUNCTION_MAP, + TIME_UNITS_MAP, + qasm3_expression_op_map, +) from pyqasm.validator import Qasm3Validator @@ -217,14 +222,6 @@ def evaluate_expression( # type: ignore[return] if expression is None: return None, [] - if isinstance(expression, (ImaginaryLiteral)): - raise_qasm3_error( - f"Unsupported expression type '{type(expression)}'", - err_type=ValidationError, - error_node=expression, - span=expression.span, - ) - def _check_and_return_value(value): if validate_only: return None, statements @@ -265,6 +262,12 @@ def _check_type_size(expression, var_name, var_format, base_type): ) return base_size + if isinstance(expression, complex): + return _check_and_return_value(expression) + + if isinstance(expression, ImaginaryLiteral): + return _check_and_return_value(expression.value * 1j) + if isinstance(expression, Identifier): var_name = expression.name if var_name in CONSTANTS_MAP: @@ -415,6 +418,15 @@ def _check_type_size(expression, var_name, var_format, base_type): if isinstance(expression, FunctionCall): # function will not return a reqd / const type # Reference : https://openqasm.com/language/types.html#compile-time-constants, para: 5 + if expression.name.name in {"abs", "real", "imag", "sqrt", "sin", "cos", "tan"}: + _val, _ = cls.evaluate_expression( + expression.arguments[0], const_expr, reqd_type, validate_only + ) + if _val is None or validate_only: + return (None, statements) + _val = FUNCTION_MAP[expression.name.name](_val) # type: ignore + return _check_and_return_value(_val) + ret_value, ret_stmts = cls.visitor_obj._visit_function_call(expression) # type: ignore statements.extend(ret_stmts) return _check_and_return_value(ret_value) diff --git a/src/pyqasm/maps/expressions.py b/src/pyqasm/maps/expressions.py index 77bbb2f4..286b42b8 100644 --- a/src/pyqasm/maps/expressions.py +++ b/src/pyqasm/maps/expressions.py @@ -48,6 +48,7 @@ "/": lambda x, y: x / y, "%": lambda x, y: x % y, "==": lambda x, y: x == y, + "**": lambda x, y: x**y, "!=": lambda x, y: x != y, "<": lambda x, y: x < y, ">": lambda x, y: x > y, @@ -128,6 +129,8 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): if isinstance(rhs_value, bool): return ((2 * CONSTANTS_MAP["pi"]) * (1 / 2)) if rhs_value else 0.0 return rhs_value # not sure + if openqasm_type == ComplexType: + return rhs_value # IEEE 754 Standard for floats @@ -163,6 +166,7 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): UintType: (bool, int, float, np.int64, np.uint64, np.float64, np.bool_), FloatType: (bool, int, float, np.int64, np.float64, np.bool_), AngleType: (float, np.float64, bool, np.bool_), + ComplexType: (complex, np.complex128), } ARRAY_TYPE_MAP = { @@ -186,3 +190,14 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): "ms": {"ns": 1_000_000, "s": 1e-3}, "s": {"ns": 1_000_000_000, "s": 1}, } + +# Function map for complex functions +FUNCTION_MAP = { + "abs": np.abs, + "real": lambda v: v.real if isinstance(v, complex) else v, + "imag": lambda v: v.imag if isinstance(v, complex) else v, + "sqrt": np.sqrt, + "sin": np.sin, + "cos": np.cos, + "tan": np.tan, +} diff --git a/src/pyqasm/pulse/validator.py b/src/pyqasm/pulse/validator.py index 78d9e21f..6dbf42d0 100644 --- a/src/pyqasm/pulse/validator.py +++ b/src/pyqasm/pulse/validator.py @@ -18,8 +18,10 @@ """ from typing import Any, Optional +import numpy as np from openqasm3.ast import ( BinaryExpression, + BinaryOperator, BitstringLiteral, Box, Cast, @@ -29,6 +31,7 @@ DurationType, FloatLiteral, Identifier, + ImaginaryLiteral, IntegerLiteral, Statement, StretchType, @@ -270,3 +273,14 @@ def validate_duration_variable( # pylint: disable=too-many-branches error_node=statement, span=statement.span, ) + + @staticmethod + def make_complex_binary_expression(value: complex) -> BinaryExpression: + """ + Make a binary expression from a complex number. + """ + return BinaryExpression( + lhs=FloatLiteral(value.real), + op=(BinaryOperator["+"] if value.imag >= 0 else BinaryOperator["-"]), + rhs=ImaginaryLiteral(np.abs(value.imag)), + ) diff --git a/src/pyqasm/validator.py b/src/pyqasm/validator.py index 14c42e55..9a4cd82e 100644 --- a/src/pyqasm/validator.py +++ b/src/pyqasm/validator.py @@ -181,7 +181,7 @@ def validate_variable_assignment_value( error_node=op_node, span=op_node.span if op_node else None, ) - elif type_to_match == bool: + elif type_to_match in (bool, complex): pass else: raise_qasm3_error( diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 905a048a..9b416d83 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -430,6 +430,24 @@ def _qubit_register_consolidation( return _valid_statements + def _handle_function_init_expression( + self, expression: Any, init_value: Any + ) -> None | qasm3_ast.Expression: + """Handle function initialization expression. + + Args: + statement (Any): The statement to handle function initialization expression. + init_value (Any): The value to handle function initialization expression. + """ + if isinstance(expression, qasm3_ast.FunctionCall): + func_name = expression.name.name + if func_name in ["abs", "real", "imag", "sqrt", "sin", "cos", "tan"]: + if isinstance(init_value, complex): + return PulseValidator.make_complex_binary_expression(init_value) + if isinstance(init_value, (float, int)): + return qasm3_ast.FloatLiteral(init_value) + return None + def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches self, statement: qasm3_ast.QuantumMeasurementStatement ) -> list[qasm3_ast.QuantumMeasurementStatement]: @@ -1379,6 +1397,15 @@ def _visit_constant_declaration( self._scope_manager.add_var_in_scope(variable) + if isinstance(base_type, qasm3_ast.ComplexType) and isinstance(variable.value, complex): + statement.init_expression = PulseValidator.make_complex_binary_expression(init_value) + + if isinstance(statement.init_expression, qasm3_ast.FunctionCall): + statement.init_expression = ( + self._handle_function_init_expression(statement.init_expression, init_value) + or statement.init_expression + ) + if self._check_only: return [] @@ -1594,6 +1621,15 @@ def _visit_classical_declaration( statements.append(statement) self._module._add_classical_register(var_name, base_size) + if isinstance(base_type, qasm3_ast.ComplexType) and isinstance(variable.value, complex): + statement.init_expression = PulseValidator.make_complex_binary_expression(init_value) + + if isinstance(statement.init_expression, qasm3_ast.FunctionCall): + statement.init_expression = ( + self._handle_function_init_expression(statement.init_expression, init_value) + or statement.init_expression + ) + if self._check_only: return [] @@ -1737,6 +1773,19 @@ def _visit_classical_assignment( lvar.value = rvalue_eval # type: ignore[union-attr] self._scope_manager.update_var_in_scope(lvar) # type: ignore[arg-type] + if isinstance(lvar_base_type, qasm3_ast.ComplexType) and isinstance( + lvar.value, complex # type: ignore[union-attr] + ): + statement.rvalue = PulseValidator.make_complex_binary_expression( + lvar.value # type: ignore[union-attr] + ) + + if isinstance(statement.rvalue, qasm3_ast.FunctionCall): + statement.rvalue = ( + self._handle_function_init_expression(statement.rvalue, rvalue_eval) + or statement.rvalue + ) + if self._check_only: return [] diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index c2266812..6abbdfff 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -671,3 +671,77 @@ def test_compiler_angle_type_size(): angle[4] ang3 = "1010"; """ loads(qasm3_string, compiler_angle_type_size=4).validate() + + +def test_complex_type_variables(): + """Test complex type variables""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + complex c1 = -2.5 - 3.5im; + complex c2 = 3.5 + 2.5im; + complex c3 = 2.0 + c2; + complex c4 = 2.0+sin(π/2) + (3.1 * 5.5im); + complex c5 = c1 * c2; + complex c6 = c1 + c2; + complex c7 = c1 - c2; + complex c8 = c1 / c2; + complex c9 = c1 ** c2; + complex c10 = sqrt(c1); + float c11 = abs(c1 * c2); + float c12 = real(c1); + float c13 = imag(c1); + float c14 = sin(π/2); + const complex c15 = -2.5 - 3.5im; + const complex c16 = 3.5 + 2.5im; + const complex c17 = 2.0 + c16; + const complex c18 = 2.0+sin(π/2) + (3.1 * 5.5im); + const complex c19 = c15 * c16; + const complex c20 = c15 + c16; + const complex c21 = c15 - c16; + const complex c22 = c15 / c16; + const complex c23 = c15 ** c16; + const complex c24 = sqrt(c15); + const float c25 = abs(c15 * c16); + const float c26 = real(c15); + const float c27 = imag(c15); + const float c28 = sin(π/2); + complex c29; + c29 = -2.5 - 3.5im; + complex c30; + c30 = 3.5 + 2.5im; + complex c31; + c31 = 2.0 + c30; + complex c32; + c32 = 2.0+sin(π/2) + (3.1 * 5.5im); + complex c33; + c33 = c29 * c30; + complex c34; + c34 = c29 + c30; + complex c35; + c35 = c29 - c30; + complex c36; + c36 = c29 / c30; + complex c37; + c37 = c29 ** c30; + complex c38; + c38 = sqrt(c29); + float c39; + c39 = abs(c29 * c30); + float c40; + c40 = real(c29); + float c41; + c41 = imag(c29); + float c42; + c42 = sin(π/2); + complex[float[64]] a = 10.0 + 5.0im; + complex[float[64]] b = -2.0 - 7.0im; + complex[float[64]] c = a + b; + complex[float[64]] d = a - b; + complex[float[64]] e = a * b; + complex[float[64]] f = a / b; + complex[float[64]] g = a ** b; + complex[float] h = a + b; + """ + + loads(qasm3_string).validate() diff --git a/tests/qasm3/test_expressions.py b/tests/qasm3/test_expressions.py index 9f023c5e..b07b7a63 100644 --- a/tests/qasm3/test_expressions.py +++ b/tests/qasm3/test_expressions.py @@ -75,14 +75,6 @@ def test_bit_in_expression(): def test_incorrect_expressions(caplog): - with pytest.raises(ValidationError, match=r"Invalid parameter .*"): - with caplog.at_level("ERROR"): - loads("OPENQASM 3; qubit q; rz(1 - 2 + 32im) q;").validate() - assert "Error at line 1, column 32" in caplog.text - assert "32.0im" in caplog.text - - caplog.clear() - with pytest.raises(ValidationError, match=r"Invalid parameter .*"): with caplog.at_level("ERROR"): loads("OPENQASM 3; qubit q; rx(~1.3) q;").validate() From 148f1a3f2467ec695fd427270bd7638b9fbfc5e9 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Sun, 3 Aug 2025 17:46:07 -0500 Subject: [PATCH 06/12] update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ebfe326..588c508f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ Types of changes: - A github workflow for validating `CHANGELOG` updates in a PR ([#214](https://github.com/qBraid/pyqasm/pull/214)) - Added `unroll` command support in PYQASM CLI with options skipping files, overwriting originals files, and specifying output paths.([#224](https://github.com/qBraid/pyqasm/pull/224)) - Added `.github/copilot-instructions.md` to the repository to document coding standards and design principles for pyqasm. This file provides detailed guidance on documentation, static typing, formatting, error handling, and adherence to the QASM specification for all code contributions. ([#234](https://github.com/qBraid/pyqasm/pull/234)) -- Added support for `Angle` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239)) +- Added support for `Angle` and `Complex` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239)) ### Improved / Modified - Added `slots=True` parameter to the data classes in `elements.py` to improve memory efficiency ([#218](https://github.com/qBraid/pyqasm/pull/218)) From 44f5c82d9abb74099e6b24a35219cf1445b2c5a0 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Mon, 4 Aug 2025 04:41:29 -0500 Subject: [PATCH 07/12] code refactor - Changed `angle_vars_in_expr` to a single variable `angle_var_in_expr` for improved clarity and functionality. - Introduced new trigonometric functions (`arccos`, `arcsin`, `arctan`) to the FUNCTION_MAP. - Adjusted tests to validate new angle expressions and ensure correct bit string conversions. --- src/pyqasm/expressions.py | 31 +++-- src/pyqasm/maps/expressions.py | 3 + src/pyqasm/pulse/validator.py | 3 +- src/pyqasm/visitor.py | 3 +- tests/qasm3/declarations/test_classical.py | 128 +++++++++++++-------- 5 files changed, 104 insertions(+), 64 deletions(-) diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index 9eb73c7d..a099166b 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -60,7 +60,7 @@ class Qasm3ExprEvaluator: """Class for evaluating QASM3 expressions.""" visitor_obj = None - angle_vars_in_expr: list[Variable] = [] + angle_var_in_expr = None @classmethod def set_visitor_obj(cls, visitor_obj) -> None: @@ -92,17 +92,14 @@ def _check_var_in_scope(cls, var_name, expression): error_node=expression, span=expression.span, ) - if var and isinstance(var.base_type, AngleType) and var not in cls.angle_vars_in_expr: - base_size = var.base_type.size - if len(cls.angle_vars_in_expr) > 0: - for var in cls.angle_vars_in_expr: - if var.base_type.size != base_size: - raise_qasm3_error( - "All 'Angle' variables in binary expression must have the same size", - error_node=expression, - span=expression.span, - ) - cls.angle_vars_in_expr.append(var) + if var and isinstance(var.base_type, AngleType): + if cls.angle_var_in_expr and cls.angle_var_in_expr != var.base_type.size: + raise_qasm3_error( + "All 'Angle' variables in binary expression must have the same size", + error_node=expression, + span=expression.span, + ) + cls.angle_var_in_expr = var.base_type.size @classmethod def _check_var_constant(cls, var_name, const_expr, expression): @@ -397,10 +394,10 @@ def _check_type_size(expression, var_name, var_format, base_type): if isinstance(expression.rhs, Cast): return (_rhs, _rhs_stmts) - if type(reqd_type) is type(AngleType) and len(cls.angle_vars_in_expr) > 0: - var_type = cls.angle_vars_in_expr[0].base_type - cls.angle_vars_in_expr.clear() - return (var_type, statements) + if type(reqd_type) is type(AngleType) and cls.angle_var_in_expr: + _var_type = AngleType(cls.angle_var_in_expr) + cls.angle_var_in_expr = None + return (_var_type, statements) return (None, statements) lhs_value, lhs_statements = cls.evaluate_expression( @@ -418,7 +415,7 @@ def _check_type_size(expression, var_name, var_format, base_type): if isinstance(expression, FunctionCall): # function will not return a reqd / const type # Reference : https://openqasm.com/language/types.html#compile-time-constants, para: 5 - if expression.name.name in {"abs", "real", "imag", "sqrt", "sin", "cos", "tan"}: + if expression.name.name in FUNCTION_MAP: _val, _ = cls.evaluate_expression( expression.arguments[0], const_expr, reqd_type, validate_only ) diff --git a/src/pyqasm/maps/expressions.py b/src/pyqasm/maps/expressions.py index 286b42b8..d6be70e1 100644 --- a/src/pyqasm/maps/expressions.py +++ b/src/pyqasm/maps/expressions.py @@ -200,4 +200,7 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): "sin": np.sin, "cos": np.cos, "tan": np.tan, + "arccos": np.arccos, + "arcsin": np.arcsin, + "arctan": np.arctan, } diff --git a/src/pyqasm/pulse/validator.py b/src/pyqasm/pulse/validator.py index 6dbf42d0..37d5d437 100644 --- a/src/pyqasm/pulse/validator.py +++ b/src/pyqasm/pulse/validator.py @@ -83,11 +83,12 @@ def validate_angle_type_value( ) angle_type_size = compiler_angle_width angle_bit_string = format(expression.value, f"0{angle_type_size}b") + # Reference: https://openqasm.com/language/types.html#angles angle_val = (2 * CONSTANTS_MAP["pi"]) * (expression.value / (2**angle_type_size)) else: angle_val = init_value % (2 * CONSTANTS_MAP["pi"]) angle_type_size = compiler_angle_width or base_size - bit_string_value = int((2**angle_type_size) * (angle_val / (2 * CONSTANTS_MAP["pi"]))) + bit_string_value = round((2**angle_type_size) * (angle_val / (2 * CONSTANTS_MAP["pi"]))) angle_bit_string = format(bit_string_value, f"0{angle_type_size}b") return angle_val, angle_bit_string diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 9b416d83..2b80ec8f 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -49,6 +49,7 @@ from pyqasm.maps.expressions import ( ARRAY_TYPE_MAP, CONSTANTS_MAP, + FUNCTION_MAP, MAX_ARRAY_DIMENSIONS, ) from pyqasm.maps.gates import ( @@ -441,7 +442,7 @@ def _handle_function_init_expression( """ if isinstance(expression, qasm3_ast.FunctionCall): func_name = expression.name.name - if func_name in ["abs", "real", "imag", "sqrt", "sin", "cos", "tan"]: + if func_name in FUNCTION_MAP: if isinstance(init_value, complex): return PulseValidator.make_complex_binary_expression(init_value) if isinstance(init_value, (float, int)): diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index 6abbdfff..7d169d0f 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -20,6 +20,7 @@ from pyqasm.entrypoint import loads from pyqasm.exceptions import ValidationError +from pyqasm.visitor import QasmVisitor, ScopeManager # pylint: disable=ungrouped-imports from tests.qasm3.resources.variables import ( ASSIGNMENT_TESTS, CASTING_TESTS, @@ -682,58 +683,63 @@ def test_complex_type_variables(): complex c2 = 3.5 + 2.5im; complex c3 = 2.0 + c2; complex c4 = 2.0+sin(π/2) + (3.1 * 5.5im); - complex c5 = c1 * c2; - complex c6 = c1 + c2; - complex c7 = c1 - c2; - complex c8 = c1 / c2; - complex c9 = c1 ** c2; - complex c10 = sqrt(c1); - float c11 = abs(c1 * c2); - float c12 = real(c1); - float c13 = imag(c1); - float c14 = sin(π/2); - const complex c15 = -2.5 - 3.5im; - const complex c16 = 3.5 + 2.5im; - const complex c17 = 2.0 + c16; - const complex c18 = 2.0+sin(π/2) + (3.1 * 5.5im); - const complex c19 = c15 * c16; - const complex c20 = c15 + c16; - const complex c21 = c15 - c16; - const complex c22 = c15 / c16; - const complex c23 = c15 ** c16; - const complex c24 = sqrt(c15); - const float c25 = abs(c15 * c16); - const float c26 = real(c15); - const float c27 = imag(c15); - const float c28 = sin(π/2); - complex c29; - c29 = -2.5 - 3.5im; - complex c30; - c30 = 3.5 + 2.5im; - complex c31; - c31 = 2.0 + c30; + complex c5 = 2.0+arcsin(π/2) + (3.1 * 5.5im); + complex c6 = 2.0+arctan(π/2) + (3.1 * 5.5im); + complex c7 = c1 * c2; + complex c8 = c1 + c2; + complex c9 = c1 - c2; + complex c10 = c1 / c2; + complex c11 = c1 ** c2; + complex c12 = sqrt(c1); + float c13 = abs(c1 * c2); + float c14 = real(c1); + float c15 = imag(c1); + float c16 = sin(π/2); + const complex c17 = -2.5 - 3.5im; + const complex c18 = 3.5 + 2.5im; + const complex c19 = 2.0 + c18; + const complex c20 = 2.0+cos(π/2) + (3.1 * 5.5im); + const complex c21 = 2.0+arccos(π/2) + (3.1 * 5.5im); + const complex c22 = c17 * c18; + const complex c23 = c17 + c18; + const complex c24 = c17 - c18; + const complex c25 = c17 / c18; + const complex c26 = c17 ** c18; + const complex c27 = sqrt(c17); + const float c28 = abs(c17 * c18); + const float c29 = real(c17); + const float c30 = imag(c17); + const float c31 = sin(π/2); complex c32; - c32 = 2.0+sin(π/2) + (3.1 * 5.5im); + c32 = -2.5 - 3.5im; complex c33; - c33 = c29 * c30; + c33 = 3.5 + 2.5im; complex c34; - c34 = c29 + c30; + c34 = 2.0 + c33; complex c35; - c35 = c29 - c30; + c35 = 2.0+tan(π/2) + (3.1 * 5.5im); complex c36; - c36 = c29 / c30; + c36 = 2.0+arctan(π/2) + (3.1 * 5.5im); complex c37; - c37 = c29 ** c30; + c37 = c32 * c33; complex c38; - c38 = sqrt(c29); - float c39; - c39 = abs(c29 * c30); - float c40; - c40 = real(c29); - float c41; - c41 = imag(c29); - float c42; - c42 = sin(π/2); + c38 = c32 + c33; + complex c39; + c39 = c32 - c33; + complex c40; + c40 = c32 / c33; + complex c41; + c41 = c32 ** c33; + complex c42; + c42 = sqrt(c32); + float c43; + c43 = abs(c32 * c33); + float c44; + c44 = real(c32); + float c45; + c45 = imag(c32); + float c46; + c46 = sin(π/2); complex[float[64]] a = 10.0 + 5.0im; complex[float[64]] b = -2.0 - 7.0im; complex[float[64]] c = a + b; @@ -745,3 +751,35 @@ def test_complex_type_variables(): """ loads(qasm3_string).validate() + + +def test_pi_expression_bit_conversion(): + """Test that pi expressions are correctly converted to bit string representations""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + + angle[4] ang1 = pi / 2; + angle[8] ang2 = 15 * (pi / 16); + angle[4] ang3 = -pi / 2; + angle[4] ang4 = -pi; + angle[8] ang5 = (pi / 2) + (pi / 4); + angle[8] ang6 = (pi / 2) - (pi / 4); + + """ + + result = loads(qasm3_string) + result.validate() + + # Create a visitor to access the scope manager + scope_manager = ScopeManager() + visitor = QasmVisitor(result, scope_manager, check_only=True) + result.accept(visitor) + scope = scope_manager.get_global_scope() + + assert scope["ang1"].angle_bit_string == "0100" # pi/2 + assert scope["ang2"].angle_bit_string == "01111000" # 15*pi/16 + assert scope["ang3"].angle_bit_string == "1100" # -pi/2 (wraps to 3*pi/2) + assert scope["ang4"].angle_bit_string == "1000" # -pi (wraps to 1) + assert scope["ang5"].angle_bit_string == "01100000" # (pi/2) + (pi/4) = 3*pi/4 + assert scope["ang6"].angle_bit_string == "00100000" # (pi/2) - (pi/4) = pi/4 From 3ecfd4ddedbf0037f853464fe5daedb2c4d586e3 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Tue, 5 Aug 2025 05:03:49 -0500 Subject: [PATCH 08/12] Add extern function support in QASM --- CHANGELOG.md | 2 +- src/README.md | 6 +- src/pyqasm/entrypoint.py | 3 + src/pyqasm/expressions.py | 75 ++++++- src/pyqasm/modules/base.py | 1 + src/pyqasm/pulse/validator.py | 135 ++++++++++++ src/pyqasm/subroutines.py | 52 ++++- src/pyqasm/visitor.py | 147 +++++++++---- tests/qasm3/subroutines/test_subroutines.py | 230 +++++++++++++++++++- 9 files changed, 594 insertions(+), 57 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 588c508f..502703fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ Types of changes: - A github workflow for validating `CHANGELOG` updates in a PR ([#214](https://github.com/qBraid/pyqasm/pull/214)) - Added `unroll` command support in PYQASM CLI with options skipping files, overwriting originals files, and specifying output paths.([#224](https://github.com/qBraid/pyqasm/pull/224)) - Added `.github/copilot-instructions.md` to the repository to document coding standards and design principles for pyqasm. This file provides detailed guidance on documentation, static typing, formatting, error handling, and adherence to the QASM specification for all code contributions. ([#234](https://github.com/qBraid/pyqasm/pull/234)) -- Added support for `Angle` and `Complex` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239)) +- Added support for `Angle`,`extern` and `Complex` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239)) ### Improved / Modified - Added `slots=True` parameter to the data classes in `elements.py` to improve memory efficiency ([#218](https://github.com/qBraid/pyqasm/pull/218)) diff --git a/src/README.md b/src/README.md index 27eeae33..e37b7d30 100644 --- a/src/README.md +++ b/src/README.md @@ -38,6 +38,6 @@ Source code for OpenQASM 3 program validator and semantic analyzer | Box | ✅ | Completed | | CalibrationStatement | 📋 | Planned | | CalibrationDefinition | 📋 | Planned | -| ComplexType | 📋 | Planned | -| AngleType | 📋 | Planned | -| ExternDeclaration | 📋 | Planned | +| ComplexType | ✅ | Completed | +| AngleType | ✅ | Completed | +| ExternDeclaration | ✅ | Completed | diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index 59435916..f2733962 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -60,6 +60,7 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: device_qubits (int): Number of physical qubits available on the target device. device_cycle_time (float): The duration of a hardware device cycle, in seconds. compiler_angle_type_size (int): The width of the angle type in the compiler. + extern_functions (list): List of extern functions to be added to the module. Raises: TypeError: If the input is not a string or an `openqasm3.ast.Program` instance. @@ -94,6 +95,8 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: module._device_cycle_time = dev_cycle_time if compiler_angle_type_size := kwargs.get("compiler_angle_type_size"): module._compiler_angle_type_size = compiler_angle_type_size + if extern_functions := kwargs.get("extern_functions"): + module._extern_functions = extern_functions return module diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index a099166b..8aa965b2 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -201,6 +201,7 @@ def evaluate_expression( # type: ignore[return] reqd_type=None, validate_only: bool = False, dt=None, + extern_fns=None, ) -> tuple: """Evaluate an expression. Scalar types are assigned by value. @@ -208,6 +209,9 @@ def evaluate_expression( # type: ignore[return] expression (Any): The expression to evaluate. const_expr (bool): Whether the expression is a constant. Defaults to False. reqd_type (Any): The required type of the expression. Defaults to None. + validate_only (bool): Whether to validate the expression only. Defaults to False. + dt (float): The time step of the compiler. Defaults to None. + extern_fns (dict): A dictionary of extern functions. Defaults to None. Returns: tuple[Any, list[Statement]] : The result of the evaluation. @@ -259,6 +263,16 @@ def _check_type_size(expression, var_name, var_format, base_type): ) return base_size + def _is_external_function_call(expression, extern_fns): + """Check if an expression is an external function call""" + return isinstance(expression, FunctionCall) and expression.name.name in extern_fns + + def _get_external_function_return_type(expression, extern_fns): + """Get the return type of an external function call""" + if _is_external_function_call(expression, extern_fns): + return extern_fns[expression.name.name][1] + return None + if isinstance(expression, complex): return _check_and_return_value(expression) @@ -361,11 +375,21 @@ def _check_type_size(expression, var_name, var_format, base_type): return cls.evaluate_expression( expression.expression, const_expr, reqd_type, validate_only ) + # Check for external function in validate_only mode + return_type = _get_external_function_return_type(expression.expression, extern_fns) + if return_type: + return (return_type, statements) return (None, []) operand, returned_stats = cls.evaluate_expression( expression.expression, const_expr, reqd_type ) + + # Handle external function replacement + if _is_external_function_call(expression.expression, extern_fns): + expression.expression = returned_stats[0] + return _check_and_return_value(None) + if expression.op.name == "~" and not isinstance(operand, int): raise_qasm3_error( f"Unsupported expression type '{type(operand)}' in ~ operation", @@ -383,10 +407,10 @@ def _check_type_size(expression, var_name, var_format, base_type): return (None, statements) _lhs, _lhs_stmts = cls.evaluate_expression( - expression.lhs, const_expr, reqd_type, validate_only + expression.lhs, const_expr, reqd_type, validate_only, extern_fns=extern_fns ) _rhs, _rhs_stmts = cls.evaluate_expression( - expression.rhs, const_expr, reqd_type, validate_only + expression.rhs, const_expr, reqd_type, validate_only, extern_fns=extern_fns ) if isinstance(expression.lhs, Cast): @@ -398,15 +422,52 @@ def _check_type_size(expression, var_name, var_format, base_type): _var_type = AngleType(cls.angle_var_in_expr) cls.angle_var_in_expr = None return (_var_type, statements) + + _lhs_return_type = None + _rhs_return_type = None + # Check for external functions in both operands + _lhs_return_type = _get_external_function_return_type(expression.lhs, extern_fns) + _rhs_return_type = _get_external_function_return_type(expression.rhs, extern_fns) + + if _lhs_return_type and _rhs_return_type: + if _lhs_return_type != _rhs_return_type: + raise_qasm3_error( + f"extern function return type mismatch in binary expression: " + f"{type(_lhs_return_type).__name__} and " + f"{type(_rhs_return_type).__name__}", + err_type=ValidationError, + error_node=expression, + span=expression.span, + ) + else: + if _lhs_return_type: + return (_lhs_return_type, statements) + if _rhs_return_type: + return (_rhs_return_type, statements) + return (None, statements) lhs_value, lhs_statements = cls.evaluate_expression( - expression.lhs, const_expr, reqd_type + expression.lhs, const_expr, reqd_type, extern_fns=extern_fns ) + # Handle external function replacement for lhs + lhs_extern_function = False + if _is_external_function_call(expression.lhs, extern_fns): + expression.lhs = lhs_statements[0] + lhs_extern_function = True statements.extend(lhs_statements) + rhs_value, rhs_statements = cls.evaluate_expression( - expression.rhs, const_expr, reqd_type + expression.rhs, const_expr, reqd_type, extern_fns=extern_fns ) + # Handle external function replacement for rhs + rhs_extern_function = False + if _is_external_function_call(expression.rhs, extern_fns): + expression.rhs = rhs_statements[0] + rhs_extern_function = True + if lhs_extern_function or rhs_extern_function: + return (None, []) + statements.extend(rhs_statements) return _check_and_return_value( qasm3_expression_op_map(expression.op.name, lhs_value, rhs_value) @@ -415,6 +476,12 @@ def _check_type_size(expression, var_name, var_format, base_type): if isinstance(expression, FunctionCall): # function will not return a reqd / const type # Reference : https://openqasm.com/language/types.html#compile-time-constants, para: 5 + if validate_only: + return_type = _get_external_function_return_type(expression, extern_fns) + if return_type: + return (return_type, statements) + return (None, statements) + if expression.name.name in FUNCTION_MAP: _val, _ = cls.evaluate_expression( expression.arguments[0], const_expr, reqd_type, validate_only diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index d2c30121..82e2f961 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -61,6 +61,7 @@ def __init__(self, name: str, program: Program): self._consolidate_qubits: Optional[bool] = False self._device_cycle_time: Optional[int] = None self._compiler_angle_type_size: Optional[int] = None + self._extern_functions: dict[str, tuple[list[str], str]] = {} @property def name(self) -> str: diff --git a/src/pyqasm/pulse/validator.py b/src/pyqasm/pulse/validator.py index 37d5d437..c1b21720 100644 --- a/src/pyqasm/pulse/validator.py +++ b/src/pyqasm/pulse/validator.py @@ -20,21 +20,28 @@ import numpy as np from openqasm3.ast import ( + ArrayLiteral, BinaryExpression, BinaryOperator, BitstringLiteral, + BitType, + BooleanLiteral, Box, Cast, ConstantDeclaration, DelayInstruction, DurationLiteral, DurationType, + Expression, + ExternDeclaration, FloatLiteral, + FunctionCall, Identifier, ImaginaryLiteral, IntegerLiteral, Statement, StretchType, + TimeUnit, ) from pyqasm.exceptions import raise_qasm3_error @@ -285,3 +292,131 @@ def make_complex_binary_expression(value: complex) -> BinaryExpression: op=(BinaryOperator["+"] if value.imag >= 0 else BinaryOperator["-"]), rhs=ImaginaryLiteral(np.abs(value.imag)), ) + + @staticmethod + def validate_extern_declaration(module: Any, statement: ExternDeclaration) -> None: + """ + Validates an extern declaration. + Args: + module: The module object + statement: The extern declaration statement + + Raises: + ValidationError: If the extern declaration is invalid + """ + args = module._extern_functions[statement.name.name][0] + if len(args) != len(statement.arguments): + raise_qasm3_error( + f"Parameter count mismatch for 'extern' subroutine '{statement.name.name}'. " + f"Expected {len(args)} but got {len(statement.arguments)}", + error_node=statement, + span=statement.span, + ) + + def _get_type_string(type_obj) -> str: + """Recursively build type string for nested types""" + type_name = type(type_obj).__name__.replace("Type", "").lower() + if hasattr(type_obj, "base_type") and type_obj.base_type is not None: + base_type_str = _get_type_string(type_obj.base_type) + if hasattr(type_obj, "size") and type_obj.size is not None: + size_val = type_obj.size.value + return f"{type_name}[{size_val}][{base_type_str}]" + return f"{type_name}[{base_type_str}]" + if hasattr(type_obj, "size") and type_obj.size is not None: + size_val = type_obj.size.value + return f"{type_name}[{size_val}]" + return type_name + + for actual_arg, extern_arg in zip(statement.arguments, args): + if actual_arg == extern_arg: + continue + actual_arg_type = _get_type_string(actual_arg.type) + if actual_arg_type != str(extern_arg).lower(): + raise_qasm3_error( + f"Parameter type mismatch for 'extern' subroutine '{statement.name.name}'. " + f"Expected {extern_arg} but got {actual_arg_type}", + error_node=statement, + span=statement.span, + ) + return_type = module._extern_functions[statement.name.name][1] + actual_type_name = _get_type_string(statement.return_type) + if return_type == statement.return_type: + return + if str(return_type).lower() != actual_type_name: + raise_qasm3_error( + f"Return type mismatch for 'extern' subroutine '{statement.name.name}'. Expected " + f"{return_type} but got {actual_type_name}", + error_node=statement, + span=statement.span, + ) + + @staticmethod + def validate_and_process_extern_function_call( # pylint: disable=too-many-branches + statement: FunctionCall, global_scope: dict, device_cycle_time: float | None + ) -> FunctionCall: + """Validate and process extern function arguments + by converting them to appropriate literals. + + Args: + statement: The function call statement to process + global_scope: The global scope of the module + device_cycle_time: The device cycle time of the module + + Returns: + The validated and processed function call statement + + Raises: + ValidationError: If the function call is invalid + """ + fn_name = statement.name.name + # pylint: disable=too-many-nested-blocks + for i, arg in enumerate(statement.arguments): + if isinstance(arg, Identifier): + arg_var = global_scope.get(arg.name) + if arg_var is None: + raise_qasm3_error( + f"Undefined variable '{arg.name}' in extern function '{fn_name}'", + error_node=statement, + span=statement.span, + ) + + assert arg_var is not None + + if arg_var.base_type is not None and isinstance(arg_var.base_type, DurationType): + statement.arguments[i] = DurationLiteral( + float(arg_var.value) if arg_var.value is not None else 0.0, + unit=(TimeUnit.dt if device_cycle_time else TimeUnit.ns), + ) + elif ( + arg_var.base_type is not None + and isinstance(arg_var.base_type, BitType) + and (isinstance(arg_var.dims, list) or isinstance(arg_var.value, np.ndarray)) + ): + array_list: list[Expression] = [] + if isinstance(arg_var.value, np.ndarray): + for val in arg_var.value: + array_list.append(BooleanLiteral(bool(val))) + else: + if arg_var.dims is not None and len(arg_var.dims) > 0: + for _ in range(arg_var.dims[0]): + array_list.append(BooleanLiteral(bool(arg_var.value))) + statement.arguments[i] = ArrayLiteral(array_list) + elif isinstance(arg_var.value, float): + statement.arguments[i] = FloatLiteral(arg_var.value) + elif isinstance(arg_var.value, int): + statement.arguments[i] = IntegerLiteral(arg_var.value) + elif isinstance(arg_var.value, bool): + statement.arguments[i] = BooleanLiteral(arg_var.value) + elif isinstance(arg_var.value, complex): + statement.arguments[i] = PulseValidator.make_complex_binary_expression( + arg_var.value + ) + else: + raise_qasm3_error( + f"Invalid argument type '{arg_var.base_type}' for extern function " + f"'{fn_name}'", + error_node=statement, + span=statement.span, + ) + + return statement diff --git a/src/pyqasm/subroutines.py b/src/pyqasm/subroutines.py index e4b358c4..a3703aa0 100644 --- a/src/pyqasm/subroutines.py +++ b/src/pyqasm/subroutines.py @@ -16,11 +16,13 @@ Module containing the class for validating QASM3 subroutines. """ +import random from typing import Optional from openqasm3.ast import ( AccessControl, ArrayReferenceType, + ExternArgument, Identifier, IndexExpression, IntType, @@ -97,7 +99,7 @@ def process_classical_arg(cls, formal_arg, actual_arg, fn_name, fn_call): formal_arg, actual_arg, actual_arg_name, fn_name, fn_call ) - @classmethod # pylint: disable-next=too-many-arguments + @classmethod # pylint: disable-next=too-many-arguments,too-many-locals def _process_classical_arg_by_value( cls, formal_arg, actual_arg, actual_arg_name, fn_name, fn_call ): @@ -146,13 +148,53 @@ def _process_classical_arg_by_value( error_node=fn_call, span=fn_call.span, ) + + if isinstance(formal_arg, ExternArgument): + formal_arg_type = formal_arg.type + formal_arg_size = None + if hasattr(formal_arg_type, "size") and formal_arg_type.size is not None: + formal_arg_size, _ = Qasm3ExprEvaluator.evaluate_expression( + formal_arg_type.size + ) + actual_arg_var = cls.visitor_obj._scope_manager.get_from_global_scope( + actual_arg_name + ) + if ( + actual_arg_var.base_type != formal_arg_type + or actual_arg_var.base_size != formal_arg_size + ): + if formal_arg_size is not None: + raise_qasm3_error( + f"Argument type mismatch in function '{fn_name}', expected " + f"{type(formal_arg_type).__name__}[{formal_arg_size}] but got " + f"{type(actual_arg_var.base_type).__name__}" + f"[{actual_arg_var.base_size}]", + error_node=fn_call, + span=fn_call.span, + ) + actual_arg_value = Qasm3ExprEvaluator.evaluate_expression(actual_arg)[0] - # save this value to be updated later in scope + if isinstance(formal_arg, ExternArgument): + # Generate a unique name for the extern argument variable + _name = fn_name + while not cls.visitor_obj._scope_manager.check_in_scope(_name) and _name == fn_name: + _name = f"{fn_name}_{random.randint(1, 1_000_000_000)}" + if actual_arg_name: + _var = cls.visitor_obj._scope_manager.get_from_global_scope(actual_arg_name) + _base_size = _var.base_size + _base_type = _var.base_type + else: + _base_size = Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0] + _base_type = formal_arg.type + else: + _name = formal_arg.name.name + _base_size = Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0] + _base_type = formal_arg.type return Variable( - name=formal_arg.name.name, - base_type=formal_arg.type, - base_size=Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0], + name=_name, + base_type=_base_type, + base_size=_base_size, dims=None, value=actual_arg_value, is_constant=False, diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 2b80ec8f..a6e1c756 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -22,7 +22,7 @@ import logging from collections import OrderedDict, deque from functools import partial -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, Sequence, cast import numpy as np import openqasm3.ast as qasm3_ast @@ -106,7 +106,9 @@ def __init__( # pylint: disable=too-many-arguments self._global_creg_size_map: dict[str, int] = {} self._custom_gates: dict[str, qasm3_ast.QuantumGateDefinition] = {} self._external_gates: list[str] = [] if external_gates is None else external_gates - self._subroutine_defns: dict[str, qasm3_ast.SubroutineDefinition] = {} + self._subroutine_defns: dict[ + str, qasm3_ast.SubroutineDefinition | qasm3_ast.ExternDeclaration + ] = {} self._check_only: bool = check_only self._unroll_barriers: bool = unroll_barriers self._recording_ext_gate_depth = False @@ -121,6 +123,7 @@ def __init__( # pylint: disable=too-many-arguments self._qubit_register_offsets: OrderedDict = OrderedDict() self._qubit_register_max_offset = 0 self._total_delay_duration_in_box = 0 + self._in_extern_function: bool = False self._scope_manager: ScopeManager = scope_manager @@ -449,6 +452,20 @@ def _handle_function_init_expression( return qasm3_ast.FloatLiteral(init_value) return None + def _handle_extern_function_cleanup( + self, statements: list, statement: qasm3_ast.Statement + ) -> None: + """Clean up extern function state and modify statements if needed. + + Args: + statements: List of statements to potentially modify + statement: The statement to append if in extern function + """ + if self._in_extern_function: + self._in_extern_function = False + statements.clear() + statements.append(statement) + def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches self, statement: qasm3_ast.QuantumMeasurementStatement ) -> list[qasm3_ast.QuantumMeasurementStatement]: @@ -1347,7 +1364,10 @@ def _visit_constant_declaration( try: init_value, stmts = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, const_expr=True, dt=self._module._device_cycle_time + statement.init_expression, + const_expr=True, + dt=self._module._device_cycle_time, + extern_fns=self._module._extern_functions, ) except ValidationError as err: raise_qasm3_error( @@ -1362,7 +1382,7 @@ def _visit_constant_declaration( base_type = statement.type base_size = self._check_variable_type_size(statement, var_name, "constant", base_type) angle_val_bit_string = None - if isinstance(base_type, qasm3_ast.AngleType): + if isinstance(base_type, qasm3_ast.AngleType) and not self._in_extern_function: init_value, angle_val_bit_string = PulseValidator.validate_angle_type_value( statement, init_value=init_value, @@ -1370,7 +1390,7 @@ def _visit_constant_declaration( compiler_angle_width=self._module._compiler_angle_type_size, ) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, validate_only=True + statement.init_expression, validate_only=True, extern_fns=self._module._extern_functions ) self._check_variable_cast_type(statement, val_type, var_name, base_type, base_size, True) variable = Variable( @@ -1392,9 +1412,10 @@ def _visit_constant_declaration( variable.time_unit = "ns" # cast + validation - variable.value = Qasm3Validator.validate_variable_assignment_value( - variable, init_value, op_node=statement - ) + if init_value is not None and not self._in_extern_function: + variable.value = Qasm3Validator.validate_variable_assignment_value( + variable, init_value, op_node=statement + ) self._scope_manager.add_var_in_scope(variable) @@ -1406,6 +1427,7 @@ def _visit_constant_declaration( self._handle_function_init_expression(statement.init_expression, init_value) or statement.init_expression ) + self._handle_extern_function_cleanup(statements, statement) if self._check_only: return [] @@ -1529,7 +1551,9 @@ def _visit_classical_declaration( else: try: init_value, stmts = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, dt=self._module._device_cycle_time + statement.init_expression, + dt=self._module._device_cycle_time, + extern_fns=self._module._extern_functions, ) statements.extend(stmts) _req_type = ( @@ -1538,9 +1562,12 @@ def _visit_classical_declaration( else None ) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, validate_only=True, reqd_type=_req_type + statement.init_expression, + validate_only=True, + reqd_type=_req_type, + extern_fns=self._module._extern_functions, ) - if isinstance(base_type, qasm3_ast.AngleType): + if isinstance(base_type, qasm3_ast.AngleType) and not self._in_extern_function: init_value, angle_val_bit_string = PulseValidator.validate_angle_type_value( statement, init_value=init_value, @@ -1593,9 +1620,10 @@ def _visit_classical_declaration( ) else: try: - variable.value = Qasm3Validator.validate_variable_assignment_value( - variable, init_value, op_node=statement - ) + if init_value is not None and not self._in_extern_function: + variable.value = Qasm3Validator.validate_variable_assignment_value( + variable, init_value, op_node=statement + ) except ValidationError as err: raise_qasm3_error( f"Invalid initialization value for variable '{var_name}'", @@ -1622,6 +1650,8 @@ def _visit_classical_declaration( statements.append(statement) self._module._add_classical_register(var_name, base_size) + self._handle_extern_function_cleanup(statements, statement) + if isinstance(base_type, qasm3_ast.ComplexType) and isinstance(variable.value, complex): statement.init_expression = PulseValidator.make_complex_binary_expression(init_value) @@ -1696,10 +1726,12 @@ def _visit_classical_assignment( lhs=lvalue, op=binary_op, rhs=rvalue # type: ignore[arg-type] ) rvalue_raw, rhs_stmts = Qasm3ExprEvaluator.evaluate_expression( - rvalue, dt=self._module._device_cycle_time + rvalue, dt=self._module._device_cycle_time, extern_fns=self._module._extern_functions ) # consists of scope check and index validation statements.extend(rhs_stmts) - val_type, _ = Qasm3ExprEvaluator.evaluate_expression(rvalue, validate_only=True) + val_type, _ = Qasm3ExprEvaluator.evaluate_expression( + rvalue, validate_only=True, extern_fns=self._module._extern_functions + ) self._check_variable_cast_type( statement, val_type, @@ -1709,7 +1741,7 @@ def _visit_classical_assignment( False, ) angle_val_bit_string = None - if isinstance(lvar_base_type, qasm3_ast.AngleType): + if isinstance(lvar_base_type, qasm3_ast.AngleType) and not self._in_extern_function: rvalue_raw, angle_val_bit_string = PulseValidator.validate_angle_type_value( statement, init_value=rvalue_raw, @@ -1723,9 +1755,10 @@ def _visit_classical_assignment( rvalue_eval = None if not isinstance(rvalue_raw, np.ndarray): # rhs is a scalar - rvalue_eval = Qasm3Validator.validate_variable_assignment_value( - lvar, rvalue_raw, op_node=statement # type: ignore[arg-type] - ) + if rvalue_raw is not None and not self._in_extern_function: + rvalue_eval = Qasm3Validator.validate_variable_assignment_value( + lvar, rvalue_raw, op_node=statement # type: ignore[arg-type] + ) else: # rhs is a list rvalue_dimensions = list(rvalue_raw.shape) @@ -1787,6 +1820,8 @@ def _visit_classical_assignment( or statement.rvalue ) + self._handle_extern_function_cleanup(statements, statement) + if self._check_only: return [] @@ -2041,7 +2076,9 @@ def _visit_forin_loop(self, statement: qasm3_ast.ForInLoop) -> list[qasm3_ast.St return [] return result - def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition) -> list[None]: + def _visit_subroutine_definition( + self, statement: qasm3_ast.SubroutineDefinition | qasm3_ast.ExternDeclaration + ) -> Sequence[None | qasm3_ast.ExternDeclaration]: """Visit a subroutine definition element. Reference: https://openqasm.com/language/subroutines.html#subroutines @@ -2052,6 +2089,7 @@ def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition None """ fn_name = statement.name.name + statements = [] if fn_name in CONSTANTS_MAP: raise_qasm3_error( @@ -2073,14 +2111,26 @@ def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition span=statement.span, ) + if isinstance(statement, qasm3_ast.ExternDeclaration): + if statement.name.name in self._module._extern_functions: + PulseValidator.validate_extern_declaration(self._module, statement) + self._module._extern_functions[statement.name.name] = ( + statement.arguments, + statement.return_type, + ) + + statements.append(statement) + self._subroutine_defns[fn_name] = statement + if self._check_only: + return [] - return [] + return statements # pylint: disable=too-many-locals, too-many-statements def _visit_function_call( self, statement: qasm3_ast.FunctionCall - ) -> tuple[Any, list[qasm3_ast.Statement]]: + ) -> tuple[Any | None, list[qasm3_ast.Statement | qasm3_ast.FunctionCall]]: """Visit a function call element. Args: @@ -2113,7 +2163,7 @@ def _visit_function_call( quantum_vars, classical_vars = [], [] for actual_arg, formal_arg in zip(statement.arguments, subroutine_def.arguments): - if isinstance(formal_arg, qasm3_ast.ClassicalArgument): + if isinstance(formal_arg, (qasm3_ast.ClassicalArgument, qasm3_ast.ExternArgument)): classical_vars.append( Qasm3SubroutineProcessor.process_classical_arg( formal_arg, actual_arg, fn_name, statement @@ -2147,25 +2197,36 @@ def _visit_function_call( self._function_qreg_transform_map.append(qubit_transform_map) return_statement = None - result = [] - for function_op in subroutine_def.body: - if isinstance(function_op, qasm3_ast.ReturnStatement): - return_statement = copy.copy(function_op) - break - try: - result.extend(self.visit_statement(copy.copy(function_op))) - except (TypeError, copy.Error): - result.extend(self.visit_statement(copy.deepcopy(function_op))) - return_value = None - if return_statement: - return_value, stmts = Qasm3ExprEvaluator.evaluate_expression( - return_statement.expression - ) - return_value = Qasm3Validator.validate_return_statement( - subroutine_def, return_statement, return_value + result: list[qasm3_ast.Statement | qasm3_ast.FunctionCall] = [] + if isinstance(subroutine_def, qasm3_ast.ExternDeclaration): + self._in_extern_function = True + global_scope = self._scope_manager.get_global_scope() + result.append( + PulseValidator.validate_and_process_extern_function_call( + statement, global_scope, self._module._device_cycle_time + ) + if not self._check_only + else statement ) - result.extend(stmts) + else: + for function_op in subroutine_def.body: + if isinstance(function_op, qasm3_ast.ReturnStatement): + return_statement = copy.copy(function_op) + break + try: + result.extend(self.visit_statement(copy.copy(function_op))) + except (TypeError, copy.Error): + result.extend(self.visit_statement(copy.deepcopy(function_op))) + + if return_statement: + return_value, stmts = Qasm3ExprEvaluator.evaluate_expression( + return_statement.expression, extern_fns=self._module._extern_functions + ) + return_value = Qasm3Validator.validate_return_statement( + subroutine_def, return_statement, return_value + ) + result.extend(stmts) # remove qubit transformation map self._function_qreg_transform_map.pop() @@ -2175,7 +2236,7 @@ def _visit_function_call( self._scope_manager.decrement_scope_level() self._scope_manager.pop_scope() - if self._check_only: + if self._check_only and not self._in_extern_function: return return_value, [] return return_value, result @@ -2627,6 +2688,7 @@ def visit_statement(self, statement: qasm3_ast.Statement) -> list[qasm3_ast.Stat qasm3_ast.AliasStatement: self._visit_alias_statement, qasm3_ast.SwitchStatement: self._visit_switch_statement, qasm3_ast.SubroutineDefinition: self._visit_subroutine_definition, + qasm3_ast.ExternDeclaration: self._visit_subroutine_definition, qasm3_ast.ExpressionStatement: lambda x: self._visit_function_call(x.expression), qasm3_ast.IODeclaration: lambda x: [], qasm3_ast.BreakStatement: self._visit_break, @@ -2689,4 +2751,5 @@ def finalize(self, unrolled_stmts): if isinstance(stmt, qasm3_ast.QuantumPhase): if len(stmt.qubits) == len(self._qubit_labels): stmt.qubits = [] + Qasm3ExprEvaluator.angle_var_in_expr = None return unrolled_stmts diff --git a/tests/qasm3/subroutines/test_subroutines.py b/tests/qasm3/subroutines/test_subroutines.py index c9e5d5e7..f1c200fe 100644 --- a/tests/qasm3/subroutines/test_subroutines.py +++ b/tests/qasm3/subroutines/test_subroutines.py @@ -19,10 +19,14 @@ import pytest -from pyqasm.entrypoint import loads +from pyqasm.entrypoint import dumps, loads from pyqasm.exceptions import ValidationError from tests.qasm3.resources.subroutines import SUBROUTINE_INCORRECT_TESTS -from tests.utils import check_single_qubit_gate_op, check_single_qubit_rotation_op +from tests.utils import ( + check_single_qubit_gate_op, + check_single_qubit_rotation_op, + check_unrolled_qasm, +) def test_function_declaration(): @@ -421,3 +425,225 @@ def test_incorrect_custom_ops(test_name, caplog): assert f"Error at line {line_num}, column {col_num}" in caplog.text assert err_line in caplog.text + + +def test_extern_function_call(): + """Test extern function call""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + float a = 1.0; + int b = 2; + extern func1(float, int) -> bit; + bit c = 2 * func1(a, b); + bit fc = -func1(a, b); + + bit[2] b1 = true; + angle ang1 = pi/2; + extern func2(bit[2], angle) -> complex; + const complex d = func2(b1, ang1); + const complex e = func2(b1, ang1) + 2.0; + const complex f = -func2(b1, ang1); + + duration t1 = 100ns; + bool b2 = true; + extern func3(duration, bool) -> int; + int dd; + dd = func3(t1, b2); + int ee; + ee = func3(t1, b2) + 2; + int ff; + ff = -func3(t1, b2); + int gg; + gg = func3(t1, b2) * func3(t1, b2); + + float[32] fa = 3.14; + float[32] fb = 2.71; + extern func4(float[32], float[32]) -> float[32]; + float[32] fc1 = func4(fa, fb); + + complex[float[64]] ca = 1.0 + 2.0im; + complex[float[64]] cb = 3.0 - 4.0im; + extern func5(complex[float[64]], complex[float[64]]) -> complex[float[64]]; + complex[float[64]] cc1 = func5(ca, cb); + + bit[4] bd = true; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(bd); + + angle[8] an = pi/4; + extern func7(angle[8]) -> angle[8]; + angle[8] af1 = func7(an); + + bool bl = false; + extern func8(bool) -> bool; + bool bf1 = func8(bl); + + int[24] ix = 42; + extern func9(int[24]) -> int[24]; + int[24] ig1 = func9(ix); + + float[64] fx = 2.718; + extern func10(float[64]) -> float[64]; + float[64] fg1 = func10(fx); + """ + + expected_qasm = """OPENQASM 3.0; + include "stdgates.inc"; + extern func1(float, int) -> bit; + bit[1] c = 2 * func1(1.0, 2); + bit[1] fc = -func1(1.0, 2); + bit[2] b1 = true; + extern func2(bit[2], angle) -> complex; + const complex d = func2({true, true}, 1.5707963267948966); + const complex e = func2({true, true}, 1.5707963267948966) + 2.0; + const complex f = -func2({true, true}, 1.5707963267948966); + extern func3(duration, bool) -> int; + dd = func3(100.0ns, True); + ee = func3(100.0ns, True) + 2; + ff = -func3(100.0ns, True); + gg = func3(100.0ns, True) * func3(100.0ns, True); + extern func4(float[32], float[32]) -> float[32]; + float[32] fc1 = func4(3.14, 2.71); + extern func5(complex[float[64]], complex[float[64]]) -> complex[float[64]]; + complex[float[64]] cc1 = func5(1.0 + 2.0im, 3.0 - 4.0im); + bit[4] bd = true; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6({true, true, true, true}); + extern func7(angle[8]) -> angle[8]; + angle[8] af1 = func7(0.7853981633974483); + extern func8(bool) -> bool; + bool bf1 = func8(False); + extern func9(int[24]) -> int[24]; + int[24] ig1 = func9(42); + extern func10(float[64]) -> float[64]; + float[64] fg1 = func10(2.718); + """ + + extern_functions = { + "func1": (["float", "int"], "bit"), + "func2": (["bit[2]", "angle"], "complex"), + "func3": (["duration", "bool"], "int"), + "func4": (["float[32]", "float[32]"], "float[32]"), + "func5": (["complex[float[64]]", "complex[float[64]]"], "complex[float[64]]"), + "func6": (["bit[4]"], "bit[4]"), + "func7": (["angle[8]"], "angle[8]"), + "func8": (["bool"], "bool"), + "func9": (["int[24]"], "int[24]"), + "func10": (["float[64]"], "float[64]"), + } + + result = loads(qasm3_string, extern_functions=extern_functions) + result.unroll() + unrolled_qasm = dumps(result) + + check_unrolled_qasm(unrolled_qasm, expected_qasm) + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit[4] bd = true; + float[64] fx = 2.718; + extern func6(bit[4]) -> bit[4]; + extern func10(float[64]) -> float[64]; + bit[4] be1 = func6(bd) * func10(fx); + """, + r"extern function return type mismatch in binary expression: BitType and FloatType", + r"Error at line 8, column 25", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit bd = true; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(bd); + """, + r"Argument type mismatch in function 'func6', expected BitType[4] but got BitType[1]", + r"Error at line 6, column 25", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit bd = true; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(fd); + """, + r"Undefined variable 'fd' used for function call 'func6'", + r"Error at line 6, column 25", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_extern_function_call_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + loads( + qasm_code, + ).validate() + first = excinfo.value.__cause__ or excinfo.value.__context__ + assert first is not None, "Expected a chained ValidationError" + msg = str(first) + assert error_message in msg + assert error_span in caplog.text + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + float fx = 2.718; + int ix = 42; + extern func1(float) -> bit; + bit be1 = func1(fx); + """, + r"Parameter count mismatch for 'extern' subroutine 'func1'. Expected 2 but got 1", + r"Error at line 6, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + float fx = 2.718; + int ix = 42; + extern func1(float[64], int[64]) -> bit; + bit be1 = func1(fx); + """, + r"Parameter type mismatch for 'extern' subroutine 'func1'." + r" Expected float but got float[64]", + r"Error at line 6, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + float fx = 2.718; + int ix = 42; + extern func1(float, int) -> bit[2]; + bit[2] be1 = func1(fx); + """, + r"Return type mismatch for 'extern' subroutine 'func1'. Expected bit but got bit[2]", + r"Error at line 6, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_extern_function_dict_call_erro(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + extern_functions = { + "func1": (["float", "int"], "bit"), + "func2": (["bit[2]", "angle"], "complex"), + "func3": (["uint"], "int"), + } + loads(qasm_code, extern_functions=extern_functions).validate() + msg = str(excinfo.value) + assert error_message in msg + assert error_span in caplog.text From eb01ca628376704fcea56b88b7976972d144f8e5 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Tue, 5 Aug 2025 20:49:25 -0500 Subject: [PATCH 09/12] Enhance QASM type validation and bitstring handling - Updated `Qasm3ExprEvaluator` to include support for `BitstringLiteral` in expression evaluations. - Improved `Qasm3SubroutineProcessor` to validate argument types, including checks for `BitType` and `AngleType`. - Added a new method in `QasmVisitor` to validate the width of bitstring literals. - Adjusted `PulseValidator` to handle string representations of bitstrings. - Updated tests to reflect changes in bitstring handling and validation logic. --- src/pyqasm/expressions.py | 5 +- src/pyqasm/maps/expressions.py | 4 +- src/pyqasm/pulse/validator.py | 25 ++-- src/pyqasm/subroutines.py | 74 ++++++++++-- src/pyqasm/visitor.py | 20 +++- tests/qasm3/subroutines/test_subroutines.py | 120 ++++++++++++++++++-- 6 files changed, 211 insertions(+), 37 deletions(-) diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index 8aa965b2..d38a9683 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -338,7 +338,7 @@ def _get_external_function_return_type(expression, extern_fns): ) return _check_and_return_value(dimensions[index]) - if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral, BitstringLiteral)): + if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral)): if reqd_type: if reqd_type == BoolType and isinstance(expression, BooleanLiteral): return _check_and_return_value(expression.value) @@ -357,6 +357,9 @@ def _get_external_function_return_type(expression, extern_fns): ) return _check_and_return_value(expression.value) + if isinstance(expression, BitstringLiteral): + return _check_and_return_value(format(expression.value, f"0{expression.width}b")) + if isinstance(expression, DurationLiteral): unit_name = expression.unit.name if dt: diff --git a/src/pyqasm/maps/expressions.py b/src/pyqasm/maps/expressions.py index d6be70e1..7a6eef2e 100644 --- a/src/pyqasm/maps/expressions.py +++ b/src/pyqasm/maps/expressions.py @@ -124,7 +124,7 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): # not sure if we wanna hande array bit assignments too. # For now, we only cater to single bit assignment. if openqasm_type == BitType: - return rhs_value != 0 + return rhs_value if openqasm_type == AngleType: if isinstance(rhs_value, bool): return ((2 * CONSTANTS_MAP["pi"]) * (1 / 2)) if rhs_value else 0.0 @@ -162,7 +162,7 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): VARIABLE_TYPE_CAST_MAP = { BoolType: (int, float, bool, np.int64, np.float64, np.bool_), IntType: (bool, int, float, np.int64, np.float64, np.bool_), - BitType: (bool, int, np.int64, np.bool_), + BitType: (bool, int, np.int64, np.bool_, str), UintType: (bool, int, float, np.int64, np.uint64, np.float64, np.bool_), FloatType: (bool, int, float, np.int64, np.float64, np.bool_), AngleType: (float, np.float64, bool, np.bool_), diff --git a/src/pyqasm/pulse/validator.py b/src/pyqasm/pulse/validator.py index c1b21720..28d3a78a 100644 --- a/src/pyqasm/pulse/validator.py +++ b/src/pyqasm/pulse/validator.py @@ -20,11 +20,9 @@ import numpy as np from openqasm3.ast import ( - ArrayLiteral, BinaryExpression, BinaryOperator, BitstringLiteral, - BitType, BooleanLiteral, Box, Cast, @@ -32,7 +30,6 @@ DelayInstruction, DurationLiteral, DurationType, - Expression, ExternDeclaration, FloatLiteral, FunctionCall, @@ -382,25 +379,13 @@ def validate_and_process_extern_function_call( # pylint: disable=too-many-branc assert arg_var is not None - if arg_var.base_type is not None and isinstance(arg_var.base_type, DurationType): + if arg_var.base_type is not None and isinstance( + arg_var.base_type, (DurationType, StretchType) + ): statement.arguments[i] = DurationLiteral( float(arg_var.value) if arg_var.value is not None else 0.0, unit=(TimeUnit.dt if device_cycle_time else TimeUnit.ns), ) - elif ( - arg_var.base_type is not None - and isinstance(arg_var.base_type, BitType) - and (isinstance(arg_var.dims, list) or isinstance(arg_var.value, np.ndarray)) - ): - array_list: list[Expression] = [] - if isinstance(arg_var.value, np.ndarray): - for val in arg_var.value: - array_list.append(BooleanLiteral(bool(val))) - else: - if arg_var.dims is not None and len(arg_var.dims) > 0: - for _ in range(arg_var.dims[0]): - array_list.append(BooleanLiteral(bool(arg_var.value))) - statement.arguments[i] = ArrayLiteral(array_list) elif isinstance(arg_var.value, float): statement.arguments[i] = FloatLiteral(arg_var.value) elif isinstance(arg_var.value, int): @@ -411,6 +396,10 @@ def validate_and_process_extern_function_call( # pylint: disable=too-many-branc statement.arguments[i] = PulseValidator.make_complex_binary_expression( arg_var.value ) + elif isinstance(arg_var.value, str): + width = len(arg_var.value) + value = int(arg_var.value, 2) + statement.arguments[i] = BitstringLiteral(value, width) else: raise_qasm3_error( f"Invalid argument type '{arg_var.base_type}' for extern function " diff --git a/src/pyqasm/subroutines.py b/src/pyqasm/subroutines.py index a3703aa0..47ac6b0c 100644 --- a/src/pyqasm/subroutines.py +++ b/src/pyqasm/subroutines.py @@ -19,15 +19,26 @@ import random from typing import Optional +import numpy as np from openqasm3.ast import ( AccessControl, + AngleType, ArrayReferenceType, + BitstringLiteral, + BitType, + BooleanLiteral, + BoolType, + ComplexType, + DurationLiteral, + DurationType, ExternArgument, + FloatType, Identifier, IndexExpression, IntType, QASMNode, QubitDeclaration, + UintType, ) from openqasm3.printer import dumps @@ -99,7 +110,7 @@ def process_classical_arg(cls, formal_arg, actual_arg, fn_name, fn_call): formal_arg, actual_arg, actual_arg_name, fn_name, fn_call ) - @classmethod # pylint: disable-next=too-many-arguments,too-many-locals + @classmethod # pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches def _process_classical_arg_by_value( cls, formal_arg, actual_arg, actual_arg_name, fn_name, fn_call ): @@ -180,13 +191,62 @@ def _process_classical_arg_by_value( _name = fn_name while not cls.visitor_obj._scope_manager.check_in_scope(_name) and _name == fn_name: _name = f"{fn_name}_{random.randint(1, 1_000_000_000)}" - if actual_arg_name: - _var = cls.visitor_obj._scope_manager.get_from_global_scope(actual_arg_name) - _base_size = _var.base_size - _base_type = _var.base_type - else: + if hasattr(formal_arg.type, "size") and formal_arg.type.size is not None: _base_size = Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0] - _base_type = formal_arg.type + else: + _base_size = None + _base_type = formal_arg.type + + # pylint: disable=too-many-boolean-expressions + if actual_arg_value and actual_arg_name is None: + if ( + ( + isinstance(_base_type, UintType) + and (not isinstance(actual_arg_value, int) or actual_arg_value < 0) + ) + or ( + isinstance(_base_type, AngleType) + and ( + not isinstance(actual_arg_value, float) + or not 0 <= actual_arg_value <= 2 * np.pi + ) + ) + or ( + isinstance(_base_type, FloatType) + and not isinstance(actual_arg_value, float) + ) + or ( + isinstance(_base_type, ComplexType) + and not isinstance(actual_arg_value, complex) + ) + or (isinstance(_base_type, IntType) and not isinstance(actual_arg_value, int)) + or ( + isinstance(_base_type, DurationType) + and not isinstance(actual_arg, DurationLiteral) + ) + or ( + isinstance(_base_type, BoolType) + and ( + not isinstance(actual_arg_value, bool) + and actual_arg_value not in (0, 1) + ) + ) + or ( + isinstance(_base_type, BitType) + and ( + not isinstance(actual_arg, (BitstringLiteral, BooleanLiteral)) + and not isinstance(actual_arg_value, (bool, int)) + ) + ) + ): + raise_qasm3_error( + f"Invalid argument value for '{fn_name}', expected " + f"'{type(_base_type).__name__}' but got value = " + f"{actual_arg_value}.", + error_node=fn_call, + span=fn_call.span, + ) + else: _name = formal_arg.name.name _base_size = Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0] diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index a6e1c756..013d3b8f 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -466,6 +466,15 @@ def _handle_extern_function_cleanup( statements.clear() statements.append(statement) + def _validate_bitstring_literal_width(self, init_value, base_size, var_name, statement): + if len(init_value) != base_size: + raise_qasm3_error( + f"Invalid bitstring literal '{init_value}' width [{len(init_value)}] " + f"for variable '{var_name}' of size [{base_size}]", + error_node=statement, + span=statement.span, + ) + def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches self, statement: qasm3_ast.QuantumMeasurementStatement ) -> list[qasm3_ast.QuantumMeasurementStatement]: @@ -1404,6 +1413,9 @@ def _visit_constant_declaration( angle_bit_string=angle_val_bit_string, ) + if isinstance(base_type, qasm3_ast.BitType) and isinstance(init_value, str): + self._validate_bitstring_literal_width(init_value, base_size, var_name, statement) + if isinstance(base_type, (qasm3_ast.DurationType, qasm3_ast.StretchType)): PulseValidator.validate_duration_literal_value(init_value, statement, base_type) if self._module._device_cycle_time: @@ -1584,6 +1596,8 @@ def _visit_classical_declaration( span=statement.span, raised_from=err, ) + if isinstance(base_type, qasm3_ast.BitType) and isinstance(init_value, str): + self._validate_bitstring_literal_width(init_value, base_size, var_name, statement) variable = Variable( var_name, @@ -1740,6 +1754,10 @@ def _visit_classical_assignment( lvar.base_size, # type: ignore[union-attr] False, ) + if isinstance(lvar_base_type, qasm3_ast.BitType) and isinstance(rvalue_raw, str): + self._validate_bitstring_literal_width( + rvalue_raw, lvar.base_size, lvar_name, statement # type: ignore[union-attr] + ) angle_val_bit_string = None if isinstance(lvar_base_type, qasm3_ast.AngleType) and not self._in_extern_function: rvalue_raw, angle_val_bit_string = PulseValidator.validate_angle_type_value( @@ -2206,8 +2224,6 @@ def _visit_function_call( PulseValidator.validate_and_process_extern_function_call( statement, global_scope, self._module._device_cycle_time ) - if not self._check_only - else statement ) else: for function_op in subroutine_def.body: diff --git a/tests/qasm3/subroutines/test_subroutines.py b/tests/qasm3/subroutines/test_subroutines.py index f1c200fe..fd3fac09 100644 --- a/tests/qasm3/subroutines/test_subroutines.py +++ b/tests/qasm3/subroutines/test_subroutines.py @@ -467,7 +467,7 @@ def test_extern_function_call(): extern func5(complex[float[64]], complex[float[64]]) -> complex[float[64]]; complex[float[64]] cc1 = func5(ca, cb); - bit[4] bd = true; + bit[4] bd = "0101"; extern func6(bit[4]) -> bit[4]; bit[4] be1 = func6(bd); @@ -495,9 +495,9 @@ def test_extern_function_call(): bit[1] fc = -func1(1.0, 2); bit[2] b1 = true; extern func2(bit[2], angle) -> complex; - const complex d = func2({true, true}, 1.5707963267948966); - const complex e = func2({true, true}, 1.5707963267948966) + 2.0; - const complex f = -func2({true, true}, 1.5707963267948966); + const complex d = func2(True, 1.5707963267948966); + const complex e = func2(True, 1.5707963267948966) + 2.0; + const complex f = -func2(True, 1.5707963267948966); extern func3(duration, bool) -> int; dd = func3(100.0ns, True); ee = func3(100.0ns, True) + 2; @@ -507,9 +507,9 @@ def test_extern_function_call(): float[32] fc1 = func4(3.14, 2.71); extern func5(complex[float[64]], complex[float[64]]) -> complex[float[64]]; complex[float[64]] cc1 = func5(1.0 + 2.0im, 3.0 - 4.0im); - bit[4] bd = true; + bit[4] bd = "0101"; extern func6(bit[4]) -> bit[4]; - bit[4] be1 = func6({true, true, true, true}); + bit[4] be1 = func6("0101"); extern func7(angle[8]) -> angle[8]; angle[8] af1 = func7(0.7853981633974483); extern func8(bool) -> bool; @@ -635,7 +635,7 @@ def test_extern_function_call_error(qasm_code, error_message, error_span, caplog ), ], ) # pylint: disable-next= too-many-arguments -def test_extern_function_dict_call_erro(qasm_code, error_message, error_span, caplog): +def test_extern_function_dict_call_error(qasm_code, error_message, error_span, caplog): with pytest.raises(ValidationError) as excinfo: with caplog.at_level("ERROR"): extern_functions = { @@ -647,3 +647,109 @@ def test_extern_function_dict_call_erro(qasm_code, error_message, error_span, ca msg = str(excinfo.value) assert error_message in msg assert error_span in caplog.text + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func1(float) -> bit; + bit be1 = func1(2); + """, + r"Invalid argument value for 'func1', expected 'FloatType' but got value = 2.", + r"Error at line 5, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func2(uint) -> complex; + complex ce1 = func2(-22); + """, + r"Invalid argument value for 'func2', expected 'UintType' but got value = -22.", + r"Error at line 5, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func3(duration) -> int; + int ie1 = func3(true); + """, + r"Invalid argument value for 'func3', expected 'DurationType' but got value = True.", + r"Error at line 5, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func4(bit[4]) -> float[64]; + float[64] fe1 = func4(3.14); + """, + r"Invalid argument value for 'func4', expected 'BitType' but got value = 3.14.", + r"Error at line 5, column 28", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func5(complex[float[64]]) -> complex[float[64]]; + complex[float[64]] ce1 = func5(42); + """, + r"Invalid argument value for 'func5', expected 'ComplexType' but got value = 42.", + r"Error at line 5, column 37", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func6(angle[8]) -> angle[8]; + angle[8] ae1 = func6(100ns); + """, + r"Invalid argument value for 'func6', expected 'AngleType' but got value = 100.0.", + r"Error at line 5, column 27", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func7(bool) -> bool; + bool be1 = func7(3.14); + """, + r"Invalid argument value for 'func7', expected 'BoolType' but got value = 3.14.", + r"Error at line 5, column 23", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func8(int[24]) -> int[24]; + int[24] ie1 = func8(2.0); + """, + r"Invalid argument value for 'func8', expected 'IntType' but got value = 2.0.", + r"Error at line 5, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_extern_function_value_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + extern_functions = { + "func1": (["float"], "bit"), + "func2": (["uint"], "complex"), + "func3": (["duration"], "int"), + "func4": (["bit[4]"], "float[64]"), + "func5": (["complex[float[64]]"], "complex[float[64]]"), + "func6": (["angle[8]"], "angle[8]"), + "func7": (["bool"], "bool"), + "func8": (["int[24]"], "int[24]"), + } + loads(qasm_code, extern_functions=extern_functions).validate() + first = excinfo.value.__cause__ or excinfo.value.__context__ + assert first is not None, "Expected a chained ValidationError" + msg = str(first) + assert error_message in msg + assert error_span in caplog.text From 29e2a3d96dcb8991037066599721bd2d7889955d Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Wed, 6 Aug 2025 22:10:15 -0500 Subject: [PATCH 10/12] code refactor --- src/pyqasm/expressions.py | 5 ---- src/pyqasm/pulse/validator.py | 33 ++++----------------- src/pyqasm/subroutines.py | 3 +- src/pyqasm/visitor.py | 2 -- tests/qasm3/declarations/test_classical.py | 26 ++++++++++++++++ tests/qasm3/subroutines/test_subroutines.py | 1 + 6 files changed, 34 insertions(+), 36 deletions(-) diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index d38a9683..2bc8a27b 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -273,9 +273,6 @@ def _get_external_function_return_type(expression, extern_fns): return extern_fns[expression.name.name][1] return None - if isinstance(expression, complex): - return _check_and_return_value(expression) - if isinstance(expression, ImaginaryLiteral): return _check_and_return_value(expression.value * 1j) @@ -489,8 +486,6 @@ def _get_external_function_return_type(expression, extern_fns): _val, _ = cls.evaluate_expression( expression.arguments[0], const_expr, reqd_type, validate_only ) - if _val is None or validate_only: - return (None, statements) _val = FUNCTION_MAP[expression.name.name](_val) # type: ignore return _check_and_return_value(_val) diff --git a/src/pyqasm/pulse/validator.py b/src/pyqasm/pulse/validator.py index 28d3a78a..23a7c1c1 100644 --- a/src/pyqasm/pulse/validator.py +++ b/src/pyqasm/pulse/validator.py @@ -23,7 +23,6 @@ BinaryExpression, BinaryOperator, BitstringLiteral, - BooleanLiteral, Box, Cast, ConstantDeclaration, @@ -313,15 +312,10 @@ def validate_extern_declaration(module: Any, statement: ExternDeclaration) -> No def _get_type_string(type_obj) -> str: """Recursively build type string for nested types""" type_name = type(type_obj).__name__.replace("Type", "").lower() - if hasattr(type_obj, "base_type") and type_obj.base_type is not None: - base_type_str = _get_type_string(type_obj.base_type) - if hasattr(type_obj, "size") and type_obj.size is not None: - size_val = type_obj.size.value - return f"{type_name}[{size_val}][{base_type_str}]" - return f"{type_name}[{base_type_str}]" - if hasattr(type_obj, "size") and type_obj.size is not None: - size_val = type_obj.size.value - return f"{type_name}[{size_val}]" + if getattr(type_obj, "base_type", None) is not None: + return f"{type_name}[{_get_type_string(type_obj.base_type)}]" + if getattr(type_obj, "size", None) is not None: + return f"{type_name}[{type_obj.size.value}]" return type_name for actual_arg, extern_arg in zip(statement.arguments, args): @@ -365,18 +359,10 @@ def validate_and_process_extern_function_call( # pylint: disable=too-many-branc Raises: ValidationError: If the function call is invalid """ - fn_name = statement.name.name - # pylint: disable=too-many-nested-blocks + for i, arg in enumerate(statement.arguments): if isinstance(arg, Identifier): arg_var = global_scope.get(arg.name) - if arg_var is None: - raise_qasm3_error( - f"Undefined variable '{arg.name}' in extern function '{fn_name}'", - error_node=statement, - span=statement.span, - ) - assert arg_var is not None if arg_var.base_type is not None and isinstance( @@ -390,8 +376,6 @@ def validate_and_process_extern_function_call( # pylint: disable=too-many-branc statement.arguments[i] = FloatLiteral(arg_var.value) elif isinstance(arg_var.value, int): statement.arguments[i] = IntegerLiteral(arg_var.value) - elif isinstance(arg_var.value, bool): - statement.arguments[i] = BooleanLiteral(arg_var.value) elif isinstance(arg_var.value, complex): statement.arguments[i] = PulseValidator.make_complex_binary_expression( arg_var.value @@ -400,12 +384,5 @@ def validate_and_process_extern_function_call( # pylint: disable=too-many-branc width = len(arg_var.value) value = int(arg_var.value, 2) statement.arguments[i] = BitstringLiteral(value, width) - else: - raise_qasm3_error( - f"Invalid argument type '{arg_var.base_type}' for extern function " - f"'{fn_name}'", - error_node=statement, - span=statement.span, - ) return statement diff --git a/src/pyqasm/subroutines.py b/src/pyqasm/subroutines.py index 47ac6b0c..6e5f08b0 100644 --- a/src/pyqasm/subroutines.py +++ b/src/pyqasm/subroutines.py @@ -38,6 +38,7 @@ IntType, QASMNode, QubitDeclaration, + StretchType, UintType, ) from openqasm3.printer import dumps @@ -221,7 +222,7 @@ def _process_classical_arg_by_value( ) or (isinstance(_base_type, IntType) and not isinstance(actual_arg_value, int)) or ( - isinstance(_base_type, DurationType) + isinstance(_base_type, (DurationType, StretchType)) and not isinstance(actual_arg, DurationLiteral) ) or ( diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 013d3b8f..47cca577 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -446,8 +446,6 @@ def _handle_function_init_expression( if isinstance(expression, qasm3_ast.FunctionCall): func_name = expression.name.name if func_name in FUNCTION_MAP: - if isinstance(init_value, complex): - return PulseValidator.make_complex_binary_expression(init_value) if isinstance(init_value, (float, int)): return qasm3_ast.FloatLiteral(init_value) return None diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index 7d169d0f..5cff8bdb 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -80,6 +80,7 @@ def test_const_declarations(): const angle[8] ang1 = 7 * (pi / 8); const angle[8] ang2 = 9 * (pi / 8); const angle[8] ang3 = ang1 + ang2; + const bit[4] a = "1011"; """ loads(qasm3_string).validate() @@ -105,6 +106,8 @@ def test_scalar_assignments(): du2 = 300s; angle[8] ang1; ang1 = 9 * (pi / 8); + bit[4] b; + b = "1011"; """ loads(qasm3_string).validate() @@ -748,6 +751,7 @@ def test_complex_type_variables(): complex[float[64]] f = a / b; complex[float[64]] g = a ** b; complex[float] h = a + b; + complex i = sqrt(1.0 + 2.0im); """ loads(qasm3_string).validate() @@ -783,3 +787,25 @@ def test_pi_expression_bit_conversion(): assert scope["ang4"].angle_bit_string == "1000" # -pi (wraps to 1) assert scope["ang5"].angle_bit_string == "01100000" # (pi/2) + (pi/4) = 3*pi/4 assert scope["ang6"].angle_bit_string == "00100000" # (pi/2) - (pi/4) = pi/4 + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit[4] i = "101"; + """, + r"Invalid bitstring literal '101' width [3] for variable 'i' of size [4]", + r"Error at line 4, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_bit_string_literal_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as err: + with caplog.at_level("ERROR"): + loads(qasm_code).validate() + assert error_message in str(err.value) + assert error_span in caplog.text diff --git a/tests/qasm3/subroutines/test_subroutines.py b/tests/qasm3/subroutines/test_subroutines.py index fd3fac09..8eeb2825 100644 --- a/tests/qasm3/subroutines/test_subroutines.py +++ b/tests/qasm3/subroutines/test_subroutines.py @@ -534,6 +534,7 @@ def test_extern_function_call(): } result = loads(qasm3_string, extern_functions=extern_functions) + result.validate() result.unroll() unrolled_qasm = dumps(result) From 45ed2b4947a2921304374546376747d9a1595bcd Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Wed, 6 Aug 2025 22:15:44 -0500 Subject: [PATCH 11/12] update test case --- tests/qasm3/declarations/test_classical.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index 5cff8bdb..25d83444 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -80,7 +80,7 @@ def test_const_declarations(): const angle[8] ang1 = 7 * (pi / 8); const angle[8] ang2 = 9 * (pi / 8); const angle[8] ang3 = ang1 + ang2; - const bit[4] a = "1011"; + const bit[4] bit_check = "1011"; """ loads(qasm3_string).validate() @@ -106,8 +106,8 @@ def test_scalar_assignments(): du2 = 300s; angle[8] ang1; ang1 = 9 * (pi / 8); - bit[4] b; - b = "1011"; + bit[4] bit_check; + bit_check = "1011"; """ loads(qasm3_string).validate() From e6c08cff9f881dd8677059c286c0613cd0456bd7 Mon Sep 17 00:00:00 2001 From: vinayswamik Date: Mon, 11 Aug 2025 04:18:17 -0500 Subject: [PATCH 12/12] code refactor --- CHANGELOG.md | 26 +++++++++++++++++++++--- src/pyqasm/entrypoint.py | 2 +- src/pyqasm/expressions.py | 42 ++++++++++++++++++++++----------------- src/pyqasm/subroutines.py | 6 ++---- src/pyqasm/visitor.py | 14 ++++++------- 5 files changed, 57 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 911aa1be..d5cf0327 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,11 +19,31 @@ Types of changes: - A github workflow for validating `CHANGELOG` updates in a PR ([#214](https://github.com/qBraid/pyqasm/pull/214)) - Added `unroll` command support in PYQASM CLI with options skipping files, overwriting originals files, and specifying output paths.([#224](https://github.com/qBraid/pyqasm/pull/224)) - Added `.github/copilot-instructions.md` to the repository to document coding standards and design principles for pyqasm. This file provides detailed guidance on documentation, static typing, formatting, error handling, and adherence to the QASM specification for all code contributions. ([#234](https://github.com/qBraid/pyqasm/pull/234)) -<<<<<<< HEAD - Added support for `Angle`,`extern` and `Complex` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239)) -======= + ###### Example: + ```qasm + OPENQASM 3.0; + include "stdgates.inc"; + angle[8] ang1; + ang1 = 9 * (pi / 8); + angle[8] ang1 = 7 * (pi / 8); + angle[8] ang3 = ang1 + ang2; + + complex c1 = -2.5 - 3.5im; + const complex c2 = 2.0+arccos(π/2) + (3.1 * 5.5im); + const complex c12 = c1 * c2; + + float a = 1.0; + int b = 2; + extern func1(float, int) -> bit; + bit c = 2 * func1(a, b); + bit fc = -func1(a, b); + + bit[4] bd = "0101"; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(bd); + ``` - Added a new `QasmModule.compare` method to compare two QASM modules, providing a detailed report of differences in gates, qubits, and measurements. This method is useful for comparing two identifying differences in QASM programs, their structure and operations. ([#233](https://github.com/qBraid/pyqasm/pull/233)) ->>>>>>> origin/main ### Improved / Modified - Added `slots=True` parameter to the data classes in `elements.py` to improve memory efficiency ([#218](https://github.com/qBraid/pyqasm/pull/218)) diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index f2733962..53bedf79 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -60,7 +60,7 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: device_qubits (int): Number of physical qubits available on the target device. device_cycle_time (float): The duration of a hardware device cycle, in seconds. compiler_angle_type_size (int): The width of the angle type in the compiler. - extern_functions (list): List of extern functions to be added to the module. + extern_functions (dict): Dictionary of extern functions to be added to the module. Raises: TypeError: If the input is not a string or an `openqasm3.ast.Program` instance. diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index 2bc8a27b..1480feec 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -201,7 +201,6 @@ def evaluate_expression( # type: ignore[return] reqd_type=None, validate_only: bool = False, dt=None, - extern_fns=None, ) -> tuple: """Evaluate an expression. Scalar types are assigned by value. @@ -211,7 +210,6 @@ def evaluate_expression( # type: ignore[return] reqd_type (Any): The required type of the expression. Defaults to None. validate_only (bool): Whether to validate the expression only. Defaults to False. dt (float): The time step of the compiler. Defaults to None. - extern_fns (dict): A dictionary of extern functions. Defaults to None. Returns: tuple[Any, list[Statement]] : The result of the evaluation. @@ -263,14 +261,16 @@ def _check_type_size(expression, var_name, var_format, base_type): ) return base_size - def _is_external_function_call(expression, extern_fns): + def _is_external_function_call(expression): """Check if an expression is an external function call""" - return isinstance(expression, FunctionCall) and expression.name.name in extern_fns + return isinstance(expression, FunctionCall) and ( + expression.name.name in cls.visitor_obj._module._extern_functions + ) - def _get_external_function_return_type(expression, extern_fns): + def _get_external_function_return_type(expression): """Get the return type of an external function call""" - if _is_external_function_call(expression, extern_fns): - return extern_fns[expression.name.name][1] + if _is_external_function_call(expression): + return cls.visitor_obj._module._extern_functions[expression.name.name][1] return None if isinstance(expression, ImaginaryLiteral): @@ -376,7 +376,7 @@ def _get_external_function_return_type(expression, extern_fns): expression.expression, const_expr, reqd_type, validate_only ) # Check for external function in validate_only mode - return_type = _get_external_function_return_type(expression.expression, extern_fns) + return_type = _get_external_function_return_type(expression.expression) if return_type: return (return_type, statements) return (None, []) @@ -386,7 +386,7 @@ def _get_external_function_return_type(expression, extern_fns): ) # Handle external function replacement - if _is_external_function_call(expression.expression, extern_fns): + if _is_external_function_call(expression.expression): expression.expression = returned_stats[0] return _check_and_return_value(None) @@ -407,10 +407,16 @@ def _get_external_function_return_type(expression, extern_fns): return (None, statements) _lhs, _lhs_stmts = cls.evaluate_expression( - expression.lhs, const_expr, reqd_type, validate_only, extern_fns=extern_fns + expression.lhs, + const_expr, + reqd_type, + validate_only, ) _rhs, _rhs_stmts = cls.evaluate_expression( - expression.rhs, const_expr, reqd_type, validate_only, extern_fns=extern_fns + expression.rhs, + const_expr, + reqd_type, + validate_only, ) if isinstance(expression.lhs, Cast): @@ -426,8 +432,8 @@ def _get_external_function_return_type(expression, extern_fns): _lhs_return_type = None _rhs_return_type = None # Check for external functions in both operands - _lhs_return_type = _get_external_function_return_type(expression.lhs, extern_fns) - _rhs_return_type = _get_external_function_return_type(expression.rhs, extern_fns) + _lhs_return_type = _get_external_function_return_type(expression.lhs) + _rhs_return_type = _get_external_function_return_type(expression.rhs) if _lhs_return_type and _rhs_return_type: if _lhs_return_type != _rhs_return_type: @@ -448,21 +454,21 @@ def _get_external_function_return_type(expression, extern_fns): return (None, statements) lhs_value, lhs_statements = cls.evaluate_expression( - expression.lhs, const_expr, reqd_type, extern_fns=extern_fns + expression.lhs, const_expr, reqd_type ) # Handle external function replacement for lhs lhs_extern_function = False - if _is_external_function_call(expression.lhs, extern_fns): + if _is_external_function_call(expression.lhs): expression.lhs = lhs_statements[0] lhs_extern_function = True statements.extend(lhs_statements) rhs_value, rhs_statements = cls.evaluate_expression( - expression.rhs, const_expr, reqd_type, extern_fns=extern_fns + expression.rhs, const_expr, reqd_type ) # Handle external function replacement for rhs rhs_extern_function = False - if _is_external_function_call(expression.rhs, extern_fns): + if _is_external_function_call(expression.rhs): expression.rhs = rhs_statements[0] rhs_extern_function = True if lhs_extern_function or rhs_extern_function: @@ -477,7 +483,7 @@ def _get_external_function_return_type(expression, extern_fns): # function will not return a reqd / const type # Reference : https://openqasm.com/language/types.html#compile-time-constants, para: 5 if validate_only: - return_type = _get_external_function_return_type(expression, extern_fns) + return_type = _get_external_function_return_type(expression) if return_type: return (return_type, statements) return (None, statements) diff --git a/src/pyqasm/subroutines.py b/src/pyqasm/subroutines.py index 6e5f08b0..fd0ab283 100644 --- a/src/pyqasm/subroutines.py +++ b/src/pyqasm/subroutines.py @@ -16,7 +16,7 @@ Module containing the class for validating QASM3 subroutines. """ -import random +import uuid from typing import Optional import numpy as np @@ -189,9 +189,7 @@ def _process_classical_arg_by_value( if isinstance(formal_arg, ExternArgument): # Generate a unique name for the extern argument variable - _name = fn_name - while not cls.visitor_obj._scope_manager.check_in_scope(_name) and _name == fn_name: - _name = f"{fn_name}_{random.randint(1, 1_000_000_000)}" + _name = f"{fn_name}_{uuid.uuid4()}" if hasattr(formal_arg.type, "size") and formal_arg.type.size is not None: _base_size = Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0] else: diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index 47cca577..005475ac 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -1374,7 +1374,6 @@ def _visit_constant_declaration( statement.init_expression, const_expr=True, dt=self._module._device_cycle_time, - extern_fns=self._module._extern_functions, ) except ValidationError as err: raise_qasm3_error( @@ -1397,7 +1396,8 @@ def _visit_constant_declaration( compiler_angle_width=self._module._compiler_angle_type_size, ) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, validate_only=True, extern_fns=self._module._extern_functions + statement.init_expression, + validate_only=True, ) self._check_variable_cast_type(statement, val_type, var_name, base_type, base_size, True) variable = Variable( @@ -1563,7 +1563,6 @@ def _visit_classical_declaration( init_value, stmts = Qasm3ExprEvaluator.evaluate_expression( statement.init_expression, dt=self._module._device_cycle_time, - extern_fns=self._module._extern_functions, ) statements.extend(stmts) _req_type = ( @@ -1575,7 +1574,6 @@ def _visit_classical_declaration( statement.init_expression, validate_only=True, reqd_type=_req_type, - extern_fns=self._module._extern_functions, ) if isinstance(base_type, qasm3_ast.AngleType) and not self._in_extern_function: init_value, angle_val_bit_string = PulseValidator.validate_angle_type_value( @@ -1738,11 +1736,13 @@ def _visit_classical_assignment( lhs=lvalue, op=binary_op, rhs=rvalue # type: ignore[arg-type] ) rvalue_raw, rhs_stmts = Qasm3ExprEvaluator.evaluate_expression( - rvalue, dt=self._module._device_cycle_time, extern_fns=self._module._extern_functions + rvalue, + dt=self._module._device_cycle_time, ) # consists of scope check and index validation statements.extend(rhs_stmts) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( - rvalue, validate_only=True, extern_fns=self._module._extern_functions + rvalue, + validate_only=True, ) self._check_variable_cast_type( statement, @@ -2235,7 +2235,7 @@ def _visit_function_call( if return_statement: return_value, stmts = Qasm3ExprEvaluator.evaluate_expression( - return_statement.expression, extern_fns=self._module._extern_functions + return_statement.expression, ) return_value = Qasm3Validator.validate_return_statement( subroutine_def, return_statement, return_value