diff --git a/.gitattributes b/.gitattributes index 4c78e9ae..d5602369 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,7 @@ # Ref : https://docs.github.com/en/get-started/getting-started-with-git/configuring-git-to-handle-line-endings#per-repository-settings -# Set line endings to lf -* text eol=lf \ No newline at end of file +# Set line endings to lf for text files +* text eol=lf + +# Binary files should not be processed as text +*.png binary \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b48491a..d5cf0327 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,30 @@ Types of changes: - A github workflow for validating `CHANGELOG` updates in a PR ([#214](https://github.com/qBraid/pyqasm/pull/214)) - Added `unroll` command support in PYQASM CLI with options skipping files, overwriting originals files, and specifying output paths.([#224](https://github.com/qBraid/pyqasm/pull/224)) - Added `.github/copilot-instructions.md` to the repository to document coding standards and design principles for pyqasm. This file provides detailed guidance on documentation, static typing, formatting, error handling, and adherence to the QASM specification for all code contributions. ([#234](https://github.com/qBraid/pyqasm/pull/234)) +- Added support for `Angle`,`extern` and `Complex` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239)) + ###### Example: + ```qasm + OPENQASM 3.0; + include "stdgates.inc"; + angle[8] ang1; + ang1 = 9 * (pi / 8); + angle[8] ang1 = 7 * (pi / 8); + angle[8] ang3 = ang1 + ang2; + + complex c1 = -2.5 - 3.5im; + const complex c2 = 2.0+arccos(π/2) + (3.1 * 5.5im); + const complex c12 = c1 * c2; + + float a = 1.0; + int b = 2; + extern func1(float, int) -> bit; + bit c = 2 * func1(a, b); + bit fc = -func1(a, b); + + bit[4] bd = "0101"; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(bd); + ``` - Added a new `QasmModule.compare` method to compare two QASM modules, providing a detailed report of differences in gates, qubits, and measurements. This method is useful for comparing two identifying differences in QASM programs, their structure and operations. ([#233](https://github.com/qBraid/pyqasm/pull/233)) ### Improved / Modified diff --git a/src/README.md b/src/README.md index 27eeae33..e37b7d30 100644 --- a/src/README.md +++ b/src/README.md @@ -38,6 +38,6 @@ Source code for OpenQASM 3 program validator and semantic analyzer | Box | ✅ | Completed | | CalibrationStatement | 📋 | Planned | | CalibrationDefinition | 📋 | Planned | -| ComplexType | 📋 | Planned | -| AngleType | 📋 | Planned | -| ExternDeclaration | 📋 | Planned | +| ComplexType | ✅ | Completed | +| AngleType | ✅ | Completed | +| ExternDeclaration | ✅ | Completed | diff --git a/src/pyqasm/elements.py b/src/pyqasm/elements.py index 736007b1..ee0a3e10 100644 --- a/src/pyqasm/elements.py +++ b/src/pyqasm/elements.py @@ -92,6 +92,7 @@ class Variable: # pylint: disable=too-many-instance-attributes value (Optional[int | float | np.ndarray]): Value of the variable. time_unit (Optional[str]): Time unit associated with the duration variable. span (Any): Span of the variable. + angle_bit_string (Optional[str]): Bit string representation of the angle value. shadow (bool): Flag indicating if the current variable is shadowed from its parent scope. is_constant (bool): Flag indicating if the variable is constant. is_register (bool): Flag indicating if the variable is a register. @@ -106,6 +107,7 @@ class Variable: # pylint: disable=too-many-instance-attributes value: Optional[int | float | np.ndarray] = None time_unit: Optional[str] = None span: Any = None + angle_bit_string: Optional[str] = None shadow: bool = False is_constant: bool = False is_qubit: bool = False diff --git a/src/pyqasm/entrypoint.py b/src/pyqasm/entrypoint.py index 2fe06fd5..53bedf79 100644 --- a/src/pyqasm/entrypoint.py +++ b/src/pyqasm/entrypoint.py @@ -59,6 +59,8 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: **kwargs: Additional arguments to pass to the loads function. device_qubits (int): Number of physical qubits available on the target device. device_cycle_time (float): The duration of a hardware device cycle, in seconds. + compiler_angle_type_size (int): The width of the angle type in the compiler. + extern_functions (dict): Dictionary of extern functions to be added to the module. Raises: TypeError: If the input is not a string or an `openqasm3.ast.Program` instance. @@ -91,6 +93,10 @@ def loads(program: openqasm3.ast.Program | str, **kwargs) -> QasmModule: module._device_qubits = dev_qbts if dev_cycle_time := kwargs.get("device_cycle_time"): module._device_cycle_time = dev_cycle_time + if compiler_angle_type_size := kwargs.get("compiler_angle_type_size"): + module._compiler_angle_type_size = compiler_angle_type_size + if extern_functions := kwargs.get("extern_functions"): + module._extern_functions = extern_functions return module diff --git a/src/pyqasm/expressions.py b/src/pyqasm/expressions.py index c2b12a36..1480feec 100644 --- a/src/pyqasm/expressions.py +++ b/src/pyqasm/expressions.py @@ -17,7 +17,9 @@ """ from openqasm3.ast import ( + AngleType, BinaryExpression, + BitstringLiteral, BitType, BooleanLiteral, BoolType, @@ -45,7 +47,12 @@ from pyqasm.analyzer import Qasm3Analyzer from pyqasm.elements import Variable from pyqasm.exceptions import ValidationError, raise_qasm3_error -from pyqasm.maps.expressions import CONSTANTS_MAP, TIME_UNITS_MAP, qasm3_expression_op_map +from pyqasm.maps.expressions import ( + CONSTANTS_MAP, + FUNCTION_MAP, + TIME_UNITS_MAP, + qasm3_expression_op_map, +) from pyqasm.validator import Qasm3Validator @@ -53,6 +60,7 @@ class Qasm3ExprEvaluator: """Class for evaluating QASM3 expressions.""" visitor_obj = None + angle_var_in_expr = None @classmethod def set_visitor_obj(cls, visitor_obj) -> None: @@ -70,8 +78,8 @@ def _check_var_in_scope(cls, var_name, expression): """ scope_manager = cls.visitor_obj._scope_manager + var = scope_manager.get_from_global_scope(var_name) if not scope_manager.check_in_scope(var_name): - var = scope_manager.get_from_global_scope(var_name) if var is not None and not var.is_constant: raise_qasm3_error( f"Global variable '{var_name}' must be a constant to use it in a local scope.", @@ -84,6 +92,14 @@ def _check_var_in_scope(cls, var_name, expression): error_node=expression, span=expression.span, ) + if var and isinstance(var.base_type, AngleType): + if cls.angle_var_in_expr and cls.angle_var_in_expr != var.base_type.size: + raise_qasm3_error( + "All 'Angle' variables in binary expression must have the same size", + error_node=expression, + span=expression.span, + ) + cls.angle_var_in_expr = var.base_type.size @classmethod def _check_var_constant(cls, var_name, const_expr, expression): @@ -192,6 +208,8 @@ def evaluate_expression( # type: ignore[return] expression (Any): The expression to evaluate. const_expr (bool): Whether the expression is a constant. Defaults to False. reqd_type (Any): The required type of the expression. Defaults to None. + validate_only (bool): Whether to validate the expression only. Defaults to False. + dt (float): The time step of the compiler. Defaults to None. Returns: tuple[Any, list[Statement]] : The result of the evaluation. @@ -203,14 +221,6 @@ def evaluate_expression( # type: ignore[return] if expression is None: return None, [] - if isinstance(expression, (ImaginaryLiteral)): - raise_qasm3_error( - f"Unsupported expression type '{type(expression)}'", - err_type=ValidationError, - error_node=expression, - span=expression.span, - ) - def _check_and_return_value(value): if validate_only: return None, statements @@ -251,10 +261,25 @@ def _check_type_size(expression, var_name, var_format, base_type): ) return base_size + def _is_external_function_call(expression): + """Check if an expression is an external function call""" + return isinstance(expression, FunctionCall) and ( + expression.name.name in cls.visitor_obj._module._extern_functions + ) + + def _get_external_function_return_type(expression): + """Get the return type of an external function call""" + if _is_external_function_call(expression): + return cls.visitor_obj._module._extern_functions[expression.name.name][1] + return None + + if isinstance(expression, ImaginaryLiteral): + return _check_and_return_value(expression.value * 1j) + if isinstance(expression, Identifier): var_name = expression.name if var_name in CONSTANTS_MAP: - if not reqd_type or reqd_type == Qasm3FloatType: + if not reqd_type or reqd_type in (Qasm3FloatType, AngleType): return _check_and_return_value(CONSTANTS_MAP[var_name]) raise_qasm3_error( f"Constant '{var_name}' not allowed in non-float expression", @@ -318,6 +343,8 @@ def _check_type_size(expression, var_name, var_format, base_type): return _check_and_return_value(expression.value) if reqd_type == Qasm3FloatType and isinstance(expression, FloatLiteral): return _check_and_return_value(expression.value) + if reqd_type == AngleType: + return _check_and_return_value(expression.value) raise_qasm3_error( f"Invalid value {expression.value} with type {type(expression)} " f"for required type {reqd_type}", @@ -327,6 +354,9 @@ def _check_type_size(expression, var_name, var_format, base_type): ) return _check_and_return_value(expression.value) + if isinstance(expression, BitstringLiteral): + return _check_and_return_value(format(expression.value, f"0{expression.width}b")) + if isinstance(expression, DurationLiteral): unit_name = expression.unit.name if dt: @@ -345,11 +375,21 @@ def _check_type_size(expression, var_name, var_format, base_type): return cls.evaluate_expression( expression.expression, const_expr, reqd_type, validate_only ) + # Check for external function in validate_only mode + return_type = _get_external_function_return_type(expression.expression) + if return_type: + return (return_type, statements) return (None, []) operand, returned_stats = cls.evaluate_expression( expression.expression, const_expr, reqd_type ) + + # Handle external function replacement + if _is_external_function_call(expression.expression): + expression.expression = returned_stats[0] + return _check_and_return_value(None) + if expression.op.name == "~" and not isinstance(operand, int): raise_qasm3_error( f"Unsupported expression type '{type(operand)}' in ~ operation", @@ -365,23 +405,75 @@ def _check_type_size(expression, var_name, var_format, base_type): if validate_only: if isinstance(expression.lhs, Cast) and isinstance(expression.rhs, Cast): return (None, statements) + + _lhs, _lhs_stmts = cls.evaluate_expression( + expression.lhs, + const_expr, + reqd_type, + validate_only, + ) + _rhs, _rhs_stmts = cls.evaluate_expression( + expression.rhs, + const_expr, + reqd_type, + validate_only, + ) + if isinstance(expression.lhs, Cast): - return cls.evaluate_expression( - expression.lhs, const_expr, reqd_type, validate_only - ) + return (_lhs, _lhs_stmts) if isinstance(expression.rhs, Cast): - return cls.evaluate_expression( - expression.rhs, const_expr, reqd_type, validate_only - ) + return (_rhs, _rhs_stmts) + + if type(reqd_type) is type(AngleType) and cls.angle_var_in_expr: + _var_type = AngleType(cls.angle_var_in_expr) + cls.angle_var_in_expr = None + return (_var_type, statements) + + _lhs_return_type = None + _rhs_return_type = None + # Check for external functions in both operands + _lhs_return_type = _get_external_function_return_type(expression.lhs) + _rhs_return_type = _get_external_function_return_type(expression.rhs) + + if _lhs_return_type and _rhs_return_type: + if _lhs_return_type != _rhs_return_type: + raise_qasm3_error( + f"extern function return type mismatch in binary expression: " + f"{type(_lhs_return_type).__name__} and " + f"{type(_rhs_return_type).__name__}", + err_type=ValidationError, + error_node=expression, + span=expression.span, + ) + else: + if _lhs_return_type: + return (_lhs_return_type, statements) + if _rhs_return_type: + return (_rhs_return_type, statements) + return (None, statements) lhs_value, lhs_statements = cls.evaluate_expression( expression.lhs, const_expr, reqd_type ) + # Handle external function replacement for lhs + lhs_extern_function = False + if _is_external_function_call(expression.lhs): + expression.lhs = lhs_statements[0] + lhs_extern_function = True statements.extend(lhs_statements) + rhs_value, rhs_statements = cls.evaluate_expression( expression.rhs, const_expr, reqd_type ) + # Handle external function replacement for rhs + rhs_extern_function = False + if _is_external_function_call(expression.rhs): + expression.rhs = rhs_statements[0] + rhs_extern_function = True + if lhs_extern_function or rhs_extern_function: + return (None, []) + statements.extend(rhs_statements) return _check_and_return_value( qasm3_expression_op_map(expression.op.name, lhs_value, rhs_value) @@ -390,6 +482,19 @@ def _check_type_size(expression, var_name, var_format, base_type): if isinstance(expression, FunctionCall): # function will not return a reqd / const type # Reference : https://openqasm.com/language/types.html#compile-time-constants, para: 5 + if validate_only: + return_type = _get_external_function_return_type(expression) + if return_type: + return (return_type, statements) + return (None, statements) + + if expression.name.name in FUNCTION_MAP: + _val, _ = cls.evaluate_expression( + expression.arguments[0], const_expr, reqd_type, validate_only + ) + _val = FUNCTION_MAP[expression.name.name](_val) # type: ignore + return _check_and_return_value(_val) + ret_value, ret_stmts = cls.visitor_obj._visit_function_call(expression) # type: ignore statements.extend(ret_stmts) return _check_and_return_value(ret_value) diff --git a/src/pyqasm/maps/expressions.py b/src/pyqasm/maps/expressions.py index aac09ca5..7a6eef2e 100644 --- a/src/pyqasm/maps/expressions.py +++ b/src/pyqasm/maps/expressions.py @@ -48,6 +48,7 @@ "/": lambda x, y: x / y, "%": lambda x, y: x % y, "==": lambda x, y: x == y, + "**": lambda x, y: x**y, "!=": lambda x, y: x != y, "<": lambda x, y: x < y, ">": lambda x, y: x > y, @@ -123,9 +124,13 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): # not sure if we wanna hande array bit assignments too. # For now, we only cater to single bit assignment. if openqasm_type == BitType: - return rhs_value != 0 + return rhs_value if openqasm_type == AngleType: + if isinstance(rhs_value, bool): + return ((2 * CONSTANTS_MAP["pi"]) * (1 / 2)) if rhs_value else 0.0 return rhs_value # not sure + if openqasm_type == ComplexType: + return rhs_value # IEEE 754 Standard for floats @@ -150,17 +155,18 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): ComplexType: complex, DurationType: float, StretchType: float, - # AngleType: None, # not sure + AngleType: float, } # Reference: https://openqasm.com/language/types.html#allowed-casts VARIABLE_TYPE_CAST_MAP = { BoolType: (int, float, bool, np.int64, np.float64, np.bool_), IntType: (bool, int, float, np.int64, np.float64, np.bool_), - BitType: (bool, int, np.int64, np.bool_), + BitType: (bool, int, np.int64, np.bool_, str), UintType: (bool, int, float, np.int64, np.uint64, np.float64, np.bool_), FloatType: (bool, int, float, np.int64, np.float64, np.bool_), - AngleType: (float, np.float64), + AngleType: (float, np.float64, bool, np.bool_), + ComplexType: (complex, np.complex128), } ARRAY_TYPE_MAP = { @@ -184,3 +190,17 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value): "ms": {"ns": 1_000_000, "s": 1e-3}, "s": {"ns": 1_000_000_000, "s": 1}, } + +# Function map for complex functions +FUNCTION_MAP = { + "abs": np.abs, + "real": lambda v: v.real if isinstance(v, complex) else v, + "imag": lambda v: v.imag if isinstance(v, complex) else v, + "sqrt": np.sqrt, + "sin": np.sin, + "cos": np.cos, + "tan": np.tan, + "arccos": np.arccos, + "arcsin": np.arcsin, + "arctan": np.arctan, +} diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index e82f5037..47629b59 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -91,6 +91,8 @@ def __init__(self, name: str, program: Program): self._consolidate_qubits: Optional[bool] = False self._user_operations: list[str] = ["load"] self._device_cycle_time: Optional[int] = None + self._compiler_angle_type_size: Optional[int] = None + self._extern_functions: dict[str, tuple[list[str], str]] = {} @property def name(self) -> str: diff --git a/src/pyqasm/pulse/validator.py b/src/pyqasm/pulse/validator.py index e1e82538..23a7c1c1 100644 --- a/src/pyqasm/pulse/validator.py +++ b/src/pyqasm/pulse/validator.py @@ -18,27 +18,84 @@ """ from typing import Any, Optional +import numpy as np from openqasm3.ast import ( BinaryExpression, + BinaryOperator, + BitstringLiteral, Box, Cast, ConstantDeclaration, DelayInstruction, DurationLiteral, DurationType, + ExternDeclaration, FloatLiteral, + FunctionCall, Identifier, + ImaginaryLiteral, IntegerLiteral, Statement, StretchType, + TimeUnit, ) from pyqasm.exceptions import raise_qasm3_error +from pyqasm.maps.expressions import CONSTANTS_MAP class PulseValidator: """Class with validation functions for Pulse visitor""" + @staticmethod + def validate_angle_type_value( + statement: Any, + init_value: int | float, + base_size: int, + compiler_angle_width: Optional[int] = None, + ) -> tuple: + """ + Validates and processes angle type value. + + Args: + statement: The AST statement node + init_value: The evaluated initialization value + base_size: The base size of the angle type + compiler_angle_width: Optional compiler angle width override + + Returns: + tuple: The processed angle value and bit string representation + + Raises: + ValidationError: If the angle initialization is invalid + """ + # Optimize: check both possible fields for BitstringLiteral in one go + init_exp = getattr(statement, "init_expression", None) + rval = getattr(statement, "rvalue", None) + is_bitstring = isinstance(init_exp, BitstringLiteral) or isinstance(rval, BitstringLiteral) + expression = init_exp or rval + if is_bitstring and expression is not None: + angle_type_size = expression.width + if compiler_angle_width: + if angle_type_size != compiler_angle_width: + raise_qasm3_error( + f"BitString angle width '{angle_type_size}' does not match " + f"with compiler angle width '{compiler_angle_width}'", + error_node=statement, + span=statement.span, + ) + angle_type_size = compiler_angle_width + angle_bit_string = format(expression.value, f"0{angle_type_size}b") + # Reference: https://openqasm.com/language/types.html#angles + angle_val = (2 * CONSTANTS_MAP["pi"]) * (expression.value / (2**angle_type_size)) + else: + angle_val = init_value % (2 * CONSTANTS_MAP["pi"]) + angle_type_size = compiler_angle_width or base_size + bit_string_value = round((2**angle_type_size) * (angle_val / (2 * CONSTANTS_MAP["pi"]))) + angle_bit_string = format(bit_string_value, f"0{angle_type_size}b") + + return angle_val, angle_bit_string + @staticmethod def validate_duration_or_stretch_statements( statement: Statement, @@ -220,3 +277,112 @@ def validate_duration_variable( # pylint: disable=too-many-branches error_node=statement, span=statement.span, ) + + @staticmethod + def make_complex_binary_expression(value: complex) -> BinaryExpression: + """ + Make a binary expression from a complex number. + """ + return BinaryExpression( + lhs=FloatLiteral(value.real), + op=(BinaryOperator["+"] if value.imag >= 0 else BinaryOperator["-"]), + rhs=ImaginaryLiteral(np.abs(value.imag)), + ) + + @staticmethod + def validate_extern_declaration(module: Any, statement: ExternDeclaration) -> None: + """ + Validates an extern declaration. + Args: + module: The module object + statement: The extern declaration statement + + Raises: + ValidationError: If the extern declaration is invalid + """ + args = module._extern_functions[statement.name.name][0] + if len(args) != len(statement.arguments): + raise_qasm3_error( + f"Parameter count mismatch for 'extern' subroutine '{statement.name.name}'. " + f"Expected {len(args)} but got {len(statement.arguments)}", + error_node=statement, + span=statement.span, + ) + + def _get_type_string(type_obj) -> str: + """Recursively build type string for nested types""" + type_name = type(type_obj).__name__.replace("Type", "").lower() + if getattr(type_obj, "base_type", None) is not None: + return f"{type_name}[{_get_type_string(type_obj.base_type)}]" + if getattr(type_obj, "size", None) is not None: + return f"{type_name}[{type_obj.size.value}]" + return type_name + + for actual_arg, extern_arg in zip(statement.arguments, args): + if actual_arg == extern_arg: + continue + actual_arg_type = _get_type_string(actual_arg.type) + if actual_arg_type != str(extern_arg).lower(): + raise_qasm3_error( + f"Parameter type mismatch for 'extern' subroutine '{statement.name.name}'. " + f"Expected {extern_arg} but got {actual_arg_type}", + error_node=statement, + span=statement.span, + ) + return_type = module._extern_functions[statement.name.name][1] + actual_type_name = _get_type_string(statement.return_type) + if return_type == statement.return_type: + return + if str(return_type).lower() != actual_type_name: + raise_qasm3_error( + f"Return type mismatch for 'extern' subroutine '{statement.name.name}'. Expected " + f"{return_type} but got {actual_type_name}", + error_node=statement, + span=statement.span, + ) + + @staticmethod + def validate_and_process_extern_function_call( # pylint: disable=too-many-branches + statement: FunctionCall, global_scope: dict, device_cycle_time: float | None + ) -> FunctionCall: + """Validate and process extern function arguments + by converting them to appropriate literals. + + Args: + statement: The function call statement to process + global_scope: The global scope of the module + device_cycle_time: The device cycle time of the module + + Returns: + The validated and processed function call statement + + Raises: + ValidationError: If the function call is invalid + """ + + for i, arg in enumerate(statement.arguments): + if isinstance(arg, Identifier): + arg_var = global_scope.get(arg.name) + assert arg_var is not None + + if arg_var.base_type is not None and isinstance( + arg_var.base_type, (DurationType, StretchType) + ): + statement.arguments[i] = DurationLiteral( + float(arg_var.value) if arg_var.value is not None else 0.0, + unit=(TimeUnit.dt if device_cycle_time else TimeUnit.ns), + ) + elif isinstance(arg_var.value, float): + statement.arguments[i] = FloatLiteral(arg_var.value) + elif isinstance(arg_var.value, int): + statement.arguments[i] = IntegerLiteral(arg_var.value) + elif isinstance(arg_var.value, complex): + statement.arguments[i] = PulseValidator.make_complex_binary_expression( + arg_var.value + ) + elif isinstance(arg_var.value, str): + width = len(arg_var.value) + value = int(arg_var.value, 2) + statement.arguments[i] = BitstringLiteral(value, width) + + return statement diff --git a/src/pyqasm/subroutines.py b/src/pyqasm/subroutines.py index e4b358c4..fd0ab283 100644 --- a/src/pyqasm/subroutines.py +++ b/src/pyqasm/subroutines.py @@ -16,16 +16,30 @@ Module containing the class for validating QASM3 subroutines. """ +import uuid from typing import Optional +import numpy as np from openqasm3.ast import ( AccessControl, + AngleType, ArrayReferenceType, + BitstringLiteral, + BitType, + BooleanLiteral, + BoolType, + ComplexType, + DurationLiteral, + DurationType, + ExternArgument, + FloatType, Identifier, IndexExpression, IntType, QASMNode, QubitDeclaration, + StretchType, + UintType, ) from openqasm3.printer import dumps @@ -97,7 +111,7 @@ def process_classical_arg(cls, formal_arg, actual_arg, fn_name, fn_call): formal_arg, actual_arg, actual_arg_name, fn_name, fn_call ) - @classmethod # pylint: disable-next=too-many-arguments + @classmethod # pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches def _process_classical_arg_by_value( cls, formal_arg, actual_arg, actual_arg_name, fn_name, fn_call ): @@ -146,13 +160,100 @@ def _process_classical_arg_by_value( error_node=fn_call, span=fn_call.span, ) + + if isinstance(formal_arg, ExternArgument): + formal_arg_type = formal_arg.type + formal_arg_size = None + if hasattr(formal_arg_type, "size") and formal_arg_type.size is not None: + formal_arg_size, _ = Qasm3ExprEvaluator.evaluate_expression( + formal_arg_type.size + ) + actual_arg_var = cls.visitor_obj._scope_manager.get_from_global_scope( + actual_arg_name + ) + if ( + actual_arg_var.base_type != formal_arg_type + or actual_arg_var.base_size != formal_arg_size + ): + if formal_arg_size is not None: + raise_qasm3_error( + f"Argument type mismatch in function '{fn_name}', expected " + f"{type(formal_arg_type).__name__}[{formal_arg_size}] but got " + f"{type(actual_arg_var.base_type).__name__}" + f"[{actual_arg_var.base_size}]", + error_node=fn_call, + span=fn_call.span, + ) + actual_arg_value = Qasm3ExprEvaluator.evaluate_expression(actual_arg)[0] - # save this value to be updated later in scope + if isinstance(formal_arg, ExternArgument): + # Generate a unique name for the extern argument variable + _name = f"{fn_name}_{uuid.uuid4()}" + if hasattr(formal_arg.type, "size") and formal_arg.type.size is not None: + _base_size = Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0] + else: + _base_size = None + _base_type = formal_arg.type + + # pylint: disable=too-many-boolean-expressions + if actual_arg_value and actual_arg_name is None: + if ( + ( + isinstance(_base_type, UintType) + and (not isinstance(actual_arg_value, int) or actual_arg_value < 0) + ) + or ( + isinstance(_base_type, AngleType) + and ( + not isinstance(actual_arg_value, float) + or not 0 <= actual_arg_value <= 2 * np.pi + ) + ) + or ( + isinstance(_base_type, FloatType) + and not isinstance(actual_arg_value, float) + ) + or ( + isinstance(_base_type, ComplexType) + and not isinstance(actual_arg_value, complex) + ) + or (isinstance(_base_type, IntType) and not isinstance(actual_arg_value, int)) + or ( + isinstance(_base_type, (DurationType, StretchType)) + and not isinstance(actual_arg, DurationLiteral) + ) + or ( + isinstance(_base_type, BoolType) + and ( + not isinstance(actual_arg_value, bool) + and actual_arg_value not in (0, 1) + ) + ) + or ( + isinstance(_base_type, BitType) + and ( + not isinstance(actual_arg, (BitstringLiteral, BooleanLiteral)) + and not isinstance(actual_arg_value, (bool, int)) + ) + ) + ): + raise_qasm3_error( + f"Invalid argument value for '{fn_name}', expected " + f"'{type(_base_type).__name__}' but got value = " + f"{actual_arg_value}.", + error_node=fn_call, + span=fn_call.span, + ) + + else: + _name = formal_arg.name.name + _base_size = Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0] + _base_type = formal_arg.type return Variable( - name=formal_arg.name.name, - base_type=formal_arg.type, - base_size=Qasm3ExprEvaluator.evaluate_expression(formal_arg.type.size)[0], + name=_name, + base_type=_base_type, + base_size=_base_size, dims=None, value=actual_arg_value, is_constant=False, diff --git a/src/pyqasm/validator.py b/src/pyqasm/validator.py index 14c42e55..9a4cd82e 100644 --- a/src/pyqasm/validator.py +++ b/src/pyqasm/validator.py @@ -181,7 +181,7 @@ def validate_variable_assignment_value( error_node=op_node, span=op_node.span if op_node else None, ) - elif type_to_match == bool: + elif type_to_match in (bool, complex): pass else: raise_qasm3_error( diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index a1b374ca..005475ac 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -22,7 +22,7 @@ import logging from collections import OrderedDict, deque from functools import partial -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, Sequence, cast import numpy as np import openqasm3.ast as qasm3_ast @@ -49,6 +49,7 @@ from pyqasm.maps.expressions import ( ARRAY_TYPE_MAP, CONSTANTS_MAP, + FUNCTION_MAP, MAX_ARRAY_DIMENSIONS, ) from pyqasm.maps.gates import ( @@ -105,7 +106,9 @@ def __init__( # pylint: disable=too-many-arguments self._global_creg_size_map: dict[str, int] = {} self._custom_gates: dict[str, qasm3_ast.QuantumGateDefinition] = {} self._external_gates: list[str] = [] if external_gates is None else external_gates - self._subroutine_defns: dict[str, qasm3_ast.SubroutineDefinition] = {} + self._subroutine_defns: dict[ + str, qasm3_ast.SubroutineDefinition | qasm3_ast.ExternDeclaration + ] = {} self._check_only: bool = check_only self._unroll_barriers: bool = unroll_barriers self._recording_ext_gate_depth = False @@ -120,6 +123,7 @@ def __init__( # pylint: disable=too-many-arguments self._qubit_register_offsets: OrderedDict = OrderedDict() self._qubit_register_max_offset = 0 self._total_delay_duration_in_box = 0 + self._in_extern_function: bool = False self._scope_manager: ScopeManager = scope_manager @@ -330,6 +334,11 @@ def _check_variable_type_size( if not hasattr(base_type, "size") or base_type.size is None else Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)[0] ) + if ( + isinstance(base_type, qasm3_ast.AngleType) + and self._module._compiler_angle_type_size + ): + base_size = self._module._compiler_angle_type_size except ValidationError as err: raise_qasm3_error( f"Invalid base size for {var_format} '{var_name}'", @@ -425,6 +434,45 @@ def _qubit_register_consolidation( return _valid_statements + def _handle_function_init_expression( + self, expression: Any, init_value: Any + ) -> None | qasm3_ast.Expression: + """Handle function initialization expression. + + Args: + statement (Any): The statement to handle function initialization expression. + init_value (Any): The value to handle function initialization expression. + """ + if isinstance(expression, qasm3_ast.FunctionCall): + func_name = expression.name.name + if func_name in FUNCTION_MAP: + if isinstance(init_value, (float, int)): + return qasm3_ast.FloatLiteral(init_value) + return None + + def _handle_extern_function_cleanup( + self, statements: list, statement: qasm3_ast.Statement + ) -> None: + """Clean up extern function state and modify statements if needed. + + Args: + statements: List of statements to potentially modify + statement: The statement to append if in extern function + """ + if self._in_extern_function: + self._in_extern_function = False + statements.clear() + statements.append(statement) + + def _validate_bitstring_literal_width(self, init_value, base_size, var_name, statement): + if len(init_value) != base_size: + raise_qasm3_error( + f"Invalid bitstring literal '{init_value}' width [{len(init_value)}] " + f"for variable '{var_name}' of size [{base_size}]", + error_node=statement, + span=statement.span, + ) + def _visit_measurement( # pylint: disable=too-many-locals, too-many-branches self, statement: qasm3_ast.QuantumMeasurementStatement ) -> list[qasm3_ast.QuantumMeasurementStatement]: @@ -1323,7 +1371,9 @@ def _visit_constant_declaration( try: init_value, stmts = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, const_expr=True, dt=self._module._device_cycle_time + statement.init_expression, + const_expr=True, + dt=self._module._device_cycle_time, ) except ValidationError as err: raise_qasm3_error( @@ -1337,14 +1387,33 @@ def _visit_constant_declaration( base_type = statement.type base_size = self._check_variable_type_size(statement, var_name, "constant", base_type) + angle_val_bit_string = None + if isinstance(base_type, qasm3_ast.AngleType) and not self._in_extern_function: + init_value, angle_val_bit_string = PulseValidator.validate_angle_type_value( + statement, + init_value=init_value, + base_size=base_size, + compiler_angle_width=self._module._compiler_angle_type_size, + ) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, validate_only=True + statement.init_expression, + validate_only=True, ) self._check_variable_cast_type(statement, val_type, var_name, base_type, base_size, True) variable = Variable( - var_name, base_type, base_size, [], init_value, is_constant=True, span=statement.span + var_name, + base_type, + len(angle_val_bit_string) if angle_val_bit_string else base_size, + [], + init_value, + is_constant=True, + span=statement.span, + angle_bit_string=angle_val_bit_string, ) + if isinstance(base_type, qasm3_ast.BitType) and isinstance(init_value, str): + self._validate_bitstring_literal_width(init_value, base_size, var_name, statement) + if isinstance(base_type, (qasm3_ast.DurationType, qasm3_ast.StretchType)): PulseValidator.validate_duration_literal_value(init_value, statement, base_type) if self._module._device_cycle_time: @@ -1353,12 +1422,23 @@ def _visit_constant_declaration( variable.time_unit = "ns" # cast + validation - variable.value = Qasm3Validator.validate_variable_assignment_value( - variable, init_value, op_node=statement - ) + if init_value is not None and not self._in_extern_function: + variable.value = Qasm3Validator.validate_variable_assignment_value( + variable, init_value, op_node=statement + ) self._scope_manager.add_var_in_scope(variable) + if isinstance(base_type, qasm3_ast.ComplexType) and isinstance(variable.value, complex): + statement.init_expression = PulseValidator.make_complex_binary_expression(init_value) + + if isinstance(statement.init_expression, qasm3_ast.FunctionCall): + statement.init_expression = ( + self._handle_function_init_expression(statement.init_expression, init_value) + or statement.init_expression + ) + self._handle_extern_function_cleanup(statements, statement) + if self._check_only: return [] @@ -1405,6 +1485,7 @@ def _visit_classical_declaration( base_type = statement.type dimensions = [] final_dimensions = [] + angle_val_bit_string = None if isinstance(base_type, qasm3_ast.StretchType): if statement.init_expression: @@ -1480,12 +1561,27 @@ def _visit_classical_declaration( else: try: init_value, stmts = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, dt=self._module._device_cycle_time + statement.init_expression, + dt=self._module._device_cycle_time, ) statements.extend(stmts) + _req_type = ( + type(qasm3_ast.AngleType()) + if isinstance(base_type, qasm3_ast.AngleType) + else None + ) val_type, _ = Qasm3ExprEvaluator.evaluate_expression( - statement.init_expression, validate_only=True + statement.init_expression, + validate_only=True, + reqd_type=_req_type, ) + if isinstance(base_type, qasm3_ast.AngleType) and not self._in_extern_function: + init_value, angle_val_bit_string = PulseValidator.validate_angle_type_value( + statement, + init_value=init_value, + base_size=base_size, + compiler_angle_width=self._module._compiler_angle_type_size, + ) self._check_variable_cast_type( statement, val_type, var_name, base_type, base_size, False ) @@ -1496,15 +1592,18 @@ def _visit_classical_declaration( span=statement.span, raised_from=err, ) + if isinstance(base_type, qasm3_ast.BitType) and isinstance(init_value, str): + self._validate_bitstring_literal_width(init_value, base_size, var_name, statement) variable = Variable( var_name, base_type, - base_size, + len(angle_val_bit_string) if angle_val_bit_string else base_size, final_dimensions, init_value, is_qubit=False, span=statement.span, + angle_bit_string=angle_val_bit_string, ) if isinstance(base_type, qasm3_ast.DurationType): @@ -1531,9 +1630,10 @@ def _visit_classical_declaration( ) else: try: - variable.value = Qasm3Validator.validate_variable_assignment_value( - variable, init_value, op_node=statement - ) + if init_value is not None and not self._in_extern_function: + variable.value = Qasm3Validator.validate_variable_assignment_value( + variable, init_value, op_node=statement + ) except ValidationError as err: raise_qasm3_error( f"Invalid initialization value for variable '{var_name}'", @@ -1560,6 +1660,17 @@ def _visit_classical_declaration( statements.append(statement) self._module._add_classical_register(var_name, base_size) + self._handle_extern_function_cleanup(statements, statement) + + if isinstance(base_type, qasm3_ast.ComplexType) and isinstance(variable.value, complex): + statement.init_expression = PulseValidator.make_complex_binary_expression(init_value) + + if isinstance(statement.init_expression, qasm3_ast.FunctionCall): + statement.init_expression = ( + self._handle_function_init_expression(statement.init_expression, init_value) + or statement.init_expression + ) + if self._check_only: return [] @@ -1625,10 +1736,14 @@ def _visit_classical_assignment( lhs=lvalue, op=binary_op, rhs=rvalue # type: ignore[arg-type] ) rvalue_raw, rhs_stmts = Qasm3ExprEvaluator.evaluate_expression( - rvalue, dt=self._module._device_cycle_time + rvalue, + dt=self._module._device_cycle_time, ) # consists of scope check and index validation statements.extend(rhs_stmts) - val_type, _ = Qasm3ExprEvaluator.evaluate_expression(rvalue, validate_only=True) + val_type, _ = Qasm3ExprEvaluator.evaluate_expression( + rvalue, + validate_only=True, + ) self._check_variable_cast_type( statement, val_type, @@ -1637,13 +1752,29 @@ def _visit_classical_assignment( lvar.base_size, # type: ignore[union-attr] False, ) + if isinstance(lvar_base_type, qasm3_ast.BitType) and isinstance(rvalue_raw, str): + self._validate_bitstring_literal_width( + rvalue_raw, lvar.base_size, lvar_name, statement # type: ignore[union-attr] + ) + angle_val_bit_string = None + if isinstance(lvar_base_type, qasm3_ast.AngleType) and not self._in_extern_function: + rvalue_raw, angle_val_bit_string = PulseValidator.validate_angle_type_value( + statement, + init_value=rvalue_raw, + base_size=lvar.base_size, # type: ignore[union-attr] + compiler_angle_width=self._module._compiler_angle_type_size, + ) + lvar.angle_bit_string = angle_val_bit_string # type: ignore[union-attr] + if angle_val_bit_string: + lvar.base_size = len(angle_val_bit_string) # type: ignore[union-attr] # cast + validation rvalue_eval = None if not isinstance(rvalue_raw, np.ndarray): # rhs is a scalar - rvalue_eval = Qasm3Validator.validate_variable_assignment_value( - lvar, rvalue_raw, op_node=statement # type: ignore[arg-type] - ) + if rvalue_raw is not None and not self._in_extern_function: + rvalue_eval = Qasm3Validator.validate_variable_assignment_value( + lvar, rvalue_raw, op_node=statement # type: ignore[arg-type] + ) else: # rhs is a list rvalue_dimensions = list(rvalue_raw.shape) @@ -1692,6 +1823,21 @@ def _visit_classical_assignment( lvar.value = rvalue_eval # type: ignore[union-attr] self._scope_manager.update_var_in_scope(lvar) # type: ignore[arg-type] + if isinstance(lvar_base_type, qasm3_ast.ComplexType) and isinstance( + lvar.value, complex # type: ignore[union-attr] + ): + statement.rvalue = PulseValidator.make_complex_binary_expression( + lvar.value # type: ignore[union-attr] + ) + + if isinstance(statement.rvalue, qasm3_ast.FunctionCall): + statement.rvalue = ( + self._handle_function_init_expression(statement.rvalue, rvalue_eval) + or statement.rvalue + ) + + self._handle_extern_function_cleanup(statements, statement) + if self._check_only: return [] @@ -1946,7 +2092,9 @@ def _visit_forin_loop(self, statement: qasm3_ast.ForInLoop) -> list[qasm3_ast.St return [] return result - def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition) -> list[None]: + def _visit_subroutine_definition( + self, statement: qasm3_ast.SubroutineDefinition | qasm3_ast.ExternDeclaration + ) -> Sequence[None | qasm3_ast.ExternDeclaration]: """Visit a subroutine definition element. Reference: https://openqasm.com/language/subroutines.html#subroutines @@ -1957,6 +2105,7 @@ def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition None """ fn_name = statement.name.name + statements = [] if fn_name in CONSTANTS_MAP: raise_qasm3_error( @@ -1978,14 +2127,26 @@ def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition span=statement.span, ) + if isinstance(statement, qasm3_ast.ExternDeclaration): + if statement.name.name in self._module._extern_functions: + PulseValidator.validate_extern_declaration(self._module, statement) + self._module._extern_functions[statement.name.name] = ( + statement.arguments, + statement.return_type, + ) + + statements.append(statement) + self._subroutine_defns[fn_name] = statement + if self._check_only: + return [] - return [] + return statements # pylint: disable=too-many-locals, too-many-statements def _visit_function_call( self, statement: qasm3_ast.FunctionCall - ) -> tuple[Any, list[qasm3_ast.Statement]]: + ) -> tuple[Any | None, list[qasm3_ast.Statement | qasm3_ast.FunctionCall]]: """Visit a function call element. Args: @@ -2018,7 +2179,7 @@ def _visit_function_call( quantum_vars, classical_vars = [], [] for actual_arg, formal_arg in zip(statement.arguments, subroutine_def.arguments): - if isinstance(formal_arg, qasm3_ast.ClassicalArgument): + if isinstance(formal_arg, (qasm3_ast.ClassicalArgument, qasm3_ast.ExternArgument)): classical_vars.append( Qasm3SubroutineProcessor.process_classical_arg( formal_arg, actual_arg, fn_name, statement @@ -2052,25 +2213,34 @@ def _visit_function_call( self._function_qreg_transform_map.append(qubit_transform_map) return_statement = None - result = [] - for function_op in subroutine_def.body: - if isinstance(function_op, qasm3_ast.ReturnStatement): - return_statement = copy.copy(function_op) - break - try: - result.extend(self.visit_statement(copy.copy(function_op))) - except (TypeError, copy.Error): - result.extend(self.visit_statement(copy.deepcopy(function_op))) - return_value = None - if return_statement: - return_value, stmts = Qasm3ExprEvaluator.evaluate_expression( - return_statement.expression - ) - return_value = Qasm3Validator.validate_return_statement( - subroutine_def, return_statement, return_value + result: list[qasm3_ast.Statement | qasm3_ast.FunctionCall] = [] + if isinstance(subroutine_def, qasm3_ast.ExternDeclaration): + self._in_extern_function = True + global_scope = self._scope_manager.get_global_scope() + result.append( + PulseValidator.validate_and_process_extern_function_call( + statement, global_scope, self._module._device_cycle_time + ) ) - result.extend(stmts) + else: + for function_op in subroutine_def.body: + if isinstance(function_op, qasm3_ast.ReturnStatement): + return_statement = copy.copy(function_op) + break + try: + result.extend(self.visit_statement(copy.copy(function_op))) + except (TypeError, copy.Error): + result.extend(self.visit_statement(copy.deepcopy(function_op))) + + if return_statement: + return_value, stmts = Qasm3ExprEvaluator.evaluate_expression( + return_statement.expression, + ) + return_value = Qasm3Validator.validate_return_statement( + subroutine_def, return_statement, return_value + ) + result.extend(stmts) # remove qubit transformation map self._function_qreg_transform_map.pop() @@ -2080,7 +2250,7 @@ def _visit_function_call( self._scope_manager.decrement_scope_level() self._scope_manager.pop_scope() - if self._check_only: + if self._check_only and not self._in_extern_function: return return_value, [] return return_value, result @@ -2532,6 +2702,7 @@ def visit_statement(self, statement: qasm3_ast.Statement) -> list[qasm3_ast.Stat qasm3_ast.AliasStatement: self._visit_alias_statement, qasm3_ast.SwitchStatement: self._visit_switch_statement, qasm3_ast.SubroutineDefinition: self._visit_subroutine_definition, + qasm3_ast.ExternDeclaration: self._visit_subroutine_definition, qasm3_ast.ExpressionStatement: lambda x: self._visit_function_call(x.expression), qasm3_ast.IODeclaration: lambda x: [], qasm3_ast.BreakStatement: self._visit_break, @@ -2594,4 +2765,5 @@ def finalize(self, unrolled_stmts): if isinstance(stmt, qasm3_ast.QuantumPhase): if len(stmt.qubits) == len(self._qubit_labels): stmt.qubits = [] + Qasm3ExprEvaluator.angle_var_in_expr = None return unrolled_stmts diff --git a/tests/qasm3/declarations/test_classical.py b/tests/qasm3/declarations/test_classical.py index 72b33f93..25d83444 100644 --- a/tests/qasm3/declarations/test_classical.py +++ b/tests/qasm3/declarations/test_classical.py @@ -20,6 +20,7 @@ from pyqasm.entrypoint import loads from pyqasm.exceptions import ValidationError +from pyqasm.visitor import QasmVisitor, ScopeManager # pylint: disable=ungrouped-imports from tests.qasm3.resources.variables import ( ASSIGNMENT_TESTS, CASTING_TESTS, @@ -45,6 +46,7 @@ def test_scalar_declarations(): bool i; duration j; stretch st; + angle[8] ang1; """ loads(qasm3_string).validate() @@ -75,6 +77,10 @@ def test_const_declarations(): const duration t8 = t2/t3; const stretch st = 300ns; const stretch st2 = t2/t3; + const angle[8] ang1 = 7 * (pi / 8); + const angle[8] ang2 = 9 * (pi / 8); + const angle[8] ang3 = ang1 + ang2; + const bit[4] bit_check = "1011"; """ loads(qasm3_string).validate() @@ -98,6 +104,10 @@ def test_scalar_assignments(): duration du = 200us; duration du2; du2 = 300s; + angle[8] ang1; + ang1 = 9 * (pi / 8); + bit[4] bit_check; + bit_check = "1011"; """ loads(qasm3_string).validate() @@ -123,6 +133,9 @@ def test_scalar_value_assignment(): duration t9 = t2 - t7; duration t10 = t2 + t7; duration t11 = t2 * t7; + angle[8] ang1 = 7 * (pi / 8); + angle[8] ang2 = 9 * (pi / 8); + angle[8] ang3 = ang1 + ang2; """ b = 5.0 @@ -603,6 +616,43 @@ def test_duration_casting_error(qasm_code, error_message, error_span, caplog): assert error_span in caplog.text +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + angle[8] ang1 = 7 * (pi / 8); + angle[7] ang2 = 9 * (pi / 8); + angle[8] ang3 = ang1 + ang2; + """, + r"All 'Angle' variables in binary expression must have the same size", + r"Error at line 6, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + angle ang1 = "1000111111"; + """, + r"BitString angle width '10' does not match with compiler angle width '8'", + r"Error at line 4, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_angle_type_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + loads(qasm_code, compiler_angle_type_size=8).validate() + + first = excinfo.value.__cause__ or excinfo.value.__context__ + assert first is not None, "Expected a chained ValidationError" + msg = str(first) + assert error_message in msg + assert error_span in caplog.text + + def test_device_time_duration_(): """Test device cycle time duration""" qasm3_string = """ @@ -613,3 +663,149 @@ def test_device_time_duration_(): const duration t3 =300us; """ loads(qasm3_string, device_cycle_time=1e-9).validate() + + +def test_compiler_angle_type_size(): + """Test compiler angle type size""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + angle[8] ang1 = 7 * (pi / 8); + const angle[8] ang2 = 9 * (pi / 8); + angle[4] ang3 = "1010"; + """ + loads(qasm3_string, compiler_angle_type_size=4).validate() + + +def test_complex_type_variables(): + """Test complex type variables""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + complex c1 = -2.5 - 3.5im; + complex c2 = 3.5 + 2.5im; + complex c3 = 2.0 + c2; + complex c4 = 2.0+sin(π/2) + (3.1 * 5.5im); + complex c5 = 2.0+arcsin(π/2) + (3.1 * 5.5im); + complex c6 = 2.0+arctan(π/2) + (3.1 * 5.5im); + complex c7 = c1 * c2; + complex c8 = c1 + c2; + complex c9 = c1 - c2; + complex c10 = c1 / c2; + complex c11 = c1 ** c2; + complex c12 = sqrt(c1); + float c13 = abs(c1 * c2); + float c14 = real(c1); + float c15 = imag(c1); + float c16 = sin(π/2); + const complex c17 = -2.5 - 3.5im; + const complex c18 = 3.5 + 2.5im; + const complex c19 = 2.0 + c18; + const complex c20 = 2.0+cos(π/2) + (3.1 * 5.5im); + const complex c21 = 2.0+arccos(π/2) + (3.1 * 5.5im); + const complex c22 = c17 * c18; + const complex c23 = c17 + c18; + const complex c24 = c17 - c18; + const complex c25 = c17 / c18; + const complex c26 = c17 ** c18; + const complex c27 = sqrt(c17); + const float c28 = abs(c17 * c18); + const float c29 = real(c17); + const float c30 = imag(c17); + const float c31 = sin(π/2); + complex c32; + c32 = -2.5 - 3.5im; + complex c33; + c33 = 3.5 + 2.5im; + complex c34; + c34 = 2.0 + c33; + complex c35; + c35 = 2.0+tan(π/2) + (3.1 * 5.5im); + complex c36; + c36 = 2.0+arctan(π/2) + (3.1 * 5.5im); + complex c37; + c37 = c32 * c33; + complex c38; + c38 = c32 + c33; + complex c39; + c39 = c32 - c33; + complex c40; + c40 = c32 / c33; + complex c41; + c41 = c32 ** c33; + complex c42; + c42 = sqrt(c32); + float c43; + c43 = abs(c32 * c33); + float c44; + c44 = real(c32); + float c45; + c45 = imag(c32); + float c46; + c46 = sin(π/2); + complex[float[64]] a = 10.0 + 5.0im; + complex[float[64]] b = -2.0 - 7.0im; + complex[float[64]] c = a + b; + complex[float[64]] d = a - b; + complex[float[64]] e = a * b; + complex[float[64]] f = a / b; + complex[float[64]] g = a ** b; + complex[float] h = a + b; + complex i = sqrt(1.0 + 2.0im); + """ + + loads(qasm3_string).validate() + + +def test_pi_expression_bit_conversion(): + """Test that pi expressions are correctly converted to bit string representations""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + + angle[4] ang1 = pi / 2; + angle[8] ang2 = 15 * (pi / 16); + angle[4] ang3 = -pi / 2; + angle[4] ang4 = -pi; + angle[8] ang5 = (pi / 2) + (pi / 4); + angle[8] ang6 = (pi / 2) - (pi / 4); + + """ + + result = loads(qasm3_string) + result.validate() + + # Create a visitor to access the scope manager + scope_manager = ScopeManager() + visitor = QasmVisitor(result, scope_manager, check_only=True) + result.accept(visitor) + scope = scope_manager.get_global_scope() + + assert scope["ang1"].angle_bit_string == "0100" # pi/2 + assert scope["ang2"].angle_bit_string == "01111000" # 15*pi/16 + assert scope["ang3"].angle_bit_string == "1100" # -pi/2 (wraps to 3*pi/2) + assert scope["ang4"].angle_bit_string == "1000" # -pi (wraps to 1) + assert scope["ang5"].angle_bit_string == "01100000" # (pi/2) + (pi/4) = 3*pi/4 + assert scope["ang6"].angle_bit_string == "00100000" # (pi/2) - (pi/4) = pi/4 + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit[4] i = "101"; + """, + r"Invalid bitstring literal '101' width [3] for variable 'i' of size [4]", + r"Error at line 4, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_bit_string_literal_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as err: + with caplog.at_level("ERROR"): + loads(qasm_code).validate() + assert error_message in str(err.value) + assert error_span in caplog.text diff --git a/tests/qasm3/resources/variables.py b/tests/qasm3/resources/variables.py index 2c1e233e..2ef514f2 100644 --- a/tests/qasm3/resources/variables.py +++ b/tests/qasm3/resources/variables.py @@ -146,18 +146,6 @@ 8, "float[23] x;", ), - "unsupported_types": ( - """ - OPENQASM 3.0; - include "stdgates.inc"; - - angle x = 3.4; - """, - "Invalid initialization value for variable 'x'", - 5, - 8, - "angle x = 3.4;", - ), "imaginary_variable": ( """ OPENQASM 3.0; @@ -375,6 +363,8 @@ int ccf1 = float(runtime_u) * int(f1); uint ul1 = uint(float[64](int[16](f1))) * 2; const int un = -int(u1); + angle ang1 = angle(true); + angle ang2 = angle(false); """ ), "Bool_test": ( diff --git a/tests/qasm3/subroutines/test_subroutines.py b/tests/qasm3/subroutines/test_subroutines.py index c9e5d5e7..8eeb2825 100644 --- a/tests/qasm3/subroutines/test_subroutines.py +++ b/tests/qasm3/subroutines/test_subroutines.py @@ -19,10 +19,14 @@ import pytest -from pyqasm.entrypoint import loads +from pyqasm.entrypoint import dumps, loads from pyqasm.exceptions import ValidationError from tests.qasm3.resources.subroutines import SUBROUTINE_INCORRECT_TESTS -from tests.utils import check_single_qubit_gate_op, check_single_qubit_rotation_op +from tests.utils import ( + check_single_qubit_gate_op, + check_single_qubit_rotation_op, + check_unrolled_qasm, +) def test_function_declaration(): @@ -421,3 +425,332 @@ def test_incorrect_custom_ops(test_name, caplog): assert f"Error at line {line_num}, column {col_num}" in caplog.text assert err_line in caplog.text + + +def test_extern_function_call(): + """Test extern function call""" + qasm3_string = """ + OPENQASM 3.0; + include "stdgates.inc"; + float a = 1.0; + int b = 2; + extern func1(float, int) -> bit; + bit c = 2 * func1(a, b); + bit fc = -func1(a, b); + + bit[2] b1 = true; + angle ang1 = pi/2; + extern func2(bit[2], angle) -> complex; + const complex d = func2(b1, ang1); + const complex e = func2(b1, ang1) + 2.0; + const complex f = -func2(b1, ang1); + + duration t1 = 100ns; + bool b2 = true; + extern func3(duration, bool) -> int; + int dd; + dd = func3(t1, b2); + int ee; + ee = func3(t1, b2) + 2; + int ff; + ff = -func3(t1, b2); + int gg; + gg = func3(t1, b2) * func3(t1, b2); + + float[32] fa = 3.14; + float[32] fb = 2.71; + extern func4(float[32], float[32]) -> float[32]; + float[32] fc1 = func4(fa, fb); + + complex[float[64]] ca = 1.0 + 2.0im; + complex[float[64]] cb = 3.0 - 4.0im; + extern func5(complex[float[64]], complex[float[64]]) -> complex[float[64]]; + complex[float[64]] cc1 = func5(ca, cb); + + bit[4] bd = "0101"; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(bd); + + angle[8] an = pi/4; + extern func7(angle[8]) -> angle[8]; + angle[8] af1 = func7(an); + + bool bl = false; + extern func8(bool) -> bool; + bool bf1 = func8(bl); + + int[24] ix = 42; + extern func9(int[24]) -> int[24]; + int[24] ig1 = func9(ix); + + float[64] fx = 2.718; + extern func10(float[64]) -> float[64]; + float[64] fg1 = func10(fx); + """ + + expected_qasm = """OPENQASM 3.0; + include "stdgates.inc"; + extern func1(float, int) -> bit; + bit[1] c = 2 * func1(1.0, 2); + bit[1] fc = -func1(1.0, 2); + bit[2] b1 = true; + extern func2(bit[2], angle) -> complex; + const complex d = func2(True, 1.5707963267948966); + const complex e = func2(True, 1.5707963267948966) + 2.0; + const complex f = -func2(True, 1.5707963267948966); + extern func3(duration, bool) -> int; + dd = func3(100.0ns, True); + ee = func3(100.0ns, True) + 2; + ff = -func3(100.0ns, True); + gg = func3(100.0ns, True) * func3(100.0ns, True); + extern func4(float[32], float[32]) -> float[32]; + float[32] fc1 = func4(3.14, 2.71); + extern func5(complex[float[64]], complex[float[64]]) -> complex[float[64]]; + complex[float[64]] cc1 = func5(1.0 + 2.0im, 3.0 - 4.0im); + bit[4] bd = "0101"; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6("0101"); + extern func7(angle[8]) -> angle[8]; + angle[8] af1 = func7(0.7853981633974483); + extern func8(bool) -> bool; + bool bf1 = func8(False); + extern func9(int[24]) -> int[24]; + int[24] ig1 = func9(42); + extern func10(float[64]) -> float[64]; + float[64] fg1 = func10(2.718); + """ + + extern_functions = { + "func1": (["float", "int"], "bit"), + "func2": (["bit[2]", "angle"], "complex"), + "func3": (["duration", "bool"], "int"), + "func4": (["float[32]", "float[32]"], "float[32]"), + "func5": (["complex[float[64]]", "complex[float[64]]"], "complex[float[64]]"), + "func6": (["bit[4]"], "bit[4]"), + "func7": (["angle[8]"], "angle[8]"), + "func8": (["bool"], "bool"), + "func9": (["int[24]"], "int[24]"), + "func10": (["float[64]"], "float[64]"), + } + + result = loads(qasm3_string, extern_functions=extern_functions) + result.validate() + result.unroll() + unrolled_qasm = dumps(result) + + check_unrolled_qasm(unrolled_qasm, expected_qasm) + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit[4] bd = true; + float[64] fx = 2.718; + extern func6(bit[4]) -> bit[4]; + extern func10(float[64]) -> float[64]; + bit[4] be1 = func6(bd) * func10(fx); + """, + r"extern function return type mismatch in binary expression: BitType and FloatType", + r"Error at line 8, column 25", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit bd = true; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(bd); + """, + r"Argument type mismatch in function 'func6', expected BitType[4] but got BitType[1]", + r"Error at line 6, column 25", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + bit bd = true; + extern func6(bit[4]) -> bit[4]; + bit[4] be1 = func6(fd); + """, + r"Undefined variable 'fd' used for function call 'func6'", + r"Error at line 6, column 25", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_extern_function_call_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + loads( + qasm_code, + ).validate() + first = excinfo.value.__cause__ or excinfo.value.__context__ + assert first is not None, "Expected a chained ValidationError" + msg = str(first) + assert error_message in msg + assert error_span in caplog.text + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + float fx = 2.718; + int ix = 42; + extern func1(float) -> bit; + bit be1 = func1(fx); + """, + r"Parameter count mismatch for 'extern' subroutine 'func1'. Expected 2 but got 1", + r"Error at line 6, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + float fx = 2.718; + int ix = 42; + extern func1(float[64], int[64]) -> bit; + bit be1 = func1(fx); + """, + r"Parameter type mismatch for 'extern' subroutine 'func1'." + r" Expected float but got float[64]", + r"Error at line 6, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + float fx = 2.718; + int ix = 42; + extern func1(float, int) -> bit[2]; + bit[2] be1 = func1(fx); + """, + r"Return type mismatch for 'extern' subroutine 'func1'. Expected bit but got bit[2]", + r"Error at line 6, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_extern_function_dict_call_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + extern_functions = { + "func1": (["float", "int"], "bit"), + "func2": (["bit[2]", "angle"], "complex"), + "func3": (["uint"], "int"), + } + loads(qasm_code, extern_functions=extern_functions).validate() + msg = str(excinfo.value) + assert error_message in msg + assert error_span in caplog.text + + +@pytest.mark.parametrize( + "qasm_code,error_message,error_span", + [ + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func1(float) -> bit; + bit be1 = func1(2); + """, + r"Invalid argument value for 'func1', expected 'FloatType' but got value = 2.", + r"Error at line 5, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func2(uint) -> complex; + complex ce1 = func2(-22); + """, + r"Invalid argument value for 'func2', expected 'UintType' but got value = -22.", + r"Error at line 5, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func3(duration) -> int; + int ie1 = func3(true); + """, + r"Invalid argument value for 'func3', expected 'DurationType' but got value = True.", + r"Error at line 5, column 12", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func4(bit[4]) -> float[64]; + float[64] fe1 = func4(3.14); + """, + r"Invalid argument value for 'func4', expected 'BitType' but got value = 3.14.", + r"Error at line 5, column 28", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func5(complex[float[64]]) -> complex[float[64]]; + complex[float[64]] ce1 = func5(42); + """, + r"Invalid argument value for 'func5', expected 'ComplexType' but got value = 42.", + r"Error at line 5, column 37", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func6(angle[8]) -> angle[8]; + angle[8] ae1 = func6(100ns); + """, + r"Invalid argument value for 'func6', expected 'AngleType' but got value = 100.0.", + r"Error at line 5, column 27", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func7(bool) -> bool; + bool be1 = func7(3.14); + """, + r"Invalid argument value for 'func7', expected 'BoolType' but got value = 3.14.", + r"Error at line 5, column 23", + ), + ( + """ + OPENQASM 3.0; + include "stdgates.inc"; + extern func8(int[24]) -> int[24]; + int[24] ie1 = func8(2.0); + """, + r"Invalid argument value for 'func8', expected 'IntType' but got value = 2.0.", + r"Error at line 5, column 12", + ), + ], +) # pylint: disable-next= too-many-arguments +def test_extern_function_value_error(qasm_code, error_message, error_span, caplog): + with pytest.raises(ValidationError) as excinfo: + with caplog.at_level("ERROR"): + extern_functions = { + "func1": (["float"], "bit"), + "func2": (["uint"], "complex"), + "func3": (["duration"], "int"), + "func4": (["bit[4]"], "float[64]"), + "func5": (["complex[float[64]]"], "complex[float[64]]"), + "func6": (["angle[8]"], "angle[8]"), + "func7": (["bool"], "bool"), + "func8": (["int[24]"], "int[24]"), + } + loads(qasm_code, extern_functions=extern_functions).validate() + first = excinfo.value.__cause__ or excinfo.value.__context__ + assert first is not None, "Expected a chained ValidationError" + msg = str(first) + assert error_message in msg + assert error_span in caplog.text diff --git a/tests/qasm3/test_expressions.py b/tests/qasm3/test_expressions.py index 9f023c5e..b07b7a63 100644 --- a/tests/qasm3/test_expressions.py +++ b/tests/qasm3/test_expressions.py @@ -75,14 +75,6 @@ def test_bit_in_expression(): def test_incorrect_expressions(caplog): - with pytest.raises(ValidationError, match=r"Invalid parameter .*"): - with caplog.at_level("ERROR"): - loads("OPENQASM 3; qubit q; rz(1 - 2 + 32im) q;").validate() - assert "Error at line 1, column 32" in caplog.text - assert "32.0im" in caplog.text - - caplog.clear() - with pytest.raises(ValidationError, match=r"Invalid parameter .*"): with caplog.at_level("ERROR"): loads("OPENQASM 3; qubit q; rx(~1.3) q;").validate()