diff --git a/.github/workflows/tests-ast.yml b/.github/workflows/tests-ast.yml index 090d248..67c5878 100644 --- a/.github/workflows/tests-ast.yml +++ b/.github/workflows/tests-ast.yml @@ -14,8 +14,8 @@ jobs: fail-fast: false matrix: # Just using minimum and maximum to avoid exploding the matrix. - python-version: ['3.7', '3.10'] - antlr-version: ['4.7', '4.11'] + python-version: ['3.7', '3.12'] + antlr-version: ['4.7', '4.13'] defaults: run: working-directory: source/openpulse diff --git a/source/grammar/openpulseParser.g4 b/source/grammar/openpulseParser.g4 index 5b465b2..16c0fcd 100644 --- a/source/grammar/openpulseParser.g4 +++ b/source/grammar/openpulseParser.g4 @@ -38,6 +38,7 @@ openpulseStatement: | quantumDeclarationStatement | resetStatement | returnStatement + | switchStatement | whileStatement ) ; diff --git a/source/grammar/qasm3Lexer.g4 b/source/grammar/qasm3Lexer.g4 index be3b5a9..7daaeac 100644 --- a/source/grammar/qasm3Lexer.g4 +++ b/source/grammar/qasm3Lexer.g4 @@ -14,8 +14,8 @@ lexer grammar qasm3Lexer; /* Language keywords. */ OPENQASM: 'OPENQASM' -> pushMode(VERSION_IDENTIFIER); -INCLUDE: 'include'; -DEFCALGRAMMAR: 'defcalgrammar'; +INCLUDE: 'include' -> pushMode(ARBITRARY_STRING); +DEFCALGRAMMAR: 'defcalgrammar' -> pushMode(ARBITRARY_STRING); DEF: 'def'; CAL: 'cal' -> mode(CAL_PRELUDE); DEFCAL: 'defcal' -> mode(DEFCAL_PRELUDE); @@ -33,6 +33,9 @@ RETURN: 'return'; FOR: 'for'; WHILE: 'while'; IN: 'in'; +SWITCH: 'switch'; +CASE: 'case'; +DEFAULT: 'default'; PRAGMA: '#'? 'pragma' -> pushMode(EAT_TO_LINE_END); AnnotationKeyword: '@' Identifier -> pushMode(EAT_TO_LINE_END); @@ -123,7 +126,7 @@ ComparisonOperator: '>' | '<' | '>=' | '<='; BitshiftOperator: '>>' | '<<'; IMAG: 'im'; -ImaginaryLiteral: (DecimalIntegerLiteral | FloatLiteral) ' '* IMAG; +ImaginaryLiteral: (DecimalIntegerLiteral | FloatLiteral) [ \t]* IMAG; BinaryIntegerLiteral: ('0b' | '0B') ([01] '_'?)* [01]; OctalIntegerLiteral: '0o' ([0-7] '_'?)* [0-7]; @@ -149,15 +152,9 @@ FloatLiteral: fragment TimeUnit: 'dt' | 'ns' | 'us' | 'µs' | 'ms' | 's'; // represents explicit time value in SI or backend units -TimingLiteral: (DecimalIntegerLiteral | FloatLiteral) TimeUnit; - +TimingLiteral: (DecimalIntegerLiteral | FloatLiteral) [ \t]* TimeUnit; BitstringLiteral: '"' ([01] '_'?)* [01] '"'; -// allow ``"str"`` and ``'str'`` -StringLiteral - : '"' ~["\r\t\n]+? '"' - | '\'' ~['\r\t\n]+? '\'' - ; // Ignore whitespace between tokens, and define C++-style comments. Whitespace: [ \t]+ -> skip ; @@ -173,6 +170,13 @@ mode VERSION_IDENTIFIER; VERSION_IDENTIFER_WHITESPACE: [ \t\r\n]+ -> skip; VersionSpecifier: [0-9]+ ('.' [0-9]+)? -> popMode; +// An include statement's path or defcalgrammar target is potentially ambiguous +// with `BitstringLiteral`. +mode ARBITRARY_STRING; + ARBITRARY_STRING_WHITESPACE: [ \t\r\n]+ -> skip; + // allow ``"str"`` and ``'str'``; + StringLiteral: ('"' ~["\r\t\n]+? '"' | '\'' ~['\r\t\n]+? '\'') -> popMode; + // A different lexer mode to swap to when we need handle tokens on a line basis // rather than the default arbitrary-whitespace-based tokenisation. This is diff --git a/source/grammar/qasm3Parser.g4 b/source/grammar/qasm3Parser.g4 index 1fec69b..439845d 100644 --- a/source/grammar/qasm3Parser.g4 +++ b/source/grammar/qasm3Parser.g4 @@ -4,7 +4,7 @@ options { tokenVocab = qasm3Lexer; } -program: version? statement* EOF; +program: version? statementOrScope* EOF; version: OPENQASM VersionSpecifier SEMICOLON; // A statement is any valid single statement of an OpenQASM 3 program, with the @@ -43,11 +43,12 @@ statement: | quantumDeclarationStatement | resetStatement | returnStatement + | switchStatement | whileStatement ) ; annotation: AnnotationKeyword RemainingLineContent?; -scope: LBRACE statement* RBRACE; +scope: LBRACE statementOrScope* RBRACE; pragma: PRAGMA RemainingLineContent; statementOrScope: statement | scope; @@ -67,6 +68,11 @@ forStatement: FOR scalarType Identifier IN (setExpression | LBRACKET rangeExpres ifStatement: IF LPAREN expression RPAREN if_body=statementOrScope (ELSE else_body=statementOrScope)?; returnStatement: RETURN (expression | measureExpression)? SEMICOLON; whileStatement: WHILE LPAREN expression RPAREN body=statementOrScope; +switchStatement: SWITCH LPAREN expression RPAREN LBRACE switchCaseItem* RBRACE; +switchCaseItem: + CASE expressionList scope + | DEFAULT scope +; // Quantum directive statements. barrierStatement: BARRIER gateOperandList? SEMICOLON; diff --git a/source/openpulse/ANTLR_VERSIONS.txt b/source/openpulse/ANTLR_VERSIONS.txt index cb19581..a2dee2f 100644 --- a/source/openpulse/ANTLR_VERSIONS.txt +++ b/source/openpulse/ANTLR_VERSIONS.txt @@ -3,3 +3,5 @@ 4.9.2 4.10.1 4.11.1 +4.12.0 +4.13.0 diff --git a/source/openpulse/openpulse/__init__.py b/source/openpulse/openpulse/__init__.py index 31dd615..1f7420e 100644 --- a/source/openpulse/openpulse/__init__.py +++ b/source/openpulse/openpulse/__init__.py @@ -14,7 +14,14 @@ the :obj:`~parser.parse` function. """ +__all__ = [ + "ast", + "parser", + "spec", + "parse", +] + __version__ = "0.5.0" -from . import ast, parser +from . import ast, parser, spec from .parser import parse diff --git a/source/openpulse/openpulse/parser.py b/source/openpulse/openpulse/parser.py index 5a4e65e..79c704f 100644 --- a/source/openpulse/openpulse/parser.py +++ b/source/openpulse/openpulse/parser.py @@ -24,7 +24,7 @@ ] from contextlib import contextmanager -from typing import List +from typing import List, Union try: from antlr4 import CommonTokenStream, InputStream, ParserRuleContext @@ -129,6 +129,13 @@ def _in_loop(self): for scope in reversed(self._current_context()) ) + def _parse_scoped_statements( + self, node: Union[qasm3Parser.ScopeContext, qasm3Parser.StatementOrScopeContext] + ) -> List[ast.Statement]: + with self._push_scope(node.parentCtx): + block = self.visit(node) + return block.statements if isinstance(block, ast.CompoundStatement) else [block] + @span def _visitPulseType(self, ctx: openpulseParser.ScalarTypeContext): if ctx.WAVEFORM(): @@ -305,6 +312,7 @@ def visitOpenpulseStatement(self, ctx: openpulseParser.OpenpulseStatementContext OpenPulseNodeVisitor.visitStatementOrScope = QASMNodeVisitor.visitStatementOrScope OpenPulseNodeVisitor.visitUnaryExpression = QASMNodeVisitor.visitUnaryExpression OpenPulseNodeVisitor.visitWhileStatement = QASMNodeVisitor.visitWhileStatement +OpenPulseNodeVisitor.visitSwitchStatement = QASMNodeVisitor.visitSwitchStatement class CalParser(QASMVisitor[None]): diff --git a/source/openpulse/openpulse/spec.py b/source/openpulse/openpulse/spec.py new file mode 100644 index 0000000..8052e62 --- /dev/null +++ b/source/openpulse/openpulse/spec.py @@ -0,0 +1,20 @@ +""" +===================================================== +Supported Specification Metadata (``openpulse.spec``) +===================================================== + +.. currentmodule:: openpulse.spec + +Metadata on the specifications supported by this package. + +.. autodata:: supported_versions +""" + +__all__ = ["supported_versions"] + +#: A list of specification versions supported by this +#: package. Each version is a :code:`str`, e.g. :code:`'3.0'`. +supported_versions = [ + "3.0", + "3.1", +] diff --git a/source/openpulse/requirements.txt b/source/openpulse/requirements.txt index ff8208b..0a07eb8 100644 --- a/source/openpulse/requirements.txt +++ b/source/openpulse/requirements.txt @@ -1,2 +1,2 @@ antlr4-python3-runtime -openqasm3>=0.5,<1.0 +openqasm3>=1.0.0,<2.0 diff --git a/source/openpulse/setup.cfg b/source/openpulse/setup.cfg index 6d1435e..8ac00e1 100644 --- a/source/openpulse/setup.cfg +++ b/source/openpulse/setup.cfg @@ -27,7 +27,7 @@ include_package_data = True install_requires = antlr4-python3-runtime # __ANTLR_VERSIONS__ importlib_metadata; python_version<'3.10' - openqasm3[parser]>=0.5,<1.0 + openqasm3[parser]>=1.0.0,<2.0 [options.packages.find] exclude = tests* diff --git a/source/openpulse/tests/test_openpulse_parser.py b/source/openpulse/tests/test_openpulse_parser.py index b3efd3d..a80a03a 100644 --- a/source/openpulse/tests/test_openpulse_parser.py +++ b/source/openpulse/tests/test_openpulse_parser.py @@ -13,6 +13,7 @@ ClassicalAssignment, ClassicalDeclaration, ComplexType, + CompoundStatement, DurationType, ExpressionStatement, ExternArgument, @@ -32,6 +33,7 @@ QuantumBarrier, RangeDefinition, ReturnStatement, + SwitchStatement, UnaryExpression, UnaryOperator, WaveformType, @@ -329,6 +331,43 @@ def test_pragma_in_cal_block(): assert _remove_spans(program) == expected +def test_switch_in_cal_block(): + p = """ + cal { + switch (x) { + case 1, 2 {} + case 3 {} + default {} + } + } + """ + + program = parse(p) + expected = Program( + statements=[ + CalibrationStatement( + body=[ + SwitchStatement( + target=Identifier(name="x"), + cases=[ + ( + [IntegerLiteral(value=1), IntegerLiteral(value=2)], + CompoundStatement(statements=[]), + ), + ( + [IntegerLiteral(value=3)], + CompoundStatement(statements=[]), + ), + ], + default=CompoundStatement(statements=[]), + ) + ] + ) + ] + ) + assert _remove_spans(program) == expected + + @pytest.mark.parametrize( "p", [ diff --git a/source/openpulse/tests/test_openpulse_printer.py b/source/openpulse/tests/test_openpulse_printer.py index 3ef4c80..0fcc7be 100644 --- a/source/openpulse/tests/test_openpulse_printer.py +++ b/source/openpulse/tests/test_openpulse_printer.py @@ -30,6 +30,18 @@ } """, """ + cal { + switch (x) { + case 1, 2 { + } + case 3 { + } + default { + } + } + } + """, + """ defcal rz(angle[20] theta) q { return shift_phase(drive(q), -theta); } @@ -44,6 +56,18 @@ } """, """ + defcal x q { + switch (x) { + case 1, 2 { + } + case 3 { + } + default { + } + } + } + """, + """ if (i == 0) { do_if_zero(); } @@ -80,6 +104,16 @@ def my_subroutine(int[32] i, qubit q) -> bit { return measure q; } """, + """ + switch (x) { + case 1, 2 { + } + case 3 { + } + default { + } + } + """, "extern my_extern(float[32], duration) -> duration;", "H $0;", "v = (x + y) * z;", # test add parens when needed diff --git a/source/openpulse/tests/test_spec.py b/source/openpulse/tests/test_spec.py new file mode 100644 index 0000000..f2a6cc9 --- /dev/null +++ b/source/openpulse/tests/test_spec.py @@ -0,0 +1,17 @@ +import re + + +from openpulse.spec import supported_versions + + +class TestSupportedVersions: + SPEC_VERSION_RE = r"^(?P[0-9]+)\.(?P[0-9]+)$" + + def test_supported_versions(self): + assert supported_versions == ["3.0", "3.1"] + + def test_types(self): + assert all(type(x) is str for x in supported_versions) # noqa + + def test_version_formats(self): + assert all(re.match(self.SPEC_VERSION_RE, x) for x in supported_versions)