diff --git a/src/QsCompiler/CommandLineTool/Commands/Diagnose.cs b/src/QsCompiler/CommandLineTool/Commands/Diagnose.cs index 6fc6433f27..8692affa79 100644 --- a/src/QsCompiler/CommandLineTool/Commands/Diagnose.cs +++ b/src/QsCompiler/CommandLineTool/Commands/Diagnose.cs @@ -173,13 +173,13 @@ private static void PrintGeneratedQs(IEnumerable evaluatedTree, Com if (Options.IsCodeSnippet(file)) { var subtree = evaluatedTree.Select(ns => FilterBySourceFile.Apply(ns, file)).Where(ns => ns.Elements.Any()); - var code = new string[] { "" }.Concat(StripSnippetWrapping(subtree).Select(FormatCompilation.FormatStatement)); + var code = new string[] { "" }.Concat(StripSnippetWrapping(subtree).Select(SyntaxTreeToQsharp.Default.ToCode)); logger.Log(InformationCode.FormattedQsCode, Enumerable.Empty(), messageParam: code.ToArray()); } else { var imports = evaluatedTree.ToImmutableDictionary(ns => ns.Name, ns => compilation.OpenDirectives(file, ns.Name).ToImmutableArray()); - SyntaxTreeToQs.Apply(out List, string>> generated, evaluatedTree, (file, imports)); + SyntaxTreeToQsharp.Apply(out List, string>> generated, evaluatedTree, (file, imports)); var code = new string[] { "" }.Concat(generated.Single().Values.Select(nsCode => $"{nsCode}{Environment.NewLine}")); logger.Log(InformationCode.FormattedQsCode, Enumerable.Empty(), file.Value, messageParam: code.ToArray()); }; diff --git a/src/QsCompiler/CommandLineTool/Commands/Format.cs b/src/QsCompiler/CommandLineTool/Commands/Format.cs index 5d73c53c2c..59494a4097 100644 --- a/src/QsCompiler/CommandLineTool/Commands/Format.cs +++ b/src/QsCompiler/CommandLineTool/Commands/Format.cs @@ -60,18 +60,6 @@ public static string UpdateArrayLiterals(string fileContent) } - /// - /// Returns formatted Q# code for the given statement. - /// Throws an ArgumentNullException if the given statement is null. - /// - internal static string FormatStatement(QsStatement statement) - { - if (statement == null) throw new ArgumentNullException(nameof(statement)); - var ToCode = new ScopeToQs(); - ToCode.onStatement(statement); - return ToCode.Output; - } - /// /// Generates formatted Q# code based on the part of the syntax tree that corresponds to each file in the given compilation. /// If the id of a file is consistent with the one assigned to a code snippet, @@ -85,13 +73,13 @@ private static IEnumerable GenerateQsCode(Compilation compilation, NonNu if (Options.IsCodeSnippet(file)) { var subtree = compilation.SyntaxTree.Values.Select(ns => FilterBySourceFile.Apply(ns, file)).Where(ns => ns.Elements.Any()); - return DiagnoseCompilation.StripSnippetWrapping(subtree).Select(FormatStatement); + return DiagnoseCompilation.StripSnippetWrapping(subtree).Select(SyntaxTreeToQsharp.Default.ToCode); } else { var imports = compilation.SyntaxTree.Values .ToImmutableDictionary(ns => ns.Name, ns => compilation.OpenDirectives(file, ns.Name).ToImmutableArray()); - var success = SyntaxTreeToQs.Apply(out List, string>> generated, compilation.SyntaxTree.Values, (file, imports)); + var success = SyntaxTreeToQsharp.Apply(out List, string>> generated, compilation.SyntaxTree.Values, (file, imports)); if (!success) logger?.Log(WarningCode.UnresolvedItemsInGeneratedQs, Enumerable.Empty(), file.Value); return generated.Single().Select(entry => diff --git a/src/QsCompiler/CompilationManager/EditorSupport/CodeActions.cs b/src/QsCompiler/CompilationManager/EditorSupport/CodeActions.cs index a605170f0e..aee6fbebf4 100644 --- a/src/QsCompiler/CompilationManager/EditorSupport/CodeActions.cs +++ b/src/QsCompiler/CompilationManager/EditorSupport/CodeActions.cs @@ -218,11 +218,10 @@ private static IEnumerable OpenDirectiveSuggestions(this FileContentMa // update deprecated operation characteristics syntax - var typeToQs = new ExpressionTypeToQs(new ExpressionToQs()); string CharacteristicsAnnotation(Characteristics c) { - typeToQs.onCharacteristicsExpression(SymbolResolution.ResolveCharacteristics(c)); - return $"{Keywords.qsCharacteristics.id} {typeToQs.Output}"; + var charEx = SyntaxTreeToQsharp.CharacteristicsExpression(SymbolResolution.ResolveCharacteristics(c)); + return charEx == null ? "" : $"{Keywords.qsCharacteristics.id} {charEx}"; } var suggestionsForOpCharacteristics = deprecatedOpCharacteristics.SelectMany(d => diff --git a/src/QsCompiler/CompilationManager/EditorSupport/SymbolInformation.cs b/src/QsCompiler/CompilationManager/EditorSupport/SymbolInformation.cs index ecb0387970..501d7d73c7 100644 --- a/src/QsCompiler/CompilationManager/EditorSupport/SymbolInformation.cs +++ b/src/QsCompiler/CompilationManager/EditorSupport/SymbolInformation.cs @@ -213,7 +213,7 @@ internal static bool TryGetReferences( .Where(spec => spec.SourceFile.Value == file.FileName.Value) .SelectMany(spec => spec.Implementation is SpecializationImplementation.Provided impl && spec.Location.IsValue - ? IdentifierLocation.Find(definition.Item.Item1, impl.Item2, file.FileName, spec.Location.Item.Offset) + ? IdentifierReferences.Find(definition.Item.Item1, impl.Item2, file.FileName, spec.Location.Item.Offset) : ImmutableArray.Empty) .Distinct().Select(AsLocation); } @@ -223,7 +223,7 @@ spec.Implementation is SpecializationImplementation.Provided impl && spec.Locati var statements = implementation.StatementsAfterDeclaration(defStart.Subtract(specPos)); var scope = new QsScope(statements.ToImmutableArray(), locals); var rootOffset = DiagnosticTools.AsTuple(specPos); - referenceLocations = IdentifierLocation.Find(definition.Item.Item1, scope, file.FileName, rootOffset).Distinct().Select(AsLocation); + referenceLocations = IdentifierReferences.Find(definition.Item.Item1, scope, file.FileName, rootOffset).Distinct().Select(AsLocation); } declarationLocation = AsLocation(file.FileName, definition.Item.Item2, defRange); return true; diff --git a/src/QsCompiler/Compiler/RewriteSteps/ClassicallyControlled.cs b/src/QsCompiler/Compiler/RewriteSteps/ClassicallyControlled.cs index dd7088f28f..01410bf9ff 100644 --- a/src/QsCompiler/Compiler/RewriteSteps/ClassicallyControlled.cs +++ b/src/QsCompiler/Compiler/RewriteSteps/ClassicallyControlled.cs @@ -5,7 +5,7 @@ using System.Linq; using Microsoft.Quantum.QsCompiler.DataTypes; using Microsoft.Quantum.QsCompiler.SyntaxTree; -using Microsoft.Quantum.QsCompiler.Transformations.ClassicallyControlledTransformation; +using Microsoft.Quantum.QsCompiler.Transformations.ClassicallyControlled; namespace Microsoft.Quantum.QsCompiler.BuiltInRewriteSteps @@ -28,7 +28,7 @@ public ClassicallyControlled() public bool Transformation(QsCompilation compilation, out QsCompilation transformed) { - transformed = ClassicallyControlledTransformation.Apply(compilation); + transformed = ReplaceClassicalControl.Apply(compilation); return true; } diff --git a/src/QsCompiler/Compiler/RewriteSteps/IntrinsicResolution.cs b/src/QsCompiler/Compiler/RewriteSteps/IntrinsicResolution.cs index 01acc81c90..724efbb601 100644 --- a/src/QsCompiler/Compiler/RewriteSteps/IntrinsicResolution.cs +++ b/src/QsCompiler/Compiler/RewriteSteps/IntrinsicResolution.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using Microsoft.Quantum.QsCompiler.SyntaxTree; -using Microsoft.Quantum.QsCompiler.Transformations.IntrinsicResolutionTransformation; +using Microsoft.Quantum.QsCompiler.Transformations.IntrinsicResolution; namespace Microsoft.Quantum.QsCompiler.BuiltInRewriteSteps @@ -35,7 +35,7 @@ public IntrinsicResolution(QsCompilation environment) public bool Transformation(QsCompilation compilation, out QsCompilation transformed) { - transformed = IntrinsicResolutionTransformation.Apply(this.Environment, compilation); + transformed = ReplaceWithTargetIntrinsics.Apply(this.Environment, compilation); return true; } diff --git a/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs b/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs index d44fe3afef..e510b22b0e 100644 --- a/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs +++ b/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs @@ -5,7 +5,7 @@ using System.Linq; using Microsoft.Quantum.QsCompiler.SyntaxTree; using Microsoft.Quantum.QsCompiler.Transformations.Monomorphization; -using Microsoft.Quantum.QsCompiler.Transformations.MonomorphizationValidation; +using Microsoft.Quantum.QsCompiler.Transformations.Monomorphization.Validation; namespace Microsoft.Quantum.QsCompiler.BuiltInRewriteSteps @@ -33,7 +33,7 @@ public Monomorphization() public bool Transformation(QsCompilation compilation, out QsCompilation transformed) { - transformed = MonomorphizationTransformation.Apply(compilation); + transformed = Monomorphize.Apply(compilation); return true; } @@ -42,7 +42,7 @@ public bool PreconditionVerification(QsCompilation compilation) => public bool PostconditionVerification(QsCompilation compilation) { - try { MonomorphizationValidationTransformation.Apply(compilation); } + try { ValidateMonomorphization.Apply(compilation); } catch { return false; } return true; } diff --git a/src/QsCompiler/Core/Core.fsproj b/src/QsCompiler/Core/Core.fsproj index 9acdc86e4d..20bd2717c1 100644 --- a/src/QsCompiler/Core/Core.fsproj +++ b/src/QsCompiler/Core/Core.fsproj @@ -13,13 +13,13 @@ DelaySign.fs - + + + - - - - + + diff --git a/src/QsCompiler/Core/DeclarationHeaders.fs b/src/QsCompiler/Core/DeclarationHeaders.fs index 7ea14d39bc..ee46da56c0 100644 --- a/src/QsCompiler/Core/DeclarationHeaders.fs +++ b/src/QsCompiler/Core/DeclarationHeaders.fs @@ -162,7 +162,7 @@ type CallableDeclarationHeader = { static member FromJson json = let info = {IsMutable = false; HasLocalQuantumDependency = false} let rec setInferredInfo = function // no need to raise an error if anything needs to be set - the info above is always correct - | QsTuple ts -> (ts |> Seq.map setInferredInfo).ToImmutableArray() |> QsTuple + | QsTuple ts -> ts |> Seq.map setInferredInfo |> ImmutableArray.CreateRange |> QsTuple | QsTupleItem (decl : LocalVariableDeclaration<_>) -> QsTupleItem {decl with InferredInformation = info} // we need to make sure that all fields that could possibly be null after deserializing // due to changes of fields over time are initialized to a proper value diff --git a/src/QsCompiler/Core/ExpressionTransformation.fs b/src/QsCompiler/Core/ExpressionTransformation.fs index b51723fcf4..d4bcb1ad9b 100644 --- a/src/QsCompiler/Core/ExpressionTransformation.fs +++ b/src/QsCompiler/Core/ExpressionTransformation.fs @@ -3,377 +3,393 @@ namespace Microsoft.Quantum.QsCompiler.Transformations.Core +open System +open System.Collections.Generic open System.Collections.Immutable open System.Numerics open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree +open Microsoft.Quantum.QsCompiler.Transformations.Core.Utils -type private ExpressionKind = QsExpressionKind -type private ExpressionType = QsTypeKind +type private ExpressionKind = + QsExpressionKind -/// Convention: -/// All methods starting with "on" implement the transformation for an expression of a certain kind. -/// All methods starting with "before" group a set of statements, and are called before applying the transformation -/// even if the corresponding transformation routine (starting with "on") is overridden. -[] -type ExpressionKindTransformation(?enable) = - let enable = defaultArg enable true +type ExpressionKindTransformationBase internal (options : TransformationOptions, _internal_) = + + let missingTransformation name _ = new InvalidOperationException(sprintf "No %s transformation has been specified." name) |> raise + let Node = if options.Rebuild then Fold else Walk - abstract member ExpressionTransformation : TypedExpression -> TypedExpression - abstract member TypeTransformation : ResolvedType -> ResolvedType + member val internal TypeTransformationHandle = missingTransformation "type" with get, set + member val internal ExpressionTransformationHandle = missingTransformation "expression" with get, set - abstract member beforeCallLike : TypedExpression * TypedExpression -> TypedExpression * TypedExpression - default this.beforeCallLike (method, arg) = (method, arg) + member this.Types = this.TypeTransformationHandle() + member this.Expressions = this.ExpressionTransformationHandle() - abstract member beforeFunctorApplication : TypedExpression -> TypedExpression - default this.beforeFunctorApplication ex = ex + new (expressionTransformation : unit -> ExpressionTransformationBase, typeTransformation : unit -> TypeTransformationBase, options) as this = + new ExpressionKindTransformationBase(options, "_internal_") then + this.TypeTransformationHandle <- typeTransformation + this.ExpressionTransformationHandle <- expressionTransformation - abstract member beforeModifierApplication : TypedExpression -> TypedExpression - default this.beforeModifierApplication ex = ex + new (options : TransformationOptions) as this = + new ExpressionKindTransformationBase(options, "_internal_") then + let typeTransformation = new TypeTransformationBase(options) + let expressionTransformation = new ExpressionTransformationBase((fun _ -> this), (fun _ -> this.Types), options) + this.TypeTransformationHandle <- fun _ -> typeTransformation + this.ExpressionTransformationHandle <- fun _ -> expressionTransformation - abstract member beforeBinaryOperatorExpression : TypedExpression * TypedExpression -> TypedExpression * TypedExpression - default this.beforeBinaryOperatorExpression (lhs, rhs) = (lhs, rhs) + new (expressionTransformation : unit -> ExpressionTransformationBase, typeTransformation : unit -> TypeTransformationBase) = + new ExpressionKindTransformationBase(expressionTransformation, typeTransformation, TransformationOptions.Default) - abstract member beforeUnaryOperatorExpression : TypedExpression -> TypedExpression - default this.beforeUnaryOperatorExpression ex = ex + new () = new ExpressionKindTransformationBase (TransformationOptions.Default) - abstract member onIdentifier : Identifier * QsNullable> -> ExpressionKind - default this.onIdentifier (sym, tArgs) = Identifier (sym, tArgs |> QsNullable<_>.Map (fun ts -> (ts |> Seq.map this.TypeTransformation).ToImmutableArray())) + // nodes containing subexpressions or subtypes - abstract member onOperationCall : TypedExpression * TypedExpression -> ExpressionKind - default this.onOperationCall (method, arg) = CallLikeExpression (this.ExpressionTransformation method, this.ExpressionTransformation arg) + abstract member OnIdentifier : Identifier * QsNullable> -> ExpressionKind + default this.OnIdentifier (sym, tArgs) = + let tArgs = tArgs |> QsNullable<_>.Map (fun ts -> ts |> Seq.map this.Types.OnType |> ImmutableArray.CreateRange) + Identifier |> Node.BuildOr InvalidExpr (sym, tArgs) - abstract member onFunctionCall : TypedExpression * TypedExpression -> ExpressionKind - default this.onFunctionCall (method, arg) = CallLikeExpression (this.ExpressionTransformation method, this.ExpressionTransformation arg) + abstract member OnOperationCall : TypedExpression * TypedExpression -> ExpressionKind + default this.OnOperationCall (method, arg) = + let method, arg = this.Expressions.OnTypedExpression method, this.Expressions.OnTypedExpression arg + CallLikeExpression |> Node.BuildOr InvalidExpr (method, arg) - abstract member onPartialApplication : TypedExpression * TypedExpression -> ExpressionKind - default this.onPartialApplication (method, arg) = CallLikeExpression (this.ExpressionTransformation method, this.ExpressionTransformation arg) + abstract member OnFunctionCall : TypedExpression * TypedExpression -> ExpressionKind + default this.OnFunctionCall (method, arg) = + let method, arg = this.Expressions.OnTypedExpression method, this.Expressions.OnTypedExpression arg + CallLikeExpression |> Node.BuildOr InvalidExpr (method, arg) - abstract member onAdjointApplication : TypedExpression -> ExpressionKind - default this.onAdjointApplication ex = AdjointApplication (this.ExpressionTransformation ex) + abstract member OnPartialApplication : TypedExpression * TypedExpression -> ExpressionKind + default this.OnPartialApplication (method, arg) = + let method, arg = this.Expressions.OnTypedExpression method, this.Expressions.OnTypedExpression arg + CallLikeExpression |> Node.BuildOr InvalidExpr (method, arg) - abstract member onControlledApplication : TypedExpression -> ExpressionKind - default this.onControlledApplication ex = ControlledApplication (this.ExpressionTransformation ex) - - abstract member onUnwrapApplication : TypedExpression -> ExpressionKind - default this.onUnwrapApplication ex = UnwrapApplication (this.ExpressionTransformation ex) - - abstract member onUnitValue : unit -> ExpressionKind - default this.onUnitValue () = ExpressionKind.UnitValue - - abstract member onMissingExpression : unit -> ExpressionKind - default this.onMissingExpression () = MissingExpr - - abstract member onInvalidExpression : unit -> ExpressionKind - default this.onInvalidExpression () = InvalidExpr - - abstract member onValueTuple : ImmutableArray -> ExpressionKind - default this.onValueTuple vs = ValueTuple ((vs |> Seq.map this.ExpressionTransformation).ToImmutableArray()) - - abstract member onArrayItem : TypedExpression * TypedExpression -> ExpressionKind - default this.onArrayItem (arr, idx) = ArrayItem (this.ExpressionTransformation arr, this.ExpressionTransformation idx) - - abstract member onNamedItem : TypedExpression * Identifier -> ExpressionKind - default this.onNamedItem (ex, acc) = NamedItem (this.ExpressionTransformation ex, acc) - - abstract member onValueArray : ImmutableArray -> ExpressionKind - default this.onValueArray vs = ValueArray ((vs |> Seq.map this.ExpressionTransformation).ToImmutableArray()) - - abstract member onNewArray : ResolvedType * TypedExpression -> ExpressionKind - default this.onNewArray (bt, idx) = NewArray (this.TypeTransformation bt, this.ExpressionTransformation idx) - - abstract member onIntLiteral : int64 -> ExpressionKind - default this.onIntLiteral i = IntLiteral i - - abstract member onBigIntLiteral : BigInteger -> ExpressionKind - default this.onBigIntLiteral b = BigIntLiteral b - - abstract member onDoubleLiteral : double -> ExpressionKind - default this.onDoubleLiteral d = DoubleLiteral d - - abstract member onBoolLiteral : bool -> ExpressionKind - default this.onBoolLiteral b = BoolLiteral b - - abstract member onResultLiteral : QsResult -> ExpressionKind - default this.onResultLiteral r = ResultLiteral r - - abstract member onPauliLiteral : QsPauli -> ExpressionKind - default this.onPauliLiteral p = PauliLiteral p - - abstract member onStringLiteral : NonNullable * ImmutableArray -> ExpressionKind - default this.onStringLiteral (s, exs) = StringLiteral (s, (exs |> Seq.map this.ExpressionTransformation).ToImmutableArray()) - - abstract member onRangeLiteral : TypedExpression * TypedExpression -> ExpressionKind - default this.onRangeLiteral (lhs, rhs) = RangeLiteral (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onCopyAndUpdateExpression : TypedExpression * TypedExpression * TypedExpression -> ExpressionKind - default this.onCopyAndUpdateExpression (lhs, accEx, rhs) = CopyAndUpdate (this.ExpressionTransformation lhs, this.ExpressionTransformation accEx, this.ExpressionTransformation rhs) - - abstract member onConditionalExpression : TypedExpression * TypedExpression * TypedExpression -> ExpressionKind - default this.onConditionalExpression (cond, ifTrue, ifFalse) = CONDITIONAL (this.ExpressionTransformation cond, this.ExpressionTransformation ifTrue, this.ExpressionTransformation ifFalse) - - abstract member onEquality : TypedExpression * TypedExpression -> ExpressionKind - default this.onEquality (lhs, rhs) = EQ (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onInequality : TypedExpression * TypedExpression -> ExpressionKind - default this.onInequality (lhs, rhs) = NEQ (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onLessThan : TypedExpression * TypedExpression -> ExpressionKind - default this.onLessThan (lhs, rhs) = LT (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onLessThanOrEqual : TypedExpression * TypedExpression -> ExpressionKind - default this.onLessThanOrEqual (lhs, rhs) = LTE (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onGreaterThan : TypedExpression * TypedExpression -> ExpressionKind - default this.onGreaterThan (lhs, rhs) = GT (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onGreaterThanOrEqual : TypedExpression * TypedExpression -> ExpressionKind - default this.onGreaterThanOrEqual (lhs, rhs) = GTE (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onLogicalAnd : TypedExpression * TypedExpression -> ExpressionKind - default this.onLogicalAnd (lhs, rhs) = AND (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onLogicalOr : TypedExpression * TypedExpression -> ExpressionKind - default this.onLogicalOr (lhs, rhs) = OR (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onAddition : TypedExpression * TypedExpression -> ExpressionKind - default this.onAddition (lhs, rhs) = ADD (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onSubtraction : TypedExpression * TypedExpression -> ExpressionKind - default this.onSubtraction (lhs, rhs) = SUB (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onMultiplication : TypedExpression * TypedExpression -> ExpressionKind - default this.onMultiplication (lhs, rhs) = MUL (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onDivision : TypedExpression * TypedExpression -> ExpressionKind - default this.onDivision (lhs, rhs) = DIV (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onExponentiate : TypedExpression * TypedExpression -> ExpressionKind - default this.onExponentiate (lhs, rhs) = POW (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onModulo : TypedExpression * TypedExpression -> ExpressionKind - default this.onModulo (lhs, rhs) = MOD (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onLeftShift : TypedExpression * TypedExpression -> ExpressionKind - default this.onLeftShift (lhs, rhs) = LSHIFT (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onRightShift : TypedExpression * TypedExpression -> ExpressionKind - default this.onRightShift (lhs, rhs) = RSHIFT (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onBitwiseExclusiveOr : TypedExpression * TypedExpression -> ExpressionKind - default this.onBitwiseExclusiveOr (lhs, rhs) = BXOR (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onBitwiseOr : TypedExpression * TypedExpression -> ExpressionKind - default this.onBitwiseOr (lhs, rhs) = BOR (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onBitwiseAnd : TypedExpression * TypedExpression -> ExpressionKind - default this.onBitwiseAnd (lhs, rhs) = BAND (this.ExpressionTransformation lhs, this.ExpressionTransformation rhs) - - abstract member onLogicalNot : TypedExpression -> ExpressionKind - default this.onLogicalNot ex = NOT (this.ExpressionTransformation ex) - - abstract member onNegative : TypedExpression -> ExpressionKind - default this.onNegative ex = NEG (this.ExpressionTransformation ex) - - abstract member onBitwiseNot : TypedExpression -> ExpressionKind - default this.onBitwiseNot ex = BNOT (this.ExpressionTransformation ex) - - - member private this.dispatchCallLikeExpression (method, arg) = + abstract member OnCallLikeExpression : TypedExpression * TypedExpression -> ExpressionKind + default this.OnCallLikeExpression (method, arg) = match method.ResolvedType.Resolution with - | _ when TypedExpression.IsPartialApplication (CallLikeExpression (method, arg)) -> this.onPartialApplication (method, arg) - | ExpressionType.Operation _ -> this.onOperationCall (method, arg) - | _ -> this.onFunctionCall (method, arg) - - abstract member Transform : ExpressionKind -> ExpressionKind - default this.Transform kind = - if not enable then kind else - match kind with - | Identifier (sym, tArgs) -> this.onIdentifier (sym, tArgs) - | CallLikeExpression (method,arg) -> this.dispatchCallLikeExpression ((method, arg) |> this.beforeCallLike) - | AdjointApplication ex -> this.onAdjointApplication (ex |> (this.beforeFunctorApplication >> this.beforeModifierApplication)) - | ControlledApplication ex -> this.onControlledApplication (ex |> (this.beforeFunctorApplication >> this.beforeModifierApplication)) - | UnwrapApplication ex -> this.onUnwrapApplication (ex |> this.beforeModifierApplication) - | UnitValue -> this.onUnitValue () - | MissingExpr -> this.onMissingExpression () - | InvalidExpr -> this.onInvalidExpression () - | ValueTuple vs -> this.onValueTuple vs - | ArrayItem (arr, idx) -> this.onArrayItem (arr, idx) - | NamedItem (ex, acc) -> this.onNamedItem (ex, acc) - | ValueArray vs -> this.onValueArray vs - | NewArray (bt, idx) -> this.onNewArray (bt, idx) - | IntLiteral i -> this.onIntLiteral i - | BigIntLiteral b -> this.onBigIntLiteral b - | DoubleLiteral d -> this.onDoubleLiteral d - | BoolLiteral b -> this.onBoolLiteral b - | ResultLiteral r -> this.onResultLiteral r - | PauliLiteral p -> this.onPauliLiteral p - | StringLiteral (s, exs) -> this.onStringLiteral (s, exs) - | RangeLiteral (lhs, rhs) -> this.onRangeLiteral (lhs, rhs) - | CopyAndUpdate (lhs, accEx, rhs) -> this.onCopyAndUpdateExpression (lhs, accEx, rhs) - | CONDITIONAL (cond, ifTrue, ifFalse) -> this.onConditionalExpression (cond, ifTrue, ifFalse) - | EQ (lhs,rhs) -> this.onEquality ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | NEQ (lhs,rhs) -> this.onInequality ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | LT (lhs,rhs) -> this.onLessThan ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | LTE (lhs,rhs) -> this.onLessThanOrEqual ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | GT (lhs,rhs) -> this.onGreaterThan ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | GTE (lhs,rhs) -> this.onGreaterThanOrEqual ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | AND (lhs,rhs) -> this.onLogicalAnd ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | OR (lhs,rhs) -> this.onLogicalOr ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | ADD (lhs,rhs) -> this.onAddition ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | SUB (lhs,rhs) -> this.onSubtraction ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | MUL (lhs,rhs) -> this.onMultiplication ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | DIV (lhs,rhs) -> this.onDivision ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | POW (lhs,rhs) -> this.onExponentiate ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | MOD (lhs,rhs) -> this.onModulo ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | LSHIFT (lhs,rhs) -> this.onLeftShift ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | RSHIFT (lhs,rhs) -> this.onRightShift ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | BXOR (lhs,rhs) -> this.onBitwiseExclusiveOr ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | BOR (lhs,rhs) -> this.onBitwiseOr ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | BAND (lhs,rhs) -> this.onBitwiseAnd ((lhs, rhs) |> this.beforeBinaryOperatorExpression) - | NOT ex -> this.onLogicalNot (ex |> this.beforeUnaryOperatorExpression) - | NEG ex -> this.onNegative (ex |> this.beforeUnaryOperatorExpression) - | BNOT ex -> this.onBitwiseNot (ex |> this.beforeUnaryOperatorExpression) - - -and ExpressionTypeTransformation(?enable) = - let enable = defaultArg enable true - - abstract member onRangeInformation : QsRangeInfo -> QsRangeInfo - default this.onRangeInformation r = r - - abstract member onCharacteristicsExpression : ResolvedCharacteristics -> ResolvedCharacteristics - default this.onCharacteristicsExpression fs = fs - - abstract member onCallableInformation : CallableInformation -> CallableInformation - default this.onCallableInformation opInfo = - let characteristics = this.onCharacteristicsExpression opInfo.Characteristics - let inferred = opInfo.InferredInformation - CallableInformation.New (characteristics, inferred) - - abstract member onUserDefinedType : UserDefinedType -> ExpressionType - default this.onUserDefinedType udt = - let ns, name = udt.Namespace, udt.Name - let range = this.onRangeInformation udt.Range - UserDefinedType.New (ns, name, range) |> ExpressionType.UserDefinedType - - abstract member onTypeParameter : QsTypeParameter -> ExpressionType - default this.onTypeParameter tp = - let origin = tp.Origin - let name = tp.TypeName - let range = this.onRangeInformation tp.Range - QsTypeParameter.New (origin, name, range) |> ExpressionType.TypeParameter - - abstract member onUnitType : unit -> ExpressionType - default this.onUnitType () = ExpressionType.UnitType - - abstract member onOperation : (ResolvedType * ResolvedType) * CallableInformation -> ExpressionType - default this.onOperation ((it, ot), info) = ExpressionType.Operation ((this.Transform it, this.Transform ot), this.onCallableInformation info) - - abstract member onFunction : ResolvedType * ResolvedType -> ExpressionType - default this.onFunction (it, ot) = ExpressionType.Function (this.Transform it, this.Transform ot) - - abstract member onTupleType : ImmutableArray -> ExpressionType - default this.onTupleType ts = ExpressionType.TupleType ((ts |> Seq.map this.Transform).ToImmutableArray()) - - abstract member onArrayType : ResolvedType -> ExpressionType - default this.onArrayType b = ExpressionType.ArrayType (this.Transform b) - - abstract member onQubit : unit -> ExpressionType - default this.onQubit () = ExpressionType.Qubit - - abstract member onMissingType : unit -> ExpressionType - default this.onMissingType () = ExpressionType.MissingType - - abstract member onInvalidType : unit -> ExpressionType - default this.onInvalidType () = ExpressionType.InvalidType - - abstract member onInt : unit -> ExpressionType - default this.onInt () = ExpressionType.Int - - abstract member onBigInt : unit -> ExpressionType - default this.onBigInt () = ExpressionType.BigInt - - abstract member onDouble : unit -> ExpressionType - default this.onDouble () = ExpressionType.Double - - abstract member onBool : unit -> ExpressionType - default this.onBool () = ExpressionType.Bool - - abstract member onString : unit -> ExpressionType - default this.onString () = ExpressionType.String - - abstract member onResult : unit -> ExpressionType - default this.onResult () = ExpressionType.Result - - abstract member onPauli : unit -> ExpressionType - default this.onPauli () = ExpressionType.Pauli - - abstract member onRange : unit -> ExpressionType - default this.onRange () = ExpressionType.Range - - member this.Transform (t : ResolvedType) = - if not enable then t else - match t.Resolution with - | ExpressionType.UnitType -> this.onUnitType () - | ExpressionType.Operation ((it, ot), fs) -> this.onOperation ((it, ot), fs) - | ExpressionType.Function (it, ot) -> this.onFunction (it, ot) - | ExpressionType.TupleType ts -> this.onTupleType ts - | ExpressionType.ArrayType b -> this.onArrayType b - | ExpressionType.UserDefinedType udt -> this.onUserDefinedType udt - | ExpressionType.TypeParameter tp -> this.onTypeParameter tp - | ExpressionType.Qubit -> this.onQubit () - | ExpressionType.MissingType -> this.onMissingType () - | ExpressionType.InvalidType -> this.onInvalidType () - | ExpressionType.Int -> this.onInt () - | ExpressionType.BigInt -> this.onBigInt () - | ExpressionType.Double -> this.onDouble () - | ExpressionType.Bool -> this.onBool () - | ExpressionType.String -> this.onString () - | ExpressionType.Result -> this.onResult () - | ExpressionType.Pauli -> this.onPauli () - | ExpressionType.Range -> this.onRange () - |> ResolvedType.New - - -and ExpressionTransformation(?enableKindTransformations) = - let enableKind = defaultArg enableKindTransformations true - let typeTransformation = new ExpressionTypeTransformation() - - abstract member Kind : ExpressionKindTransformation - default this.Kind = { - new ExpressionKindTransformation (enableKind) with - override x.ExpressionTransformation ex = this.Transform ex - override x.TypeTransformation t = this.Type.Transform t - } - - abstract member Type : ExpressionTypeTransformation - default this.Type = typeTransformation - - abstract member onRangeInformation : QsNullable -> QsNullable - default this.onRangeInformation r = r - - abstract member onExpressionInformation : InferredExpressionInformation -> InferredExpressionInformation - default this.onExpressionInformation info = info - - abstract member onTypeParamResolutions : ImmutableDictionary<(QsQualifiedName*NonNullable), ResolvedType> -> ImmutableDictionary<(QsQualifiedName*NonNullable), ResolvedType> - default this.onTypeParamResolutions typeParams = + | _ when TypedExpression.IsPartialApplication (CallLikeExpression (method, arg)) -> this.OnPartialApplication (method, arg) + | ExpressionType.Operation _ -> this.OnOperationCall (method, arg) + | _ -> this.OnFunctionCall (method, arg) + + abstract member OnAdjointApplication : TypedExpression -> ExpressionKind + default this.OnAdjointApplication ex = + let ex = this.Expressions.OnTypedExpression ex + AdjointApplication |> Node.BuildOr InvalidExpr ex + + abstract member OnControlledApplication : TypedExpression -> ExpressionKind + default this.OnControlledApplication ex = + let ex = this.Expressions.OnTypedExpression ex + ControlledApplication |> Node.BuildOr InvalidExpr ex + + abstract member OnUnwrapApplication : TypedExpression -> ExpressionKind + default this.OnUnwrapApplication ex = + let ex = this.Expressions.OnTypedExpression ex + UnwrapApplication |> Node.BuildOr InvalidExpr ex + + abstract member OnValueTuple : ImmutableArray -> ExpressionKind + default this.OnValueTuple vs = + let values = vs |> Seq.map this.Expressions.OnTypedExpression |> ImmutableArray.CreateRange + ValueTuple |> Node.BuildOr InvalidExpr values + + abstract member OnArrayItem : TypedExpression * TypedExpression -> ExpressionKind + default this.OnArrayItem (arr, idx) = + let arr, idx = this.Expressions.OnTypedExpression arr, this.Expressions.OnTypedExpression idx + ArrayItem |> Node.BuildOr InvalidExpr (arr, idx) + + abstract member OnNamedItem : TypedExpression * Identifier -> ExpressionKind + default this.OnNamedItem (ex, acc) = + let ex = this.Expressions.OnTypedExpression ex + NamedItem |> Node.BuildOr InvalidExpr (ex, acc) + + abstract member OnValueArray : ImmutableArray -> ExpressionKind + default this.OnValueArray vs = + let values = vs |> Seq.map this.Expressions.OnTypedExpression |> ImmutableArray.CreateRange + ValueArray |> Node.BuildOr InvalidExpr values + + abstract member OnNewArray : ResolvedType * TypedExpression -> ExpressionKind + default this.OnNewArray (bt, idx) = + let bt, idx = this.Types.OnType bt, this.Expressions.OnTypedExpression idx + NewArray |> Node.BuildOr InvalidExpr (bt, idx) + + abstract member OnStringLiteral : NonNullable * ImmutableArray -> ExpressionKind + default this.OnStringLiteral (s, exs) = + let exs = exs |> Seq.map this.Expressions.OnTypedExpression |> ImmutableArray.CreateRange + StringLiteral |> Node.BuildOr InvalidExpr (s, exs) + + abstract member OnRangeLiteral : TypedExpression * TypedExpression -> ExpressionKind + default this.OnRangeLiteral (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + RangeLiteral |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnCopyAndUpdateExpression : TypedExpression * TypedExpression * TypedExpression -> ExpressionKind + default this.OnCopyAndUpdateExpression (lhs, accEx, rhs) = + let lhs, accEx, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression accEx, this.Expressions.OnTypedExpression rhs + CopyAndUpdate |> Node.BuildOr InvalidExpr (lhs, accEx, rhs) + + abstract member OnConditionalExpression : TypedExpression * TypedExpression * TypedExpression -> ExpressionKind + default this.OnConditionalExpression (cond, ifTrue, ifFalse) = + let cond, ifTrue, ifFalse = this.Expressions.OnTypedExpression cond, this.Expressions.OnTypedExpression ifTrue, this.Expressions.OnTypedExpression ifFalse + CONDITIONAL |> Node.BuildOr InvalidExpr (cond, ifTrue, ifFalse) + + abstract member OnEquality : TypedExpression * TypedExpression -> ExpressionKind + default this.OnEquality (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + EQ |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnInequality : TypedExpression * TypedExpression -> ExpressionKind + default this.OnInequality (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + NEQ |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnLessThan : TypedExpression * TypedExpression -> ExpressionKind + default this.OnLessThan (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + LT |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnLessThanOrEqual : TypedExpression * TypedExpression -> ExpressionKind + default this.OnLessThanOrEqual (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + LTE |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnGreaterThan : TypedExpression * TypedExpression -> ExpressionKind + default this.OnGreaterThan (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + GT |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnGreaterThanOrEqual : TypedExpression * TypedExpression -> ExpressionKind + default this.OnGreaterThanOrEqual (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + GTE |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnLogicalAnd : TypedExpression * TypedExpression -> ExpressionKind + default this.OnLogicalAnd (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + AND |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnLogicalOr : TypedExpression * TypedExpression -> ExpressionKind + default this.OnLogicalOr (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + OR |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnAddition : TypedExpression * TypedExpression -> ExpressionKind + default this.OnAddition (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + ADD |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnSubtraction : TypedExpression * TypedExpression -> ExpressionKind + default this.OnSubtraction (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + SUB |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnMultiplication : TypedExpression * TypedExpression -> ExpressionKind + default this.OnMultiplication (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + MUL |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnDivision : TypedExpression * TypedExpression -> ExpressionKind + default this.OnDivision (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + DIV |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnExponentiate : TypedExpression * TypedExpression -> ExpressionKind + default this.OnExponentiate (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + POW |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnModulo : TypedExpression * TypedExpression -> ExpressionKind + default this.OnModulo (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + MOD |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnLeftShift : TypedExpression * TypedExpression -> ExpressionKind + default this.OnLeftShift (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + LSHIFT |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnRightShift : TypedExpression * TypedExpression -> ExpressionKind + default this.OnRightShift (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + RSHIFT |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnBitwiseExclusiveOr : TypedExpression * TypedExpression -> ExpressionKind + default this.OnBitwiseExclusiveOr (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + BXOR |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnBitwiseOr : TypedExpression * TypedExpression -> ExpressionKind + default this.OnBitwiseOr (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + BOR |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnBitwiseAnd : TypedExpression * TypedExpression -> ExpressionKind + default this.OnBitwiseAnd (lhs, rhs) = + let lhs, rhs = this.Expressions.OnTypedExpression lhs, this.Expressions.OnTypedExpression rhs + BAND |> Node.BuildOr InvalidExpr (lhs, rhs) + + abstract member OnLogicalNot : TypedExpression -> ExpressionKind + default this.OnLogicalNot ex = + let ex = this.Expressions.OnTypedExpression ex + NOT |> Node.BuildOr InvalidExpr ex + + abstract member OnNegative : TypedExpression -> ExpressionKind + default this.OnNegative ex = + let ex = this.Expressions.OnTypedExpression ex + NEG |> Node.BuildOr InvalidExpr ex + + abstract member OnBitwiseNot : TypedExpression -> ExpressionKind + default this.OnBitwiseNot ex = + let ex = this.Expressions.OnTypedExpression ex + BNOT |> Node.BuildOr InvalidExpr ex + + + // leaf nodes + + abstract member OnUnitValue : unit -> ExpressionKind + default this.OnUnitValue () = ExpressionKind.UnitValue + + abstract member OnMissingExpression : unit -> ExpressionKind + default this.OnMissingExpression () = MissingExpr + + abstract member OnInvalidExpression : unit -> ExpressionKind + default this.OnInvalidExpression () = InvalidExpr + + abstract member OnIntLiteral : int64 -> ExpressionKind + default this.OnIntLiteral i = IntLiteral i + + abstract member OnBigIntLiteral : BigInteger -> ExpressionKind + default this.OnBigIntLiteral b = BigIntLiteral b + + abstract member OnDoubleLiteral : double -> ExpressionKind + default this.OnDoubleLiteral d = DoubleLiteral d + + abstract member OnBoolLiteral : bool -> ExpressionKind + default this.OnBoolLiteral b = BoolLiteral b + + abstract member OnResultLiteral : QsResult -> ExpressionKind + default this.OnResultLiteral r = ResultLiteral r + + abstract member OnPauliLiteral : QsPauli -> ExpressionKind + default this.OnPauliLiteral p = PauliLiteral p + + + // transformation root called on each node + + abstract member OnExpressionKind : ExpressionKind -> ExpressionKind + default this.OnExpressionKind kind = + if not options.Enable then kind else + let transformed = kind |> function + | Identifier (sym, tArgs) -> this.OnIdentifier (sym, tArgs) + | CallLikeExpression (method,arg) -> this.OnCallLikeExpression (method, arg) + | AdjointApplication ex -> this.OnAdjointApplication (ex) + | ControlledApplication ex -> this.OnControlledApplication (ex) + | UnwrapApplication ex -> this.OnUnwrapApplication (ex) + | UnitValue -> this.OnUnitValue () + | MissingExpr -> this.OnMissingExpression () + | InvalidExpr -> this.OnInvalidExpression () + | ValueTuple vs -> this.OnValueTuple vs + | ArrayItem (arr, idx) -> this.OnArrayItem (arr, idx) + | NamedItem (ex, acc) -> this.OnNamedItem (ex, acc) + | ValueArray vs -> this.OnValueArray vs + | NewArray (bt, idx) -> this.OnNewArray (bt, idx) + | IntLiteral i -> this.OnIntLiteral i + | BigIntLiteral b -> this.OnBigIntLiteral b + | DoubleLiteral d -> this.OnDoubleLiteral d + | BoolLiteral b -> this.OnBoolLiteral b + | ResultLiteral r -> this.OnResultLiteral r + | PauliLiteral p -> this.OnPauliLiteral p + | StringLiteral (s, exs) -> this.OnStringLiteral (s, exs) + | RangeLiteral (lhs, rhs) -> this.OnRangeLiteral (lhs, rhs) + | CopyAndUpdate (lhs, accEx, rhs) -> this.OnCopyAndUpdateExpression (lhs, accEx, rhs) + | CONDITIONAL (cond, ifTrue, ifFalse) -> this.OnConditionalExpression (cond, ifTrue, ifFalse) + | EQ (lhs,rhs) -> this.OnEquality (lhs, rhs) + | NEQ (lhs,rhs) -> this.OnInequality (lhs, rhs) + | LT (lhs,rhs) -> this.OnLessThan (lhs, rhs) + | LTE (lhs,rhs) -> this.OnLessThanOrEqual (lhs, rhs) + | GT (lhs,rhs) -> this.OnGreaterThan (lhs, rhs) + | GTE (lhs,rhs) -> this.OnGreaterThanOrEqual (lhs, rhs) + | AND (lhs,rhs) -> this.OnLogicalAnd (lhs, rhs) + | OR (lhs,rhs) -> this.OnLogicalOr (lhs, rhs) + | ADD (lhs,rhs) -> this.OnAddition (lhs, rhs) + | SUB (lhs,rhs) -> this.OnSubtraction (lhs, rhs) + | MUL (lhs,rhs) -> this.OnMultiplication (lhs, rhs) + | DIV (lhs,rhs) -> this.OnDivision (lhs, rhs) + | POW (lhs,rhs) -> this.OnExponentiate (lhs, rhs) + | MOD (lhs,rhs) -> this.OnModulo (lhs, rhs) + | LSHIFT (lhs,rhs) -> this.OnLeftShift (lhs, rhs) + | RSHIFT (lhs,rhs) -> this.OnRightShift (lhs, rhs) + | BXOR (lhs,rhs) -> this.OnBitwiseExclusiveOr (lhs, rhs) + | BOR (lhs,rhs) -> this.OnBitwiseOr (lhs, rhs) + | BAND (lhs,rhs) -> this.OnBitwiseAnd (lhs, rhs) + | NOT ex -> this.OnLogicalNot (ex) + | NEG ex -> this.OnNegative (ex) + | BNOT ex -> this.OnBitwiseNot (ex) + id |> Node.BuildOr kind transformed + + +and ExpressionTransformationBase internal (options : TransformationOptions, _internal_) = + + let missingTransformation name _ = new InvalidOperationException(sprintf "No %s transformation has been specified." name) |> raise + let Node = if options.Rebuild then Fold else Walk + + member val internal TypeTransformationHandle = missingTransformation "type" with get, set + member val internal ExpressionKindTransformationHandle = missingTransformation "expression kind" with get, set + + member this.Types = this.TypeTransformationHandle() + member this.ExpressionKinds = this.ExpressionKindTransformationHandle() + + new (exkindTransformation : unit -> ExpressionKindTransformationBase, typeTransformation : unit -> TypeTransformationBase, options : TransformationOptions) as this = + new ExpressionTransformationBase(options, "_internal_") then + this.TypeTransformationHandle <- typeTransformation + this.ExpressionKindTransformationHandle <- exkindTransformation + + new (options : TransformationOptions) as this = + new ExpressionTransformationBase(options, "_internal_") then + let typeTransformation = new TypeTransformationBase(options) + let exprKindTransformation = new ExpressionKindTransformationBase((fun _ -> this), (fun _ -> this.Types), options) + this.TypeTransformationHandle <- fun _ -> typeTransformation + this.ExpressionKindTransformationHandle <- fun _ -> exprKindTransformation + + new (exkindTransformation : unit -> ExpressionKindTransformationBase, typeTransformation : unit -> TypeTransformationBase) = + new ExpressionTransformationBase(exkindTransformation, typeTransformation, TransformationOptions.Default) + + new () = new ExpressionTransformationBase(TransformationOptions.Default) + + + // supplementary expression information + + abstract member OnRangeInformation : QsNullable -> QsNullable + default this.OnRangeInformation r = r + + abstract member OnExpressionInformation : InferredExpressionInformation -> InferredExpressionInformation + default this.OnExpressionInformation info = info + + + // nodes containing subexpressions or subtypes + + /// If DisableRebuild is set to true, this method won't walk the types in the dictionary. + abstract member OnTypeParamResolutions : ImmutableDictionary<(QsQualifiedName*NonNullable), ResolvedType> -> ImmutableDictionary<(QsQualifiedName*NonNullable), ResolvedType> + default this.OnTypeParamResolutions typeParams = let asTypeParameter (key) = QsTypeParameter.New (fst key, snd key, Null) let filteredTypeParams = typeParams - |> Seq.map (fun kv -> this.Type.onTypeParameter (kv.Key |> asTypeParameter), kv.Value) - |> Seq.choose (function | TypeParameter tp, value -> Some ((tp.Origin, tp.TypeName), this.Type.Transform value) | _ -> None) - filteredTypeParams.ToImmutableDictionary (fst,snd) - - abstract member Transform : TypedExpression -> TypedExpression - default this.Transform (ex : TypedExpression) = - let range = this.onRangeInformation ex.Range - let typeParamResolutions = this.onTypeParamResolutions ex.TypeParameterResolutions - let kind = this.Kind.Transform ex.Expression - let exType = this.Type.Transform ex.ResolvedType - let inferredInfo = this.onExpressionInformation ex.InferredInformation - TypedExpression.New (kind, typeParamResolutions, exType, inferredInfo, range) + |> Seq.map (fun kv -> this.Types.OnTypeParameter (kv.Key |> asTypeParameter), kv.Value) + |> Seq.choose (function | TypeParameter tp, value -> Some ((tp.Origin, tp.TypeName), this.Types.OnType value) | _ -> None) + |> Seq.map (fun (key, value) -> new KeyValuePair<_,_>(key, value)) + ImmutableDictionary.CreateRange |> Node.BuildOr typeParams filteredTypeParams + + + // transformation root called on each node + + abstract member OnTypedExpression : TypedExpression -> TypedExpression + default this.OnTypedExpression (ex : TypedExpression) = + if not options.Enable then ex else + let range = this.OnRangeInformation ex.Range + let typeParamResolutions = this.OnTypeParamResolutions ex.TypeParameterResolutions + let kind = this.ExpressionKinds.OnExpressionKind ex.Expression + let exType = this.Types.OnType ex.ResolvedType + let inferredInfo = this.OnExpressionInformation ex.InferredInformation + TypedExpression.New |> Node.BuildOr ex (kind, typeParamResolutions, exType, inferredInfo, range) diff --git a/src/QsCompiler/Core/ExpressionWalker.fs b/src/QsCompiler/Core/ExpressionWalker.fs deleted file mode 100644 index ccaa4a85de..0000000000 --- a/src/QsCompiler/Core/ExpressionWalker.fs +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace Microsoft.Quantum.QsCompiler.Transformations.Core - -open System.Collections.Immutable -open System.Numerics -open Microsoft.Quantum.QsCompiler.DataTypes -open Microsoft.Quantum.QsCompiler.SyntaxExtensions -open Microsoft.Quantum.QsCompiler.SyntaxTokens -open Microsoft.Quantum.QsCompiler.SyntaxTree - - -/// Convention: -/// All methods starting with "on" implement the walk for an expression of a certain kind. -/// All methods starting with "before" group a set of statements, and are called before walking the set -/// even if the corresponding walk routine (starting with "on") is overridden. -/// -/// These classes differ from the "*Transformation" classes in that these classes visit every node in the -/// syntax tree, but don't create a new syntax tree, while the Transformation classes generate a new (or -/// at least partially new) tree from the old one. -/// Effectively, the Transformation classes implement fold, while the Walker classes implement iter. -[] -type ExpressionKindWalker(?enable) = - let enable = defaultArg enable true - - abstract member ExpressionWalker : TypedExpression -> unit - abstract member TypeWalker : ResolvedType -> unit - - abstract member beforeCallLike : TypedExpression * TypedExpression -> unit - default this.beforeCallLike (method, arg) = () - - abstract member beforeFunctorApplication : TypedExpression -> unit - default this.beforeFunctorApplication ex = () - - abstract member beforeModifierApplication : TypedExpression -> unit - default this.beforeModifierApplication ex = () - - abstract member beforeBinaryOperatorExpression : TypedExpression * TypedExpression -> unit - default this.beforeBinaryOperatorExpression (lhs, rhs) = () - - abstract member beforeUnaryOperatorExpression : TypedExpression -> unit - default this.beforeUnaryOperatorExpression ex = () - - - abstract member onIdentifier : Identifier * QsNullable> -> unit - default this.onIdentifier (sym, tArgs) = tArgs |> QsNullable<_>.Iter (fun ts -> (ts |> Seq.iter this.TypeWalker)) - - abstract member onOperationCall : TypedExpression * TypedExpression -> unit - default this.onOperationCall (method, arg) = - this.ExpressionWalker method - this.ExpressionWalker arg - - abstract member onFunctionCall : TypedExpression * TypedExpression -> unit - default this.onFunctionCall (method, arg) = - this.ExpressionWalker method - this.ExpressionWalker arg - - abstract member onPartialApplication : TypedExpression * TypedExpression -> unit - default this.onPartialApplication (method, arg) = - this.ExpressionWalker method - this.ExpressionWalker arg - - abstract member onAdjointApplication : TypedExpression -> unit - default this.onAdjointApplication ex = this.ExpressionWalker ex - - abstract member onControlledApplication : TypedExpression -> unit - default this.onControlledApplication ex = this.ExpressionWalker ex - - abstract member onUnwrapApplication : TypedExpression -> unit - default this.onUnwrapApplication ex = this.ExpressionWalker ex - - abstract member onUnitValue : unit -> unit - default this.onUnitValue () = () - - abstract member onMissingExpression : unit -> unit - default this.onMissingExpression () = () - - abstract member onInvalidExpression : unit -> unit - default this.onInvalidExpression () = () - - abstract member onValueTuple : ImmutableArray -> unit - default this.onValueTuple vs = vs |> Seq.iter this.ExpressionWalker - - abstract member onArrayItem : TypedExpression * TypedExpression -> unit - default this.onArrayItem (arr, idx) = - this.ExpressionWalker arr - this.ExpressionWalker idx - - abstract member onNamedItem : TypedExpression * Identifier -> unit - default this.onNamedItem (ex, acc) = this.ExpressionWalker ex - - abstract member onValueArray : ImmutableArray -> unit - default this.onValueArray vs = vs |> Seq.iter this.ExpressionWalker - - abstract member onNewArray : ResolvedType * TypedExpression -> unit - default this.onNewArray (bt, idx) = - this.TypeWalker bt - this.ExpressionWalker idx - - abstract member onIntLiteral : int64 -> unit - default this.onIntLiteral i = () - - abstract member onBigIntLiteral : BigInteger -> unit - default this.onBigIntLiteral b = () - - abstract member onDoubleLiteral : double -> unit - default this.onDoubleLiteral d = () - - abstract member onBoolLiteral : bool -> unit - default this.onBoolLiteral b = () - - abstract member onResultLiteral : QsResult -> unit - default this.onResultLiteral r = () - - abstract member onPauliLiteral : QsPauli -> unit - default this.onPauliLiteral p = () - - abstract member onStringLiteral : NonNullable * ImmutableArray -> unit - default this.onStringLiteral (s, exs) = exs |> Seq.iter this.ExpressionWalker - - abstract member onRangeLiteral : TypedExpression * TypedExpression -> unit - default this.onRangeLiteral (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onCopyAndUpdateExpression : TypedExpression * TypedExpression * TypedExpression -> unit - default this.onCopyAndUpdateExpression (lhs, accEx, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker accEx - this.ExpressionWalker rhs - - abstract member onConditionalExpression : TypedExpression * TypedExpression * TypedExpression -> unit - default this.onConditionalExpression (cond, ifTrue, ifFalse) = - this.ExpressionWalker cond - this.ExpressionWalker ifTrue - this.ExpressionWalker ifFalse - - abstract member onEquality : TypedExpression * TypedExpression -> unit - default this.onEquality (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onInequality : TypedExpression * TypedExpression -> unit - default this.onInequality (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onLessThan : TypedExpression * TypedExpression -> unit - default this.onLessThan (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onLessThanOrEqual : TypedExpression * TypedExpression -> unit - default this.onLessThanOrEqual (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onGreaterThan : TypedExpression * TypedExpression -> unit - default this.onGreaterThan (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onGreaterThanOrEqual : TypedExpression * TypedExpression -> unit - default this.onGreaterThanOrEqual (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onLogicalAnd : TypedExpression * TypedExpression -> unit - default this.onLogicalAnd (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onLogicalOr : TypedExpression * TypedExpression -> unit - default this.onLogicalOr (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onAddition : TypedExpression * TypedExpression -> unit - default this.onAddition (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onSubtraction : TypedExpression * TypedExpression -> unit - default this.onSubtraction (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onMultiplication : TypedExpression * TypedExpression -> unit - default this.onMultiplication (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onDivision : TypedExpression * TypedExpression -> unit - default this.onDivision (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onExponentiate : TypedExpression * TypedExpression -> unit - default this.onExponentiate (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onModulo : TypedExpression * TypedExpression -> unit - default this.onModulo (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onLeftShift : TypedExpression * TypedExpression -> unit - default this.onLeftShift (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onRightShift : TypedExpression * TypedExpression -> unit - default this.onRightShift (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onBitwiseExclusiveOr : TypedExpression * TypedExpression -> unit - default this.onBitwiseExclusiveOr (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onBitwiseOr : TypedExpression * TypedExpression -> unit - default this.onBitwiseOr (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onBitwiseAnd : TypedExpression * TypedExpression -> unit - default this.onBitwiseAnd (lhs, rhs) = - this.ExpressionWalker lhs - this.ExpressionWalker rhs - - abstract member onLogicalNot : TypedExpression -> unit - default this.onLogicalNot ex = this.ExpressionWalker ex - - abstract member onNegative : TypedExpression -> unit - default this.onNegative ex = this.ExpressionWalker ex - - abstract member onBitwiseNot : TypedExpression -> unit - default this.onBitwiseNot ex = this.ExpressionWalker ex - - - member private this.dispatchCallLikeExpression (method, arg) = - match method.ResolvedType.Resolution with - | _ when TypedExpression.IsPartialApplication (CallLikeExpression (method, arg)) -> this.onPartialApplication (method, arg) - | ExpressionType.Operation _ -> this.onOperationCall (method, arg) - | _ -> this.onFunctionCall (method, arg) - - abstract member Walk : ExpressionKind -> unit - default this.Walk kind = - if not enable then () else - match kind with - | Identifier (sym, tArgs) -> this.onIdentifier (sym, tArgs) - | CallLikeExpression (method,arg) -> this.beforeCallLike (method, arg) - this.dispatchCallLikeExpression (method, arg) - | AdjointApplication ex -> this.beforeFunctorApplication ex - this.beforeModifierApplication ex - this.onAdjointApplication ex - | ControlledApplication ex -> this.beforeFunctorApplication ex - this.beforeModifierApplication ex - this.onControlledApplication ex - | UnwrapApplication ex -> this.beforeModifierApplication ex - this.onUnwrapApplication ex - | UnitValue -> this.onUnitValue () - | MissingExpr -> this.onMissingExpression () - | InvalidExpr -> this.onInvalidExpression () - | ValueTuple vs -> this.onValueTuple vs - | ArrayItem (arr, idx) -> this.onArrayItem (arr, idx) - | NamedItem (ex, acc) -> this.onNamedItem (ex, acc) - | ValueArray vs -> this.onValueArray vs - | NewArray (bt, idx) -> this.onNewArray (bt, idx) - | IntLiteral i -> this.onIntLiteral i - | BigIntLiteral b -> this.onBigIntLiteral b - | DoubleLiteral d -> this.onDoubleLiteral d - | BoolLiteral b -> this.onBoolLiteral b - | ResultLiteral r -> this.onResultLiteral r - | PauliLiteral p -> this.onPauliLiteral p - | StringLiteral (s, exs) -> this.onStringLiteral (s, exs) - | RangeLiteral (lhs, rhs) -> this.onRangeLiteral (lhs, rhs) - | CopyAndUpdate (lhs, accEx, rhs) -> this.onCopyAndUpdateExpression (lhs, accEx, rhs) - | CONDITIONAL (cond, ifTrue, ifFalse) -> this.onConditionalExpression (cond, ifTrue, ifFalse) - | EQ (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onEquality (lhs, rhs) - | NEQ (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onInequality (lhs, rhs) - | LT (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onLessThan (lhs, rhs) - | LTE (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onLessThanOrEqual (lhs, rhs) - | GT (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onGreaterThan (lhs, rhs) - | GTE (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onGreaterThanOrEqual (lhs, rhs) - | AND (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onLogicalAnd (lhs, rhs) - | OR (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onLogicalOr (lhs, rhs) - | ADD (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onAddition (lhs, rhs) - | SUB (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onSubtraction (lhs, rhs) - | MUL (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onMultiplication (lhs, rhs) - | DIV (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onDivision (lhs, rhs) - | POW (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onExponentiate (lhs, rhs) - | MOD (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onModulo (lhs, rhs) - | LSHIFT (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onLeftShift (lhs, rhs) - | RSHIFT (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onRightShift (lhs, rhs) - | BXOR (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onBitwiseExclusiveOr (lhs, rhs) - | BOR (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onBitwiseOr (lhs, rhs) - | BAND (lhs,rhs) -> this.beforeBinaryOperatorExpression (lhs, rhs) - this.onBitwiseAnd (lhs, rhs) - | NOT ex -> this.beforeUnaryOperatorExpression ex - this.onLogicalNot ex - | NEG ex -> this.beforeUnaryOperatorExpression ex - this.onNegative ex - | BNOT ex -> this.beforeUnaryOperatorExpression ex - this.onBitwiseNot ex - - -and ExpressionTypeWalker(?enable) = - let enable = defaultArg enable true - - abstract member onRangeInformation : QsRangeInfo -> unit - default this.onRangeInformation r =() - - abstract member onCharacteristicsExpression : ResolvedCharacteristics -> unit - default this.onCharacteristicsExpression fs = () - - abstract member onCallableInformation : CallableInformation -> unit - default this.onCallableInformation opInfo = - this.onCharacteristicsExpression opInfo.Characteristics - - abstract member onUserDefinedType : UserDefinedType -> unit - default this.onUserDefinedType udt = - this.onRangeInformation udt.Range - - abstract member onTypeParameter : QsTypeParameter -> unit - default this.onTypeParameter tp = - this.onRangeInformation tp.Range - - abstract member onUnitType : unit -> unit - default this.onUnitType () = () - - abstract member onOperation : (ResolvedType * ResolvedType) * CallableInformation -> unit - default this.onOperation ((it, ot), info) = - this.Walk it - this.Walk ot - this.onCallableInformation info - - abstract member onFunction : ResolvedType * ResolvedType -> unit - default this.onFunction (it, ot) = - this.Walk it - this.Walk ot - - abstract member onTupleType : ImmutableArray -> unit - default this.onTupleType ts = ts |> Seq.iter this.Walk - - abstract member onArrayType : ResolvedType -> unit - default this.onArrayType b = this.Walk b - - abstract member onQubit : unit -> unit - default this.onQubit () = () - - abstract member onMissingType : unit -> unit - default this.onMissingType () = () - - abstract member onInvalidType : unit -> unit - default this.onInvalidType () = () - - abstract member onInt : unit -> unit - default this.onInt () = () - - abstract member onBigInt : unit -> unit - default this.onBigInt () = () - - abstract member onDouble : unit -> unit - default this.onDouble () = () - - abstract member onBool : unit -> unit - default this.onBool () = () - - abstract member onString : unit -> unit - default this.onString () = () - - abstract member onResult : unit -> unit - default this.onResult () = () - - abstract member onPauli : unit -> unit - default this.onPauli () = () - - abstract member onRange : unit -> unit - default this.onRange () = () - - member this.Walk (t : ResolvedType) = - if not enable then () else - match t.Resolution with - | ExpressionType.UnitType -> this.onUnitType () - | ExpressionType.Operation ((it, ot), fs) -> this.onOperation ((it, ot), fs) - | ExpressionType.Function (it, ot) -> this.onFunction (it, ot) - | ExpressionType.TupleType ts -> this.onTupleType ts - | ExpressionType.ArrayType b -> this.onArrayType b - | ExpressionType.UserDefinedType udt -> this.onUserDefinedType udt - | ExpressionType.TypeParameter tp -> this.onTypeParameter tp - | ExpressionType.Qubit -> this.onQubit () - | ExpressionType.MissingType -> this.onMissingType () - | ExpressionType.InvalidType -> this.onInvalidType () - | ExpressionType.Int -> this.onInt () - | ExpressionType.BigInt -> this.onBigInt () - | ExpressionType.Double -> this.onDouble () - | ExpressionType.Bool -> this.onBool () - | ExpressionType.String -> this.onString () - | ExpressionType.Result -> this.onResult () - | ExpressionType.Pauli -> this.onPauli () - | ExpressionType.Range -> this.onRange () - - -and ExpressionWalker(?enableKindWalkers) = - let enableKind = defaultArg enableKindWalkers true - let typeWalker = new ExpressionTypeWalker() - - abstract member Kind : ExpressionKindWalker - default this.Kind = { - new ExpressionKindWalker (enableKind) with - override x.ExpressionWalker ex = this.Walk ex - override x.TypeWalker t = this.Type.Walk t - } - - abstract member Type : ExpressionTypeWalker - default this.Type = typeWalker - - abstract member onRangeInformation : QsNullable -> unit - default this.onRangeInformation r = () - - abstract member onExpressionInformation : InferredExpressionInformation -> unit - default this.onExpressionInformation info = () - - abstract member Walk : TypedExpression -> unit - default this.Walk (ex : TypedExpression) = - this.onRangeInformation ex.Range - this.Kind.Walk ex.Expression - this.Type.Walk ex.ResolvedType - this.onExpressionInformation ex.InferredInformation diff --git a/src/QsCompiler/Core/NamespaceTransformation.fs b/src/QsCompiler/Core/NamespaceTransformation.fs new file mode 100644 index 0000000000..471d6bf6d0 --- /dev/null +++ b/src/QsCompiler/Core/NamespaceTransformation.fs @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Quantum.QsCompiler.Transformations.Core + +open System +open System.Collections.Immutable +open System.Linq +open Microsoft.Quantum.QsCompiler.DataTypes +open Microsoft.Quantum.QsCompiler.SyntaxExtensions +open Microsoft.Quantum.QsCompiler.SyntaxTokens +open Microsoft.Quantum.QsCompiler.SyntaxTree +open Microsoft.Quantum.QsCompiler.Transformations.Core.Utils + +type QsArgumentTuple = QsTuple> + + +type NamespaceTransformationBase internal (options : TransformationOptions, _internal_) = + + let missingTransformation name _ = new InvalidOperationException(sprintf "No %s transformation has been specified." name) |> raise + let Node = if options.Rebuild then Fold else Walk + + member val internal StatementTransformationHandle = missingTransformation "statement" with get, set + member this.Statements = this.StatementTransformationHandle() + + new (statementTransformation : unit -> StatementTransformationBase, options : TransformationOptions) as this = + new NamespaceTransformationBase(options, "_internal_") then + this.StatementTransformationHandle <- statementTransformation + + new (options : TransformationOptions) as this = + new NamespaceTransformationBase(options, "_internal_") then + let statementTransformation = new StatementTransformationBase(options) + this.StatementTransformationHandle <- fun _ -> statementTransformation + + new (statementTransformation : unit -> StatementTransformationBase) = + new NamespaceTransformationBase(statementTransformation, TransformationOptions.Default) + + new () = new NamespaceTransformationBase(TransformationOptions.Default) + + + // subconstructs used within declarations + + abstract member OnLocation : QsNullable -> QsNullable + default this.OnLocation l = l + + abstract member OnDocumentation : ImmutableArray -> ImmutableArray + default this.OnDocumentation doc = doc + + abstract member OnSourceFile : NonNullable -> NonNullable + default this.OnSourceFile f = f + + abstract member OnAttribute : QsDeclarationAttribute -> QsDeclarationAttribute + default this.OnAttribute att = att + + abstract member OnTypeItems : QsTuple -> QsTuple + default this.OnTypeItems tItem = + match tItem with + | QsTuple items as original -> + let transformed = items |> Seq.map this.OnTypeItems |> ImmutableArray.CreateRange + QsTuple |> Node.BuildOr original transformed + | QsTupleItem (Anonymous itemType) as original -> + let t = this.Statements.Expressions.Types.OnType itemType + QsTupleItem << Anonymous |> Node.BuildOr original t + | QsTupleItem (Named item) as original -> + let loc = item.Position, item.Range + let t = this.Statements.Expressions.Types.OnType item.Type + let info = this.Statements.Expressions.OnExpressionInformation item.InferredInformation + QsTupleItem << Named << LocalVariableDeclaration<_>.New info.IsMutable |> Node.BuildOr original (loc, item.VariableName, t, info.HasLocalQuantumDependency) + + abstract member OnArgumentTuple : QsArgumentTuple -> QsArgumentTuple + default this.OnArgumentTuple arg = + match arg with + | QsTuple items as original -> + let transformed = items |> Seq.map this.OnArgumentTuple |> ImmutableArray.CreateRange + QsTuple |> Node.BuildOr original transformed + | QsTupleItem item as original -> + let loc = item.Position, item.Range + let t = this.Statements.Expressions.Types.OnType item.Type + let info = this.Statements.Expressions.OnExpressionInformation item.InferredInformation + QsTupleItem << LocalVariableDeclaration<_>.New info.IsMutable |> Node.BuildOr original (loc, item.VariableName, t, info.HasLocalQuantumDependency) + + abstract member OnSignature : ResolvedSignature -> ResolvedSignature + default this.OnSignature (s : ResolvedSignature) = + let typeParams = s.TypeParameters + let argType = this.Statements.Expressions.Types.OnType s.ArgumentType + let returnType = this.Statements.Expressions.Types.OnType s.ReturnType + let info = this.Statements.Expressions.Types.OnCallableInformation s.Information + ResolvedSignature.New |> Node.BuildOr s ((argType, returnType), info, typeParams) + + + // specialization declarations and implementations + + abstract member OnProvidedImplementation : QsArgumentTuple * QsScope -> QsArgumentTuple * QsScope + default this.OnProvidedImplementation (argTuple, body) = + let argTuple = this.OnArgumentTuple argTuple + let body = this.Statements.OnScope body + argTuple, body + + abstract member OnSelfInverseDirective : unit -> unit + default this.OnSelfInverseDirective () = () + + abstract member OnInvertDirective : unit -> unit + default this.OnInvertDirective () = () + + abstract member OnDistributeDirective : unit -> unit + default this.OnDistributeDirective () = () + + abstract member OnInvalidGeneratorDirective : unit -> unit + default this.OnInvalidGeneratorDirective () = () + + abstract member OnExternalImplementation : unit -> unit + default this.OnExternalImplementation () = () + + abstract member OnIntrinsicImplementation : unit -> unit + default this.OnIntrinsicImplementation () = () + + abstract member OnGeneratedImplementation : QsGeneratorDirective -> QsGeneratorDirective + default this.OnGeneratedImplementation (directive : QsGeneratorDirective) = + match directive with + | SelfInverse -> this.OnSelfInverseDirective (); SelfInverse + | Invert -> this.OnInvertDirective(); Invert + | Distribute -> this.OnDistributeDirective(); Distribute + | InvalidGenerator -> this.OnInvalidGeneratorDirective(); InvalidGenerator + + abstract member OnSpecializationImplementation : SpecializationImplementation -> SpecializationImplementation + default this.OnSpecializationImplementation (implementation : SpecializationImplementation) = + let Build kind transformed = kind |> Node.BuildOr implementation transformed + match implementation with + | External -> this.OnExternalImplementation(); External + | Intrinsic -> this.OnIntrinsicImplementation(); Intrinsic + | Generated dir -> this.OnGeneratedImplementation dir |> Build Generated + | Provided (argTuple, body) -> this.OnProvidedImplementation (argTuple, body) |> Build Provided + + /// This method is defined for the sole purpose of eliminating code duplication for each of the specialization kinds. + /// It is hence not intended and should never be needed for public use. + member private this.OnSpecializationKind (spec : QsSpecialization) = + let source = this.OnSourceFile spec.SourceFile + let loc = this.OnLocation spec.Location + let attributes = spec.Attributes |> Seq.map this.OnAttribute |> ImmutableArray.CreateRange + let typeArgs = spec.TypeArguments |> QsNullable<_>.Map (fun args -> args |> Seq.map this.Statements.Expressions.Types.OnType |> ImmutableArray.CreateRange) + let signature = this.OnSignature spec.Signature + let impl = this.OnSpecializationImplementation spec.Implementation + let doc = this.OnDocumentation spec.Documentation + let comments = spec.Comments + QsSpecialization.New spec.Kind (source, loc) |> Node.BuildOr spec (spec.Parent, attributes, typeArgs, signature, impl, doc, comments) + + abstract member OnBodySpecialization : QsSpecialization -> QsSpecialization + default this.OnBodySpecialization spec = this.OnSpecializationKind spec + + abstract member OnAdjointSpecialization : QsSpecialization -> QsSpecialization + default this.OnAdjointSpecialization spec = this.OnSpecializationKind spec + + abstract member OnControlledSpecialization : QsSpecialization -> QsSpecialization + default this.OnControlledSpecialization spec = this.OnSpecializationKind spec + + abstract member OnControlledAdjointSpecialization : QsSpecialization -> QsSpecialization + default this.OnControlledAdjointSpecialization spec = this.OnSpecializationKind spec + + abstract member OnSpecializationDeclaration : QsSpecialization -> QsSpecialization + default this.OnSpecializationDeclaration (spec : QsSpecialization) = + match spec.Kind with + | QsSpecializationKind.QsBody -> this.OnBodySpecialization spec + | QsSpecializationKind.QsAdjoint -> this.OnAdjointSpecialization spec + | QsSpecializationKind.QsControlled -> this.OnControlledSpecialization spec + | QsSpecializationKind.QsControlledAdjoint -> this.OnControlledAdjointSpecialization spec + + + // type and callable declarations + + /// This method is defined for the sole purpose of eliminating code duplication for each of the callable kinds. + /// It is hence not intended and should never be needed for public use. + member private this.OnCallableKind (c : QsCallable) = + let source = this.OnSourceFile c.SourceFile + let loc = this.OnLocation c.Location + let attributes = c.Attributes |> Seq.map this.OnAttribute |> ImmutableArray.CreateRange + let signature = this.OnSignature c.Signature + let argTuple = this.OnArgumentTuple c.ArgumentTuple + let specializations = c.Specializations |> Seq.sortBy (fun c -> c.Kind) |> Seq.map this.OnSpecializationDeclaration |> ImmutableArray.CreateRange + let doc = this.OnDocumentation c.Documentation + let comments = c.Comments + QsCallable.New c.Kind (source, loc) |> Node.BuildOr c (c.FullName, attributes, argTuple, signature, specializations, doc, comments) + + abstract member OnOperation : QsCallable -> QsCallable + default this.OnOperation c = this.OnCallableKind c + + abstract member OnFunction : QsCallable -> QsCallable + default this.OnFunction c = this.OnCallableKind c + + abstract member OnTypeConstructor : QsCallable -> QsCallable + default this.OnTypeConstructor c = this.OnCallableKind c + + abstract member OnCallableDeclaration : QsCallable -> QsCallable + default this.OnCallableDeclaration (c : QsCallable) = + match c.Kind with + | QsCallableKind.Function -> this.OnFunction c + | QsCallableKind.Operation -> this.OnOperation c + | QsCallableKind.TypeConstructor -> this.OnTypeConstructor c + + abstract member OnTypeDeclaration : QsCustomType -> QsCustomType + default this.OnTypeDeclaration t = + let source = this.OnSourceFile t.SourceFile + let loc = this.OnLocation t.Location + let attributes = t.Attributes |> Seq.map this.OnAttribute |> ImmutableArray.CreateRange + let underlyingType = this.Statements.Expressions.Types.OnType t.Type + let typeItems = this.OnTypeItems t.TypeItems + let doc = this.OnDocumentation t.Documentation + let comments = t.Comments + QsCustomType.New (source, loc) |> Node.BuildOr t (t.FullName, attributes, typeItems, underlyingType, doc, comments) + + + // transformation roots called on each namespace or namespace element + + abstract member OnNamespaceElement : QsNamespaceElement -> QsNamespaceElement + default this.OnNamespaceElement element = + if not options.Enable then element else + match element with + | QsCustomType t -> t |> this.OnTypeDeclaration |> QsCustomType + | QsCallable c -> c |> this.OnCallableDeclaration |> QsCallable + + abstract member OnNamespace : QsNamespace -> QsNamespace + default this.OnNamespace ns = + if not options.Enable then ns else + let name = ns.Name + let doc = ns.Documentation.AsEnumerable().SelectMany(fun entry -> + entry |> Seq.map (fun doc -> entry.Key, this.OnDocumentation doc)).ToLookup(fst, snd) + let elements = ns.Elements |> Seq.map this.OnNamespaceElement |> ImmutableArray.CreateRange + QsNamespace.New |> Node.BuildOr ns (name, elements, doc) + diff --git a/src/QsCompiler/Core/StatementTransformation.fs b/src/QsCompiler/Core/StatementTransformation.fs index b902ef149d..83e5d98ce1 100644 --- a/src/QsCompiler/Core/StatementTransformation.fs +++ b/src/QsCompiler/Core/StatementTransformation.fs @@ -3,187 +3,242 @@ namespace Microsoft.Quantum.QsCompiler.Transformations.Core +open System open System.Collections.Immutable open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree +open Microsoft.Quantum.QsCompiler.Transformations.Core.Utils -/// Convention: -/// All methods starting with "on" implement the transformation for a statement of a certain kind. -/// All methods starting with "before" group a set of statements, and are called before applying the transformation -/// even if the corresponding transformation routine (starting with "on") is overridden. -[] -type StatementKindTransformation(?enable) = - let enable = defaultArg enable true +type StatementKindTransformationBase internal (options : TransformationOptions, _internal_) = - abstract member ScopeTransformation : QsScope -> QsScope - abstract member ExpressionTransformation : TypedExpression -> TypedExpression - abstract member TypeTransformation : ResolvedType -> ResolvedType - abstract member LocationTransformation : QsNullable -> QsNullable + let missingTransformation name _ = new InvalidOperationException(sprintf "No %s transformation has been specified." name) |> raise + let Node = if options.Rebuild then Fold else Walk - abstract member onQubitInitializer : ResolvedInitializer -> ResolvedInitializer - default this.onQubitInitializer init = - match init.Resolution with - | SingleQubitAllocation -> SingleQubitAllocation - | QubitRegisterAllocation ex -> QubitRegisterAllocation (this.ExpressionTransformation ex) - | QubitTupleAllocation is -> QubitTupleAllocation ((is |> Seq.map this.onQubitInitializer).ToImmutableArray()) - | InvalidInitializer -> InvalidInitializer - |> ResolvedInitializer.New + member val internal ExpressionTransformationHandle = missingTransformation "expression" with get, set + member val internal StatementTransformationHandle = missingTransformation "statement" with get, set - abstract member beforeVariableDeclaration : SymbolTuple -> SymbolTuple - default this.beforeVariableDeclaration syms = syms + member this.Expressions = this.ExpressionTransformationHandle() + member this.Statements = this.StatementTransformationHandle() - abstract member onSymbolTuple : SymbolTuple -> SymbolTuple - default this.onSymbolTuple syms = syms + new (statementTransformation : unit -> StatementTransformationBase, expressionTransformation : unit -> ExpressionTransformationBase, options : TransformationOptions) as this = + new StatementKindTransformationBase(options, "_internal_") then + this.ExpressionTransformationHandle <- expressionTransformation + this.StatementTransformationHandle <- statementTransformation + new (options : TransformationOptions) as this = + new StatementKindTransformationBase(options, "_internal_") then + let expressionTransformation = new ExpressionTransformationBase(options) + let statementTransformation = new StatementTransformationBase((fun _ -> this), (fun _ -> this.Expressions), options) + this.ExpressionTransformationHandle <- fun _ -> expressionTransformation + this.StatementTransformationHandle <- fun _ -> statementTransformation - abstract member onExpressionStatement : TypedExpression -> QsStatementKind - default this.onExpressionStatement ex = QsExpressionStatement (this.ExpressionTransformation ex) + new (statementTransformation : unit -> StatementTransformationBase, expressionTransformation : unit -> ExpressionTransformationBase) = + new StatementKindTransformationBase(statementTransformation, expressionTransformation, TransformationOptions.Default) - abstract member onReturnStatement : TypedExpression -> QsStatementKind - default this.onReturnStatement ex = QsReturnStatement (this.ExpressionTransformation ex) + new () = new StatementKindTransformationBase(TransformationOptions.Default) - abstract member onFailStatement : TypedExpression -> QsStatementKind - default this.onFailStatement ex = QsFailStatement (this.ExpressionTransformation ex) - abstract member onVariableDeclaration : QsBinding -> QsStatementKind - default this.onVariableDeclaration stm = - let rhs = this.ExpressionTransformation stm.Rhs - let lhs = this.onSymbolTuple stm.Lhs - QsBinding.New stm.Kind (lhs, rhs) |> QsVariableDeclaration + // subconstructs used within statements - abstract member onValueUpdate : QsValueUpdate -> QsStatementKind - default this.onValueUpdate stm = - let rhs = this.ExpressionTransformation stm.Rhs - let lhs = this.ExpressionTransformation stm.Lhs - QsValueUpdate.New (lhs, rhs) |> QsValueUpdate + abstract member OnSymbolTuple : SymbolTuple -> SymbolTuple + default this.OnSymbolTuple syms = syms - abstract member onPositionedBlock : QsNullable * QsPositionedBlock -> QsNullable * QsPositionedBlock - default this.onPositionedBlock (intro : QsNullable, block : QsPositionedBlock) = - let location = this.LocationTransformation block.Location + abstract member OnQubitInitializer : ResolvedInitializer -> ResolvedInitializer + default this.OnQubitInitializer init = + let transformed = init.Resolution |> function + | SingleQubitAllocation -> SingleQubitAllocation + | QubitRegisterAllocation ex as orig -> QubitRegisterAllocation |> Node.BuildOr orig (this.Expressions.OnTypedExpression ex) + | QubitTupleAllocation is as orig -> QubitTupleAllocation |> Node.BuildOr orig (is |> Seq.map this.OnQubitInitializer |> ImmutableArray.CreateRange) + | InvalidInitializer -> InvalidInitializer + ResolvedInitializer.New |> Node.BuildOr init transformed + + abstract member OnPositionedBlock : QsNullable * QsPositionedBlock -> QsNullable * QsPositionedBlock + default this.OnPositionedBlock (intro : QsNullable, block : QsPositionedBlock) = + let location = this.Statements.OnLocation block.Location let comments = block.Comments - let expr = intro |> QsNullable<_>.Map this.ExpressionTransformation - let body = this.ScopeTransformation block.Body - expr, QsPositionedBlock.New comments location body + let expr = intro |> QsNullable<_>.Map this.Expressions.OnTypedExpression + let body = this.Statements.OnScope block.Body + let PositionedBlock (expr, body, location, comments) = expr, QsPositionedBlock.New comments location body + PositionedBlock |> Node.BuildOr (intro, block) (expr, body, location, comments) + + + // statements containing subconstructs or expressions + + abstract member OnVariableDeclaration : QsBinding -> QsStatementKind + default this.OnVariableDeclaration stm = + let rhs = this.Expressions.OnTypedExpression stm.Rhs + let lhs = this.OnSymbolTuple stm.Lhs + QsVariableDeclaration << QsBinding.New stm.Kind |> Node.BuildOr EmptyStatement (lhs, rhs) - abstract member onConditionalStatement : QsConditionalStatement -> QsStatementKind - default this.onConditionalStatement stm = + abstract member OnValueUpdate : QsValueUpdate -> QsStatementKind + default this.OnValueUpdate stm = + let rhs = this.Expressions.OnTypedExpression stm.Rhs + let lhs = this.Expressions.OnTypedExpression stm.Lhs + QsValueUpdate << QsValueUpdate.New |> Node.BuildOr EmptyStatement (lhs, rhs) + + abstract member OnConditionalStatement : QsConditionalStatement -> QsStatementKind + default this.OnConditionalStatement stm = let cases = stm.ConditionalBlocks |> Seq.map (fun (c, b) -> - let cond, block = this.onPositionedBlock (Value c, b) + let cond, block = this.OnPositionedBlock (Value c, b) let invalidCondition () = failwith "missing condition in if-statement" - cond.ValueOrApply invalidCondition, block) - let defaultCase = stm.Default |> QsNullable<_>.Map (fun b -> this.onPositionedBlock (Null, b) |> snd) - QsConditionalStatement.New (cases, defaultCase) |> QsConditionalStatement - - abstract member onForStatement : QsForStatement -> QsStatementKind - default this.onForStatement stm = - let iterVals = this.ExpressionTransformation stm.IterationValues - let loopVar = fst stm.LoopItem |> this.onSymbolTuple - let loopVarType = this.TypeTransformation (snd stm.LoopItem) - let body = this.ScopeTransformation stm.Body - QsForStatement.New ((loopVar, loopVarType), iterVals, body) |> QsForStatement - - abstract member onWhileStatement : QsWhileStatement -> QsStatementKind - default this.onWhileStatement stm = - let condition = this.ExpressionTransformation stm.Condition - let body = this.ScopeTransformation stm.Body - QsWhileStatement.New (condition, body) |> QsWhileStatement - - abstract member onRepeatStatement : QsRepeatStatement -> QsStatementKind - default this.onRepeatStatement stm = - let repeatBlock = this.onPositionedBlock (Null, stm.RepeatBlock) |> snd - let successCondition, fixupBlock = this.onPositionedBlock (Value stm.SuccessCondition, stm.FixupBlock) + cond.ValueOrApply invalidCondition, block) |> ImmutableArray.CreateRange + let defaultCase = stm.Default |> QsNullable<_>.Map (fun b -> this.OnPositionedBlock (Null, b) |> snd) + QsConditionalStatement << QsConditionalStatement.New |> Node.BuildOr EmptyStatement (cases, defaultCase) + + abstract member OnForStatement : QsForStatement -> QsStatementKind + default this.OnForStatement stm = + let iterVals = this.Expressions.OnTypedExpression stm.IterationValues + let loopVar = fst stm.LoopItem |> this.OnSymbolTuple + let loopVarType = this.Expressions.Types.OnType (snd stm.LoopItem) + let body = this.Statements.OnScope stm.Body + QsForStatement << QsForStatement.New |> Node.BuildOr EmptyStatement ((loopVar, loopVarType), iterVals, body) + + abstract member OnWhileStatement : QsWhileStatement -> QsStatementKind + default this.OnWhileStatement stm = + let condition = this.Expressions.OnTypedExpression stm.Condition + let body = this.Statements.OnScope stm.Body + QsWhileStatement << QsWhileStatement.New |> Node.BuildOr EmptyStatement (condition, body) + + abstract member OnRepeatStatement : QsRepeatStatement -> QsStatementKind + default this.OnRepeatStatement stm = + let repeatBlock = this.OnPositionedBlock (Null, stm.RepeatBlock) |> snd + let successCondition, fixupBlock = this.OnPositionedBlock (Value stm.SuccessCondition, stm.FixupBlock) let invalidCondition () = failwith "missing success condition in repeat-statement" - QsRepeatStatement.New (repeatBlock, successCondition.ValueOrApply invalidCondition, fixupBlock) |> QsRepeatStatement + QsRepeatStatement << QsRepeatStatement.New |> Node.BuildOr EmptyStatement (repeatBlock, successCondition.ValueOrApply invalidCondition, fixupBlock) + + abstract member OnConjugation : QsConjugation -> QsStatementKind + default this.OnConjugation stm = + let outer = this.OnPositionedBlock (Null, stm.OuterTransformation) |> snd + let inner = this.OnPositionedBlock (Null, stm.InnerTransformation) |> snd + QsConjugation << QsConjugation.New |> Node.BuildOr EmptyStatement (outer, inner) + + abstract member OnExpressionStatement : TypedExpression -> QsStatementKind + default this.OnExpressionStatement ex = + let transformed = this.Expressions.OnTypedExpression ex + QsExpressionStatement |> Node.BuildOr EmptyStatement transformed + + abstract member OnReturnStatement : TypedExpression -> QsStatementKind + default this.OnReturnStatement ex = + let transformed = this.Expressions.OnTypedExpression ex + QsReturnStatement |> Node.BuildOr EmptyStatement transformed + + abstract member OnFailStatement : TypedExpression -> QsStatementKind + default this.OnFailStatement ex = + let transformed = this.Expressions.OnTypedExpression ex + QsFailStatement |> Node.BuildOr EmptyStatement transformed + + /// This method is defined for the sole purpose of eliminating code duplication for each of the specialization kinds. + /// It is hence not intended and should never be needed for public use. + member private this.OnQubitScopeKind (stm : QsQubitScope) = + let kind = stm.Kind + let rhs = this.OnQubitInitializer stm.Binding.Rhs + let lhs = this.OnSymbolTuple stm.Binding.Lhs + let body = this.Statements.OnScope stm.Body + QsQubitScope << QsQubitScope.New kind |> Node.BuildOr EmptyStatement ((lhs, rhs), body) - abstract member onConjugation : QsConjugation -> QsStatementKind - default this.onConjugation stm = - let outer = this.onPositionedBlock (Null, stm.OuterTransformation) |> snd - let inner = this.onPositionedBlock (Null, stm.InnerTransformation) |> snd - QsConjugation.New (outer, inner) |> QsConjugation + abstract member OnAllocateQubits : QsQubitScope -> QsStatementKind + default this.OnAllocateQubits stm = this.OnQubitScopeKind stm - abstract member onQubitScope : QsQubitScope -> QsStatementKind - default this.onQubitScope (stm : QsQubitScope) = - let kind = stm.Kind - let rhs = this.onQubitInitializer stm.Binding.Rhs - let lhs = this.onSymbolTuple stm.Binding.Lhs - let body = this.ScopeTransformation stm.Body - QsQubitScope.New kind ((lhs, rhs), body) |> QsQubitScope + abstract member OnBorrowQubits : QsQubitScope -> QsStatementKind + default this.OnBorrowQubits stm = this.OnQubitScopeKind stm - abstract member onAllocateQubits : QsQubitScope -> QsStatementKind - default this.onAllocateQubits stm = this.onQubitScope stm + abstract member OnQubitScope : QsQubitScope -> QsStatementKind + default this.OnQubitScope (stm : QsQubitScope) = + match stm.Kind with + | Allocate -> this.OnAllocateQubits stm + | Borrow -> this.OnBorrowQubits stm - abstract member onBorrowQubits : QsQubitScope -> QsStatementKind - default this.onBorrowQubits stm = this.onQubitScope stm + // leaf nodes + + abstract member OnEmptyStatement : unit -> QsStatementKind + default this.OnEmptyStatement () = EmptyStatement - member private this.dispatchQubitScope (stm : QsQubitScope) = - match stm.Kind with - | Allocate -> this.onAllocateQubits stm - | Borrow -> this.onBorrowQubits stm - - abstract member Transform : QsStatementKind -> QsStatementKind - default this.Transform kind = - let beforeBinding (stm : QsBinding) = { stm with Lhs = this.beforeVariableDeclaration stm.Lhs } - let beforeForStatement (stm : QsForStatement) = {stm with LoopItem = (this.beforeVariableDeclaration (fst stm.LoopItem), snd stm.LoopItem)} - let beforeQubitScope (stm : QsQubitScope) = {stm with Binding = {stm.Binding with Lhs = this.beforeVariableDeclaration stm.Binding.Lhs}} - - if not enable then kind else - match kind with - | QsExpressionStatement ex -> this.onExpressionStatement (ex) - | QsReturnStatement ex -> this.onReturnStatement (ex) - | QsFailStatement ex -> this.onFailStatement (ex) - | QsVariableDeclaration stm -> this.onVariableDeclaration (stm |> beforeBinding) - | QsValueUpdate stm -> this.onValueUpdate (stm) - | QsConditionalStatement stm -> this.onConditionalStatement (stm) - | QsForStatement stm -> this.onForStatement (stm |> beforeForStatement) - | QsWhileStatement stm -> this.onWhileStatement (stm) - | QsRepeatStatement stm -> this.onRepeatStatement (stm) - | QsConjugation stm -> this.onConjugation (stm) - | QsQubitScope stm -> this.dispatchQubitScope (stm |> beforeQubitScope) - - -and ScopeTransformation(?enableStatementKindTransformations) = - let enableStatementKind = defaultArg enableStatementKindTransformations true - let expressionsTransformation = new ExpressionTransformation() - - abstract member Expression : ExpressionTransformation - default this.Expression = expressionsTransformation - - abstract member StatementKind : StatementKindTransformation - default this.StatementKind = { - new StatementKindTransformation (enableStatementKind) with - override x.ScopeTransformation s = this.Transform s - override x.ExpressionTransformation ex = this.Expression.Transform ex - override x.TypeTransformation t = this.Expression.Type.Transform t - override x.LocationTransformation l = this.onLocation l - } - - abstract member onLocation : QsNullable -> QsNullable - default this.onLocation loc = loc - - abstract member onLocalDeclarations : LocalDeclarations -> LocalDeclarations - default this.onLocalDeclarations decl = + + // transformation root called on each statement + + abstract member OnStatementKind : QsStatementKind -> QsStatementKind + default this.OnStatementKind kind = + if not options.Enable then kind else + let transformed = kind |> function + | QsExpressionStatement ex -> this.OnExpressionStatement ex + | QsReturnStatement ex -> this.OnReturnStatement ex + | QsFailStatement ex -> this.OnFailStatement ex + | QsVariableDeclaration stm -> this.OnVariableDeclaration stm + | QsValueUpdate stm -> this.OnValueUpdate stm + | QsConditionalStatement stm -> this.OnConditionalStatement stm + | QsForStatement stm -> this.OnForStatement stm + | QsWhileStatement stm -> this.OnWhileStatement stm + | QsRepeatStatement stm -> this.OnRepeatStatement stm + | QsConjugation stm -> this.OnConjugation stm + | QsQubitScope stm -> this.OnQubitScope stm + | EmptyStatement -> this.OnEmptyStatement () + id |> Node.BuildOr kind transformed + + +and StatementTransformationBase internal (options : TransformationOptions, _internal_) = + + let missingTransformation name _ = new InvalidOperationException(sprintf "No %s transformation has been specified." name) |> raise + let Node = if options.Rebuild then Fold else Walk + + member val internal ExpressionTransformationHandle = missingTransformation "expression" with get, set + member val internal StatementKindTransformationHandle = missingTransformation "statement kind" with get, set + + member this.Expressions = this.ExpressionTransformationHandle() + member this.StatementKinds = this.StatementKindTransformationHandle() + + new (statementKindTransformation : unit -> StatementKindTransformationBase, expressionTransformation : unit -> ExpressionTransformationBase, options : TransformationOptions) as this = + new StatementTransformationBase(options, "_internal_") then + this.ExpressionTransformationHandle <- expressionTransformation + this.StatementKindTransformationHandle <- statementKindTransformation + + new (options : TransformationOptions) as this = + new StatementTransformationBase(options, "_internal_") then + let expressionTransformation = new ExpressionTransformationBase(options) + let statementTransformation = new StatementKindTransformationBase((fun _ -> this), (fun _ -> this.Expressions), options) + this.ExpressionTransformationHandle <- fun _ -> expressionTransformation + this.StatementKindTransformationHandle <- fun _ -> statementTransformation + + new (statementKindTransformation : unit -> StatementKindTransformationBase, expressionTransformation : unit -> ExpressionTransformationBase) = + new StatementTransformationBase(statementKindTransformation, expressionTransformation, TransformationOptions.Default) + + new () = new StatementTransformationBase(TransformationOptions.Default) + + + // supplementary statement information + + abstract member OnLocation : QsNullable -> QsNullable + default this.OnLocation loc = loc + + /// If DisableRebuild is set to true, this method won't walk the local variables declared by the statement. + abstract member OnLocalDeclarations : LocalDeclarations -> LocalDeclarations + default this.OnLocalDeclarations decl = let onLocalVariableDeclaration (local : LocalVariableDeclaration>) = let loc = local.Position, local.Range - let info = this.Expression.onExpressionInformation local.InferredInformation - let varType = this.Expression.Type.Transform local.Type + let info = this.Expressions.OnExpressionInformation local.InferredInformation + let varType = this.Expressions.Types.OnType local.Type LocalVariableDeclaration.New info.IsMutable (loc, local.VariableName, varType, info.HasLocalQuantumDependency) - let variableDeclarations = decl.Variables |> Seq.map onLocalVariableDeclaration |> ImmutableArray.CreateRange - LocalDeclarations.New variableDeclarations + let variableDeclarations = decl.Variables |> Seq.map onLocalVariableDeclaration + LocalDeclarations.New << ImmutableArray.CreateRange |> Node.BuildOr decl variableDeclarations + + + // transformation roots called on each statement or statement block - abstract member onStatement : QsStatement -> QsStatement - default this.onStatement stm = - let location = this.onLocation stm.Location + abstract member OnStatement : QsStatement -> QsStatement + default this.OnStatement stm = + if not options.Enable then stm else + let location = this.OnLocation stm.Location let comments = stm.Comments - let kind = this.StatementKind.Transform stm.Statement - let varDecl = this.onLocalDeclarations stm.SymbolDeclarations - QsStatement.New comments location (kind, varDecl) - - abstract member Transform : QsScope -> QsScope - default this.Transform scope = - let parentSymbols = this.onLocalDeclarations scope.KnownSymbols - let statements = scope.Statements |> Seq.map this.onStatement - QsScope.New (statements, parentSymbols) + let kind = this.StatementKinds.OnStatementKind stm.Statement + let varDecl = this.OnLocalDeclarations stm.SymbolDeclarations + QsStatement.New comments location |> Node.BuildOr stm (kind, varDecl) + + abstract member OnScope : QsScope -> QsScope + default this.OnScope scope = + if not options.Enable then scope else + let parentSymbols = this.OnLocalDeclarations scope.KnownSymbols + let statements = scope.Statements |> Seq.map this.OnStatement |> ImmutableArray.CreateRange + QsScope.New |> Node.BuildOr scope (statements, parentSymbols) diff --git a/src/QsCompiler/Core/StatementWalker.fs b/src/QsCompiler/Core/StatementWalker.fs deleted file mode 100644 index 3b7d63c4df..0000000000 --- a/src/QsCompiler/Core/StatementWalker.fs +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace Microsoft.Quantum.QsCompiler.Transformations.Core - -open Microsoft.Quantum.QsCompiler.DataTypes -open Microsoft.Quantum.QsCompiler.SyntaxTokens -open Microsoft.Quantum.QsCompiler.SyntaxTree - - -/// Convention: -/// All methods starting with "on" implement the walk for an expression of a certain kind. -/// All methods starting with "before" group a set of statements, and are called before walking the set -/// even if the corresponding walk routine (starting with "on") is overridden. -/// -/// These classes differ from the "*Transformation" classes in that these classes visit every node in the -/// syntax tree, but don't create a new syntax tree, while the Transformation classes generate a new (or -/// at least partially new) tree from the old one. -/// Effectively, the Transformation classes implement fold, while the Walker classes implement iter. -[] -type StatementKindWalker(?enable) = - let enable = defaultArg enable true - - abstract member ScopeWalker : QsScope -> unit - abstract member ExpressionWalker : TypedExpression -> unit - abstract member TypeWalker : ResolvedType -> unit - abstract member LocationWalker : QsNullable -> unit - - abstract member onQubitInitializer : ResolvedInitializer -> unit - default this.onQubitInitializer init = - match init.Resolution with - | SingleQubitAllocation -> () - | QubitRegisterAllocation ex -> this.ExpressionWalker ex - | QubitTupleAllocation is -> is |> Seq.iter this.onQubitInitializer - | InvalidInitializer -> () - - abstract member beforeVariableDeclaration : SymbolTuple -> unit - default this.beforeVariableDeclaration syms = () - - abstract member onSymbolTuple : SymbolTuple -> unit - default this.onSymbolTuple syms = () - - - abstract member onExpressionStatement : TypedExpression -> unit - default this.onExpressionStatement ex = this.ExpressionWalker ex - - abstract member onReturnStatement : TypedExpression -> unit - default this.onReturnStatement ex = this.ExpressionWalker ex - - abstract member onFailStatement : TypedExpression -> unit - default this.onFailStatement ex = this.ExpressionWalker ex - - abstract member onVariableDeclaration : QsBinding -> unit - default this.onVariableDeclaration stm = - this.ExpressionWalker stm.Rhs - this.onSymbolTuple stm.Lhs - - abstract member onValueUpdate : QsValueUpdate -> unit - default this.onValueUpdate stm = - this.ExpressionWalker stm.Rhs - this.ExpressionWalker stm.Lhs - - abstract member onPositionedBlock : QsNullable * QsPositionedBlock -> unit - default this.onPositionedBlock (intro : QsNullable, block : QsPositionedBlock) = - this.LocationWalker block.Location - match intro with - | Value ex -> this.ExpressionWalker ex - | Null -> () - this.ScopeWalker block.Body - - abstract member onConditionalStatement : QsConditionalStatement -> unit - default this.onConditionalStatement stm = - stm.ConditionalBlocks |> Seq.iter (fun (c, b) -> this.onPositionedBlock (Value c, b)) - stm.Default |> QsNullable<_>.Iter (fun b -> this.onPositionedBlock (Null, b)) - - abstract member onForStatement : QsForStatement -> unit - default this.onForStatement stm = - this.ExpressionWalker stm.IterationValues - fst stm.LoopItem |> this.onSymbolTuple - this.TypeWalker (snd stm.LoopItem) - this.ScopeWalker stm.Body - - abstract member onWhileStatement : QsWhileStatement -> unit - default this.onWhileStatement stm = - this.ExpressionWalker stm.Condition - this.ScopeWalker stm.Body - - abstract member onRepeatStatement : QsRepeatStatement -> unit - default this.onRepeatStatement stm = - this.onPositionedBlock (Null, stm.RepeatBlock) - this.onPositionedBlock (Value stm.SuccessCondition, stm.FixupBlock) - - abstract member onConjugation : QsConjugation -> unit - default this.onConjugation stm = - this.onPositionedBlock (Null, stm.OuterTransformation) - this.onPositionedBlock (Null, stm.InnerTransformation) - - abstract member onQubitScope : QsQubitScope -> unit - default this.onQubitScope (stm : QsQubitScope) = - this.onQubitInitializer stm.Binding.Rhs - this.onSymbolTuple stm.Binding.Lhs - this.ScopeWalker stm.Body - - abstract member onAllocateQubits : QsQubitScope -> unit - default this.onAllocateQubits stm = this.onQubitScope stm - - abstract member onBorrowQubits : QsQubitScope -> unit - default this.onBorrowQubits stm = this.onQubitScope stm - - - member private this.dispatchQubitScope (stm : QsQubitScope) = - match stm.Kind with - | Allocate -> this.onAllocateQubits stm - | Borrow -> this.onBorrowQubits stm - - abstract member Walk : QsStatementKind -> unit - default this.Walk kind = - let beforeBinding (stm : QsBinding) = this.beforeVariableDeclaration stm.Lhs - let beforeForStatement (stm : QsForStatement) = this.beforeVariableDeclaration (fst stm.LoopItem) - let beforeQubitScope (stm : QsQubitScope) = this.beforeVariableDeclaration stm.Binding.Lhs - - if not enable then () else - match kind with - | QsExpressionStatement ex -> this.onExpressionStatement ex - | QsReturnStatement ex -> this.onReturnStatement ex - | QsFailStatement ex -> this.onFailStatement ex - | QsVariableDeclaration stm -> beforeBinding stm - this.onVariableDeclaration stm - | QsValueUpdate stm -> this.onValueUpdate stm - | QsConditionalStatement stm -> this.onConditionalStatement stm - | QsForStatement stm -> beforeForStatement stm - this.onForStatement stm - | QsWhileStatement stm -> this.onWhileStatement stm - | QsRepeatStatement stm -> this.onRepeatStatement stm - | QsConjugation stm -> this.onConjugation stm - | QsQubitScope stm -> beforeQubitScope stm - this.dispatchQubitScope stm - - -and ScopeWalker(?enableStatementKindWalkers) = - let enableStatementKind = defaultArg enableStatementKindWalkers true - let expressionsWalker = new ExpressionWalker() - - abstract member Expression : ExpressionWalker - default this.Expression = expressionsWalker - - abstract member StatementKind : StatementKindWalker - default this.StatementKind = { - new StatementKindWalker (enableStatementKind) with - override x.ScopeWalker s = this.Walk s - override x.ExpressionWalker ex = this.Expression.Walk ex - override x.TypeWalker t = this.Expression.Type.Walk t - override x.LocationWalker l = this.onLocation l - } - - abstract member onLocation : QsNullable -> unit - default this.onLocation loc = () - - abstract member onStatement : QsStatement -> unit - default this.onStatement stm = - this.onLocation stm.Location - this.StatementKind.Walk stm.Statement - - abstract member Walk : QsScope -> unit - default this.Walk scope = - scope.Statements |> Seq.iter this.onStatement diff --git a/src/QsCompiler/Core/SymbolTable.fs b/src/QsCompiler/Core/SymbolTable.fs index 6f29799cef..251c45ad80 100644 --- a/src/QsCompiler/Core/SymbolTable.fs +++ b/src/QsCompiler/Core/SymbolTable.fs @@ -282,7 +282,7 @@ and Namespace private // ignore ambiguous/clashing references let FilterUnique (g : IGrouping<_,_>) = - if g.Count() > 1 then None // TODO: give warning?? + if g.Count() > 1 then None else g.Single() |> Some let typesInRefs = typesInRefs.GroupBy(fun t -> t.QualifiedName.Name) |> Seq.choose FilterUnique let callablesInRefs = callablesInRefs.GroupBy(fun c -> c.QualifiedName.Name) |> Seq.choose FilterUnique @@ -1211,7 +1211,7 @@ and NamespaceManager try this.ClearResolutions() match Namespaces.TryGetValue nsName with | true, ns when ns.Sources.Contains source -> - let validAlias = String.IsNullOrWhiteSpace alias || NonNullable.New (alias.Trim()) |> Namespaces.ContainsKey |> not // TODO: DISALLOW TWO ALIAS WITH THE SAME NAME? + let validAlias = String.IsNullOrWhiteSpace alias || NonNullable.New (alias.Trim()) |> Namespaces.ContainsKey |> not if validAlias && Namespaces.ContainsKey opened then ns.TryAddOpenDirective source (opened, openedRange) (alias, aliasRange.ValueOr openedRange) elif validAlias then [| openedRange |> QsCompilerDiagnostic.Error (ErrorCode.UnknownNamespace, [opened.Value]) |] else [| aliasRange.ValueOr openedRange |> QsCompilerDiagnostic.Error (ErrorCode.InvalidNamespaceAliasName, [alias]) |] diff --git a/src/QsCompiler/Core/SyntaxGenerator.fs b/src/QsCompiler/Core/SyntaxGenerator.fs index a9a1c40fc7..d7dfe233d9 100644 --- a/src/QsCompiler/Core/SyntaxGenerator.fs +++ b/src/QsCompiler/Core/SyntaxGenerator.fs @@ -16,35 +16,38 @@ open Microsoft.Quantum.QsCompiler.Transformations.Core // transformations used to strip range information for auto-generated syntax -type private StripPositionInfoFromType () = - inherit ExpressionTypeTransformation(true) - override this.onRangeInformation _ = Null - -type private StripPositionInfoFromExpression () = - inherit ExpressionTransformation() - let typeTransformation = new StripPositionInfoFromType() :> ExpressionTypeTransformation - override this.onRangeInformation _ = Null - override this.Type = typeTransformation - -type private StripPositionInfoFromScope() = - inherit ScopeTransformation() - let expressionTransformation = new StripPositionInfoFromExpression() - override this.onLocation _ = Null - override this.Expression = expressionTransformation :> ExpressionTransformation - -type public StripPositionInfo() = +type private StripPositionInfoFromType (parent : StripPositionInfo) = + inherit TypeTransformation(parent) + override this.OnRangeInformation _ = Null + +and private StripPositionInfoFromExpression (parent : StripPositionInfo) = + inherit ExpressionTransformation(parent) + override this.OnRangeInformation _ = Null + +and private StripPositionInfoFromStatement(parent : StripPositionInfo) = + inherit StatementTransformation(parent) + override this.OnLocation _ = Null + +and private StripPositionInfoFromNamespace(parent : StripPositionInfo) = + inherit NamespaceTransformation(parent) + override this.OnLocation _ = Null + +and public StripPositionInfo private (_internal_) = inherit SyntaxTreeTransformation() - let scopeTransformation = new StripPositionInfoFromScope() static let defaultInstance = new StripPositionInfo() - - override this.Scope = scopeTransformation :> ScopeTransformation - override this.onLocation loc = Null + + new () as this = + StripPositionInfo("_internal_") then + this.Types <- new StripPositionInfoFromType(this) + this.Expressions <- new StripPositionInfoFromExpression(this) + this.Statements <- new StripPositionInfoFromStatement(this) + this.Namespaces <- new StripPositionInfoFromNamespace(this) static member public Default = defaultInstance - static member public Apply t = defaultInstance.Scope.Expression.Type.Transform t - static member public Apply e = defaultInstance.Scope.Expression.Transform e - static member public Apply s = defaultInstance.Scope.Transform s - static member public Apply a = defaultInstance.Transform a + static member public Apply t = defaultInstance.Types.OnType t + static member public Apply e = defaultInstance.Expressions.OnTypedExpression e + static member public Apply s = defaultInstance.Statements.OnScope s + static member public Apply a = defaultInstance.Namespaces.OnNamespace a module SyntaxGenerator = diff --git a/src/QsCompiler/Core/SyntaxTreeTransformation.fs b/src/QsCompiler/Core/SyntaxTreeTransformation.fs new file mode 100644 index 0000000000..c29b877e15 --- /dev/null +++ b/src/QsCompiler/Core/SyntaxTreeTransformation.fs @@ -0,0 +1,488 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Quantum.QsCompiler.Transformations.Core + + +// setup for syntax tree transformations with internal state + +type SyntaxTreeTransformation<'T> private (state : 'T, _internal_ : string) = + + let mutable _Types = new TypeTransformation<'T>(TransformationOptions.Default, _internal_) + let mutable _ExpressionKinds = new ExpressionKindTransformation<'T>(TransformationOptions.Default, _internal_) + let mutable _Expressions = new ExpressionTransformation<'T>(TransformationOptions.Default, _internal_) + let mutable _StatementKinds = new StatementKindTransformation<'T>(TransformationOptions.Default, _internal_) + let mutable _Statements = new StatementTransformation<'T>(TransformationOptions.Default, _internal_) + let mutable _Namespaces = new NamespaceTransformation<'T>(TransformationOptions.Default, _internal_) + + /// Transformation invoked for all types encountered when traversing (parts of) the syntax tree. + member this.Types + with get() = _Types + and set value = _Types <- value + + /// Transformation invoked for all expression kinds encountered when traversing (parts of) the syntax tree. + member this.ExpressionKinds + with get() = _ExpressionKinds + and set value = _ExpressionKinds <- value + + /// Transformation invoked for all expressions encountered when traversing (parts of) the syntax tree. + member this.Expressions + with get() = _Expressions + and set value = _Expressions <- value + + /// Transformation invoked for all statement kinds encountered when traversing (parts of) the syntax tree. + member this.StatementKinds + with get() = _StatementKinds + and set value = _StatementKinds <- value + + /// Transformation invoked for all statements encountered when traversing (parts of) the syntax tree. + member this.Statements + with get() = _Statements + and set value = _Statements <- value + + /// Transformation invoked for all namespaces encountered when traversing (parts of) the syntax tree. + member this.Namespaces + with get() = _Namespaces + and set value = _Namespaces <- value + + + member this.SharedState = state + + new (state : 'T, options : TransformationOptions) as this = + SyntaxTreeTransformation<'T>(state, "_internal_") then + this.Types <- new TypeTransformation<'T>(this, options) + this.ExpressionKinds <- new ExpressionKindTransformation<'T>(this, options) + this.Expressions <- new ExpressionTransformation<'T>(this, options) + this.StatementKinds <- new StatementKindTransformation<'T>(this, options) + this.Statements <- new StatementTransformation<'T>(this, options) + this.Namespaces <- new NamespaceTransformation<'T>(this, options) + + new (state : 'T) = new SyntaxTreeTransformation<'T>(state, TransformationOptions.Default) + + +and TypeTransformation<'T> internal (options, _internal_) = + inherit TypeTransformationBase(options) + let mutable _Transformation : SyntaxTreeTransformation<'T> option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = _Transformation <- Some value + + member this.SharedState = this.Transformation.SharedState + + new (parentTransformation : SyntaxTreeTransformation<'T>, options : TransformationOptions) as this = + new TypeTransformation<'T>(options, "_internal_") then + this.Transformation <- parentTransformation + + new (sharedState : 'T, options : TransformationOptions) as this = + TypeTransformation<'T>(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation<'T>(sharedState, options) + this.Transformation.Types <- this + + new (parentTransformation : SyntaxTreeTransformation<'T>) = + new TypeTransformation<'T>(parentTransformation, TransformationOptions.Default) + + new (sharedState : 'T) = + new TypeTransformation<'T>(sharedState, TransformationOptions.Default) + + +and ExpressionKindTransformation<'T> internal (options, _internal_) = + inherit ExpressionKindTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation<'T> option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.ExpressionTransformationHandle <- fun _ -> value.Expressions :> ExpressionTransformationBase + this.TypeTransformationHandle <- fun _ -> value.Types :> TypeTransformationBase + + member this.SharedState = this.Transformation.SharedState + + new (parentTransformation : SyntaxTreeTransformation<'T>, options : TransformationOptions) as this = + ExpressionKindTransformation<'T>(options, "_internal_") then + this.Transformation <- parentTransformation + + new (sharedState : 'T, options : TransformationOptions) as this = + ExpressionKindTransformation<'T>(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation<'T>(sharedState, options) + this.Transformation.Types <- new TypeTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.ExpressionKinds <- this + + new (parentTransformation : SyntaxTreeTransformation<'T>) = + new ExpressionKindTransformation<'T>(parentTransformation, TransformationOptions.Default) + + new (sharedState : 'T) = + new ExpressionKindTransformation<'T>(sharedState, TransformationOptions.Default) + + +and ExpressionTransformation<'T> internal (options, _internal_) = + inherit ExpressionTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation<'T> option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.ExpressionKindTransformationHandle <- fun _ -> value.ExpressionKinds :> ExpressionKindTransformationBase + this.TypeTransformationHandle <- fun _ -> value.Types :> TypeTransformationBase + + member this.SharedState = this.Transformation.SharedState + + new (parentTransformation : SyntaxTreeTransformation<'T>, options : TransformationOptions) as this = + ExpressionTransformation<'T>(options, "_internal_") then + this.Transformation <- parentTransformation + + new (sharedState : 'T, options : TransformationOptions) as this = + ExpressionTransformation<'T>(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation<'T>(sharedState, options) + this.Transformation.Types <- new TypeTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- this + + new (parentTransformation : SyntaxTreeTransformation<'T>) = + new ExpressionTransformation<'T>(parentTransformation, TransformationOptions.Default) + + new (sharedState : 'T) = + new ExpressionTransformation<'T>(sharedState, TransformationOptions.Default) + + +and StatementKindTransformation<'T> internal (options, _internal_) = + inherit StatementKindTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation<'T> option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.StatementTransformationHandle <- fun _ -> value.Statements :> StatementTransformationBase + this.ExpressionTransformationHandle <- fun _ -> value.Expressions :> ExpressionTransformationBase + + member this.SharedState = this.Transformation.SharedState + + new (parentTransformation : SyntaxTreeTransformation<'T>, options : TransformationOptions) as this = + StatementKindTransformation<'T>(options, "_internal_") then + this.Transformation <- parentTransformation + + new (sharedState : 'T, options : TransformationOptions) as this = + StatementKindTransformation<'T>(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation<'T>(sharedState, options) + this.Transformation.Types <- new TypeTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- new ExpressionTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.StatementKinds <- this + + new (parentTransformation : SyntaxTreeTransformation<'T>) = + new StatementKindTransformation<'T>(parentTransformation, TransformationOptions.Default) + + new (sharedState : 'T) = + new StatementKindTransformation<'T>(sharedState, TransformationOptions.Default) + + +and StatementTransformation<'T> internal (options, _internal_) = + inherit StatementTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation<'T> option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.StatementKindTransformationHandle <- fun _ -> value.StatementKinds :> StatementKindTransformationBase + this.ExpressionTransformationHandle <- fun _ -> value.Expressions :> ExpressionTransformationBase + + member this.SharedState = this.Transformation.SharedState + + new (parentTransformation : SyntaxTreeTransformation<'T>, options : TransformationOptions) as this = + StatementTransformation<'T>(options, "_internal_") then + this.Transformation <- parentTransformation + + new (sharedState : 'T, options : TransformationOptions) as this = + StatementTransformation<'T>(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation<'T>(sharedState, options) + this.Transformation.Types <- new TypeTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- new ExpressionTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Statements <- this + + new (parentTransformation : SyntaxTreeTransformation<'T>) = + new StatementTransformation<'T>(parentTransformation, TransformationOptions.Default) + + new (sharedState : 'T) = + new StatementTransformation<'T>(sharedState, TransformationOptions.Default) + + +and NamespaceTransformation<'T> internal (options, _internal_ : string) = + inherit NamespaceTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation<'T> option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.StatementTransformationHandle <- fun _ -> value.Statements :> StatementTransformationBase + + member this.SharedState = this.Transformation.SharedState + + new (parentTransformation : SyntaxTreeTransformation<'T>, options : TransformationOptions) as this = + NamespaceTransformation<'T>(options, "_internal_") then + this.Transformation <- parentTransformation + + new (sharedState : 'T, options : TransformationOptions) as this = + NamespaceTransformation<'T>(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation<'T>(sharedState, options) + this.Transformation.Types <- new TypeTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- new ExpressionTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Statements <- new StatementTransformation<'T>(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Namespaces <- this + + new (parentTransformation : SyntaxTreeTransformation<'T>) = + new NamespaceTransformation<'T>(parentTransformation, TransformationOptions.Default) + + new (sharedState : 'T) = + new NamespaceTransformation<'T>(sharedState, TransformationOptions.Default) + + +// setup for syntax tree transformations without internal state + +type SyntaxTreeTransformation private (_internal_ : string) = + + let mutable _Types = new TypeTransformation(TransformationOptions.Default, _internal_) + let mutable _ExpressionKinds = new ExpressionKindTransformation(TransformationOptions.Default, _internal_) + let mutable _Expressions = new ExpressionTransformation(TransformationOptions.Default, _internal_) + let mutable _StatementKinds = new StatementKindTransformation(TransformationOptions.Default, _internal_) + let mutable _Statements = new StatementTransformation(TransformationOptions.Default, _internal_) + let mutable _Namespaces = new NamespaceTransformation(TransformationOptions.Default, _internal_) + + member this.Types + with get() = _Types + and set value = _Types <- value + + member this.ExpressionKinds + with get() = _ExpressionKinds + and set value = _ExpressionKinds <- value + + member this.Expressions + with get() = _Expressions + and set value = _Expressions <- value + + member this.StatementKinds + with get() = _StatementKinds + and set value = _StatementKinds <- value + + member this.Statements + with get() = _Statements + and set value = _Statements <- value + + member this.Namespaces + with get() = _Namespaces + and set value = _Namespaces <- value + + + new (options : TransformationOptions) as this = + SyntaxTreeTransformation("_internal_") then + this.Types <- new TypeTransformation(this, options) + this.ExpressionKinds <- new ExpressionKindTransformation(this, options) + this.Expressions <- new ExpressionTransformation(this, options) + this.StatementKinds <- new StatementKindTransformation(this, options) + this.Statements <- new StatementTransformation(this, options) + this.Namespaces <- new NamespaceTransformation(this, options) + + new () = new SyntaxTreeTransformation(TransformationOptions.Default) + + +and TypeTransformation internal (options, _internal_) = + inherit TypeTransformationBase(options) + let mutable _Transformation : SyntaxTreeTransformation option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = _Transformation <- Some value + + new (parentTransformation : SyntaxTreeTransformation, options : TransformationOptions) as this = + new TypeTransformation(options, "_internal_") then + this.Transformation <- parentTransformation + + new (options : TransformationOptions) as this = + TypeTransformation(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation(options) + this.Transformation.Types <- this + + new (parentTransformation : SyntaxTreeTransformation) = + new TypeTransformation(parentTransformation, TransformationOptions.Default) + + new () = new TypeTransformation(TransformationOptions.Default) + + +and ExpressionKindTransformation internal (options, _internal_) = + inherit ExpressionKindTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.ExpressionTransformationHandle <- fun _ -> value.Expressions :> ExpressionTransformationBase + this.TypeTransformationHandle <- fun _ -> value.Types :> TypeTransformationBase + + new (parentTransformation : SyntaxTreeTransformation, options : TransformationOptions) as this = + ExpressionKindTransformation(options, "_internal_") then + this.Transformation <- parentTransformation + + new (options : TransformationOptions) as this = + ExpressionKindTransformation(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation(options) + this.Transformation.Types <- new TypeTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.ExpressionKinds <- this + + new (parentTransformation : SyntaxTreeTransformation) = + new ExpressionKindTransformation(parentTransformation, TransformationOptions.Default) + + new () = new ExpressionKindTransformation(TransformationOptions.Default) + + +and ExpressionTransformation internal (options, _internal_) = + inherit ExpressionTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.ExpressionKindTransformationHandle <- fun _ -> value.ExpressionKinds :> ExpressionKindTransformationBase + this.TypeTransformationHandle <- fun _ -> value.Types :> TypeTransformationBase + + new (parentTransformation : SyntaxTreeTransformation, options : TransformationOptions) as this = + ExpressionTransformation(options, "_internal_") then + this.Transformation <- parentTransformation + + new (options : TransformationOptions) as this = + ExpressionTransformation(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation(options) + this.Transformation.Types <- new TypeTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- this + + new (parentTransformation : SyntaxTreeTransformation) = + new ExpressionTransformation(parentTransformation, TransformationOptions.Default) + + new () = new ExpressionTransformation(TransformationOptions.Default) + + +and StatementKindTransformation internal (options, _internal_) = + inherit StatementKindTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.StatementTransformationHandle <- fun _ -> value.Statements :> StatementTransformationBase + this.ExpressionTransformationHandle <- fun _ -> value.Expressions :> ExpressionTransformationBase + + new (parentTransformation : SyntaxTreeTransformation, options : TransformationOptions) as this = + StatementKindTransformation(options, "_internal_") then + this.Transformation <- parentTransformation + + new (options : TransformationOptions) as this = + StatementKindTransformation(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation(options) + this.Transformation.Types <- new TypeTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- new ExpressionTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.StatementKinds <- this + + new (parentTransformation : SyntaxTreeTransformation) = + new StatementKindTransformation(parentTransformation, TransformationOptions.Default) + + new () = new StatementKindTransformation(TransformationOptions.Default) + + +and StatementTransformation internal (options, _internal_) = + inherit StatementTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.StatementKindTransformationHandle <- fun _ -> value.StatementKinds :> StatementKindTransformationBase + this.ExpressionTransformationHandle <- fun _ -> value.Expressions :> ExpressionTransformationBase + + new (parentTransformation : SyntaxTreeTransformation, options : TransformationOptions) as this = + StatementTransformation(options, "_internal_") then + this.Transformation <- parentTransformation + + new (options : TransformationOptions) as this = + StatementTransformation(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation(options) + this.Transformation.Types <- new TypeTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- new ExpressionTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Statements <- this + + new (parentTransformation : SyntaxTreeTransformation) = + new StatementTransformation(parentTransformation, TransformationOptions.Default) + + new () = new StatementTransformation(TransformationOptions.Default) + + +and NamespaceTransformation internal (options, _internal_ : string) = + inherit NamespaceTransformationBase(options, _internal_) + let mutable _Transformation : SyntaxTreeTransformation option = None // will be set to a suitable Some value once construction is complete + + /// Handle to the parent SyntaxTreeTransformation. + /// This handle is always safe to access and will be set to a suitable value + /// even if no parent transformation has been specified upon construction. + member this.Transformation + with get () = _Transformation.Value + and private set value = + _Transformation <- Some value + this.StatementTransformationHandle <- fun _ -> value.Statements :> StatementTransformationBase + + new (parentTransformation : SyntaxTreeTransformation, options : TransformationOptions) as this = + NamespaceTransformation(options, "_internal_") then + this.Transformation <- parentTransformation + + new (options : TransformationOptions) as this = + NamespaceTransformation(options, "_internal_") then + this.Transformation <- new SyntaxTreeTransformation(options) + this.Transformation.Types <- new TypeTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Expressions <- new ExpressionTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Statements <- new StatementTransformation(this.Transformation, TransformationOptions.Disabled) + this.Transformation.Namespaces <- this + + new (parentTransformation : SyntaxTreeTransformation) = + new NamespaceTransformation(parentTransformation, TransformationOptions.Default) + + new () = new NamespaceTransformation(TransformationOptions.Default) + + + + diff --git a/src/QsCompiler/Core/TransformationOptions.fs b/src/QsCompiler/Core/TransformationOptions.fs new file mode 100644 index 0000000000..1fff425bbd --- /dev/null +++ b/src/QsCompiler/Core/TransformationOptions.fs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Quantum.QsCompiler.Transformations.Core + + +/// Used to configure the behavior of the default implementations for transformations. +type TransformationOptions = internal { + /// Disables the transformation at the transformation root, + /// meaning the transformation won't recur into leaf nodes or subnodes. + Enable : bool + /// Indicates that the transformation is used to walk the syntax tree, but does not modify any of the nodes. + /// If set to true, the nodes will not be rebuilt during the transformation. + /// Setting this to true constitutes a promise that the return value of all methods will be ignored. + Rebuild : bool +} + with + + /// Default transformation setting. + /// The transformation will recur into leaf and subnodes, + /// and all nodes will be rebuilt upon transformation. + static member Default = { + Enable = true + Rebuild = true + } + + /// Disables the transformation at the transformation root, + /// meaning the transformation won't recur into leaf nodes or subnodes. + static member Disabled = { + Enable = false + Rebuild = true + } + + /// Indicates that the transformation is used to walk the syntax tree, but does not modify any of the nodes. + /// All nodes will be traversed recursively, but the nodes will not be rebuilt. + /// Setting this option constitutes a promise that the return value of all methods will be ignored. + static member NoRebuild = { + Enable = true + Rebuild = false + } + + +/// Tools for adapting the default implementations for transformations +/// based on the specified options. +module internal Utils = + type internal INode = + abstract member BuildOr<'a, 'b> : 'b -> 'a -> ('a -> 'b) -> 'b + + let Fold = { new INode with member __.BuildOr _ arg builder = builder arg} + let Walk = { new INode with member __.BuildOr original _ _ = original} + diff --git a/src/QsCompiler/Core/TreeTransformation.fs b/src/QsCompiler/Core/TreeTransformation.fs deleted file mode 100644 index 351caa9c18..0000000000 --- a/src/QsCompiler/Core/TreeTransformation.fs +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace Microsoft.Quantum.QsCompiler.Transformations.Core - -open System.Collections.Immutable -open System.Linq -open Microsoft.Quantum.QsCompiler.DataTypes -open Microsoft.Quantum.QsCompiler.SyntaxExtensions -open Microsoft.Quantum.QsCompiler.SyntaxTokens -open Microsoft.Quantum.QsCompiler.SyntaxTree - -type QsArgumentTuple = QsTuple> - - -/// Convention: -/// All methods starting with "on" implement the transformation syntax tree element. -/// All methods starting with "before" group a set of elements, and are called before applying the transformation -/// even if the corresponding transformation routine (starting with "on") is overridden. -type SyntaxTreeTransformation() = - let scopeTransformation = new ScopeTransformation() - - abstract member Scope : ScopeTransformation - default this.Scope = scopeTransformation - - - abstract member beforeNamespaceElement : QsNamespaceElement -> QsNamespaceElement - default this.beforeNamespaceElement e = e - - abstract member beforeCallable : QsCallable -> QsCallable - default this.beforeCallable c = c - - abstract member beforeSpecialization : QsSpecialization -> QsSpecialization - default this.beforeSpecialization spec = spec - - abstract member beforeSpecializationImplementation : SpecializationImplementation -> SpecializationImplementation - default this.beforeSpecializationImplementation impl = impl - - abstract member beforeGeneratedImplementation : QsGeneratorDirective -> QsGeneratorDirective - default this.beforeGeneratedImplementation dir = dir - - - abstract member onLocation : QsNullable -> QsNullable - default this.onLocation l = l - - abstract member onDocumentation : ImmutableArray -> ImmutableArray - default this.onDocumentation doc = doc - - abstract member onSourceFile : NonNullable -> NonNullable - default this.onSourceFile f = f - - abstract member onTypeItems : QsTuple -> QsTuple - default this.onTypeItems tItem = - match tItem with - | QsTuple items -> (items |> Seq.map this.onTypeItems).ToImmutableArray() |> QsTuple - | QsTupleItem (Anonymous itemType) -> - let t = this.Scope.Expression.Type.Transform itemType - Anonymous t |> QsTupleItem - | QsTupleItem (Named item) -> - let loc = item.Position, item.Range - let t = this.Scope.Expression.Type.Transform item.Type - let info = this.Scope.Expression.onExpressionInformation item.InferredInformation - LocalVariableDeclaration<_>.New info.IsMutable (loc, item.VariableName, t, info.HasLocalQuantumDependency) |> Named |> QsTupleItem - - abstract member onArgumentTuple : QsArgumentTuple -> QsArgumentTuple - default this.onArgumentTuple arg = - match arg with - | QsTuple items -> (items |> Seq.map this.onArgumentTuple).ToImmutableArray() |> QsTuple - | QsTupleItem item -> - let loc = item.Position, item.Range - let t = this.Scope.Expression.Type.Transform item.Type - let info = this.Scope.Expression.onExpressionInformation item.InferredInformation - LocalVariableDeclaration<_>.New info.IsMutable (loc, item.VariableName, t, info.HasLocalQuantumDependency) |> QsTupleItem - - abstract member onSignature : ResolvedSignature -> ResolvedSignature - default this.onSignature (s : ResolvedSignature) = - let typeParams = s.TypeParameters - let argType = this.Scope.Expression.Type.Transform s.ArgumentType - let returnType = this.Scope.Expression.Type.Transform s.ReturnType - let info = this.Scope.Expression.Type.onCallableInformation s.Information - ResolvedSignature.New ((argType, returnType), info, typeParams) - - - abstract member onExternalImplementation : unit -> unit - default this.onExternalImplementation () = () - - abstract member onIntrinsicImplementation : unit -> unit - default this.onIntrinsicImplementation () = () - - abstract member onProvidedImplementation : QsArgumentTuple * QsScope -> QsArgumentTuple * QsScope - default this.onProvidedImplementation (argTuple, body) = - let argTuple = this.onArgumentTuple argTuple - let body = this.Scope.Transform body - argTuple, body - - abstract member onSelfInverseDirective : unit -> unit - default this.onSelfInverseDirective () = () - - abstract member onInvertDirective : unit -> unit - default this.onInvertDirective () = () - - abstract member onDistributeDirective : unit -> unit - default this.onDistributeDirective () = () - - abstract member onInvalidGeneratorDirective : unit -> unit - default this.onInvalidGeneratorDirective () = () - - member this.dispatchGeneratedImplementation (dir : QsGeneratorDirective) = - match this.beforeGeneratedImplementation dir with - | SelfInverse -> this.onSelfInverseDirective (); SelfInverse - | Invert -> this.onInvertDirective(); Invert - | Distribute -> this.onDistributeDirective(); Distribute - | InvalidGenerator -> this.onInvalidGeneratorDirective(); InvalidGenerator - - member this.dispatchSpecializationImplementation (impl : SpecializationImplementation) = - match this.beforeSpecializationImplementation impl with - | External -> this.onExternalImplementation(); External - | Intrinsic -> this.onIntrinsicImplementation(); Intrinsic - | Generated dir -> this.dispatchGeneratedImplementation dir |> Generated - | Provided (argTuple, body) -> this.onProvidedImplementation (argTuple, body) |> Provided - - - abstract member onSpecializationImplementation : QsSpecialization -> QsSpecialization - default this.onSpecializationImplementation (spec : QsSpecialization) = - let source = this.onSourceFile spec.SourceFile - let loc = this.onLocation spec.Location - let attributes = spec.Attributes |> Seq.map this.onAttribute |> ImmutableArray.CreateRange - let typeArgs = spec.TypeArguments |> QsNullable<_>.Map (fun args -> (args |> Seq.map this.Scope.Expression.Type.Transform).ToImmutableArray()) - let signature = this.onSignature spec.Signature - let impl = this.dispatchSpecializationImplementation spec.Implementation - let doc = this.onDocumentation spec.Documentation - let comments = spec.Comments - QsSpecialization.New spec.Kind (source, loc) (spec.Parent, attributes, typeArgs, signature, impl, doc, comments) - - abstract member onBodySpecialization : QsSpecialization -> QsSpecialization - default this.onBodySpecialization spec = this.onSpecializationImplementation spec - - abstract member onAdjointSpecialization : QsSpecialization -> QsSpecialization - default this.onAdjointSpecialization spec = this.onSpecializationImplementation spec - - abstract member onControlledSpecialization : QsSpecialization -> QsSpecialization - default this.onControlledSpecialization spec = this.onSpecializationImplementation spec - - abstract member onControlledAdjointSpecialization : QsSpecialization -> QsSpecialization - default this.onControlledAdjointSpecialization spec = this.onSpecializationImplementation spec - - member this.dispatchSpecialization (spec : QsSpecialization) = - let spec = this.beforeSpecialization spec - match spec.Kind with - | QsSpecializationKind.QsBody -> this.onBodySpecialization spec - | QsSpecializationKind.QsAdjoint -> this.onAdjointSpecialization spec - | QsSpecializationKind.QsControlled -> this.onControlledSpecialization spec - | QsSpecializationKind.QsControlledAdjoint -> this.onControlledAdjointSpecialization spec - - - abstract member onType : QsCustomType -> QsCustomType - default this.onType t = - let source = this.onSourceFile t.SourceFile - let loc = this.onLocation t.Location - let attributes = t.Attributes |> Seq.map this.onAttribute |> ImmutableArray.CreateRange - let underlyingType = this.Scope.Expression.Type.Transform t.Type - let typeItems = this.onTypeItems t.TypeItems - let doc = this.onDocumentation t.Documentation - let comments = t.Comments - QsCustomType.New (source, loc) (t.FullName, attributes, typeItems, underlyingType, doc, comments) - - abstract member onCallableImplementation : QsCallable -> QsCallable - default this.onCallableImplementation (c : QsCallable) = - let source = this.onSourceFile c.SourceFile - let loc = this.onLocation c.Location - let attributes = c.Attributes |> Seq.map this.onAttribute |> ImmutableArray.CreateRange - let signature = this.onSignature c.Signature - let argTuple = this.onArgumentTuple c.ArgumentTuple - let specializations = c.Specializations |> Seq.map this.dispatchSpecialization - let doc = this.onDocumentation c.Documentation - let comments = c.Comments - QsCallable.New c.Kind (source, loc) (c.FullName, attributes, argTuple, signature, specializations, doc, comments) - - abstract member onOperation : QsCallable -> QsCallable - default this.onOperation c = this.onCallableImplementation c - - abstract member onFunction : QsCallable -> QsCallable - default this.onFunction c = this.onCallableImplementation c - - abstract member onTypeConstructor : QsCallable -> QsCallable - default this.onTypeConstructor c = this.onCallableImplementation c - - member this.dispatchCallable (c : QsCallable) = - let c = this.beforeCallable c - match c.Kind with - | QsCallableKind.Function -> this.onFunction c - | QsCallableKind.Operation -> this.onOperation c - | QsCallableKind.TypeConstructor -> this.onTypeConstructor c - - - abstract member onAttribute : QsDeclarationAttribute -> QsDeclarationAttribute - default this.onAttribute att = att - - member this.dispatchNamespaceElement element = - match this.beforeNamespaceElement element with - | QsCustomType t -> t |> this.onType |> QsCustomType - | QsCallable c -> c |> this.dispatchCallable |> QsCallable - - abstract member Transform : QsNamespace -> QsNamespace - default this.Transform ns = - let name = ns.Name - let doc = ns.Documentation.AsEnumerable().SelectMany(fun entry -> - entry |> Seq.map (fun doc -> entry.Key, this.onDocumentation doc)).ToLookup(fst, snd) - let elements = ns.Elements |> Seq.map this.dispatchNamespaceElement - QsNamespace.New (name, elements, doc) - diff --git a/src/QsCompiler/Core/TreeWalker.fs b/src/QsCompiler/Core/TreeWalker.fs deleted file mode 100644 index 0626d25192..0000000000 --- a/src/QsCompiler/Core/TreeWalker.fs +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace Microsoft.Quantum.QsCompiler.Transformations.Core - -open System.Collections.Immutable -open System.Linq -open Microsoft.Quantum.QsCompiler.DataTypes -open Microsoft.Quantum.QsCompiler.SyntaxTokens -open Microsoft.Quantum.QsCompiler.SyntaxTree - - - -/// Convention: -/// All methods starting with "on" implement the walk for an expression of a certain kind. -/// All methods starting with "before" group a set of statements, and are called before walking the set -/// even if the corresponding walk routine (starting with "on") is overridden. -/// -/// These classes differ from the "*Transformation" classes in that these classes visit every node in the -/// syntax tree, but don't create a new syntax tree, while the Transformation classes generate a new (or -/// at least partially new) tree from the old one. -/// Effectively, the Transformation classes implement fold, while the Walker classes implement iter. -type SyntaxTreeWalker() = - let scopeWalker = new ScopeWalker() - - abstract member Scope : ScopeWalker - default this.Scope = scopeWalker - - - abstract member beforeNamespaceElement : QsNamespaceElement -> unit - default this.beforeNamespaceElement e = () - - abstract member beforeCallable : QsCallable -> unit - default this.beforeCallable c = () - - abstract member beforeSpecialization : QsSpecialization -> unit - default this.beforeSpecialization spec = () - - abstract member beforeSpecializationImplementation : SpecializationImplementation -> unit - default this.beforeSpecializationImplementation impl = () - - abstract member beforeGeneratedImplementation : QsGeneratorDirective -> unit - default this.beforeGeneratedImplementation dir = () - - - abstract member onLocation : QsNullable -> unit - default this.onLocation l = () - - abstract member onDocumentation : ImmutableArray -> unit - default this.onDocumentation doc = () - - abstract member onSourceFile : NonNullable -> unit - default this.onSourceFile f = () - - abstract member onTypeItems : QsTuple -> unit - default this.onTypeItems tItem = - match tItem with - | QsTuple items -> items |> Seq.iter this.onTypeItems - | QsTupleItem (Anonymous itemType) -> this.Scope.Expression.Type.Walk itemType - | QsTupleItem (Named item) -> this.Scope.Expression.Type.Walk item.Type - - abstract member onArgumentTuple : QsArgumentTuple -> unit - default this.onArgumentTuple arg = - match arg with - | QsTuple items -> items |> Seq.iter this.onArgumentTuple - | QsTupleItem item -> this.Scope.Expression.Type.Walk item.Type - - abstract member onSignature : ResolvedSignature -> unit - default this.onSignature (s : ResolvedSignature) = - this.Scope.Expression.Type.Walk s.ArgumentType - this.Scope.Expression.Type.Walk s.ReturnType - this.Scope.Expression.Type.onCallableInformation s.Information - - - abstract member onExternalImplementation : unit -> unit - default this.onExternalImplementation () = () - - abstract member onIntrinsicImplementation : unit -> unit - default this.onIntrinsicImplementation () = () - - abstract member onProvidedImplementation : QsArgumentTuple * QsScope -> unit - default this.onProvidedImplementation (argTuple, body) = - this.onArgumentTuple argTuple - this.Scope.Walk body - - abstract member onSelfInverseDirective : unit -> unit - default this.onSelfInverseDirective () = () - - abstract member onInvertDirective : unit -> unit - default this.onInvertDirective () = () - - abstract member onDistributeDirective : unit -> unit - default this.onDistributeDirective () = () - - abstract member onInvalidGeneratorDirective : unit -> unit - default this.onInvalidGeneratorDirective () = () - - member this.dispatchGeneratedImplementation (dir : QsGeneratorDirective) = - this.beforeGeneratedImplementation dir - match dir with - | SelfInverse -> this.onSelfInverseDirective () - | Invert -> this.onInvertDirective() - | Distribute -> this.onDistributeDirective() - | InvalidGenerator -> this.onInvalidGeneratorDirective() - - member this.dispatchSpecializationImplementation (impl : SpecializationImplementation) = - this.beforeSpecializationImplementation impl - match impl with - | External -> this.onExternalImplementation() - | Intrinsic -> this.onIntrinsicImplementation() - | Generated dir -> this.dispatchGeneratedImplementation dir - | Provided (argTuple, body) -> this.onProvidedImplementation (argTuple, body) - - - abstract member onSpecializationImplementation : QsSpecialization -> unit - default this.onSpecializationImplementation (spec : QsSpecialization) = - this.onSourceFile spec.SourceFile - this.onLocation spec.Location - spec.Attributes |> Seq.iter this.onAttribute - spec.TypeArguments |> QsNullable<_>.Iter (fun args -> (args |> Seq.iter this.Scope.Expression.Type.Walk)) - this.onSignature spec.Signature - this.dispatchSpecializationImplementation spec.Implementation - this.onDocumentation spec.Documentation - - abstract member onBodySpecialization : QsSpecialization -> unit - default this.onBodySpecialization spec = this.onSpecializationImplementation spec - - abstract member onAdjointSpecialization : QsSpecialization -> unit - default this.onAdjointSpecialization spec = this.onSpecializationImplementation spec - - abstract member onControlledSpecialization : QsSpecialization -> unit - default this.onControlledSpecialization spec = this.onSpecializationImplementation spec - - abstract member onControlledAdjointSpecialization : QsSpecialization -> unit - default this.onControlledAdjointSpecialization spec = this.onSpecializationImplementation spec - - member this.dispatchSpecialization (spec : QsSpecialization) = - this.beforeSpecialization spec - match spec.Kind with - | QsSpecializationKind.QsBody -> this.onBodySpecialization spec - | QsSpecializationKind.QsAdjoint -> this.onAdjointSpecialization spec - | QsSpecializationKind.QsControlled -> this.onControlledSpecialization spec - | QsSpecializationKind.QsControlledAdjoint -> this.onControlledAdjointSpecialization spec - - - abstract member onType : QsCustomType -> unit - default this.onType t = - this.onSourceFile t.SourceFile - this.onLocation t.Location - t.Attributes |> Seq.iter this.onAttribute - this.Scope.Expression.Type.Walk t.Type - this.onTypeItems t.TypeItems - this.onDocumentation t.Documentation - - abstract member onCallableImplementation : QsCallable -> unit - default this.onCallableImplementation (c : QsCallable) = - this.onSourceFile c.SourceFile - this.onLocation c.Location - c.Attributes |> Seq.iter this.onAttribute - this.onSignature c.Signature - this.onArgumentTuple c.ArgumentTuple - c.Specializations |> Seq.iter this.dispatchSpecialization - this.onDocumentation c.Documentation - - abstract member onOperation : QsCallable -> unit - default this.onOperation c = this.onCallableImplementation c - - abstract member onFunction : QsCallable -> unit - default this.onFunction c = this.onCallableImplementation c - - abstract member onTypeConstructor : QsCallable -> unit - default this.onTypeConstructor c = this.onCallableImplementation c - - member this.dispatchCallable (c : QsCallable) = - this.beforeCallable c - match c.Kind with - | QsCallableKind.Function -> this.onFunction c - | QsCallableKind.Operation -> this.onOperation c - | QsCallableKind.TypeConstructor -> this.onTypeConstructor c - - - abstract member onAttribute : QsDeclarationAttribute -> unit - default this.onAttribute att = () - - member this.dispatchNamespaceElement element = - this.beforeNamespaceElement element - match element with - | QsCustomType t -> t |> this.onType - | QsCallable c -> c |> this.dispatchCallable - - abstract member Walk : QsNamespace -> unit - default this.Walk ns = - ns.Documentation.AsEnumerable() |> Seq.iter (fun grouping -> grouping |> Seq.iter this.onDocumentation) - ns.Elements |> Seq.iter this.dispatchNamespaceElement diff --git a/src/QsCompiler/Core/TypeTransformation.fs b/src/QsCompiler/Core/TypeTransformation.fs new file mode 100644 index 0000000000..8737747695 --- /dev/null +++ b/src/QsCompiler/Core/TypeTransformation.fs @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Quantum.QsCompiler.Transformations.Core + +open System.Collections.Immutable +open Microsoft.Quantum.QsCompiler.DataTypes +open Microsoft.Quantum.QsCompiler.SyntaxExtensions +open Microsoft.Quantum.QsCompiler.SyntaxTokens +open Microsoft.Quantum.QsCompiler.SyntaxTree +open Microsoft.Quantum.QsCompiler.Transformations.Core.Utils + +type private ExpressionType = + QsTypeKind + + +type TypeTransformationBase(options : TransformationOptions) = + + let Node = if options.Rebuild then Fold else Walk + new () = new TypeTransformationBase(TransformationOptions.Default) + + + // supplementary type information + + abstract member OnRangeInformation : QsRangeInfo -> QsRangeInfo + default this.OnRangeInformation r = r + + abstract member OnCharacteristicsExpression : ResolvedCharacteristics -> ResolvedCharacteristics + default this.OnCharacteristicsExpression fs = fs + + abstract member OnCallableInformation : CallableInformation -> CallableInformation + default this.OnCallableInformation opInfo = + let characteristics = this.OnCharacteristicsExpression opInfo.Characteristics + let inferred = opInfo.InferredInformation + CallableInformation.New |> Node.BuildOr opInfo (characteristics, inferred) + + + // nodes containing subtypes + + abstract member OnUserDefinedType : UserDefinedType -> ExpressionType + default this.OnUserDefinedType udt = + let ns, name = udt.Namespace, udt.Name + let range = this.OnRangeInformation udt.Range + ExpressionType.UserDefinedType << UserDefinedType.New |> Node.BuildOr InvalidType (ns, name, range) + + abstract member OnTypeParameter : QsTypeParameter -> ExpressionType + default this.OnTypeParameter tp = + let origin = tp.Origin + let name = tp.TypeName + let range = this.OnRangeInformation tp.Range + ExpressionType.TypeParameter << QsTypeParameter.New |> Node.BuildOr InvalidType (origin, name, range) + + abstract member OnOperation : (ResolvedType * ResolvedType) * CallableInformation -> ExpressionType + default this.OnOperation ((it, ot), info) = + let transformed = (this.OnType it, this.OnType ot), this.OnCallableInformation info + ExpressionType.Operation |> Node.BuildOr InvalidType transformed + + abstract member OnFunction : ResolvedType * ResolvedType -> ExpressionType + default this.OnFunction (it, ot) = + let transformed = this.OnType it, this.OnType ot + ExpressionType.Function |> Node.BuildOr InvalidType transformed + + abstract member OnTupleType : ImmutableArray -> ExpressionType + default this.OnTupleType ts = + let transformed = ts |> Seq.map this.OnType |> ImmutableArray.CreateRange + ExpressionType.TupleType |> Node.BuildOr InvalidType transformed + + abstract member OnArrayType : ResolvedType -> ExpressionType + default this.OnArrayType b = + ExpressionType.ArrayType |> Node.BuildOr InvalidType (this.OnType b) + + + // leaf nodes + + abstract member OnUnitType : unit -> ExpressionType + default this.OnUnitType () = ExpressionType.UnitType + + abstract member OnQubit : unit -> ExpressionType + default this.OnQubit () = ExpressionType.Qubit + + abstract member OnMissingType : unit -> ExpressionType + default this.OnMissingType () = ExpressionType.MissingType + + abstract member OnInvalidType : unit -> ExpressionType + default this.OnInvalidType () = ExpressionType.InvalidType + + abstract member OnInt : unit -> ExpressionType + default this.OnInt () = ExpressionType.Int + + abstract member OnBigInt : unit -> ExpressionType + default this.OnBigInt () = ExpressionType.BigInt + + abstract member OnDouble : unit -> ExpressionType + default this.OnDouble () = ExpressionType.Double + + abstract member OnBool : unit -> ExpressionType + default this.OnBool () = ExpressionType.Bool + + abstract member OnString : unit -> ExpressionType + default this.OnString () = ExpressionType.String + + abstract member OnResult : unit -> ExpressionType + default this.OnResult () = ExpressionType.Result + + abstract member OnPauli : unit -> ExpressionType + default this.OnPauli () = ExpressionType.Pauli + + abstract member OnRange : unit -> ExpressionType + default this.OnRange () = ExpressionType.Range + + + // transformation root called on each node + + member this.OnType (t : ResolvedType) = + if not options.Enable then t else + let transformed = t.Resolution |> function + | ExpressionType.UnitType -> this.OnUnitType () + | ExpressionType.Operation ((it, ot), fs) -> this.OnOperation ((it, ot), fs) + | ExpressionType.Function (it, ot) -> this.OnFunction (it, ot) + | ExpressionType.TupleType ts -> this.OnTupleType ts + | ExpressionType.ArrayType b -> this.OnArrayType b + | ExpressionType.UserDefinedType udt -> this.OnUserDefinedType udt + | ExpressionType.TypeParameter tp -> this.OnTypeParameter tp + | ExpressionType.Qubit -> this.OnQubit () + | ExpressionType.MissingType -> this.OnMissingType () + | ExpressionType.InvalidType -> this.OnInvalidType () + | ExpressionType.Int -> this.OnInt () + | ExpressionType.BigInt -> this.OnBigInt () + | ExpressionType.Double -> this.OnDouble () + | ExpressionType.Bool -> this.OnBool () + | ExpressionType.String -> this.OnString () + | ExpressionType.Result -> this.OnResult () + | ExpressionType.Pauli -> this.OnPauli () + | ExpressionType.Range -> this.OnRange () + ResolvedType.New |> Node.BuildOr t transformed diff --git a/src/QsCompiler/DataStructures/SyntaxExtensions.fs b/src/QsCompiler/DataStructures/SyntaxExtensions.fs index 4c71dc94c7..0e67788f4d 100644 --- a/src/QsCompiler/DataStructures/SyntaxExtensions.fs +++ b/src/QsCompiler/DataStructures/SyntaxExtensions.fs @@ -296,7 +296,8 @@ type QsStatement with | QsReturnStatement _ | QsFailStatement _ | QsVariableDeclaration _ - | QsValueUpdate _ -> Seq.empty + | QsValueUpdate _ + | EmptyStatement -> Seq.empty | QsConditionalStatement s -> (Seq.append (s.ConditionalBlocks |> Seq.collect (fun (_, b) -> b.Body.Statements)) diff --git a/src/QsCompiler/DataStructures/SyntaxTree.fs b/src/QsCompiler/DataStructures/SyntaxTree.fs index 47fa010d67..aa2c70be09 100644 --- a/src/QsCompiler/DataStructures/SyntaxTree.fs +++ b/src/QsCompiler/DataStructures/SyntaxTree.fs @@ -564,6 +564,7 @@ and QsStatementKind = | QsRepeatStatement of QsRepeatStatement | QsConjugation of QsConjugation | QsQubitScope of QsQubitScope // includes both using and borrowing scopes +| EmptyStatement and QsStatement = { diff --git a/src/QsCompiler/DocumentationParser/DocItem.cs b/src/QsCompiler/DocumentationParser/DocItem.cs index 5afd8aa0a3..b07de0d4be 100644 --- a/src/QsCompiler/DocumentationParser/DocItem.cs +++ b/src/QsCompiler/DocumentationParser/DocItem.cs @@ -26,7 +26,7 @@ internal abstract class DocItem protected readonly string replacement; /// - /// The item's kind, as a string (Utilities.OperationKind, .FunctionKind, or .UdtKind) + /// The item's kind, as a string (Utils.OperationKind, .FunctionKind, or .UdtKind) /// internal string ItemType => this.itemType; /// diff --git a/src/QsCompiler/DocumentationParser/Utils.cs b/src/QsCompiler/DocumentationParser/Utils.cs index 9c2c3896ec..1b79293474 100644 --- a/src/QsCompiler/DocumentationParser/Utils.cs +++ b/src/QsCompiler/DocumentationParser/Utils.cs @@ -123,12 +123,8 @@ internal static YamlNode BuildSequenceMappingNode(Dictionary pai /// /// The resolved type /// A string containing the source representation of the type - internal static string ResolvedTypeToString(ResolvedType t) - { - var exprTransformer = new ExpressionToQs(); - var transformer = new ExpressionTypeToQs(exprTransformer); - return transformer.Apply(t); - } + internal static string ResolvedTypeToString(ResolvedType t) => + SyntaxTreeToQsharp.Default.ToCode(t); /// /// Populates a YAML mapping node with information describing a Q# resolved type. diff --git a/src/QsCompiler/Optimizations/OptimizingTransformations/CallableInlining.fs b/src/QsCompiler/Optimizations/OptimizingTransformations/CallableInlining.fs index 600e0348f7..0b851d4630 100644 --- a/src/QsCompiler/Optimizations/OptimizingTransformations/CallableInlining.fs +++ b/src/QsCompiler/Optimizations/OptimizingTransformations/CallableInlining.fs @@ -10,7 +10,7 @@ open Microsoft.Quantum.QsCompiler.Experimental.Utils open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.Core +open Microsoft.Quantum.QsCompiler.Transformations /// Represents all the functors applied to an operation call @@ -34,7 +34,7 @@ type private InliningInfo = { functors: Functors callable: QsCallable arg: TypedExpression - specArgs: QsArgumentTuple + specArgs: QsTuple> body: QsScope returnType: ResolvedType } with @@ -84,19 +84,45 @@ type private InliningInfo = { maybe { let! functors, callable, arg = InliningInfo.TrySplitCall callables expr.Expression let! specArgs, body = InliningInfo.TryGetProvidedImpl callable functors - let body = ReplaceTypeParams(expr.TypeParameterResolutions).Transform body - let returnType = ReplaceTypeParams(expr.TypeParameterResolutions).Expression.Type.Transform callable.Signature.ReturnType + let body = ReplaceTypeParams(expr.TypeParameterResolutions).Statements.OnScope body + let returnType = ReplaceTypeParams(expr.TypeParameterResolutions).Types.OnType callable.Signature.ReturnType return { functors = functors; callable = callable; arg = arg; specArgs = specArgs; body = body; returnType = returnType } } /// The SyntaxTreeTransformation used to inline callables -type CallableInlining(callables) = - inherit OptimizingTransformation() +type CallableInlining private (_private_ : string) = + inherit TransformationBase() // The current callable we're in the process of transforming - let mutable currentCallable: QsCallable option = None - let mutable renamer: VariableRenaming option = None + member val CurrentCallable: QsCallable option = None with get, set + member val Renamer: VariableRenaming option = None with get, set + + new (callables) as this = + new CallableInlining("_private_") then + this.Namespaces <- new CallableInliningNamespaces(this) + this.Statements <- new CallableInliningStatements(this, callables) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for CallableInlining +and private CallableInliningNamespaces (parent : CallableInlining) = + inherit NamespaceTransformationBase(parent) + + override __.OnNamespace x = + let x = base.OnNamespace x + VariableRenaming().Namespaces.OnNamespace x + + override __.OnCallableDeclaration c = + let renamerVal = VariableRenaming() + let c = renamerVal.Namespaces.OnCallableDeclaration c + parent.CurrentCallable <- Some c + parent.Renamer <- Some renamerVal + base.OnCallableDeclaration c + +/// private helper class for CallableInlining +and private CallableInliningStatements (parent : CallableInlining, callables : ImmutableDictionary<_,_>) = + inherit StatementCollectorTransformation(parent) /// Recursively finds all the callables that could be inlined into the given scope. /// Includes callables that are invoked within the implementation of another call. @@ -117,7 +143,7 @@ type CallableInlining(callables) = /// Returns whether the given callable could eventually inline the given callable. /// Is used to prevent inlining recursive functions into themselves forever. - let cannotReachCallable (callables: ImmutableDictionary) (scope: QsScope) (cannotReach: QsQualifiedName) = + let cannotReachCallable (scope: QsScope) (cannotReach: QsQualifiedName) = let mySet = HashSet() findAllCalls callables scope mySet not (mySet.Contains cannotReach) @@ -126,18 +152,18 @@ type CallableInlining(callables) = maybe { let! ii = InliningInfo.TryGetInfo callables expr - let! currentCallable = currentCallable - let! renamer = renamer - renamer.clearStack() - do! check (cannotReachCallable callables ii.body currentCallable.FullName) + let! currentCallable = parent.CurrentCallable + let! renamer = parent.Renamer + renamer.RenamingStack <- [Map.empty] + do! check (cannotReachCallable ii.body currentCallable.FullName) do! check (ii.functors.controlled < 2) // TODO - support multiple Controlled functors - do! check (cannotReachCallable callables ii.body ii.callable.FullName || isLiteral callables ii.arg) + do! check (cannotReachCallable ii.body ii.callable.FullName || isLiteral callables ii.arg) let newBinding = QsBinding.New ImmutableBinding (toSymbolTuple ii.specArgs, ii.arg) let newStatements = ii.body.Statements.Insert (0, newBinding |> QsVariableDeclaration |> wrapStmt) - |> Seq.map renamer.Scope.onStatement + |> Seq.map renamer.Statements.OnStatement |> Seq.map (fun s -> s.Statement) |> ImmutableArray.CreateRange return ii, newStatements @@ -168,38 +194,24 @@ type CallableInlining(callables) = return newStatements, returnExpr } + /// Given a statement, returns a sequence of statements to replace this statement with. + /// Inlines simple calls that have exactly 0 or 1 return statements. + override __.CollectStatements stmt = + maybe { + match stmt with + | QsExpressionStatement ex -> + match safeInline ex with + | Some stmts -> + return upcast stmts + | None -> + let! stmts, returnExpr = safeInlineReturn ex + return Seq.append stmts [QsExpressionStatement returnExpr] + | QsVariableDeclaration s -> + let! stmts, returnExpr = safeInlineReturn s.Rhs + return Seq.append stmts [QsVariableDeclaration {s with Rhs = returnExpr}] + | QsValueUpdate s -> + let! stmts, returnExpr = safeInlineReturn s.Rhs + return Seq.append stmts [QsValueUpdate {s with Rhs = returnExpr}] + | _ -> return! None + } |? Seq.singleton stmt - override __.Transform x = - let x = base.Transform x - VariableRenaming().Transform x - - override __.onCallableImplementation c = - let renamerVal = VariableRenaming() - let c = renamerVal.onCallableImplementation c - currentCallable <- Some c - renamer <- Some renamerVal - base.onCallableImplementation c - - override __.Scope = upcast { new StatementCollectorTransformation() with - - /// Given a statement, returns a sequence of statements to replace this statement with. - /// Inlines simple calls that have exactly 0 or 1 return statements. - override __.TransformStatement stmt = - maybe { - match stmt with - | QsExpressionStatement ex -> - match safeInline ex with - | Some stmts -> - return upcast stmts - | None -> - let! stmts, returnExpr = safeInlineReturn ex - return Seq.append stmts [QsExpressionStatement returnExpr] - | QsVariableDeclaration s -> - let! stmts, returnExpr = safeInlineReturn s.Rhs - return Seq.append stmts [QsVariableDeclaration {s with Rhs = returnExpr}] - | QsValueUpdate s -> - let! stmts, returnExpr = safeInlineReturn s.Rhs - return Seq.append stmts [QsValueUpdate {s with Rhs = returnExpr}] - | _ -> return! None - } |? Seq.singleton stmt - } diff --git a/src/QsCompiler/Optimizations/OptimizingTransformations/ConstantPropagation.fs b/src/QsCompiler/Optimizations/OptimizingTransformations/ConstantPropagation.fs index 7fd900a6b8..7539adbbf0 100644 --- a/src/QsCompiler/Optimizations/OptimizingTransformations/ConstantPropagation.fs +++ b/src/QsCompiler/Optimizations/OptimizingTransformations/ConstantPropagation.fs @@ -3,6 +3,7 @@ namespace Microsoft.Quantum.QsCompiler.Experimental +open System.Collections.Generic open System.Collections.Immutable open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.Experimental.Evaluation @@ -10,15 +11,34 @@ open Microsoft.Quantum.QsCompiler.Experimental.Utils open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.Core +open Microsoft.Quantum.QsCompiler.Transformations /// The SyntaxTreeTransformation used to evaluate constants -type ConstantPropagation(callables) = - inherit OptimizingTransformation() +type ConstantPropagation private (_private_ : string) = + inherit TransformationBase() /// The current dictionary that maps variables to the values we substitute for them - let mutable constants = Map.empty + member val Constants = new Dictionary() + + new (callables) as this = + new ConstantPropagation("_private_") then + this.Namespaces <- new ConstantPropagationNamespaces(this) + this.StatementKinds <- new ConstantPropagationStatementKinds(this, callables) + this.Expressions <- (new ExpressionEvaluator(callables, this.Constants, 1000)).Expressions + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for ConstantPropagation +and private ConstantPropagationNamespaces(parent : ConstantPropagation) = + inherit NamespaceTransformationBase(parent) + + override __.OnProvidedImplementation (argTuple, body) = + parent.Constants.Clear() + base.OnProvidedImplementation (argTuple, body) + +/// private helper class for ConstantPropagation +and private ConstantPropagationStatementKinds (parent : ConstantPropagation, callables) = + inherit Core.StatementKindTransformation(parent) /// Returns whether the given expression should be propagated as a constant. /// For a statement of the form "let x = [expr];", if shouldPropagate(expr) is true, @@ -36,70 +56,51 @@ type ConstantPropagation(callables) = && Seq.forall id sub) expr.Fold folder + override so.OnVariableDeclaration stm = + let lhs = so.OnSymbolTuple stm.Lhs + let rhs = so.Expressions.OnTypedExpression stm.Rhs + if stm.Kind = ImmutableBinding then + defineVarTuple (shouldPropagate callables) parent.Constants (lhs, rhs) + QsBinding.New stm.Kind (lhs, rhs) |> QsVariableDeclaration + + override this.OnConditionalStatement stm = + let cbList, cbListEnd = + stm.ConditionalBlocks |> Seq.fold (fun s (cond, block) -> + let newCond = this.Expressions.OnTypedExpression cond + match newCond.Expression with + | BoolLiteral true -> s @ [Null, block] + | BoolLiteral false -> s + | _ -> s @ [Value cond, block] + ) [] |> List.ofSeq |> takeWhilePlus1 (fun (c, _) -> c <> Null) + let newDefault = cbListEnd |> Option.map (snd >> Value) |? stm.Default + + let cbList = cbList |> List.map (fun (c, b) -> this.OnPositionedBlock (c, b)) + let newDefault = match newDefault with Value x -> this.OnPositionedBlock (Null, x) |> snd |> Value | Null -> Null + + match cbList, newDefault with + | [], Value x -> + x.Body |> newScopeStatement + | [], Null -> + QsScope.New ([], LocalDeclarations.New []) |> newScopeStatement + | _ -> + let invalidCondition () = failwith "missing condition" + let cases = cbList |> Seq.map (fun (c, b) -> (c.ValueOrApply invalidCondition, b)) + QsConditionalStatement.New (cases, newDefault) |> QsConditionalStatement + + override this.OnQubitScope (stm : QsQubitScope) = + let kind = stm.Kind + let lhs = this.OnSymbolTuple stm.Binding.Lhs + let rhs = this.OnQubitInitializer stm.Binding.Rhs + + jointFlatten (lhs, rhs) |> Seq.iter (fun (l, r) -> + match l, r.Resolution with + | VariableName name, QubitRegisterAllocation {Expression = IntLiteral num} -> + let arrayIden = Identifier (LocalVariable name, Null) |> wrapExpr (ArrayType (ResolvedType.New Qubit)) + let elemI = fun i -> ArrayItem (arrayIden, IntLiteral (int64 i) |> wrapExpr Int) + let expr = Seq.init (safeCastInt64 num) (elemI >> wrapExpr Qubit) |> ImmutableArray.CreateRange |> ValueArray |> wrapExpr (ArrayType (ResolvedType.New Qubit)) + defineVar (fun _ -> true) parent.Constants (name.Value, expr) + | _ -> ()) + + let body = this.Statements.OnScope stm.Body + QsQubitScope.New kind ((lhs, rhs), body) |> QsQubitScope - override __.onProvidedImplementation (argTuple, body) = - constants <- Map.empty - base.onProvidedImplementation (argTuple, body) - - /// The ScopeTransformation used to evaluate constants - override __.Scope = { new ScopeTransformation() with - - /// The ExpressionTransformation used to evaluate constant expressions - override __.Expression = upcast ExpressionEvaluator(callables, constants, 1000) - - /// The StatementKindTransformation used to evaluate constants - override scope.StatementKind = { new StatementKindTransformation() with - override __.ExpressionTransformation x = scope.Expression.Transform x - override __.LocationTransformation x = x - override __.ScopeTransformation x = scope.Transform x - override __.TypeTransformation x = x - - override so.onVariableDeclaration stm = - let lhs = so.onSymbolTuple stm.Lhs - let rhs = so.ExpressionTransformation stm.Rhs - if stm.Kind = ImmutableBinding then - constants <- defineVarTuple (shouldPropagate callables) constants (lhs, rhs) - QsBinding.New stm.Kind (lhs, rhs) |> QsVariableDeclaration - - override this.onConditionalStatement stm = - let cbList, cbListEnd = - stm.ConditionalBlocks |> Seq.fold (fun s (cond, block) -> - let newCond = this.ExpressionTransformation cond - match newCond.Expression with - | BoolLiteral true -> s @ [Null, block] - | BoolLiteral false -> s - | _ -> s @ [Value cond, block] - ) [] |> List.ofSeq |> takeWhilePlus1 (fun (c, _) -> c <> Null) - let newDefault = cbListEnd |> Option.map (snd >> Value) |? stm.Default - - let cbList = cbList |> List.map (fun (c, b) -> this.onPositionedBlock (c, b)) - let newDefault = match newDefault with Value x -> this.onPositionedBlock (Null, x) |> snd |> Value | Null -> Null - - match cbList, newDefault with - | [], Value x -> - x.Body |> newScopeStatement - | [], Null -> - QsScope.New ([], LocalDeclarations.New []) |> newScopeStatement - | _ -> - let invalidCondition () = failwith "missing condition" - let cases = cbList |> Seq.map (fun (c, b) -> (c.ValueOrApply invalidCondition, b)) - QsConditionalStatement.New (cases, newDefault) |> QsConditionalStatement - - override this.onQubitScope (stm : QsQubitScope) = - let kind = stm.Kind - let lhs = this.onSymbolTuple stm.Binding.Lhs - let rhs = this.onQubitInitializer stm.Binding.Rhs - - jointFlatten (lhs, rhs) |> Seq.iter (fun (l, r) -> - match l, r.Resolution with - | VariableName name, QubitRegisterAllocation {Expression = IntLiteral num} -> - let arrayIden = Identifier (LocalVariable name, Null) |> wrapExpr (ArrayType (ResolvedType.New Qubit)) - let elemI = fun i -> ArrayItem (arrayIden, IntLiteral (int64 i) |> wrapExpr Int) - let expr = Seq.init (safeCastInt64 num) (elemI >> wrapExpr Qubit) |> ImmutableArray.CreateRange |> ValueArray |> wrapExpr (ArrayType (ResolvedType.New Qubit)) - constants <- defineVar (fun _ -> true) constants (name.Value, expr) - | _ -> ()) - - let body = this.ScopeTransformation stm.Body - QsQubitScope.New kind ((lhs, rhs), body) |> QsQubitScope - } - } diff --git a/src/QsCompiler/Optimizations/OptimizingTransformations/LoopUnrolling.fs b/src/QsCompiler/Optimizations/OptimizingTransformations/LoopUnrolling.fs index 9cded93562..dc0fe183d8 100644 --- a/src/QsCompiler/Optimizations/OptimizingTransformations/LoopUnrolling.fs +++ b/src/QsCompiler/Optimizations/OptimizingTransformations/LoopUnrolling.fs @@ -7,44 +7,51 @@ open Microsoft.Quantum.QsCompiler.Experimental.Utils open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.Core +open Microsoft.Quantum.QsCompiler.Transformations /// The SyntaxTreeTransformation used to unroll loops -type LoopUnrolling(callables, maxSize) = - inherit OptimizingTransformation() - - override __.Transform x = - let x = base.Transform x - VariableRenaming().Transform x - - override __.Scope = { new ScopeTransformation() with - override scope.StatementKind = { new StatementKindTransformation() with - override __.ExpressionTransformation x = x - override __.LocationTransformation x = x - override __.ScopeTransformation x = scope.Transform x - override __.TypeTransformation x = x - - override this.onForStatement stm = - let loopVar = fst stm.LoopItem |> this.onSymbolTuple - let iterVals = this.ExpressionTransformation stm.IterationValues - let loopVarType = this.TypeTransformation (snd stm.LoopItem) - let body = this.ScopeTransformation stm.Body - maybe { - let! iterValsList = - match iterVals.Expression with - | RangeLiteral _ when isLiteral callables iterVals -> - rangeLiteralToSeq iterVals.Expression |> Seq.map (IntLiteral >> wrapExpr Int) |> List.ofSeq |> Some - | ValueArray va -> va |> List.ofSeq |> Some - | _ -> None - do! check (iterValsList.Length <= maxSize) - let iterRange = iterValsList |> List.map (fun x -> - let variableDecl = QsBinding.New ImmutableBinding (loopVar, x) |> QsVariableDeclaration |> wrapStmt - let innerScope = { stm.Body with Statements = stm.Body.Statements.Insert(0, variableDecl) } - innerScope |> newScopeStatement |> wrapStmt) - let outerScope = QsScope.New (iterRange, stm.Body.KnownSymbols) - return outerScope |> newScopeStatement |> this.Transform - } - |? (QsForStatement.New ((loopVar, loopVarType), iterVals, body) |> QsForStatement) +type LoopUnrolling private (_private_ : string) = + inherit TransformationBase() + + new (callables, maxSize) as this = + new LoopUnrolling("_private_") then + this.Namespaces <- new LoopUnrollingNamespaces(this) + this.StatementKinds <- new LoopUnrollingStatementKinds(this, callables, maxSize) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for LoopUnrolling +and private LoopUnrollingNamespaces (parent : LoopUnrolling) = + inherit NamespaceTransformationBase(parent) + + override __.OnNamespace x = + let x = base.OnNamespace x + VariableRenaming().Namespaces.OnNamespace x + +/// private helper class for LoopUnrolling +and private LoopUnrollingStatementKinds (parent : LoopUnrolling, callables, maxSize) = + inherit Core.StatementKindTransformation(parent) + + override this.OnForStatement stm = + let loopVar = fst stm.LoopItem |> this.OnSymbolTuple + let iterVals = this.Expressions.OnTypedExpression stm.IterationValues + let loopVarType = this.Expressions.Types.OnType (snd stm.LoopItem) + let body = this.Statements.OnScope stm.Body + maybe { + let! iterValsList = + match iterVals.Expression with + | RangeLiteral _ when isLiteral callables iterVals -> + rangeLiteralToSeq iterVals.Expression |> Seq.map (IntLiteral >> wrapExpr Int) |> List.ofSeq |> Some + | ValueArray va -> va |> List.ofSeq |> Some + | _ -> None + do! check (iterValsList.Length <= maxSize) + let iterRange = iterValsList |> List.map (fun x -> + let variableDecl = QsBinding.New ImmutableBinding (loopVar, x) |> QsVariableDeclaration |> wrapStmt + let innerScope = { stm.Body with Statements = stm.Body.Statements.Insert(0, variableDecl) } + innerScope |> newScopeStatement |> wrapStmt) + let outerScope = QsScope.New (iterRange, stm.Body.KnownSymbols) + return outerScope |> newScopeStatement |> this.OnStatementKind } - } + |? (QsForStatement.New ((loopVar, loopVarType), iterVals, body) |> QsForStatement) + diff --git a/src/QsCompiler/Optimizations/OptimizingTransformations/StatementGrouping.fs b/src/QsCompiler/Optimizations/OptimizingTransformations/StatementGrouping.fs index 22933721eb..545444e4a0 100644 --- a/src/QsCompiler/Optimizations/OptimizingTransformations/StatementGrouping.fs +++ b/src/QsCompiler/Optimizations/OptimizingTransformations/StatementGrouping.fs @@ -6,26 +6,36 @@ namespace Microsoft.Quantum.QsCompiler.Experimental open Microsoft.Quantum.QsCompiler.Experimental.OptimizationTools open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.Core +open Microsoft.Quantum.QsCompiler.Transformations /// The SyntaxTreeTransformation used to reorder statements depending on how they impact the program state. -type StatementGrouping() = - inherit OptimizingTransformation() +type StatementGrouping private (_private_ : string) = + inherit TransformationBase() + + new () as this = + new StatementGrouping("_private_") then + this.Statements <- new StatementGroupingStatements(this) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for StatementGrouping +and private StatementGroupingStatements (parent : StatementGrouping) = + inherit Core.StatementTransformation(parent) /// Returns whether a statements is purely classical. /// The statement must have no classical or quantum side effects other than defining a variable. let isPureClassical stmt = let c = SideEffectChecker() - c.onStatement stmt |> ignore - not c.hasQuantum && not c.hasMutation && not c.hasInterrupts + c.Statements.OnStatement stmt |> ignore + not c.HasQuantum && not c.HasMutation && not c.HasInterrupts /// Returns whether a statement is purely quantum. /// The statement must have no classical side effects, but can have quantum side effects. let isPureQuantum stmt = let c = SideEffectChecker() - c.onStatement stmt |> ignore - c.hasQuantum && not c.hasMutation && not c.hasInterrupts + c.Statements.OnStatement stmt |> ignore + c.HasQuantum && not c.HasMutation && not c.HasInterrupts /// Reorders a list of statements such that the pure classical statements occur before the pure quantum statements let rec reorderStatements = function @@ -35,10 +45,9 @@ type StatementGrouping() = else a :: reorderStatements (b :: tail) | x -> x + override this.OnScope scope = + let parentSymbols = scope.KnownSymbols + let statements = scope.Statements |> Seq.map this.OnStatement |> List.ofSeq |> reorderStatements + QsScope.New (statements, parentSymbols) + - override __.Scope = { new ScopeTransformation() with - override this.Transform scope = - let parentSymbols = scope.KnownSymbols - let statements = scope.Statements |> Seq.map this.onStatement |> List.ofSeq |> reorderStatements - QsScope.New (statements, parentSymbols) - } diff --git a/src/QsCompiler/Optimizations/OptimizingTransformations/StatementRemoving.fs b/src/QsCompiler/Optimizations/OptimizingTransformations/StatementRemoving.fs index 9fbcf45430..4b8c46d234 100644 --- a/src/QsCompiler/Optimizations/OptimizingTransformations/StatementRemoving.fs +++ b/src/QsCompiler/Optimizations/OptimizingTransformations/StatementRemoving.fs @@ -10,61 +10,71 @@ open Microsoft.Quantum.QsCompiler.Experimental.Utils open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree +open Microsoft.Quantum.QsCompiler.Transformations /// The SyntaxTreeTransformation used to remove useless statements -type StatementRemoval(removeFunctions) = - inherit OptimizingTransformation() +type StatementRemoval private (_private_ : string) = + inherit TransformationBase() - override __.Scope = upcast { new StatementCollectorTransformation() with + new (removeFunctions : bool) as this = + new StatementRemoval("_private_") then + this.Statements <- new VariableRemovalStatements(this, removeFunctions) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) - /// Given a statement, returns a sequence of statements to replace this statement with. - /// Removes useless statements, such as variable declarations with discarded values. - /// Replaces ScopeStatements and empty qubit allocations with the statements they contain. - /// Splits tuple declaration and updates into single declarations or updates for each variable. - /// Splits QsQubitScopes by replacing register allocations with single qubit allocations. - override __.TransformStatement stmt = - let c = SideEffectChecker() - c.StatementKind.Transform stmt |> ignore +/// private helper class for StatementRemoval +and private VariableRemovalStatements (parent : StatementRemoval, removeFunctions) = + inherit StatementCollectorTransformation(parent) - let c2 = MutationChecker() - c2.StatementKind.Transform stmt |> ignore + /// Given a statement, returns a sequence of statements to replace this statement with. + /// Removes useless statements, such as variable declarations with discarded values. + /// Replaces ScopeStatements and empty qubit allocations with the statements they contain. + /// Splits tuple declaration and updates into single declarations or updates for each variable. + /// Splits QsQubitScopes by replacing register allocations with single qubit allocations. + override __.CollectStatements stmt = + + let c = SideEffectChecker() + c.StatementKinds.OnStatementKind stmt |> ignore + + let c2 = MutationChecker() + c2.StatementKinds.OnStatementKind stmt |> ignore + + match stmt with + | QsVariableDeclaration {Lhs = lhs} + | QsValueUpdate {Lhs = LocalVarTuple lhs} + when isAllDiscarded lhs && not c.HasQuantum -> Seq.empty + | QsVariableDeclaration s -> + jointFlatten (s.Lhs, s.Rhs) |> Seq.map (QsBinding.New s.Kind >> QsVariableDeclaration) + | QsValueUpdate s -> + jointFlatten (s.Lhs, s.Rhs) |> Seq.map (QsValueUpdate.New >> QsValueUpdate) + | QsQubitScope s when isAllDiscarded s.Binding.Lhs -> + s.Body.Statements |> Seq.map (fun x -> x.Statement) + | QsQubitScope s -> + let mutable newStatements = [] + let myList = jointFlatten (s.Binding.Lhs, s.Binding.Rhs) |> Seq.collect (fun (l, r) -> + match l, r.Resolution with + | VariableName name, QubitRegisterAllocation {Expression = IntLiteral num} -> + let elemI = fun i -> Identifier (LocalVariable (NonNullable<_>.New (sprintf "__qsItem%d__%s__" i name.Value)), Null) + let expr = Seq.init (safeCastInt64 num) (elemI >> wrapExpr Qubit) |> ImmutableArray.CreateRange |> ValueArray |> wrapExpr (ArrayType (ResolvedType.New Qubit)) + let newStmt = QsVariableDeclaration (QsBinding.New QsBindingKind.ImmutableBinding (l, expr)) + newStatements <- wrapStmt newStmt :: newStatements + Seq.init (safeCastInt64 num) (fun i -> + VariableName (NonNullable<_>.New (sprintf "__qsItem%d__%s__" i name.Value)), + ResolvedInitializer.New SingleQubitAllocation) + | DiscardedItem, _ -> Seq.empty + | _ -> Seq.singleton (l, r)) |> List.ofSeq + match myList with + | [] -> newScopeStatement s.Body |> Seq.singleton + | [lhs, rhs] -> + let newBody = QsScope.New (s.Body.Statements.InsertRange (0, newStatements), s.Body.KnownSymbols) + QsQubitScope.New s.Kind ((lhs, rhs), newBody) |> QsQubitScope |> Seq.singleton + | _ -> + let lhs = List.map fst myList |> ImmutableArray.CreateRange |> VariableNameTuple + let rhs = List.map snd myList |> ImmutableArray.CreateRange |> QubitTupleAllocation |> ResolvedInitializer.New + let newBody = QsScope.New (s.Body.Statements.InsertRange (0, newStatements), s.Body.KnownSymbols) + QsQubitScope.New s.Kind ((lhs, rhs), newBody) |> QsQubitScope |> Seq.singleton + | ScopeStatement s -> s.Body.Statements |> Seq.map (fun x -> x.Statement) + | _ when not c.HasQuantum && c2.ExternalMutations.IsEmpty && not c.HasInterrupts && (not c.HasOutput || removeFunctions) -> Seq.empty + | a -> Seq.singleton a - match stmt with - | QsVariableDeclaration {Lhs = lhs} - | QsValueUpdate {Lhs = LocalVarTuple lhs} - when isAllDiscarded lhs && not c.hasQuantum -> Seq.empty - | QsVariableDeclaration s -> - jointFlatten (s.Lhs, s.Rhs) |> Seq.map (QsBinding.New s.Kind >> QsVariableDeclaration) - | QsValueUpdate s -> - jointFlatten (s.Lhs, s.Rhs) |> Seq.map (QsValueUpdate.New >> QsValueUpdate) - | QsQubitScope s when isAllDiscarded s.Binding.Lhs -> - s.Body.Statements |> Seq.map (fun x -> x.Statement) - | QsQubitScope s -> - let mutable newStatements = [] - let myList = jointFlatten (s.Binding.Lhs, s.Binding.Rhs) |> Seq.collect (fun (l, r) -> - match l, r.Resolution with - | VariableName name, QubitRegisterAllocation {Expression = IntLiteral num} -> - let elemI = fun i -> Identifier (LocalVariable (NonNullable<_>.New (sprintf "__qsItem%d__%s__" i name.Value)), Null) - let expr = Seq.init (safeCastInt64 num) (elemI >> wrapExpr Qubit) |> ImmutableArray.CreateRange |> ValueArray |> wrapExpr (ArrayType (ResolvedType.New Qubit)) - let newStmt = QsVariableDeclaration (QsBinding.New QsBindingKind.ImmutableBinding (l, expr)) - newStatements <- wrapStmt newStmt :: newStatements - Seq.init (safeCastInt64 num) (fun i -> - VariableName (NonNullable<_>.New (sprintf "__qsItem%d__%s__" i name.Value)), - ResolvedInitializer.New SingleQubitAllocation) - | DiscardedItem, _ -> Seq.empty - | _ -> Seq.singleton (l, r)) |> List.ofSeq - match myList with - | [] -> newScopeStatement s.Body |> Seq.singleton - | [lhs, rhs] -> - let newBody = QsScope.New (s.Body.Statements.InsertRange (0, newStatements), s.Body.KnownSymbols) - QsQubitScope.New s.Kind ((lhs, rhs), newBody) |> QsQubitScope |> Seq.singleton - | _ -> - let lhs = List.map fst myList |> ImmutableArray.CreateRange |> VariableNameTuple - let rhs = List.map snd myList |> ImmutableArray.CreateRange |> QubitTupleAllocation |> ResolvedInitializer.New - let newBody = QsScope.New (s.Body.Statements.InsertRange (0, newStatements), s.Body.KnownSymbols) - QsQubitScope.New s.Kind ((lhs, rhs), newBody) |> QsQubitScope |> Seq.singleton - | ScopeStatement s -> s.Body.Statements |> Seq.map (fun x -> x.Statement) - | _ when not c.hasQuantum && c2.externalMutations.IsEmpty && not c.hasInterrupts && (not c.hasOutput || removeFunctions) -> Seq.empty - | a -> Seq.singleton a - } diff --git a/src/QsCompiler/Optimizations/OptimizingTransformations/TransformationBase.fs b/src/QsCompiler/Optimizations/OptimizingTransformations/TransformationBase.fs index 36d1f84479..31b2d2e9e8 100644 --- a/src/QsCompiler/Optimizations/OptimizingTransformations/TransformationBase.fs +++ b/src/QsCompiler/Optimizations/OptimizingTransformations/TransformationBase.fs @@ -11,21 +11,29 @@ open Microsoft.Quantum.QsCompiler.Transformations.Core /// It provides a function called `checkChanged` which returns true, if /// the transformation leads to a change in any of the namespaces' syntax /// tree, except for changes in the namespaces' documentation string. -type OptimizingTransformation() = +type TransformationBase private (_private_) = inherit SyntaxTreeTransformation() - let mutable changed = false + member val Changed = false with get, set /// Returns whether the syntax tree has been modified since this function was last called - member internal __.checkChanged() = - let x = changed - changed <- false - x + member internal this.CheckChanged() = + let res = this.Changed + this.Changed <- false + res + + new () as this = + new TransformationBase("_private_") then + this.Namespaces <- new NamespaceTransformationBase(this) + +/// private helper class for OptimizingTransformation +and private NamespaceTransformationBase (parent : TransformationBase) = + inherit NamespaceTransformation(parent) /// Checks whether the syntax tree changed at all - override __.Transform x = - let newX = base.Transform x - if (x.Elements, x.Name) <> (newX.Elements, newX.Name) then changed <- true + override this.OnNamespace x = + let newX = base.OnNamespace x + if (x.Elements, x.Name) <> (newX.Elements, newX.Name) then parent.Changed <- true newX diff --git a/src/QsCompiler/Optimizations/OptimizingTransformations/VariableRemoving.fs b/src/QsCompiler/Optimizations/OptimizingTransformations/VariableRemoving.fs index effe779493..9dee12b9df 100644 --- a/src/QsCompiler/Optimizations/OptimizingTransformations/VariableRemoving.fs +++ b/src/QsCompiler/Optimizations/OptimizingTransformations/VariableRemoving.fs @@ -7,38 +7,45 @@ open System.Collections.Immutable open Microsoft.Quantum.QsCompiler.Experimental.OptimizationTools open Microsoft.Quantum.QsCompiler.Experimental.Utils open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.Core +open Microsoft.Quantum.QsCompiler.Transformations /// The SyntaxTreeTransformation used to remove useless statements -type VariableRemoval() = - inherit OptimizingTransformation() +type VariableRemoval(_private_) = + inherit TransformationBase() - let mutable referenceCounter = None + member val internal ReferenceCounter = None with get, set - override __.onProvidedImplementation (argTuple, body) = + new () as this = + new VariableRemoval("_private_") then + this.Namespaces <- new VariableRemovalNamespaces(this) + this.StatementKinds <- new VariableRemovalStatementKinds(this) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for VariableRemoval +and private VariableRemovalNamespaces (parent : VariableRemoval) = + inherit NamespaceTransformationBase(parent) + + override __.OnProvidedImplementation (argTuple, body) = let r = ReferenceCounter() - r.Transform body |> ignore - referenceCounter <- Some r - base.onProvidedImplementation (argTuple, body) - - override __.Scope = { new ScopeTransformation() with - override this.StatementKind = { new StatementKindTransformation() with - override __.ExpressionTransformation x = x - override __.LocationTransformation x = x - override __.ScopeTransformation x = this.Transform x - override __.TypeTransformation x = x - - override stmtKind.onSymbolTuple syms = - match syms with - | VariableName item -> - maybe { - let! r = referenceCounter - let uses = r.getNumUses item - do! check (uses = 0) - return DiscardedItem - } |? syms - | VariableNameTuple items -> Seq.map stmtKind.onSymbolTuple items |> ImmutableArray.CreateRange |> VariableNameTuple - | InvalidItem | DiscardedItem -> syms - } - } + r.Statements.OnScope body |> ignore + parent.ReferenceCounter <- Some r + base.OnProvidedImplementation (argTuple, body) + +/// private helper class for VariableRemoval +and private VariableRemovalStatementKinds (parent : VariableRemoval) = + inherit Core.StatementKindTransformation(parent) + + override stmtKind.OnSymbolTuple syms = + match syms with + | VariableName item -> + maybe { + let! r = parent.ReferenceCounter + let uses = r.NumberOfUses item + do! check (uses = 0) + return DiscardedItem + } |? syms + | VariableNameTuple items -> Seq.map stmtKind.OnSymbolTuple items |> ImmutableArray.CreateRange |> VariableNameTuple + | InvalidItem | DiscardedItem -> syms + diff --git a/src/QsCompiler/Optimizations/PreEvaluation.fs b/src/QsCompiler/Optimizations/PreEvaluation.fs index 0be417e167..71f78c6268 100644 --- a/src/QsCompiler/Optimizations/PreEvaluation.fs +++ b/src/QsCompiler/Optimizations/PreEvaluation.fs @@ -22,25 +22,25 @@ type PreEvaluation = /// function that takes as input such a dictionary of callables. /// /// Disclaimer: This is an experimental feature. - static member WithScript (script : Func, OptimizingTransformation seq>) (arg : QsCompilation) = + static member WithScript (script : Func, TransformationBase seq>) (arg : QsCompilation) = // TODO: this should actually only evaluate everything for each entry point let rec evaluate (tree : _ list) = let mutable tree = tree - tree <- List.map (StripAllKnownSymbols().Transform) tree - tree <- List.map (VariableRenaming().Transform) tree + tree <- List.map (StripAllKnownSymbols().Namespaces.OnNamespace) tree + tree <- List.map (VariableRenaming().Namespaces.OnNamespace) tree let callables = GlobalCallableResolutions tree // needs to be constructed in every iteration let optimizers = script.Invoke callables |> Seq.toList - for opt in optimizers do tree <- List.map opt.Transform tree - if optimizers |> List.exists (fun opt -> opt.checkChanged()) then evaluate tree + for opt in optimizers do tree <- List.map opt.Namespaces.OnNamespace tree + if optimizers |> List.exists (fun opt -> opt.CheckChanged()) then evaluate tree else tree let namespaces = arg.Namespaces |> Seq.map StripPositionInfo.Apply |> List.ofSeq |> evaluate QsCompilation.New (namespaces.ToImmutableArray(), arg.EntryPoints) /// Default sequence of optimizing transformations - static member DefaultScript removeFunctions maxSize : Func<_, OptimizingTransformation seq> = + static member DefaultScript removeFunctions maxSize : Func<_, TransformationBase seq> = new Func<_,_> (fun callables -> seq { VariableRemoval() StatementRemoval(removeFunctions) diff --git a/src/QsCompiler/Optimizations/PureCircuitFinding.fs b/src/QsCompiler/Optimizations/PureCircuitFinding.fs index c1c6b369d8..67a146ced6 100644 --- a/src/QsCompiler/Optimizations/PureCircuitFinding.fs +++ b/src/QsCompiler/Optimizations/PureCircuitFinding.fs @@ -10,12 +10,35 @@ open Microsoft.Quantum.QsCompiler.Experimental.Utils open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.Core +open Microsoft.Quantum.QsCompiler.Transformations /// The SyntaxTreeTransformation used to find and optimize pure circuits -type PureCircuitFinder(callables) = - inherit OptimizingTransformation() +type PureCircuitFinder private (_private_ : string) = + inherit TransformationBase() + + member val internal DistinctQubitFinder = None with get, set + + new (callables) as this = + new PureCircuitFinder("_private_") then + this.Namespaces <- new PureCircuitFinderNamespaces(this) + this.Statements <- new PureCircuitFinderStatements(this, callables) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for PureCircuitFinder +and private PureCircuitFinderNamespaces (parent : PureCircuitFinder) = + inherit Core.NamespaceTransformation(parent) + + override __.OnCallableDeclaration c = + let r = FindDistinctQubits() + r.Namespaces.OnCallableDeclaration c |> ignore + parent.DistinctQubitFinder <- Some r + base.OnCallableDeclaration c + +/// private helper class for PureCircuitFinder +and private PureCircuitFinderStatements (parent : PureCircuitFinder, callables : ImmutableDictionary<_,_>) = + inherit Core.StatementTransformation(parent) /// Returns whether an expression is an operation call let isOperation expr = @@ -27,39 +50,32 @@ type PureCircuitFinder(callables) = | _ -> false | _ -> false - let mutable distinctQubitFinder = None + override this.OnScope scope = + let mutable circuit = ImmutableArray.Empty + let mutable newStatements = ImmutableArray.Empty + + let finishCircuit () = + if circuit.Length <> 0 then + let newCircuit = optimizeExprList callables parent.DistinctQubitFinder.Value.DistinctNames circuit + (*if newCircuit <> circuit then + printfn "Removed %d gates" (circuit.Length - newCircuit.Length) + printfn "Old: %O" (List.map (fun x -> printExpr x.Expression) circuit) + printfn "New: %O" (List.map (fun x -> printExpr x.Expression) newCircuit) + printfn ""*) + newStatements <- newStatements.AddRange (Seq.map (QsExpressionStatement >> wrapStmt) newCircuit) + circuit <- ImmutableArray.Empty + + for stmt in scope.Statements do + match stmt.Statement with + | QsExpressionStatement expr when isOperation expr -> + circuit <- circuit.Add expr + | _ -> + finishCircuit() + newStatements <- newStatements.Add (this.OnStatement stmt) + finishCircuit() + + QsScope.New (newStatements, scope.KnownSymbols) + + + - override __.onCallableImplementation c = - let r = FindDistinctQubits() - r.onCallableImplementation c |> ignore - distinctQubitFinder <- Some r - base.onCallableImplementation c - - override __.Scope = { new ScopeTransformation() with - - override this.Transform scope = - let mutable circuit = ImmutableArray.Empty - let mutable newStatements = ImmutableArray.Empty - - let finishCircuit () = - if circuit.Length <> 0 then - let newCircuit = optimizeExprList callables distinctQubitFinder.Value.distinctNames circuit - (*if newCircuit <> circuit then - printfn "Removed %d gates" (circuit.Length - newCircuit.Length) - printfn "Old: %O" (List.map (fun x -> printExpr x.Expression) circuit) - printfn "New: %O" (List.map (fun x -> printExpr x.Expression) newCircuit) - printfn ""*) - newStatements <- newStatements.AddRange (Seq.map (QsExpressionStatement >> wrapStmt) newCircuit) - circuit <- ImmutableArray.Empty - - for stmt in scope.Statements do - match stmt.Statement with - | QsExpressionStatement expr when isOperation expr -> - circuit <- circuit.Add expr - | _ -> - finishCircuit() - newStatements <- newStatements.Add (this.onStatement stmt) - finishCircuit() - - QsScope.New (newStatements, scope.KnownSymbols) - } diff --git a/src/QsCompiler/Optimizations/Utils/Evaluation.fs b/src/QsCompiler/Optimizations/Utils/Evaluation.fs index e052c745e0..e352c18b03 100644 --- a/src/QsCompiler/Optimizations/Utils/Evaluation.fs +++ b/src/QsCompiler/Optimizations/Utils/Evaluation.fs @@ -4,6 +4,7 @@ module internal Microsoft.Quantum.QsCompiler.Experimental.Evaluation open System +open System.Collections.Generic open System.Collections.Immutable open System.Numerics open Microsoft.Quantum.QsCompiler @@ -18,7 +19,7 @@ open Microsoft.Quantum.QsCompiler.Transformations.Core /// Represents the internal state of a function evaluation. /// The first element is a map that stores the current values of all the local variables. /// The second element is a counter that stores the remaining number of statements we evaluate. -type private EvalState = Map * int +type private EvalState = Dictionary * int /// Represents any interrupt to the normal control flow of a function evaluation. /// Includes return statements, errors, and (if they were added) break/continue statements. @@ -37,7 +38,7 @@ type private Imp<'t> = Imperative /// Evaluates functions by stepping through their code -type internal FunctionEvaluator(callables: ImmutableDictionary) = +type internal FunctionEvaluator(callables : IDictionary) = /// Represents a computation that decreases the remaining statements counter by 1. /// Yields an OutOfStatements interrupt if this decreases the remaining statements below 0. @@ -48,9 +49,10 @@ type internal FunctionEvaluator(callables: ImmutableDictionary = imperative { + let setVars callables entry : Imp = imperative { let! vars, counter = getState - do! putState (defineVarTuple (isLiteral callables) vars entry, counter) + defineVarTuple (isLiteral callables) vars entry + do! putState (vars, counter) } /// Casts a BoolLiteral to the corresponding bool @@ -60,15 +62,15 @@ type internal FunctionEvaluator(callables: ImmutableDictionary ArgumentException ("Not a BoolLiteral: " + x.Expression.ToString()) |> raise /// Evaluates and simplifies a single Q# expression - member internal __.evaluateExpression expr: Imp = imperative { + member internal this.EvaluateExpression expr : Imp = imperative { let! vars, counter = getState - let result = ExpressionEvaluator(callables, vars, counter / 2).Transform expr + let result = ExpressionEvaluator(callables, vars, counter / 2).Expressions.OnTypedExpression expr if isLiteral callables result then return result else yield CouldNotEvaluate ("Not a literal: " + result.Expression.ToString()) } /// Evaluates a single Q# statement - member private this.evaluateStatement (statement: QsStatement): Imp = imperative { + member private this.EvaluateStatement (statement : QsStatement) = imperative { do! incrementState match statement.Statement with @@ -77,34 +79,34 @@ type internal FunctionEvaluator(callables: ImmutableDictionary - let! value = this.evaluateExpression expr + let! value = this.EvaluateExpression expr yield Returned value | QsFailStatement expr -> - let! value = this.evaluateExpression expr + let! value = this.EvaluateExpression expr yield Failed value | QsVariableDeclaration s -> - let! value = this.evaluateExpression s.Rhs + let! value = this.EvaluateExpression s.Rhs do! setVars callables (s.Lhs, value) | QsValueUpdate s -> match s.Lhs with | LocalVarTuple vt -> - let! value = this.evaluateExpression s.Rhs + let! value = this.EvaluateExpression s.Rhs do! setVars callables (vt, value) | _ -> yield CouldNotEvaluate ("Unknown LHS of value update statement: " + s.Lhs.Expression.ToString()) | QsConditionalStatement s -> let mutable evalElseCase = true for cond, block in s.ConditionalBlocks do - let! value = this.evaluateExpression cond <&> castToBool + let! value = this.EvaluateExpression cond <&> castToBool if value then - do! this.evaluateScope block.Body + do! this.EvaluateScope block.Body evalElseCase <- false do! Break if evalElseCase then match s.Default with - | Value block -> do! this.evaluateScope block.Body + | Value block -> do! this.EvaluateScope block.Body | _ -> () | QsForStatement stmt -> - let! iterExpr = this.evaluateExpression stmt.IterationValues + let! iterExpr = this.EvaluateExpression stmt.IterationValues let! iterSeq = imperative { match iterExpr.Expression with | RangeLiteral _ when isLiteral callables iterExpr -> @@ -116,33 +118,34 @@ type internal FunctionEvaluator(callables: ImmutableDictionary - while this.evaluateExpression stmt.Condition <&> castToBool do - do! this.evaluateScope stmt.Body + while this.EvaluateExpression stmt.Condition <&> castToBool do + do! this.EvaluateScope stmt.Body | QsRepeatStatement stmt -> while true do - do! this.evaluateScope stmt.RepeatBlock.Body - let! value = this.evaluateExpression stmt.SuccessCondition <&> castToBool + do! this.EvaluateScope stmt.RepeatBlock.Body + let! value = this.EvaluateExpression stmt.SuccessCondition <&> castToBool if value then do! Break - do! this.evaluateScope stmt.FixupBlock.Body + do! this.EvaluateScope stmt.FixupBlock.Body | QsQubitScope _ -> yield CouldNotEvaluate "Cannot allocate qubits in function" | QsConjugation _ -> yield CouldNotEvaluate "Cannot conjugate in function" + | EmptyStatement -> () } /// Evaluates a list of Q# statements - member private this.evaluateScope (scope: QsScope): Imp = imperative { + member private this.EvaluateScope (scope : QsScope) = imperative { for stmt in scope.Statements do - do! this.evaluateStatement stmt + do! this.EvaluateStatement stmt } /// Evaluates the given Q# function on the given argument. /// Returns Some ([expr]) if we successfully evaluate the function as [expr]. /// Returns None if we were unable to evaluate the function. /// Throws an ArgumentException if the input is not a function, or if the function is invalid. - member internal this.evaluateFunction (name: QsQualifiedName) (arg: TypedExpression) (types: QsNullable>) (stmtsLeft: int): TypedExpression option = + member internal this.EvaluateFunction (name : QsQualifiedName) (arg : TypedExpression) (stmtsLeft : int) = let callable = callables.[name] if callable.Kind = Operation then ArgumentException "Input is not a function" |> raise @@ -151,37 +154,39 @@ type internal FunctionEvaluator(callables: ImmutableDictionary - let vars = defineVarTuple (isLiteral callables) Map.empty (toSymbolTuple specArgs, arg) - match this.evaluateScope scope (vars, stmtsLeft) with + let vars = new Dictionary<_,_>() + defineVarTuple (isLiteral callables) vars (toSymbolTuple specArgs, arg) + match this.EvaluateScope scope (vars, stmtsLeft) with | Normal _ -> None | Break _ -> None | Interrupt (Returned expr) -> Some expr | Interrupt (Failed _) -> None | Interrupt TooManyStatements -> None - | Interrupt (CouldNotEvaluate reason) -> None + | Interrupt (CouldNotEvaluate _) -> None | _ -> None /// The ExpressionTransformation used to evaluate constant expressions -and internal ExpressionEvaluator(callables: ImmutableDictionary, constants: Map, stmtsLeft: int) = - inherit ExpressionTransformation() +and internal ExpressionEvaluator private (_private_) = + inherit SyntaxTreeTransformation() - override this.Kind = upcast { new ExpressionKindEvaluator(callables, constants, stmtsLeft) with - override __.ExpressionTransformation x = this.Transform x - override __.TypeTransformation x = this.Type.Transform x } + internal new (callables : IDictionary, constants : IDictionary, stmtsLeft : int) as this = + new ExpressionEvaluator("_private_") then + this.ExpressionKinds <- new ExpressionKindEvaluator(this, callables, constants, stmtsLeft) + this.Types <- new TypeTransformation(this, TransformationOptions.Disabled) /// The ExpressionKindTransformation used to evaluate constant expressions -and [] private ExpressionKindEvaluator(callables: ImmutableDictionary, constants: Map, stmtsLeft: int) = - inherit ExpressionKindTransformation() +and private ExpressionKindEvaluator(parent, callables: IDictionary, constants: IDictionary, stmtsLeft: int) = + inherit ExpressionKindTransformation(parent) - member private this.simplify e1 = this.ExpressionTransformation e1 + member private this.simplify e1 = this.Expressions.OnTypedExpression e1 member private this.simplify (e1, e2) = - (this.ExpressionTransformation e1, this.ExpressionTransformation e2) + (this.Expressions.OnTypedExpression e1, this.Expressions.OnTypedExpression e2) member private this.simplify (e1, e2, e3) = - (this.ExpressionTransformation e1, this.ExpressionTransformation e2, this.ExpressionTransformation e3) + (this.Expressions.OnTypedExpression e1, this.Expressions.OnTypedExpression e2, this.Expressions.OnTypedExpression e3) member private this.arithBoolBinaryOp qop bigIntOp doubleOp intOp lhs rhs = let lhs, rhs = this.simplify (lhs, rhs) @@ -205,46 +210,49 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio | IntLiteral a, IntLiteral b -> IntLiteral (intOp a b) | _ -> qop (lhs, rhs) - override this.onIdentifier (sym, tArgs) = + override this.OnIdentifier (sym, tArgs) = match sym with - | LocalVariable name -> Map.tryFind name.Value constants |> Option.map (fun x -> x.Expression) |? Identifier (sym, tArgs) + | LocalVariable name -> + match constants.TryGetValue name.Value with + | true, ex -> ex.Expression + | _ -> Identifier (sym, tArgs) | _ -> Identifier (sym, tArgs) - override this.onFunctionCall (method, arg) = + override this.OnFunctionCall (method, arg) = let method, arg = this.simplify (method, arg) maybe { match method.Expression with - | Identifier (GlobalCallable qualName, types) -> + | Identifier (GlobalCallable qualName, _) -> do! check (stmtsLeft > 0 && isLiteral callables arg) let fe = FunctionEvaluator (callables) - return! fe.evaluateFunction qualName arg types stmtsLeft |> Option.map (fun x -> x.Expression) + return! fe.EvaluateFunction qualName arg stmtsLeft |> Option.map (fun x -> x.Expression) | CallLikeExpression (baseMethod, partialArg) -> do! check (TypedExpression.IsPartialApplication method.Expression) - return this.Transform (CallLikeExpression (baseMethod, fillPartialArg (partialArg, arg))) + return this.OnExpressionKind (CallLikeExpression (baseMethod, fillPartialArg (partialArg, arg))) | _ -> return! None } |? CallLikeExpression (method, arg) - override this.onOperationCall (method, arg) = + override this.OnOperationCall (method, arg) = let method, arg = this.simplify (method, arg) maybe { match method.Expression with | CallLikeExpression (baseMethod, partialArg) -> do! check (TypedExpression.IsPartialApplication method.Expression) - return this.Transform (CallLikeExpression (baseMethod, fillPartialArg (partialArg, arg))) + return this.OnExpressionKind (CallLikeExpression (baseMethod, fillPartialArg (partialArg, arg))) | _ -> return! None } |? CallLikeExpression (method, arg) - override this.onPartialApplication (method, arg) = + override this.OnPartialApplication (method, arg) = let method, arg = this.simplify (method, arg) maybe { match method.Expression with | CallLikeExpression (baseMethod, partialArg) -> do! check (TypedExpression.IsPartialApplication method.Expression) - return this.Transform (CallLikeExpression (baseMethod, fillPartialArg (partialArg, arg))) + return this.OnExpressionKind (CallLikeExpression (baseMethod, fillPartialArg (partialArg, arg))) | _ -> return! None } |? CallLikeExpression (method, arg) - override this.onUnwrapApplication ex = + override this.OnUnwrapApplication ex = let ex = this.simplify ex match ex.Expression with | CallLikeExpression ({Expression = Identifier (GlobalCallable qualName, types)}, arg) @@ -259,7 +267,7 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio arg.Expression | _ -> UnwrapApplication ex - override this.onArrayItem (arr, idx) = + override this.OnArrayItem (arr, idx) = let arr, idx = this.simplify (arr, idx) match arr.Expression, idx.Expression with | ValueArray va, IntLiteral i -> va.[safeCastInt64 i].Expression @@ -267,13 +275,13 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio rangeLiteralToSeq idx.Expression |> Seq.map (fun i -> va.[safeCastInt64 i]) |> ImmutableArray.CreateRange |> ValueArray | _ -> ArrayItem (arr, idx) - override this.onNewArray (bt, idx) = + override this.OnNewArray (bt, idx) = let idx = this.simplify idx match idx.Expression with | IntLiteral i -> constructNewArray bt.Resolution (safeCastInt64 i) |? NewArray (bt, idx) | _ -> NewArray (bt, idx) - override this.onCopyAndUpdateExpression (lhs, accEx, rhs) = + override this.OnCopyAndUpdateExpression (lhs, accEx, rhs) = let lhs, accEx, rhs = this.simplify (lhs, accEx, rhs) match lhs.Expression, accEx.Expression, rhs.Expression with | ValueArray va, IntLiteral i, _ -> ValueArray (va.SetItem(safeCastInt64 i, rhs)) @@ -283,44 +291,44 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio // TODO - handle named items in user-defined types | _ -> CopyAndUpdate (lhs, accEx, rhs) - override this.onConditionalExpression (e1, e2, e3) = + override this.OnConditionalExpression (e1, e2, e3) = let e1 = this.simplify e1 match e1.Expression with | BoolLiteral a -> if a then (this.simplify e2).Expression else (this.simplify e3).Expression | _ -> CONDITIONAL (e1, this.simplify e2, this.simplify e3) - override this.onEquality (lhs, rhs) = + override this.OnEquality (lhs, rhs) = let lhs, rhs = this.simplify (lhs, rhs) match isLiteral callables lhs && isLiteral callables rhs with | true -> BoolLiteral (lhs.Expression = rhs.Expression) | false -> EQ (lhs, rhs) - override this.onInequality (lhs, rhs) = + override this.OnInequality (lhs, rhs) = let lhs, rhs = this.simplify (lhs, rhs) match isLiteral callables lhs && isLiteral callables rhs with | true -> BoolLiteral (lhs.Expression <> rhs.Expression) | false -> NEQ (lhs, rhs) - override this.onLessThan (lhs, rhs) = + override this.OnLessThan (lhs, rhs) = this.arithBoolBinaryOp LT (<) (<) (<) lhs rhs - override this.onLessThanOrEqual (lhs, rhs) = + override this.OnLessThanOrEqual (lhs, rhs) = this.arithBoolBinaryOp LTE (<=) (<=) (<=) lhs rhs - override this.onGreaterThan (lhs, rhs) = + override this.OnGreaterThan (lhs, rhs) = this.arithBoolBinaryOp GT (>) (>) (>) lhs rhs - override this.onGreaterThanOrEqual (lhs, rhs) = + override this.OnGreaterThanOrEqual (lhs, rhs) = this.arithBoolBinaryOp GTE (>=) (>=) (>=) lhs rhs - override this.onLogicalAnd (lhs, rhs) = + override this.OnLogicalAnd (lhs, rhs) = let lhs = this.simplify lhs match lhs.Expression with | BoolLiteral true -> (this.simplify rhs).Expression | BoolLiteral false -> BoolLiteral false | _ -> AND (lhs, this.simplify rhs) - override this.onLogicalOr (lhs, rhs) = + override this.OnLogicalOr (lhs, rhs) = let lhs = this.simplify lhs match lhs.Expression with | BoolLiteral true -> BoolLiteral true @@ -332,7 +340,7 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio // - rewrites (integers, big integers, and doubles): // 0 + x = x // x + 0 = x - override this.onAddition (lhs, rhs) = + override this.OnAddition (lhs, rhs) = let lhs, rhs = this.simplify (lhs, rhs) match lhs.Expression, rhs.Expression with | ValueArray a, ValueArray b -> ValueArray (a.AddRange b) @@ -351,7 +359,7 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio // x - 0 = x // 0 - x = -x // x - x = 0 - override this.onSubtraction (lhs, rhs) = + override this.OnSubtraction (lhs, rhs) = let lhs, rhs = this.simplify (lhs, rhs) match lhs.Expression, rhs.Expression with | op, BigIntLiteral zero when zero.IsZero -> op @@ -374,7 +382,7 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio // 0 * x = 0 // x * 1 = x // 1 * x = x - override this.onMultiplication (lhs, rhs) = + override this.OnMultiplication (lhs, rhs) = let lhs, rhs = this.simplify (lhs, rhs) match lhs.Expression, rhs.Expression with | _, (BigIntLiteral zero) @@ -392,7 +400,7 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio | _ -> this.arithNumBinaryOp MUL (*) (*) (*) lhs rhs // - simplifies multiplication of two constants into single constant - override this.onDivision (lhs, rhs) = + override this.OnDivision (lhs, rhs) = let lhs, rhs = this.simplify (lhs, rhs) match lhs.Expression, rhs.Expression with | op, (BigIntLiteral one) when one.IsOne -> op @@ -400,7 +408,7 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio | op, (IntLiteral 1L) -> op | _ -> this.arithNumBinaryOp DIV (/) (/) (/) lhs rhs - override this.onExponentiate (lhs, rhs) = + override this.OnExponentiate (lhs, rhs) = let lhs, rhs = this.simplify (lhs, rhs) match lhs.Expression, rhs.Expression with | BigIntLiteral a, IntLiteral b -> BigIntLiteral (BigInteger.Pow(a, safeCastInt64 b)) @@ -408,31 +416,31 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio | IntLiteral a, IntLiteral b -> IntLiteral (longPow a b) | _ -> POW (lhs, rhs) - override this.onModulo (lhs, rhs) = + override this.OnModulo (lhs, rhs) = this.intBinaryOp MOD (%) (%) lhs rhs - override this.onLeftShift (lhs, rhs) = + override this.OnLeftShift (lhs, rhs) = this.intBinaryOp LSHIFT (fun l r -> l <<< safeCastBigInt r) (fun l r -> l <<< safeCastInt64 r) lhs rhs - override this.onRightShift (lhs, rhs) = + override this.OnRightShift (lhs, rhs) = this.intBinaryOp RSHIFT (fun l r -> l >>> safeCastBigInt r) (fun l r -> l >>> safeCastInt64 r) lhs rhs - override this.onBitwiseExclusiveOr (lhs, rhs) = + override this.OnBitwiseExclusiveOr (lhs, rhs) = this.intBinaryOp BXOR (^^^) (^^^) lhs rhs - override this.onBitwiseOr (lhs, rhs) = + override this.OnBitwiseOr (lhs, rhs) = this.intBinaryOp BOR (|||) (|||) lhs rhs - override this.onBitwiseAnd (lhs, rhs) = + override this.OnBitwiseAnd (lhs, rhs) = this.intBinaryOp BAND (&&&) (&&&) lhs rhs - override this.onLogicalNot expr = + override this.OnLogicalNot expr = let expr = this.simplify expr match expr.Expression with | BoolLiteral a -> BoolLiteral (not a) | _ -> NOT expr - override this.onNegative expr = + override this.OnNegative expr = let expr = this.simplify expr match expr.Expression with | BigIntLiteral a -> BigIntLiteral (-a) @@ -440,7 +448,7 @@ and [] private ExpressionKindEvaluator(callables: ImmutableDictio | IntLiteral a -> IntLiteral (-a) | _ -> NEG expr - override this.onBitwiseNot expr = + override this.OnBitwiseNot expr = let expr = this.simplify expr match expr.Expression with | IntLiteral a -> IntLiteral (~~~a) diff --git a/src/QsCompiler/Optimizations/Utils/HelperFunctions.fs b/src/QsCompiler/Optimizations/Utils/HelperFunctions.fs index 44d295f2e1..6d70ae100a 100644 --- a/src/QsCompiler/Optimizations/Utils/HelperFunctions.fs +++ b/src/QsCompiler/Optimizations/Utils/HelperFunctions.fs @@ -4,6 +4,7 @@ module internal Microsoft.Quantum.QsCompiler.Experimental.Utils open System +open System.Collections.Generic open System.Collections.Immutable open System.Numerics open Microsoft.Quantum.QsCompiler @@ -31,7 +32,7 @@ let internal check x = if x then Some () else None /// Returns whether a given expression is a literal (and thus a constant) -let rec internal isLiteral (callables: ImmutableDictionary) (expr: TypedExpression): bool = +let rec internal isLiteral (callables: IDictionary) (expr: TypedExpression): bool = let folder ex sub = match ex.Expression with | IntLiteral _ | BigIntLiteral _ | DoubleLiteral _ | BoolLiteral _ | ResultLiteral _ | PauliLiteral _ | StringLiteral _ @@ -50,19 +51,20 @@ let rec internal isLiteral (callables: ImmutableDictionary Map.add name value +let internal defineVar check (constants : IDictionary<_,_>) (name, value) = + if check value then constants.[name] <- value /// Applies the given function op on a SymbolTuple, ValueTuple pair -let rec private onTuple op constants (names, values) = +let rec private onTuple op constants (names, values) : unit = match names, values with | VariableName name, _ -> op constants (name.Value, values) | VariableNameTuple namesTuple, Tuple valuesTuple -> if namesTuple.Length <> valuesTuple.Length then ArgumentException "names and values have different lengths" |> raise - Seq.zip namesTuple valuesTuple |> Seq.fold (onTuple op) constants - | _ -> constants + for sym, value in Seq.zip namesTuple valuesTuple do + onTuple op constants (sym, value) + | _ -> () /// Returns a Constants with the given variables defined as the given values let internal defineVarTuple check = onTuple (defineVar check) diff --git a/src/QsCompiler/Optimizations/Utils/OptimizationTools.fs b/src/QsCompiler/Optimizations/Utils/OptimizationTools.fs index e1c175f293..90e087510f 100644 --- a/src/QsCompiler/Optimizations/Utils/OptimizationTools.fs +++ b/src/QsCompiler/Optimizations/Utils/OptimizationTools.fs @@ -4,192 +4,213 @@ module internal Microsoft.Quantum.QsCompiler.Experimental.OptimizationTools open System.Collections.Immutable +open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.Experimental.Utils open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.Core +open Microsoft.Quantum.QsCompiler.Transformations -/// A SyntaxTreeTransformation that finds identifiers in each implementation that represent distict values. -/// Should be called at the QsCallable level, not as the QsNamespace level, as it's meant to operate on a single callable. -type internal FindDistinctQubits() = - inherit SyntaxTreeTransformation() +/// A SyntaxTreeTransformation that finds identifiers in each implementation that represent distict qubit values. +/// Should be called at the specialization level, as it's meant to operate on a single implementation. +type internal FindDistinctQubits private (_private_) = + inherit Core.SyntaxTreeTransformation() - let mutable _distinctNames = Set.empty + member val DistinctNames : Set> = Set.empty with get, set - /// A set of identifier names that we expect to represent distinct values - member __.distinctNames = _distinctNames + internal new () as this = + new FindDistinctQubits("_private_") then + this.Namespaces <- new DistinctQubitsNamespaces(this) + this.StatementKinds <- new DistinctQubitsStatementKinds(this) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) - override __.onProvidedImplementation (argTuple, body) = +/// private helper class for FindDistinctQubits +and private DistinctQubitsNamespaces (parent : FindDistinctQubits) = + inherit Core.NamespaceTransformation(parent) + + override this.OnProvidedImplementation (argTuple, body) = argTuple |> toSymbolTuple |> flatten |> Seq.iter (function - | VariableName name -> _distinctNames <- _distinctNames.Add name + | VariableName name -> parent.DistinctNames <- parent.DistinctNames.Add name | _ -> ()) - base.onProvidedImplementation (argTuple, body) - - override __.Scope = { new ScopeTransformation() with - override this.StatementKind = { new StatementKindTransformation() with - override __.ScopeTransformation s = this.Transform s - override __.ExpressionTransformation ex = ex - override __.TypeTransformation t = t - override __.LocationTransformation l = l - - override __.onQubitScope stm = - stm.Binding.Lhs |> flatten |> Seq.iter (function - | VariableName name -> _distinctNames <- _distinctNames.Add name - | _ -> ()) - base.onQubitScope stm - } - } + base.OnProvidedImplementation (argTuple, body) +/// private helper class for FindDistinctQubits +and private DistinctQubitsStatementKinds (parent : FindDistinctQubits) = + inherit Core.StatementKindTransformation(parent) -/// A ScopeTransformation that tracks what variables the transformed code could mutate -type internal MutationChecker() = - inherit ScopeTransformation() + override this.OnQubitScope stm = + stm.Binding.Lhs |> flatten |> Seq.iter (function + | VariableName name -> parent.DistinctNames <- parent.DistinctNames.Add name + | _ -> ()) + base.OnQubitScope stm - let mutable declaredVars = Set.empty - let mutable mutatedVars = Set.empty - /// The set of variables that this code doesn't declare but does mutate - member __.externalMutations = mutatedVars - declaredVars +/// A transformation that tracks what variables the transformed code could mutate. +/// Should be called at the specialization level, as it's meant to operate on a single implementation. +type internal MutationChecker private (_private_) = + inherit Core.SyntaxTreeTransformation() - override this.StatementKind = { new StatementKindTransformation() with - override __.ScopeTransformation s = this.Transform s - override __.ExpressionTransformation ex = ex - override __.TypeTransformation t = t - override __.LocationTransformation l = l + member val DeclaredVariables : Set> = Set.empty with get, set + member val MutatedVariables : Set> = Set.empty with get, set - override __.onVariableDeclaration stm = - flatten stm.Lhs |> Seq.iter (function - | VariableName name -> declaredVars <- declaredVars.Add name - | _ -> ()) - base.onVariableDeclaration stm + /// Contains the set of variables that this code doesn't declare but does mutate. + member this.ExternalMutations = this.MutatedVariables - this.DeclaredVariables + + internal new () as this = + new MutationChecker("_private_") then + this.StatementKinds <- new MutationCheckerStatementKinds(this) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for MutationChecker +and private MutationCheckerStatementKinds(parent : MutationChecker) = + inherit Core.StatementKindTransformation(parent) + + override this.OnVariableDeclaration stm = + flatten stm.Lhs |> Seq.iter (function + | VariableName name -> parent.DeclaredVariables <- parent.DeclaredVariables.Add name + | _ -> ()) + base.OnVariableDeclaration stm - override __.onValueUpdate stm = - match stm.Lhs with - | LocalVarTuple v -> - flatten v |> Seq.iter (function - | VariableName name -> mutatedVars <- mutatedVars.Add name - | _ -> ()) - | _ -> () - base.onValueUpdate stm - } + override this.OnValueUpdate stm = + match stm.Lhs with + | LocalVarTuple v -> + flatten v |> Seq.iter (function + | VariableName name -> parent.MutatedVariables <- parent.MutatedVariables.Add name + | _ -> ()) + | _ -> () + base.OnValueUpdate stm -/// A ScopeTransformation that counts how many times each variable is referenced -type internal ReferenceCounter() = - inherit ScopeTransformation() +/// A transformation that counts how many times each local variable is referenced. +/// Should be called at the specialization level, as it's meant to operate on a single implementation. +type internal ReferenceCounter private (_private_) = + inherit Core.SyntaxTreeTransformation() - let mutable numUses = Map.empty + member val internal UsedVariables = Map.empty with get, set /// Returns the number of times the variable with the given name is referenced - member __.getNumUses name = numUses.TryFind name |? 0 + member this.NumberOfUses name = this.UsedVariables.TryFind name |? 0 - override this.Expression = { new ExpressionTransformation() with - override expr.Kind = { new ExpressionKindTransformation() with - override __.ExpressionTransformation ex = expr.Transform ex - override __.TypeTransformation t = t + internal new () as this = + new ReferenceCounter("_private_") then + this.ExpressionKinds <- new ReferenceCounterExpressionKinds(this) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) - override __.onIdentifier (sym, tArgs) = - match sym with - | LocalVariable name -> numUses <- numUses.Add (name, this.getNumUses name + 1) - | _ -> () - base.onIdentifier (sym, tArgs) - } - } +/// private helper class for ReferenceCounter +and private ReferenceCounterExpressionKinds(parent : ReferenceCounter) = + inherit Core.ExpressionKindTransformation(parent) + override this.OnIdentifier (sym, tArgs) = + match sym with + | LocalVariable name -> parent.UsedVariables <- parent.UsedVariables.Add (name, parent.NumberOfUses name + 1) + | _ -> () + base.OnIdentifier (sym, tArgs) -/// A ScopeTransformation that substitutes type parameters according to the given dictionary -type internal ReplaceTypeParams(typeParams: ImmutableDictionary<_, ResolvedType>) = - inherit ScopeTransformation() - override __.Expression = { new ExpressionTransformation() with - override __.Type = { new ExpressionTypeTransformation() with - override __.onTypeParameter tp = - let key = tp.Origin, tp.TypeName - match typeParams.TryGetValue key with - | true, t -> t.Resolution - | _ -> TypeKind.TypeParameter tp - } - } +/// private helper class for ReplaceTypeParams +type private ReplaceTypeParamsTypes(parent : Core.SyntaxTreeTransformation<_>) = + inherit Core.TypeTransformation, ResolvedType>>(parent) + override this.OnTypeParameter tp = + let key = tp.Origin, tp.TypeName + match this.SharedState.TryGetValue key with + | true, t -> t.Resolution + | _ -> TypeKind.TypeParameter tp -/// A ScopeTransformation that tracks what side effects the transformed code could cause -type internal SideEffectChecker() = - inherit ScopeTransformation() +/// A transformation that substitutes type parameters according to the given dictionary +/// Should be called at the specialization level, as it's meant to operate on a single implementation. +/// Does *not* update the type paremeter resolution dictionaries. +type internal ReplaceTypeParams private (typeParams: ImmutableDictionary<_, ResolvedType>, _private_) = + inherit Core.SyntaxTreeTransformation, ResolvedType>>(typeParams) - let mutable anyQuantum = false - let mutable anyMutation = false - let mutable anyInterrupts = false - let mutable anyOutput = false + internal new (typeParams: ImmutableDictionary<_, ResolvedType>) as this = + new ReplaceTypeParams(typeParams, "_private_") then + this.Types <- new ReplaceTypeParamsTypes(this) - /// Whether the transformed code might have any quantum side effects (such as calling operations) - member __.hasQuantum = anyQuantum - /// Whether the transformed code might change the value of any mutable variable - member __.hasMutation = anyMutation - /// Whether the transformed code has any statements that interrupt normal control flow (such as returns) - member __.hasInterrupts = anyInterrupts - /// Whether the transformed code might output any messages to the console - member __.hasOutput = anyOutput - override __.Expression = { new ExpressionTransformation() with - override expr.Kind = { new ExpressionKindTransformation() with - override __.ExpressionTransformation ex = expr.Transform ex - override __.TypeTransformation t = t +/// private helper class for SideEffectChecker +type private SideEffectCheckerExpressionKinds(parent : SideEffectChecker) = + inherit Core.ExpressionKindTransformation(parent) + + override this.OnFunctionCall (method, arg) = + parent.HasOutput <- true + base.OnFunctionCall (method, arg) + + override this.OnOperationCall (method, arg) = + parent.HasQuantum <- true + parent.HasOutput <- true + base.OnOperationCall (method, arg) - override __.onFunctionCall (method, arg) = - anyOutput <- true - base.onFunctionCall (method, arg) +/// private helper class for SideEffectChecker +and private SideEffectCheckerStatementKinds(parent : SideEffectChecker) = + inherit Core.StatementKindTransformation(parent) - override __.onOperationCall (method, arg) = - anyQuantum <- true - anyOutput <- true - base.onOperationCall (method, arg) - } - } + override this.OnValueUpdate stm = + let mutatesState = match stm.Lhs with LocalVarTuple x when isAllDiscarded x -> false | _ -> true + parent.HasMutation <- parent.HasMutation || mutatesState + base.OnValueUpdate stm - override this.StatementKind = { new StatementKindTransformation() with - override __.ScopeTransformation s = this.Transform s - override __.ExpressionTransformation ex = this.Expression.Transform ex - override __.TypeTransformation t = t - override __.LocationTransformation l = l + override this.OnReturnStatement stm = + parent.HasInterrupts <- true + base.OnReturnStatement stm - override __.onValueUpdate stm = - let mutatesState = match stm.Lhs with LocalVarTuple x when isAllDiscarded x -> false | _ -> true - anyMutation <- anyMutation || mutatesState - base.onValueUpdate stm + override this.OnFailStatement stm = + parent.HasInterrupts <- true + base.OnFailStatement stm - override __.onReturnStatement stm = - anyInterrupts <- true - base.onReturnStatement stm +/// A ScopeTransformation that tracks what side effects the transformed code could cause +and internal SideEffectChecker private (_private_) = + inherit Core.SyntaxTreeTransformation() + + /// Whether the transformed code might have any quantum side effects (such as calling operations) + member val HasQuantum = false with get, set + /// Whether the transformed code might change the value of any mutable variable + member val HasMutation = false with get, set + /// Whether the transformed code has any statements that interrupt normal control flow (such as returns) + member val HasInterrupts = false with get, set + /// Whether the transformed code might output any messages to the console + member val HasOutput = false with get, set - override __.onFailStatement stm = - anyInterrupts <- true - base.onFailStatement stm - } + internal new () as this = + new SideEffectChecker("_private_") then + this.ExpressionKinds <- new SideEffectCheckerExpressionKinds(this) + this.StatementKinds <- new SideEffectCheckerStatementKinds(this) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) /// A ScopeTransformation that replaces one statement with zero or more statements -type [] internal StatementCollectorTransformation() = - inherit ScopeTransformation() +type [] internal StatementCollectorTransformation(parent : Core.SyntaxTreeTransformation) = + inherit Core.StatementTransformation(parent) - abstract member TransformStatement: QsStatementKind -> QsStatementKind seq + abstract member CollectStatements: QsStatementKind -> QsStatementKind seq - override this.Transform scope = + override this.OnScope scope = let parentSymbols = scope.KnownSymbols let statements = scope.Statements - |> Seq.map this.onStatement + |> Seq.map this.OnStatement |> Seq.map (fun x -> x.Statement) - |> Seq.collect this.TransformStatement + |> Seq.collect this.CollectStatements |> Seq.map wrapStmt QsScope.New (statements, parentSymbols) /// A SyntaxTreeTransformation that removes all known symbols from anywhere in the AST -type internal StripAllKnownSymbols() = - inherit SyntaxTreeTransformation() +type internal StripAllKnownSymbols(_private_) = + inherit Core.SyntaxTreeTransformation() + + internal new () as this = + new StripAllKnownSymbols("_private_") then + this.Statements <- new StripAllKnownSymbolsStatements(this) + this.Expressions <- new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled) + this.Types <- new Core.TypeTransformation(this, Core.TransformationOptions.Disabled) + +/// private helper class for StripAllKnownSymbols +and private StripAllKnownSymbolsStatements(parent : StripAllKnownSymbols) = + inherit Core.StatementTransformation(parent) + + override this.OnScope scope = + QsScope.New (scope.Statements |> Seq.map this.OnStatement, LocalDeclarations.Empty) - override __.Scope = { new ScopeTransformation() with - override this.Transform scope = - QsScope.New (scope.Statements |> Seq.map this.onStatement, LocalDeclarations.Empty) - } diff --git a/src/QsCompiler/Optimizations/Utils/VariableRenaming.fs b/src/QsCompiler/Optimizations/Utils/VariableRenaming.fs index ff5197413a..9cb58ee935 100644 --- a/src/QsCompiler/Optimizations/Utils/VariableRenaming.fs +++ b/src/QsCompiler/Optimizations/Utils/VariableRenaming.fs @@ -17,21 +17,9 @@ open Microsoft.Quantum.QsCompiler.Transformations.Core /// When called on a function body, will transform it such that all local variables defined /// in the function body have unique names, generating new variable names if needed. /// Autogenerated variable names have the form __qsVar[X]__[name]__. -type VariableRenaming() = +type VariableRenaming private (_private_) = inherit SyntaxTreeTransformation() - /// Returns a copy of the given variable stack inside of a new scope - let enterScope map = Map.empty :: map - - /// Returns a copy of the given variable stack outside of the current scope. - /// Throws an ArgumentException if the given variable stack is empty. - let exitScope = List.tail - - /// Returns the value associated to the given key in the given variable stack. - /// If the key is associated with multiple values, returns the one highest on the stack. - /// Returns None if the key isn't associated with any values. - let tryGet key = List.tryPick (Map.tryFind key) - /// Returns a copy of the given variable stack with the given key set to the given value. /// Throws an ArgumentException if the given variable stack is empty. let set (key, value) = function @@ -41,92 +29,111 @@ type VariableRenaming() = /// A regex that matches the original name of a mangled variable name let varNameRegex = Regex("^__qsVar\d+__(.+)__$") + /// Given a possibly-mangled variable name, returns the original variable name + let demangle varName = + let m = varNameRegex.Match varName + if m.Success then m.Groups.[1].Value else varName /// The number of times a variable is referenced - let mutable newNamesSet = Set.empty + member val internal NewNamesSet = Set.empty with get, set /// The current dictionary of new names to substitute for variables - let mutable renamingStack = [Map.empty] + member val internal RenamingStack = [Map.empty] with get, set /// Whether we should skip entering the next scope we encounter - let mutable skipScope = false + member val internal SkipScope = false with get, set + + /// Returns a copy of the given variable stack inside of a new scope + member internal this.EnterScope map = Map.empty :: map + + /// Returns a copy of the given variable stack outside of the current scope. + /// Throws an ArgumentException if the given variable stack is empty. + member internal this.ExitScope = List.tail - /// Given a possibly-mangled variable name, returns the original variable name - let demangle varName = - let m = varNameRegex.Match varName - if m.Success then m.Groups.[1].Value else varName /// Given a new variable name, generates a new unique name and updates the state accordingly - let generateUniqueName varName = + member this.GenerateUniqueName varName = let baseName = demangle varName let mutable num, newName = 0, baseName - while newNamesSet.Contains newName do + while this.NewNamesSet.Contains newName do num <- num + 1 newName <- sprintf "__qsVar%d__%s__" num baseName - newNamesSet <- newNamesSet.Add newName - renamingStack <- set (varName, newName) renamingStack + this.NewNamesSet <- this.NewNamesSet.Add newName + this.RenamingStack <- set (varName, newName) this.RenamingStack newName + member this.Clear() = + this.NewNamesSet <- Set.empty + this.RenamingStack <- [Map.empty] + + new () as this = + new VariableRenaming("_private_") then + this.Namespaces <- new VariableRenamingNamespaces(this) + this.Statements <- new VariableRenamingStatements(this) + this.StatementKinds <- new VariableRenamingStatementKinds(this) + this.ExpressionKinds <- new VariableRenamingExpressionKinds(this) + this.Types <- new TypeTransformation(this, TransformationOptions.Disabled) + +/// private helper class for VariableRenaming +and private VariableRenamingNamespaces (parent : VariableRenaming) = + inherit NamespaceTransformation(parent) + /// Processes the initial argument tuple from the function declaration let rec processArgTuple = function - | QsTupleItem {VariableName = ValidName name} -> generateUniqueName name.Value |> ignore + | QsTupleItem {VariableName = ValidName name} -> parent.GenerateUniqueName name.Value |> ignore | QsTupleItem {VariableName = InvalidName} -> () | QsTuple items -> Seq.iter processArgTuple items - member __.clearStack() = - renamingStack <- [Map.empty] + override __.OnProvidedImplementation (argTuple, body) = + parent.Clear() + do processArgTuple argTuple + base.OnProvidedImplementation (argTuple, body) + +/// private helper class for VariableRenaming +and private VariableRenamingStatements (parent : VariableRenaming) = + inherit StatementTransformation(parent) + + override this.OnScope x = + if parent.SkipScope then + parent.SkipScope <- false + base.OnScope x + else + parent.RenamingStack <- parent.EnterScope parent.RenamingStack + let result = base.OnScope x + parent.RenamingStack <- parent.ExitScope parent.RenamingStack + result + +/// private helper class for VariableRenaming +and private VariableRenamingStatementKinds (parent : VariableRenaming) = + inherit StatementKindTransformation(parent) + + override this.OnSymbolTuple syms = + match syms with + | VariableName item -> VariableName (NonNullable<_>.New (parent.GenerateUniqueName item.Value)) + | VariableNameTuple items -> Seq.map this.OnSymbolTuple items |> ImmutableArray.CreateRange |> VariableNameTuple + | InvalidItem | DiscardedItem -> syms + + override this.OnRepeatStatement stm = + parent.RenamingStack <- parent.EnterScope parent.RenamingStack + parent.SkipScope <- true + let result = base.OnRepeatStatement stm + parent.RenamingStack <- parent.ExitScope parent.RenamingStack + result + +/// private helper class for VariableRenaming +and private VariableRenamingExpressionKinds (parent : VariableRenaming) = + inherit ExpressionKindTransformation(parent) + + /// Returns the value associated to the given key in the given variable stack. + /// If the key is associated with multiple values, returns the one highest on the stack. + /// Returns None if the key isn't associated with any values. + let tryGet key = List.tryPick (Map.tryFind key) + override this.OnIdentifier (sym, tArgs) = + maybe { + let! name = + match sym with + | LocalVariable name -> Some name.Value + | _ -> None + let! newName = tryGet name parent.RenamingStack + return Identifier (LocalVariable (NonNullable<_>.New newName), tArgs) + } |? Identifier (sym, tArgs) - override __.onProvidedImplementation (argTuple, body) = - newNamesSet <- Set.empty - renamingStack <- [Map.empty] - do processArgTuple argTuple - base.onProvidedImplementation (argTuple, body) - - override __.Scope = { new ScopeTransformation() with - - override __.Transform x = - if skipScope then - skipScope <- false - base.Transform x - else - renamingStack <- enterScope renamingStack - let result = base.Transform x - renamingStack <- exitScope renamingStack - result - - override __.Expression = { new ExpressionTransformation() with - override expr.Kind = { new ExpressionKindTransformation() with - override __.ExpressionTransformation ex = expr.Transform ex - override __.TypeTransformation t = t - - override __.onIdentifier (sym, tArgs) = - maybe { - let! name = - match sym with - | LocalVariable name -> Some name.Value - | _ -> None - let! newName = tryGet name renamingStack - return Identifier (LocalVariable (NonNullable<_>.New newName), tArgs) - } |? Identifier (sym, tArgs) - } - } - - override this.StatementKind = { new StatementKindTransformation() with - override __.ExpressionTransformation x = this.Expression.Transform x - override __.LocationTransformation x = x - override __.ScopeTransformation x = this.Transform x - override __.TypeTransformation x = x - - override this.onSymbolTuple syms = - match syms with - | VariableName item -> VariableName (NonNullable<_>.New (generateUniqueName item.Value)) - | VariableNameTuple items -> Seq.map this.onSymbolTuple items |> ImmutableArray.CreateRange |> VariableNameTuple - | InvalidItem | DiscardedItem -> syms - - override __.onRepeatStatement stm = - renamingStack <- enterScope renamingStack - skipScope <- true - let result = base.onRepeatStatement stm - renamingStack <- exitScope renamingStack - result - } - } diff --git a/src/QsCompiler/SyntaxProcessor/DeclarationVerification.fs b/src/QsCompiler/SyntaxProcessor/DeclarationVerification.fs index 3abaa9e83a..05045779bc 100644 --- a/src/QsCompiler/SyntaxProcessor/DeclarationVerification.fs +++ b/src/QsCompiler/SyntaxProcessor/DeclarationVerification.fs @@ -164,7 +164,7 @@ let rec private singleAdditionalArg mismatchErr (qsSym : QsSymbol) = | sym -> sym |> singleAndOmitted |> nameAndRange false -let private StripRangeInfo = StripPositionInfo.Default.onArgumentTuple +let private StripRangeInfo = StripPositionInfo.Default.Namespaces.OnArgumentTuple /// Given the declared argument tuple of a callable, and the declared symbol tuple for the corresponding body specialization, /// verifies that the symbol tuple indeed has the expected shape for that specialization. diff --git a/src/QsCompiler/SyntaxProcessor/ExpressionVerification.fs b/src/QsCompiler/SyntaxProcessor/ExpressionVerification.fs index 25ab87dea1..62410c1e63 100644 --- a/src/QsCompiler/SyntaxProcessor/ExpressionVerification.fs +++ b/src/QsCompiler/SyntaxProcessor/ExpressionVerification.fs @@ -24,12 +24,12 @@ open Microsoft.Quantum.QsCompiler.Transformations.QsCodeOutput // utils for verifying types in expressions type private StripInferredInfoFromType () = - inherit ExpressionTypeTransformation(true) - default this.onCallableInformation opInfo = - let characteristics = this.onCharacteristicsExpression opInfo.Characteristics + inherit TypeTransformationBase() + default this.OnCallableInformation opInfo = + let characteristics = this.OnCharacteristicsExpression opInfo.Characteristics CallableInformation.New (characteristics, InferredCallableInformation.NoInformation) - override this.onRangeInformation _ = QsRangeInfo.Null -let private StripInferredInfoFromType = (new StripInferredInfoFromType()).Transform + override this.OnRangeInformation _ = QsRangeInfo.Null +let private StripInferredInfoFromType = (new StripInferredInfoFromType()).OnType /// used for type matching arguments in call-like expressions type private Variance = @@ -50,7 +50,7 @@ let private missingFunctors (target : ImmutableHashSet<_>, given) = /// Return the string representation for a ResolveType. /// User defined types are represented by their full name. -let internal toString = (new ExpressionTypeToQs(new ExpressionToQs())).Apply +let internal toString (t : ResolvedType) = SyntaxTreeToQsharp.Default.ToCode t /// Given two resolve types, determines and returns a common base type if such a type exists, /// or pushes adds a suitable error using addError and returns invalid type if a common base type does not exist. diff --git a/src/QsCompiler/SyntaxProcessor/StatementVerification.fs b/src/QsCompiler/SyntaxProcessor/StatementVerification.fs index b9ad2b37e9..98b7ef89cb 100644 --- a/src/QsCompiler/SyntaxProcessor/StatementVerification.fs +++ b/src/QsCompiler/SyntaxProcessor/StatementVerification.fs @@ -267,12 +267,12 @@ let NewConjugation (outer : QsPositionedBlock, inner : QsPositionedBlock) = | Value loc -> loc let usedInOuter = let accumulate = new AccumulateIdentifiers() - accumulate.Transform outer.Body |> ignore - accumulate.UsedLocalVariables + accumulate.Statements.OnScope outer.Body |> ignore + accumulate.SharedState.UsedLocalVariables let updatedInInner = let accumulate = new AccumulateIdentifiers() - accumulate.Transform inner.Body |> ignore - accumulate.ReassignedVariables + accumulate.Statements.OnScope inner.Body |> ignore + accumulate.SharedState.ReassignedVariables let updateErrs = updatedInInner |> Seq.filter (fun updated -> usedInOuter.Contains updated.Key) |> Seq.collect id |> Seq.map (fun loc -> (loc.Offset, loc.Range |> QsCompilerDiagnostic.Error (ErrorCode.InvalidReassignmentInApplyBlock, []))) |> Seq.toArray diff --git a/src/QsCompiler/SyntaxProcessor/SyntaxExtensions.fs b/src/QsCompiler/SyntaxProcessor/SyntaxExtensions.fs index 0bea65010e..1ae3433ed8 100644 --- a/src/QsCompiler/SyntaxProcessor/SyntaxExtensions.fs +++ b/src/QsCompiler/SyntaxProcessor/SyntaxExtensions.fs @@ -260,21 +260,24 @@ let private namespaceDocumentation (docs : ILookup, Immutabl PrintSummary allDoc markdown type private TName () = - inherit ExpressionTypeToQs(new ExpressionToQs()) - override this.onCharacteristicsExpression characteristics = + inherit SyntaxTreeToQsharp.TypeTransformation() + override this.OnCharacteristicsExpression characteristics = if characteristics.AreInvalid then this.Output <- "?"; characteristics - else base.onCharacteristicsExpression characteristics - override this.onInvalidType() = + else base.OnCharacteristicsExpression characteristics + override this.OnInvalidType() = this.Output <- "?" InvalidType - override this.onUserDefinedType udt = + override this.OnUserDefinedType udt = this.Output <- udt.Name.Value UserDefinedType udt + member this.Apply t = + this.OnType t |> ignore + this.Output let private TypeString = new TName() let private TypeName = TypeString.Apply let private CharacteristicsAnnotation (ex, format) = - ex |> TypeString.onCharacteristicsExpression |> ignore - if String.IsNullOrWhiteSpace TypeString.Output then "" else sprintf "is %s" TypeString.Output |> format + let charEx = SyntaxTreeToQsharp.CharacteristicsExpression ex + if String.IsNullOrWhiteSpace charEx then "" else sprintf "is %s" charEx |> format [] let public TypeInfo (symbolTable : NamespaceManager) (currentNS, source) (qsType : QsType) markdown = @@ -335,14 +338,14 @@ let private printCallableKind capitalize = function [] let public PrintArgumentTuple item = - SyntaxTreeToQs.ArgumentTuple (item, new Func<_,_>(TypeName)) // note: needs to match the corresponding part of the output constructed by PrintSignature below! + SyntaxTreeToQsharp.ArgumentTuple (item, new Func<_,_>(TypeName)) // note: needs to match the corresponding part of the output constructed by PrintSignature below! [] let public PrintSignature (header : CallableDeclarationHeader) = let callable = QsCallable.New header.Kind (header.SourceFile, Null) (header.QualifiedName, header.Attributes, header.ArgumentTuple, header.Signature, ImmutableArray.Empty, ImmutableArray.Empty, QsComments.Empty); - let signature = SyntaxTreeToQs.DeclarationSignature (callable, new Func<_,_>(TypeName)) + let signature = SyntaxTreeToQsharp.DeclarationSignature (callable, new Func<_,_>(TypeName)) let annotation = CharacteristicsAnnotation (header.Signature.Information.Characteristics, sprintf "%s%s" newLine) sprintf "%s%s" signature annotation @@ -390,8 +393,8 @@ let public DeclarationInfo symbolTable (locals : LocalDeclarations) (currentNS, match qsSym |> globalCallableResolution symbolTable (currentNS, source) with | Some decl, _ -> let functorSupport characteristics = - TypeString.onCharacteristicsExpression characteristics |> ignore - if String.IsNullOrWhiteSpace TypeString.Output then "(None)" else TypeString.Output + let charEx = SyntaxTreeToQsharp.CharacteristicsExpression characteristics + if String.IsNullOrWhiteSpace charEx then "(None)" else charEx let name = sprintf "%s %s" (printCallableKind false decl.Kind) decl.QualifiedName.Name.Value |> withNewLine let ns = sprintf "Namespace: %s" decl.QualifiedName.Namespace.Value |> withNewLine let input = sprintf "Input type: %s" (decl.Signature.ArgumentType |> TypeName) |> withNewLine diff --git a/src/QsCompiler/SyntaxProcessor/TreeVerification.fs b/src/QsCompiler/SyntaxProcessor/TreeVerification.fs index ce81a02b12..e12df1aaa8 100644 --- a/src/QsCompiler/SyntaxProcessor/TreeVerification.fs +++ b/src/QsCompiler/SyntaxProcessor/TreeVerification.fs @@ -73,7 +73,8 @@ let AllPathsReturnValueOrFail body = | QsStatementKind.QsExpressionStatement _ | QsStatementKind.QsFailStatement _ | QsStatementKind.QsValueUpdate _ - | QsStatementKind.QsVariableDeclaration _ -> () + | QsStatementKind.QsVariableDeclaration _ + | QsStatementKind.EmptyStatement -> () // returns true if all paths in the given scope contain a terminating (i.e. return or fail) statement let rec checkTermination (scope : QsScope) = @@ -98,7 +99,8 @@ let AllPathsReturnValueOrFail body = | QsStatementKind.QsExpressionStatement _ | QsStatementKind.QsFailStatement _ | QsStatementKind.QsValueUpdate _ - | QsStatementKind.QsVariableDeclaration _ -> true + | QsStatementKind.QsVariableDeclaration _ + | QsStatementKind.EmptyStatement -> true let returnOrFailAndAfter = Seq.toList <| scope.Statements.SkipWhile isNonTerminatingStatement if returnOrFailAndAfter.Length <> 0 then diff --git a/src/QsCompiler/TestTargets/Libraries/Library1/Library1.csproj b/src/QsCompiler/TestTargets/Libraries/Library1/Library1.csproj index 40476e0288..4b95b74dce 100644 --- a/src/QsCompiler/TestTargets/Libraries/Library1/Library1.csproj +++ b/src/QsCompiler/TestTargets/Libraries/Library1/Library1.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/QsCompiler/TestTargets/Simulation/Example/Example.csproj b/src/QsCompiler/TestTargets/Simulation/Example/Example.csproj index b7d89acc9c..79ff1984b2 100644 --- a/src/QsCompiler/TestTargets/Simulation/Example/Example.csproj +++ b/src/QsCompiler/TestTargets/Simulation/Example/Example.csproj @@ -1,4 +1,4 @@ - + Detailed @@ -6,6 +6,7 @@ netcoreapp3.0 false false + 0219 diff --git a/src/QsCompiler/TestTargets/Simulation/Target/Simulation.csproj b/src/QsCompiler/TestTargets/Simulation/Target/Simulation.csproj index daee5a684b..65781cd08e 100644 --- a/src/QsCompiler/TestTargets/Simulation/Target/Simulation.csproj +++ b/src/QsCompiler/TestTargets/Simulation/Target/Simulation.csproj @@ -12,7 +12,7 @@ - + diff --git a/src/QsCompiler/Tests.Compiler/ClassicalControlTests.fs b/src/QsCompiler/Tests.Compiler/ClassicalControlTests.fs index a8ebac129f..89be97a4bc 100644 --- a/src/QsCompiler/Tests.Compiler/ClassicalControlTests.fs +++ b/src/QsCompiler/Tests.Compiler/ClassicalControlTests.fs @@ -9,7 +9,7 @@ open Microsoft.Quantum.QsCompiler open Microsoft.Quantum.QsCompiler.CompilationBuilder open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.ClassicallyControlledTransformation +open Microsoft.Quantum.QsCompiler.Transformations.ClassicallyControlled open Xunit open Microsoft.Quantum.QsCompiler.Transformations.QsCodeOutput open System.Text.RegularExpressions @@ -50,7 +50,7 @@ type ClassicalControlTests () = srcChunks.Length >= testNumber + 1 |> Assert.True let shared = srcChunks.[0] let compilationDataStructures = BuildContent <| shared + srcChunks.[testNumber] - let processedCompilation = ClassicallyControlledTransformation.Apply compilationDataStructures.BuiltCompilation + let processedCompilation = ReplaceClassicalControl.Apply compilationDataStructures.BuiltCompilation Assert.NotNull processedCompilation Signatures.SignatureCheck [Signatures.ClassicalControlNs] Signatures.ClassicalControlSignatures.[testNumber-1] processedCompilation processedCompilation @@ -61,16 +61,17 @@ type ClassicalControlTests () = let GetCtlAdjFromCallable call = call.Specializations |> Seq.find (fun x -> x.Kind = QsSpecializationKind.QsControlledAdjoint) let GetLinesFromSpecialization specialization = - let writer = new SyntaxTreeToQs() + let writer = new SyntaxTreeToQsharp() specialization |> fun x -> match x.Implementation with | Provided (_, body) -> Some body | _ -> None |> Option.get - |> writer.Scope.Transform + |> writer.Statements.OnScope |> ignore - (writer.Scope :?> ScopeToQs).Output.Split(Environment.NewLine) - |> Array.filter (fun str -> str <> String.Empty) + writer.SharedState.StatementOutputHandle + |> Seq.filter (not << String.IsNullOrWhiteSpace) + |> Seq.toArray let CheckIfLineIsCall ``namespace`` name input = let call = sprintf @"(%s\.)?%s" <| Regex.Escape ``namespace`` <| Regex.Escape name diff --git a/src/QsCompiler/Tests.Compiler/LinkingTests.fs b/src/QsCompiler/Tests.Compiler/LinkingTests.fs index 7956a74edf..e97dd4f3ce 100644 --- a/src/QsCompiler/Tests.Compiler/LinkingTests.fs +++ b/src/QsCompiler/Tests.Compiler/LinkingTests.fs @@ -12,9 +12,9 @@ open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.Diagnostics open Microsoft.Quantum.QsCompiler.SyntaxExtensions open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations.IntrinsicResolutionTransformation +open Microsoft.Quantum.QsCompiler.Transformations.IntrinsicResolution open Microsoft.Quantum.QsCompiler.Transformations.Monomorphization -open Microsoft.Quantum.QsCompiler.Transformations.MonomorphizationValidation +open Microsoft.Quantum.QsCompiler.Transformations.Monomorphization.Validation open Xunit open Xunit.Abstractions @@ -71,10 +71,10 @@ type LinkingTests (output:ITestOutputHelper) = let compilationDataStructures = this.BuildContent input - let monomorphicCompilation = MonomorphizationTransformation.Apply compilationDataStructures.BuiltCompilation + let monomorphicCompilation = Monomorphize.Apply compilationDataStructures.BuiltCompilation Assert.NotNull monomorphicCompilation - MonomorphizationValidationTransformation.Apply monomorphicCompilation + ValidateMonomorphization.Apply monomorphicCompilation monomorphicCompilation @@ -83,7 +83,7 @@ type LinkingTests (output:ITestOutputHelper) = let envDS = this.BuildContent environment let sourceDS = this.BuildContent source - IntrinsicResolutionTransformation.Apply(envDS.BuiltCompilation, sourceDS.BuiltCompilation) + ReplaceWithTargetIntrinsics.Apply(envDS.BuiltCompilation, sourceDS.BuiltCompilation) member private this.RunIntrinsicResolutionTest testNumber = diff --git a/src/QsCompiler/Tests.Compiler/OptimizationTests.fs b/src/QsCompiler/Tests.Compiler/OptimizationTests.fs index fc2112e44a..f1b296247e 100644 --- a/src/QsCompiler/Tests.Compiler/OptimizationTests.fs +++ b/src/QsCompiler/Tests.Compiler/OptimizationTests.fs @@ -26,9 +26,7 @@ let private buildCompilation code = let private optimize code = let mutable compilation = buildCompilation code compilation <- PreEvaluation.All compilation - let toQs = SyntaxTreeToQs() - compilation.Namespaces |> Seq.iter (toQs.Transform >> ignore) - toQs.Output + String.Join(Environment.NewLine, compilation.Namespaces |> Seq.map SyntaxTreeToQsharp.Default.ToCode) /// Helper function that saves the compiler output as a test case (in the bin directory) let private createTestCase path = diff --git a/src/QsCompiler/Tests.Compiler/RegexTests.fs b/src/QsCompiler/Tests.Compiler/RegexTests.fs index e24318b359..afbbd53c62 100644 --- a/src/QsCompiler/Tests.Compiler/RegexTests.fs +++ b/src/QsCompiler/Tests.Compiler/RegexTests.fs @@ -83,23 +83,23 @@ let ``Strip unique variable name resolution`` () = |> List.map NonNullable.New origNames - |> List.map (fun var -> var, NameResolution.StripUniqueName var) + |> List.map (fun var -> var, UniqueVariableNames.StripUniqueName var) |> List.iter Assert.Equal origNames - |> List.map NameResolution.GenerateUniqueName - |> List.map (fun unique -> unique, NameResolution.GenerateUniqueName unique) - |> List.map (fun (unique, twiceWrapped) -> unique, NameResolution.StripUniqueName twiceWrapped) + |> List.map NameResolution.SharedState.GenerateUniqueName + |> List.map (fun unique -> unique, NameResolution.SharedState.GenerateUniqueName unique) + |> List.map (fun (unique, twiceWrapped) -> unique, UniqueVariableNames.StripUniqueName twiceWrapped) |> List.iter Assert.Equal origNames - |> List.map (fun var -> var, NameResolution.GenerateUniqueName var) - |> List.map (fun (var, unique) -> var, NameResolution.GenerateUniqueName unique) - |> List.map (fun (var, twiceWrapped) -> var, NameResolution.StripUniqueName twiceWrapped) - |> List.map (fun (var, unique) -> var, NameResolution.StripUniqueName unique) + |> List.map (fun var -> var, NameResolution.SharedState.GenerateUniqueName var) + |> List.map (fun (var, unique) -> var, NameResolution.SharedState.GenerateUniqueName unique) + |> List.map (fun (var, twiceWrapped) -> var, UniqueVariableNames.StripUniqueName twiceWrapped) + |> List.map (fun (var, unique) -> var, UniqueVariableNames.StripUniqueName unique) |> List.iter Assert.Equal origNames - |> List.map (fun var -> var, NameResolution.GenerateUniqueName var) - |> List.map (fun (var, unique) -> var, NameResolution.StripUniqueName unique) + |> List.map (fun var -> var, NameResolution.SharedState.GenerateUniqueName var) + |> List.map (fun (var, unique) -> var, UniqueVariableNames.StripUniqueName unique) |> List.iter Assert.Equal diff --git a/src/QsCompiler/Tests.Compiler/TestCases/ExecutionTests/LoggingBasedTests.qs b/src/QsCompiler/Tests.Compiler/TestCases/ExecutionTests/LoggingBasedTests.qs index 1fd70b314e..1517ebf633 100644 --- a/src/QsCompiler/Tests.Compiler/TestCases/ExecutionTests/LoggingBasedTests.qs +++ b/src/QsCompiler/Tests.Compiler/TestCases/ExecutionTests/LoggingBasedTests.qs @@ -33,10 +33,12 @@ namespace Microsoft.Quantum.Testing.ExecutionTests { ULog("V1"); within { + let dummy = 0; ULog("U3"); ULog("V3"); } apply { + let dummy = 0; ULog("Core3"); } } diff --git a/src/QsCompiler/Tests.Compiler/TestCases/OptimizerTests/Miscellaneous_output.txt b/src/QsCompiler/Tests.Compiler/TestCases/OptimizerTests/Miscellaneous_output.txt index e9158b6892..262bdb264f 100644 --- a/src/QsCompiler/Tests.Compiler/TestCases/OptimizerTests/Miscellaneous_output.txt +++ b/src/QsCompiler/Tests.Compiler/TestCases/OptimizerTests/Miscellaneous_output.txt @@ -30,7 +30,6 @@ namespace Microsoft.Quantum.Testing { operation op1 (q1 : Qubit, r1 : BigEndianRegister) : Unit is Ctl + Adj { - adjoint self; body (...) { Adjoint CNOT(q1, (r1!)[0]); @@ -42,6 +41,8 @@ namespace Microsoft.Quantum.Testing { } } + adjoint self; + ///automatically generated QsControlled specialization for Microsoft.Quantum.Testing.op1 controlled (__controlQubits__, ...) { Controlled (Adjoint CNOT)(__controlQubits__, (q1, (r1!)[0])); diff --git a/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs b/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs index 7a183d5d5f..e7fd7462ac 100644 --- a/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs +++ b/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs @@ -86,7 +86,7 @@ let public SignatureCheck checkedNamespaces targetSignatures compilation = let makeArgsString (args : ResolvedType) = match args.Resolution with | QsTypeKind.UnitType -> "()" - | _ -> args |> (ExpressionToQs () |> ExpressionTypeToQs).Apply + | _ -> args |> SyntaxTreeToQsharp.Default.ToCode let removeAt i lst = Seq.append diff --git a/src/QsCompiler/Tests.Compiler/TransformationTests.fs b/src/QsCompiler/Tests.Compiler/TransformationTests.fs index ae304aa929..2a376638b5 100644 --- a/src/QsCompiler/Tests.Compiler/TransformationTests.fs +++ b/src/QsCompiler/Tests.Compiler/TransformationTests.fs @@ -11,7 +11,7 @@ open Microsoft.Quantum.QsCompiler open Microsoft.Quantum.QsCompiler.CompilationBuilder open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.SyntaxTree -open Microsoft.Quantum.QsCompiler.Transformations +open Microsoft.Quantum.QsCompiler.Transformations.Core open Microsoft.Quantum.QsCompiler.Transformations.QsCodeOutput open Xunit @@ -26,45 +26,49 @@ type private Counter () = member val forCount = 0 with get, set member val ifsCount = 0 with get, set -type private StatementKindCounter(stm, counter : Counter) = - inherit StatementKindWalker(stm) - override this.onConditionalStatement (node:QsConditionalStatement) = - counter.ifsCount <- counter.ifsCount + 1 - base.onConditionalStatement node - - override this.onForStatement (node:QsForStatement) = - counter.forCount <- counter.forCount + 1 - base.onForStatement node - -and private StatementCounter(counter) = - inherit ScopeWalker - (Func<_,_>(fun s -> new StatementKindCounter(s :?> StatementCounter, counter)), new ExpressionCounter(counter)) - -and private ExpressionKindCounter(ex, counter : Counter) = - inherit ExpressionKindWalker(ex) - - override this.beforeCallLike (op,args) = - counter.callsCount <- counter.callsCount + 1 - base.beforeCallLike (op, args) +type private SyntaxCounter private(counter : Counter, ?options) = + inherit SyntaxTreeTransformation(defaultArg options TransformationOptions.Default) -and private ExpressionCounter(counter) = - inherit ExpressionWalker - (new Func<_,_>(fun e -> new ExpressionKindCounter(e :?> ExpressionCounter, counter))) + member this.Counter = counter + new (?options) as this = + new SyntaxCounter(new Counter()) then + this.Namespaces <- new SyntaxCounterNamespaces(this) + this.StatementKinds <- new SyntaxCounterStatementKinds(this) + this.ExpressionKinds <- new SyntaxCounterExpressionKinds(this) -type private SyntaxCounter(counter) = - inherit SyntaxTreeWalker(new StatementCounter(counter)) +and private SyntaxCounterNamespaces(parent : SyntaxCounter) = + inherit NamespaceTransformation(parent) - override this.beforeCallable (node:QsCallable) = + override this.OnCallableDeclaration (node:QsCallable) = match node.Kind with - | Operation -> counter.opsCount <- counter.opsCount + 1 - | Function -> counter.funCount <- counter.funCount + 1 + | Operation -> parent.Counter.opsCount <- parent.Counter.opsCount + 1 + | Function -> parent.Counter.funCount <- parent.Counter.funCount + 1 | TypeConstructor -> () + base.OnCallableDeclaration node + + override this.OnTypeDeclaration (udt:QsCustomType) = + parent.Counter.udtCount <- parent.Counter.udtCount + 1 + base.OnTypeDeclaration udt + +and private SyntaxCounterStatementKinds(parent : SyntaxCounter) = + inherit StatementKindTransformation(parent) + + override this.OnConditionalStatement (node:QsConditionalStatement) = + parent.Counter.ifsCount <- parent.Counter.ifsCount + 1 + base.OnConditionalStatement node - override this.onType (udt:QsCustomType) = - counter.udtCount <- counter.udtCount + 1 - base.onType udt + override this.OnForStatement (node:QsForStatement) = + parent.Counter.forCount <- parent.Counter.forCount + 1 + base.OnForStatement node + +and private SyntaxCounterExpressionKinds(parent : SyntaxCounter) = + inherit ExpressionKindTransformation(parent) + + override this.OnCallLikeExpression (op,args) = + parent.Counter.callsCount <- parent.Counter.callsCount + 1 + base.OnCallLikeExpression (op, args) let private buildSyntaxTree code = @@ -82,16 +86,28 @@ let private buildSyntaxTree code = [] let ``basic walk`` () = let tree = Path.Combine(Path.GetFullPath ".", "TestCases", "Transformation.qs") |> File.ReadAllText |> buildSyntaxTree - let counter = new Counter() - tree |> Seq.iter (SyntaxCounter(counter)).Walk + let walker = new SyntaxCounter(TransformationOptions.NoRebuild) + tree |> Seq.iter (walker.Namespaces.OnNamespace >> ignore) - Assert.Equal (4, counter.udtCount) - Assert.Equal (1, counter.funCount) - Assert.Equal (5, counter.opsCount) - Assert.Equal (7, counter.forCount) - Assert.Equal (6, counter.ifsCount) - Assert.Equal (20, counter.callsCount) + Assert.Equal (4, walker.Counter.udtCount) + Assert.Equal (1, walker.Counter.funCount) + Assert.Equal (5, walker.Counter.opsCount) + Assert.Equal (7, walker.Counter.forCount) + Assert.Equal (6, walker.Counter.ifsCount) + Assert.Equal (20, walker.Counter.callsCount) +[] +let ``basic transformation`` () = + let tree = Path.Combine(Path.GetFullPath ".", "TestCases", "Transformation.qs") |> File.ReadAllText |> buildSyntaxTree + let walker = new SyntaxCounter() + tree |> Seq.iter (walker.Namespaces.OnNamespace >> ignore) + + Assert.Equal (4, walker.Counter.udtCount) + Assert.Equal (1, walker.Counter.funCount) + Assert.Equal (5, walker.Counter.opsCount) + Assert.Equal (7, walker.Counter.forCount) + Assert.Equal (6, walker.Counter.ifsCount) + Assert.Equal (20, walker.Counter.callsCount) [] let ``generation of open statements`` () = @@ -117,7 +133,7 @@ let ``generation of open statements`` () = let imports = ImmutableDictionary.Empty.Add(ns.Name, openDirectives) let codeOutput = ref null - SyntaxTreeToQs.Apply (codeOutput, tree, struct (source, imports)) |> Assert.True + SyntaxTreeToQsharp.Apply (codeOutput, tree, struct (source, imports)) |> Assert.True let lines = Utils.SplitLines (codeOutput.Value.Single().[ns.Name]) Assert.Equal(13, lines.Count()) diff --git a/src/QsCompiler/Transformations/BasicTransformations.cs b/src/QsCompiler/Transformations/BasicTransformations.cs index 60add02b96..e8195ff9ce 100644 --- a/src/QsCompiler/Transformations/BasicTransformations.cs +++ b/src/QsCompiler/Transformations/BasicTransformations.cs @@ -7,258 +7,309 @@ using System.Linq; using Microsoft.Quantum.QsCompiler.DataTypes; using Microsoft.Quantum.QsCompiler.SyntaxTree; +using Microsoft.Quantum.QsCompiler.Transformations.Core; namespace Microsoft.Quantum.QsCompiler.Transformations.BasicTransformations { - // syntax tree transformations - - public class GetSourceFiles : - SyntaxTreeTransformation + public class GetSourceFiles + : SyntaxTreeTransformation { + public class TransformationState + { + internal readonly HashSet> SourceFiles = + new HashSet>(); + } + + + private GetSourceFiles() + : base(new TransformationState(), TransformationOptions.NoRebuild) + { + this.Namespaces = new NamespaceTransformation(this); + this.Statements = new StatementTransformation(this, TransformationOptions.Disabled); + this.Expressions = new ExpressionTransformation(this, TransformationOptions.Disabled); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); + } + + // static methods for convenience + /// /// Returns a hash set containing all source files in the given namespaces. - /// Throws an ArgumentNullException if the given sequence or any of the given namespaces is null. + /// Throws an ArgumentNullException if the given sequence or any of the given namespaces is null. /// public static ImmutableHashSet> Apply(IEnumerable namespaces) { if (namespaces == null || namespaces.Contains(null)) throw new ArgumentNullException(nameof(namespaces)); var filter = new GetSourceFiles(); - foreach(var ns in namespaces) filter.Transform(ns); - return filter.SourceFiles.ToImmutableHashSet(); + foreach (var ns in namespaces) filter.Namespaces.OnNamespace(ns); + return filter.SharedState.SourceFiles.ToImmutableHashSet(); } /// /// Returns a hash set containing all source files in the given namespace(s). - /// Throws an ArgumentNullException if any of the given namespaces is null. + /// Throws an ArgumentNullException if any of the given namespaces is null. /// - public static ImmutableHashSet> Apply(params QsNamespace[] namespaces) => + public static ImmutableHashSet> Apply(params QsNamespace[] namespaces) => Apply((IEnumerable)namespaces); - private readonly HashSet> SourceFiles; - private GetSourceFiles() : - base(new NoScopeTransformations()) => - this.SourceFiles = new HashSet>(); - public override QsSpecialization onSpecializationImplementation(QsSpecialization spec) // short cut to avoid further evaluation - { - this.onSourceFile(spec.SourceFile); - return spec; - } + // helper classes - public override NonNullable onSourceFile(NonNullable f) + private class NamespaceTransformation + : NamespaceTransformation { - this.SourceFiles.Add(f); - return base.onSourceFile(f); + + public NamespaceTransformation(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } + + public override QsSpecialization OnSpecializationDeclaration(QsSpecialization spec) // short cut to avoid further evaluation + { + this.OnSourceFile(spec.SourceFile); + return spec; + } + + public override NonNullable OnSourceFile(NonNullable f) + { + this.SharedState.SourceFiles.Add(f); + return base.OnSourceFile(f); + } } } + /// /// Calling Transform on a syntax tree returns a new tree that only contains the type and callable declarations - /// that are defined in the source file with the identifier given upon initialization. - /// The transformation also ensures that the elements in each namespace are ordered according to + /// that are defined in the source file with the identifier given upon initialization. + /// The transformation also ensures that the elements in each namespace are ordered according to /// the location at which they are defined in the file. Auto-generated declarations will be ordered alphabetically. /// - public class FilterBySourceFile : - SyntaxTreeTransformation + public class FilterBySourceFile + : SyntaxTreeTransformation { + public class TransformationState + { + internal readonly Func, bool> Predicate; + internal readonly List<(int?, QsNamespaceElement)> Elements = + new List<(int?, QsNamespaceElement)>(); + + public TransformationState(Func, bool> predicate) => + this.Predicate = predicate ?? throw new ArgumentNullException(nameof(predicate)); + } + + + public FilterBySourceFile(Func, bool> predicate) + : base(new TransformationState(predicate)) + { + this.Namespaces = new NamespaceTransformation(this); + this.Statements = new StatementTransformation(this, TransformationOptions.Disabled); + this.Expressions = new ExpressionTransformation(this, TransformationOptions.Disabled); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); + } + + // static methods for convenience + public static QsNamespace Apply(QsNamespace ns, Func, bool> predicate) { if (ns == null) throw new ArgumentNullException(nameof(ns)); var filter = new FilterBySourceFile(predicate); - return filter.Transform(ns); + return filter.Namespaces.OnNamespace(ns); } public static QsNamespace Apply(QsNamespace ns, params NonNullable[] fileIds) { var sourcesToKeep = fileIds.Select(f => f.Value).ToImmutableHashSet(); - return FilterBySourceFile.Apply(ns, s => sourcesToKeep.Contains(s.Value)); + return Apply(ns, s => sourcesToKeep.Contains(s.Value)); } - private readonly List<(int?, QsNamespaceElement)> Elements; - private readonly Func, bool> Predicate; - public FilterBySourceFile(Func, bool> predicate) : - base(new NoScopeTransformations()) - { - this.Predicate = predicate ?? throw new ArgumentNullException(nameof(predicate)); - this.Elements = new List<(int?, QsNamespaceElement)>(); - } + // helper classes - private QsCallable AddCallableIfInSource(QsCallable c) + public class NamespaceTransformation + : NamespaceTransformation { - if (Predicate(c.SourceFile)) - { Elements.Add((c.Location.IsValue ? c.Location.Item.Offset.Item1 : (int?)null, QsNamespaceElement.NewQsCallable(c))); } - return c; - } + public NamespaceTransformation(SyntaxTreeTransformation parent) + : base(parent) { } - private QsCustomType AddTypeIfInSource(QsCustomType t) - { - if (Predicate(t.SourceFile)) - { Elements.Add((t.Location.IsValue ? t.Location.Item.Offset.Item1 : (int?)null, QsNamespaceElement.NewQsCustomType(t))); } - return t; - } + // TODO: these overrides needs to be adapted once we support external specializations - // TODO: these transformations needs to be adapted once we support external specializations - public override QsCustomType onType(QsCustomType t) => AddTypeIfInSource(t); - public override QsCallable onCallableImplementation(QsCallable c) => AddCallableIfInSource(c); + public override QsCustomType OnTypeDeclaration(QsCustomType t) + { + if (this.SharedState.Predicate(t.SourceFile)) + { this.SharedState.Elements.Add((t.Location.IsValue ? t.Location.Item.Offset.Item1 : (int?)null, QsNamespaceElement.NewQsCustomType(t))); } + return t; + } - public override QsNamespace Transform(QsNamespace ns) - { - static int SortComparison((int?, QsNamespaceElement) x, (int?, QsNamespaceElement) y) + public override QsCallable OnCallableDeclaration(QsCallable c) { - if (x.Item1.HasValue && y.Item1.HasValue) return Comparer.Default.Compare(x.Item1.Value, y.Item1.Value); - if (!x.Item1.HasValue && !y.Item1.HasValue) return Comparer.Default.Compare(x.Item2.GetFullName().ToString(), y.Item2.GetFullName().ToString()); - return x.Item1.HasValue ? -1 : 1; + if (this.SharedState.Predicate(c.SourceFile)) + { this.SharedState.Elements.Add((c.Location.IsValue ? c.Location.Item.Offset.Item1 : (int?)null, QsNamespaceElement.NewQsCallable(c))); } + return c; + } + + public override QsNamespace OnNamespace(QsNamespace ns) + { + static int SortComparison((int?, QsNamespaceElement) x, (int?, QsNamespaceElement) y) + { + if (x.Item1.HasValue && y.Item1.HasValue) return Comparer.Default.Compare(x.Item1.Value, y.Item1.Value); + if (!x.Item1.HasValue && !y.Item1.HasValue) return Comparer.Default.Compare(x.Item2.GetFullName().ToString(), y.Item2.GetFullName().ToString()); + return x.Item1.HasValue ? -1 : 1; + } + this.SharedState.Elements.Clear(); + base.OnNamespace(ns); + this.SharedState.Elements.Sort(SortComparison); + return new QsNamespace(ns.Name, this.SharedState.Elements.Select(e => e.Item2).ToImmutableArray(), ns.Documentation); } - this.Elements.Clear(); - base.Transform(ns); - this.Elements.Sort(SortComparison); - return new QsNamespace(ns.Name, this.Elements.Select(e => e.Item2).ToImmutableArray(), ns.Documentation); } } - // scope transformations - /// - /// Class that allows to transform scopes by keeping only statements whose expressions satisfy a certain criterion. - /// Calling Transform will build a new Scope that contains only the statements for which the fold of a given condition - /// over all contained expressions evaluates to true. - /// If evaluateOnSubexpressions is set to true, the fold is evaluated on all subexpressions as well. + /// Class that allows to transform scopes by keeping only statements whose expressions satisfy a certain criterion. + /// Calling Transform will build a new Scope that contains only the statements for which the fold of a given condition + /// over all contained expressions evaluates to true. + /// If evaluateOnSubexpressions is set to true, the fold is evaluated on all subexpressions as well. /// - public class SelectByFoldingOverExpressions : - ScopeTransformation> - where K : Core.StatementKindTransformation + public class SelectByFoldingOverExpressions + : SyntaxTreeTransformation { - protected readonly Func Condition; - protected readonly Func Fold; - private readonly bool Seed; + public class TransformationState + : FoldOverExpressions.IFoldingState + { + public bool Recur { get; } + public readonly bool Seed; - protected SelectByFoldingOverExpressions SubSelector; - protected virtual SelectByFoldingOverExpressions GetSubSelector() => - new SelectByFoldingOverExpressions(this.Condition, this.Fold, this.Seed, this._Expression.recur, null); + internal readonly Func Condition; + internal readonly Func ConstructFold; - public bool SatisfiesCondition => this._Expression.Result; + public bool Fold(TypedExpression ex, bool current) => + this.ConstructFold(this.Condition(ex), current); - public SelectByFoldingOverExpressions( - Func condition, Func fold, bool seed, bool evaluateOnSubexpressions = true, - Func>, K> statementKind = null) : - base(statementKind, new FoldOverExpressions((ex, current) => fold(condition(ex), current), seed, recur: evaluateOnSubexpressions)) - { - this.Condition = condition ?? throw new ArgumentNullException(nameof(condition)); - this.Fold = fold ?? throw new ArgumentNullException(nameof(condition)); - this.Seed = seed; - } + public bool FoldResult { get; set; } + public bool SatisfiesCondition => this.FoldResult; - protected new Core.StatementKindTransformation _StatementKind => base.StatementKind; - public override Core.StatementKindTransformation StatementKind - { - get + public TransformationState(Func condition, Func fold, bool seed, bool recur = true) { - this.SubSelector = GetSubSelector(); - return this.SubSelector._StatementKind; // don't spawn the next one + this.Recur = recur; + this.Seed = seed; + this.FoldResult = seed; + this.Condition = condition ?? throw new ArgumentNullException(nameof(condition)); + this.ConstructFold = fold ?? throw new ArgumentNullException(nameof(fold)); } } - public override QsStatement onStatement(QsStatement stm) + + public SelectByFoldingOverExpressions(Func condition, Func fold, bool seed, bool evaluateOnSubexpressions = true) + : base(new TransformationState(condition, fold, seed, evaluateOnSubexpressions)) { - stm = base.onStatement(stm); - this._Expression.Result = this.Fold(this._Expression.Result, this.SubSelector._Expression.Result); - return stm; + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); + this.Expressions = new FoldOverExpressions(this); + this.Statements = new StatementTransformation( + state => new SelectByFoldingOverExpressions(state.Condition, state.ConstructFold, state.Seed, state.Recur), + this); } - public override QsScope Transform(QsScope scope) + + // helper classes + + public class StatementTransformation

+ : Core.StatementTransformation where P : SelectByFoldingOverExpressions { - var statements = new List(); - foreach (var statement in scope.Statements) + protected P SubSelector; + protected readonly Func CreateSelector; + + ///

+ /// The given function for creating a new subselector is expected to initialize a new internal state with the same configurations as the one given upon construction. + /// Upon initialization, the FoldResult of the internal state should be set to the specified seed rather than the FoldResult of the given constructor argument. + /// + public StatementTransformation(Func createSelector, SyntaxTreeTransformation parent) + : base(parent) => + this.CreateSelector = createSelector ?? throw new ArgumentNullException(nameof(createSelector)); + + public override QsStatement OnStatement(QsStatement stm) { - // StatementKind.Transform sets a new Subselector that walks all expressions contained in statement, - // and sets its satisfiesCondition to true if one of them satisfies the condition given on initialization - var transformed = this.onStatement(statement); - if (this.SubSelector.SatisfiesCondition) statements.Add(transformed); + this.SubSelector = this.CreateSelector(this.SharedState); + var loc = this.SubSelector.Statements.OnLocation(stm.Location); + var stmKind = this.SubSelector.StatementKinds.OnStatementKind(stm.Statement); + var varDecl = this.SubSelector.Statements.OnLocalDeclarations(stm.SymbolDeclarations); + this.SharedState.FoldResult = this.SharedState.ConstructFold( + this.SharedState.FoldResult, this.SubSelector.SharedState.FoldResult); + return new QsStatement(stmKind, varDecl, loc, stm.Comments); + } + + public override QsScope OnScope(QsScope scope) + { + var statements = new List(); + foreach (var statement in scope.Statements) + { + // StatementKind.Transform sets a new Subselector that walks all expressions contained in statement, + // and sets its satisfiesCondition to true if one of them satisfies the condition given on initialization + var transformed = this.OnStatement(statement); + if (this.SubSelector.SharedState.SatisfiesCondition) statements.Add(transformed); + } + return new QsScope(statements.ToImmutableArray(), scope.KnownSymbols); } - return new QsScope(statements.ToImmutableArray(), scope.KnownSymbols); } } + /// - /// Class that allows to transform scopes by keeping only statements that contain certain expressions. - /// Calling Transform will build a new Scope that contains only the statements - /// which contain an expression or subexpression (only if evaluateOnSubexpressions is set to true) - /// that satisfies the condition given on initialization. + /// Class that allows to transform scopes by keeping only statements that contain certain expressions. + /// Calling Transform will build a new Scope that contains only the statements + /// which contain an expression or subexpression (only if evaluateOnSubexpressions is set to true) + /// that satisfies the condition given on initialization. /// - public class SelectByAnyContainedExpression : - SelectByFoldingOverExpressions - where K : Core.StatementKindTransformation + public class SelectByAnyContainedExpression + : SelectByFoldingOverExpressions { - private readonly Func, K> GetStatementKind; - protected override SelectByFoldingOverExpressions GetSubSelector() => - new SelectByAnyContainedExpression(this.Condition, this._Expression.recur, this.GetStatementKind); - - public SelectByAnyContainedExpression( - Func condition, bool evaluateOnSubexpressions = true, - Func, K> statementKind = null) : - base(condition, (a, b) => a || b, false, evaluateOnSubexpressions, s => statementKind(s as SelectByFoldingOverExpressions)) => - this.GetStatementKind = statementKind; + public SelectByAnyContainedExpression(Func condition, bool evaluateOnSubexpressions = true) + : base(condition, (a, b) => a || b, false, evaluateOnSubexpressions) { } } /// - /// Class that allows to transform scopes by keeping only statements whose expressions satisfy a certain criterion. - /// Calling Transform will build a new Scope that contains only the statements + /// Class that allows to transform scopes by keeping only statements whose expressions satisfy a certain criterion. + /// Calling Transform will build a new Scope that contains only the statements /// for which all contained expressions or subexpressions satisfy the condition given on initialization. /// Note that subexpressions will only be verified if evaluateOnSubexpressions is set to true (default value). /// - public class SelectByAllContainedExpressions : - SelectByFoldingOverExpressions - where K : Core.StatementKindTransformation + public class SelectByAllContainedExpressions + : SelectByFoldingOverExpressions { - private readonly Func, K> GetStatementKind; - protected override SelectByFoldingOverExpressions GetSubSelector() => - new SelectByAllContainedExpressions(this.Condition, this._Expression.recur, this.GetStatementKind); - - public SelectByAllContainedExpressions( - Func condition, bool evaluateOnSubexpressions = true, - Func, K> statementKind = null) : - base(condition, (a, b) => a && b, true, evaluateOnSubexpressions, s => statementKind(s as SelectByFoldingOverExpressions)) => - this.GetStatementKind = statementKind; + public SelectByAllContainedExpressions(Func condition, bool evaluateOnSubexpressions = true) + : base(condition, (a, b) => a && b, true, evaluateOnSubexpressions) { } } - // expression transformations - /// - /// Class that evaluates a fold on Transform. - /// If recur is set to true on initialization (default value), + /// Class that evaluates a fold on upon transforming an expression. + /// If recur is set to true in the internal state of the transformation, /// the fold function given on initialization is applied to all subexpressions as well as the expression itself - - /// i.e. the fold it take starting on inner expressions (from the inside out). - /// Otherwise the set Action is only applied to the expression itself. - /// The result of the fold is accessible via the Result property. + /// i.e. the fold it take starting on inner expressions (from the inside out). + /// Otherwise the specified folder is only applied to the expression itself. + /// The result of the fold is accessible via the FoldResult property in the internal state of the transformation. + /// The transformation itself merely walks expressions and rebuilding is disabled. /// - public class FoldOverExpressions : - ExpressionTransformation>> + public class FoldOverExpressions + : ExpressionTransformation where T : FoldOverExpressions.IFoldingState { - private static readonly Func< - ExpressionTransformation>, Core.ExpressionTypeTransformation>, - ExpressionKindTransformation>> InitializeKind = - e => new ExpressionKindTransformation>(e as FoldOverExpressions); - - internal readonly bool recur; - public readonly Func Fold; - public T Result { get; set; } - - public FoldOverExpressions(Func fold, T seed, bool recur = true) : - base(recur ? InitializeKind : null) // we need to enable expression kind transformations in order to walk subexpressions + public interface IFoldingState { - this.Fold = fold ?? throw new ArgumentNullException(nameof(fold)); - this.Result = seed; - this.recur = recur; + public bool Recur { get; } + public S Fold(TypedExpression ex, S current); + public S FoldResult { get; set; } } - public override TypedExpression Transform(TypedExpression ex) + + public FoldOverExpressions(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } + + public FoldOverExpressions(T state) + : base(state) { } + + + public override TypedExpression OnTypedExpression(TypedExpression ex) { - ex = recur ? base.Transform(ex) : ex; - this.Result = Fold(ex, this.Result); + ex = this.SharedState.Recur ? base.OnTypedExpression(ex) : ex; + this.SharedState.FoldResult = this.SharedState.Fold(ex, this.SharedState.FoldResult); return ex; } } diff --git a/src/QsCompiler/Transformations/ClassicallyControlled.cs b/src/QsCompiler/Transformations/ClassicallyControlled.cs new file mode 100644 index 0000000000..600c8fde3e --- /dev/null +++ b/src/QsCompiler/Transformations/ClassicallyControlled.cs @@ -0,0 +1,1063 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.Quantum.QsCompiler.DataTypes; +using Microsoft.Quantum.QsCompiler.SyntaxTokens; +using Microsoft.Quantum.QsCompiler.SyntaxTree; +using Microsoft.Quantum.QsCompiler.Transformations.Core; +using Microsoft.Quantum.QsCompiler.Transformations.SearchAndReplace; + + +namespace Microsoft.Quantum.QsCompiler.Transformations.ClassicallyControlled +{ + using ExpressionKind = QsExpressionKind; + using ResolvedTypeKind = QsTypeKind; + using TypeArgsResolution = ImmutableArray, ResolvedType>>; + + /// + /// This transformation works in two passes. + /// 1st Pass: Hoist the contents of conditional statements into separate operations, where possible. + /// 2nd Pass: On the way down the tree, reshape conditional statements to replace Elif's and + /// top level OR and AND conditions with equivalent nested if-else statements. One the way back up + /// the tree, convert conditional statements into ApplyIf calls, where possible. + /// This relies on anything having type parameters must be a global callable. + /// + public static class ReplaceClassicalControl + { + public static QsCompilation Apply(QsCompilation compilation) + { + compilation = HoistTransformation.Apply(compilation); + + return ConvertConditions.Apply(compilation); + } + + private class ConvertConditions : SyntaxTreeTransformation + { + public static QsCompilation Apply(QsCompilation compilation) + { + var filter = new ConvertConditions(compilation); + + return new QsCompilation(compilation.Namespaces.Select(ns => filter.Namespaces.OnNamespace(ns)).ToImmutableArray(), compilation.EntryPoints); + } + + public class TransformationState + { + public readonly QsCompilation Compilation; + + public TransformationState(QsCompilation compilation) + { + Compilation = compilation; + } + } + + private ConvertConditions(QsCompilation compilation) : base(new TransformationState(compilation)) + { + this.Namespaces = new NamespaceTransformation(this); + this.Statements = new StatementTransformation(this); + this.Expressions = new ExpressionTransformation(this, TransformationOptions.Disabled); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); + } + + private class NamespaceTransformation : NamespaceTransformation + { + public NamespaceTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override QsCallable OnFunction(QsCallable c) => c; // Prevent anything in functions from being considered + } + + private class StatementTransformation : StatementTransformation + { + public StatementTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + private (bool, TypedExpression, TypedExpression) IsValidScope(QsScope scope) + { + // if the scope has exactly one statement in it and that statement is a call like expression statement + if (scope != null + && scope.Statements.Length == 1 + && scope.Statements[0].Statement is QsStatementKind.QsExpressionStatement expr + && expr.Item.ResolvedType.Resolution.IsUnitType + && expr.Item.Expression is ExpressionKind.CallLikeExpression call + && !TypedExpression.IsPartialApplication(expr.Item.Expression) + && call.Item1.Expression is ExpressionKind.Identifier) + { + // We are dissolving the application of arguments here, so the call's type argument + // resolutions have to be moved to the 'identifier' sub expression. + + var callTypeArguments = expr.Item.TypeArguments; + var idTypeArguments = call.Item1.TypeArguments; + var combinedTypeArguments = Utils.GetCombinedTypeResolution(callTypeArguments, idTypeArguments); + + // This relies on anything having type parameters must be a global callable. + var newCallIdentifier = call.Item1; + if (combinedTypeArguments.Any() + && newCallIdentifier.Expression is ExpressionKind.Identifier id + && id.Item1 is Identifier.GlobalCallable global) + { + var globalCallable = SharedState.Compilation.Namespaces + .Where(ns => ns.Name.Equals(global.Item.Namespace)) + .Callables() + .FirstOrDefault(c => c.FullName.Name.Equals(global.Item.Name)); + + QsCompilerError.Verify(globalCallable != null, $"Could not find the global reference {global.Item.Namespace.Value + "." + global.Item.Name.Value}"); + + var callableTypeParameters = globalCallable.Signature.TypeParameters + .Select(x => x as QsLocalSymbol.ValidName); + + QsCompilerError.Verify(callableTypeParameters.All(x => x != null), $"Invalid type parameter names."); + + newCallIdentifier = new TypedExpression( + ExpressionKind.NewIdentifier( + id.Item1, + QsNullable>.NewValue( + callableTypeParameters + .Select(x => combinedTypeArguments.First(y => y.Item2.Equals(x.Item)).Item3).ToImmutableArray())), + combinedTypeArguments, + call.Item1.ResolvedType, + call.Item1.InferredInformation, + call.Item1.Range); + } + + return (true, newCallIdentifier, call.Item2); + } + + return (false, null, null); + } + + #region Apply If + + private TypedExpression CreateApplyIfExpression(QsResult result, TypedExpression conditionExpression, QsScope conditionScope, QsScope defaultScope) + { + var (isCondValid, condId, condArgs) = IsValidScope(conditionScope); + var (isDefaultValid, defaultId, defaultArgs) = IsValidScope(defaultScope); + + BuiltIn controlOpInfo; + TypedExpression controlArgs; + ImmutableArray targetArgs; + + var props = ImmutableHashSet.Empty; + + if (isCondValid) + { + // Get characteristic properties from global id + if (condId.ResolvedType.Resolution is ResolvedTypeKind.Operation op) + { + props = op.Item2.Characteristics.GetProperties(); + } + + (bool adj, bool ctl) = (props.Contains(OpProperty.Adjointable), props.Contains(OpProperty.Controllable)); + + if (isDefaultValid) + { + if (adj && ctl) + { + controlOpInfo = BuiltIn.ApplyIfElseRCA; + } + else if (adj) + { + controlOpInfo = BuiltIn.ApplyIfElseRA; + } + else if (ctl) + { + controlOpInfo = BuiltIn.ApplyIfElseRC; + } + else + { + controlOpInfo = BuiltIn.ApplyIfElseR; + } + + var (zeroOpArg, oneOpArg) = (result == QsResult.Zero) + ? (Utils.CreateValueTupleExpression(condId, condArgs), Utils.CreateValueTupleExpression(defaultId, defaultArgs)) + : (Utils.CreateValueTupleExpression(defaultId, defaultArgs), Utils.CreateValueTupleExpression(condId, condArgs)); + + controlArgs = Utils.CreateValueTupleExpression(conditionExpression, zeroOpArg, oneOpArg); + + targetArgs = ImmutableArray.Create(condArgs.ResolvedType, defaultArgs.ResolvedType); + } + else if (defaultScope == null) + { + if (adj && ctl) + { + controlOpInfo = (result == QsResult.Zero) + ? BuiltIn.ApplyIfZeroCA + : BuiltIn.ApplyIfOneCA; + } + else if (adj) + { + controlOpInfo = (result == QsResult.Zero) + ? BuiltIn.ApplyIfZeroA + : BuiltIn.ApplyIfOneA; + } + else if (ctl) + { + controlOpInfo = (result == QsResult.Zero) + ? BuiltIn.ApplyIfZeroC + : BuiltIn.ApplyIfOneC; + } + else + { + controlOpInfo = (result == QsResult.Zero) + ? BuiltIn.ApplyIfZero + : BuiltIn.ApplyIfOne; + } + + controlArgs = Utils.CreateValueTupleExpression( + conditionExpression, + Utils.CreateValueTupleExpression(condId, condArgs)); + + targetArgs = ImmutableArray.Create(condArgs.ResolvedType); + } + else + { + return null; // ToDo: Diagnostic message - default body exists, but is not valid + } + + } + else + { + return null; // ToDo: Diagnostic message - cond body not valid + } + + // Build the surrounding apply-if call + var controlOpId = Utils.CreateIdentifierExpression( + Identifier.NewGlobalCallable(new QsQualifiedName(controlOpInfo.Namespace, controlOpInfo.Name)), + targetArgs + .Zip(controlOpInfo.TypeParameters, (type, param) => Tuple.Create(new QsQualifiedName(controlOpInfo.Namespace, controlOpInfo.Name), param, type)) + .ToImmutableArray(), + Utils.GetOperationType(props, controlArgs.ResolvedType)); + + // Creates identity resolutions for the call expression + var opTypeArgResolutions = targetArgs + .SelectMany(x => + x.Resolution is ResolvedTypeKind.TupleType tup + ? tup.Item + : ImmutableArray.Create(x)) + .Where(x => x.Resolution.IsTypeParameter) + .Select(x => (x.Resolution as ResolvedTypeKind.TypeParameter).Item) + .GroupBy(x => (x.Origin, x.TypeName)) + .Select(group => + { + var typeParam = group.First(); + return Tuple.Create(typeParam.Origin, typeParam.TypeName, ResolvedType.New(ResolvedTypeKind.NewTypeParameter(typeParam))); + }) + .ToImmutableArray(); + + return Utils.CreateCallLikeExpression(controlOpId, controlArgs, opTypeArgResolutions); + } + + private QsStatement CreateApplyIfStatement(QsStatement statement, QsResult result, TypedExpression conditionExpression, QsScope conditionScope, QsScope defaultScope) + { + var controlCall = CreateApplyIfExpression(result, conditionExpression, conditionScope, defaultScope); + + if (controlCall != null) + { + return new QsStatement( + QsStatementKind.NewQsExpressionStatement(controlCall), + statement.SymbolDeclarations, + QsNullable.Null, + statement.Comments); + } + else + { + // ToDo: add diagnostic message here + return statement; // If the blocks can't be converted, return the original + } + } + + #endregion + + #region Condition Reshaping Logic + + private (bool, QsConditionalStatement) ProcessElif(QsConditionalStatement cond) + { + if (cond.ConditionalBlocks.Length < 2) return (false, cond); + + var subCond = new QsConditionalStatement(cond.ConditionalBlocks.RemoveAt(0), cond.Default); + var secondCondBlock = cond.ConditionalBlocks[1].Item2; + + var subIfStatment = new QsStatement + ( + QsStatementKind.NewQsConditionalStatement(subCond), + LocalDeclarations.Empty, + secondCondBlock.Location, + secondCondBlock.Comments + ); + + var newDefault = QsNullable.NewValue(new QsPositionedBlock( + new QsScope(ImmutableArray.Create(subIfStatment), secondCondBlock.Body.KnownSymbols), + secondCondBlock.Location, + QsComments.Empty)); + + return (true, new QsConditionalStatement(ImmutableArray.Create(cond.ConditionalBlocks[0]), newDefault)); + } + + private (bool, QsConditionalStatement) ProcessOR(QsConditionalStatement cond) + { + // This method expects elif blocks to have been abstracted out + if (cond.ConditionalBlocks.Length != 1) return (false, cond); + + var (condition, block) = cond.ConditionalBlocks[0]; + + if (condition.Expression is ExpressionKind.OR orCond) + { + var subCond = new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(orCond.Item2, block)), cond.Default); + var subIfStatment = new QsStatement + ( + QsStatementKind.NewQsConditionalStatement(subCond), + LocalDeclarations.Empty, + block.Location, + QsComments.Empty + ); + var newDefault = QsNullable.NewValue(new QsPositionedBlock( + new QsScope(ImmutableArray.Create(subIfStatment), block.Body.KnownSymbols), + block.Location, + QsComments.Empty)); + + return (true, new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(orCond.Item1, block)), newDefault)); + } + else + { + return (false, cond); + } + } + + private (bool, QsConditionalStatement) ProcessAND(QsConditionalStatement cond) + { + // This method expects elif blocks to have been abstracted out + if (cond.ConditionalBlocks.Length != 1) return (false, cond); + + var (condition, block) = cond.ConditionalBlocks[0]; + + if (condition.Expression is ExpressionKind.AND andCond) + { + var subCond = new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(andCond.Item2, block)), cond.Default); + var subIfStatment = new QsStatement + ( + QsStatementKind.NewQsConditionalStatement(subCond), + LocalDeclarations.Empty, + block.Location, + QsComments.Empty + ); + var newBlock = new QsPositionedBlock( + new QsScope(ImmutableArray.Create(subIfStatment), block.Body.KnownSymbols), + block.Location, + QsComments.Empty); + + return (true, new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(andCond.Item1, newBlock)), cond.Default)); + } + else + { + return (false, cond); + } + } + + private QsStatement ReshapeConditional(QsStatement statement) + { + if (statement.Statement is QsStatementKind.QsConditionalStatement cond) + { + var stm = cond.Item; + (_, stm) = ProcessElif(stm); + bool wasOrProcessed, wasAndProcessed; + do + { + (wasOrProcessed, stm) = ProcessOR(stm); + (wasAndProcessed, stm) = ProcessAND(stm); + } while (wasOrProcessed || wasAndProcessed); + + return new QsStatement + ( + QsStatementKind.NewQsConditionalStatement(stm), + statement.SymbolDeclarations, + statement.Location, + statement.Comments + ); + } + return statement; + } + + #endregion + + #region Condition Checking Logic + + private (bool, TypedExpression, QsScope, QsScope) IsConditionWithSingleBlock(QsStatement statement) + { + if (statement.Statement is QsStatementKind.QsConditionalStatement cond && cond.Item.ConditionalBlocks.Length == 1) + { + return (true, cond.Item.ConditionalBlocks[0].Item1, cond.Item.ConditionalBlocks[0].Item2.Body, cond.Item.Default.ValueOr(null)?.Body); + } + + return (false, null, null, null); + } + + private (bool, QsResult, TypedExpression) IsConditionedOnResultLiteralExpression(TypedExpression expression) + { + if (expression.Expression is ExpressionKind.EQ eq) + { + if (eq.Item1.Expression is ExpressionKind.ResultLiteral exp1) + { + return (true, exp1.Item, eq.Item2); + } + else if (eq.Item2.Expression is ExpressionKind.ResultLiteral exp2) + { + return (true, exp2.Item, eq.Item1); + } + } + + return (false, null, null); + } + + #endregion + + public override QsScope OnScope(QsScope scope) + { + var parentSymbols = this.OnLocalDeclarations(scope.KnownSymbols); + var statements = new List(); + + foreach (var statement in scope.Statements) + { + if (statement.Statement is QsStatementKind.QsConditionalStatement) + { + var stm = ReshapeConditional(statement); + stm = this.OnStatement(stm); + + var (isCondition, cond, conditionScope, defaultScope) = IsConditionWithSingleBlock(stm); + + if (isCondition) + { + /*ToDo: this could be a separate function.*/ + var (isCompareLiteral, literal, nonLiteral) = IsConditionedOnResultLiteralExpression(cond); + if (isCompareLiteral) + { + statements.Add(CreateApplyIfStatement(stm, literal, nonLiteral, conditionScope, defaultScope)); + } + else + { + statements.Add(stm); + } + /**/ + } + } + else + { + statements.Add(this.OnStatement(statement)); + } + } + + return new QsScope(statements.ToImmutableArray(), parentSymbols); + } + } + } + } + + + /// + /// Transformation handling the first pass task of hoisting of the contents of conditional statements. + /// If blocks are first validated to see if they can safely be hoisted into their own operation. + /// Validation requirements are that there are no return statements and that there are no set statements + /// on mutables declared outside the block. Setting mutables declared inside the block is valid. + /// If the block is valid, and there is more than one statement in the block, a new operation with the + /// block's contents is generated, having all the same type parameters as the calling context + /// and all known variables at the start of the block become parameters to the new operation. + /// The contents of the conditional block are then replaced with a call to the new operation with all + /// the type parameters and known variables being forwarded to the new operation as arguments. + /// + internal static class HoistTransformation // this class can be made public once its functionality is no longer tied to the classically-controlled transformation + { + internal static QsCompilation Apply(QsCompilation compilation) => HoistContents.Apply(compilation); + + private class HoistContents : SyntaxTreeTransformation + { + public static QsCompilation Apply(QsCompilation compilation) + { + var filter = new HoistContents(); + + return new QsCompilation(compilation.Namespaces.Select(ns => filter.Namespaces.OnNamespace(ns)).ToImmutableArray(), compilation.EntryPoints); + } + + public class CallableDetails + { + public QsCallable Callable; + public QsSpecialization Adjoint; + public QsSpecialization Controlled; + public QsSpecialization ControlledAdjoint; + public QsNullable> TypeParamTypes; + + public CallableDetails(QsCallable callable) + { + Callable = callable; + Adjoint = callable.Specializations.FirstOrDefault(spec => spec.Kind == QsSpecializationKind.QsAdjoint); + Controlled = callable.Specializations.FirstOrDefault(spec => spec.Kind == QsSpecializationKind.QsControlled); + ControlledAdjoint = callable.Specializations.FirstOrDefault(spec => spec.Kind == QsSpecializationKind.QsControlledAdjoint); + TypeParamTypes = callable.Signature.TypeParameters.Any(param => param.IsValidName) + ? QsNullable>.NewValue(callable.Signature.TypeParameters + .Where(param => param.IsValidName) + .Select(param => + ResolvedType.New(ResolvedTypeKind.NewTypeParameter(new QsTypeParameter( + callable.FullName, + ((QsLocalSymbol.ValidName)param).Item, + QsNullable>.Null + )))) + .ToImmutableArray()) + : QsNullable>.Null; + } + } + + public class TransformationState + { + public bool IsValidScope = true; + public List ControlOperations = null; + public ImmutableArray>> CurrentHoistParams = + ImmutableArray>>.Empty; + public bool ContainsHoistParamRef = false; + + public CallableDetails CurrentCallable = null; + public bool InBody = false; + public bool InAdjoint = false; + public bool InControlled = false; + public bool InWithinBlock = false; + + private (ResolvedSignature, IEnumerable) MakeSpecializations( + QsQualifiedName callableName, ResolvedType argsType, SpecializationImplementation bodyImplementation) + { + QsSpecialization MakeSpec(QsSpecializationKind kind, ResolvedSignature signature, SpecializationImplementation impl) => + new QsSpecialization( + kind, + callableName, + ImmutableArray.Empty, + CurrentCallable.Callable.SourceFile, + QsNullable.Null, + QsNullable>.Null, + signature, + impl, + ImmutableArray.Empty, + QsComments.Empty); + + var adj = CurrentCallable.Adjoint; + var ctl = CurrentCallable.Controlled; + var ctlAdj = CurrentCallable.ControlledAdjoint; + + bool addAdjoint = false; + bool addControlled = false; + + if (InWithinBlock) + { + addAdjoint = true; + addControlled = false; + } + else if (InBody) + { + if (adj != null && adj.Implementation is SpecializationImplementation.Generated adjGen) addAdjoint = adjGen.Item.IsInvert; + if (ctl != null && ctl.Implementation is SpecializationImplementation.Generated ctlGen) addControlled = ctlGen.Item.IsDistribute; + if (ctlAdj != null && ctlAdj.Implementation is SpecializationImplementation.Generated ctlAdjGen) + { + addAdjoint = addAdjoint || ctlAdjGen.Item.IsInvert && ctl.Implementation.IsGenerated; + addControlled = addControlled || ctlAdjGen.Item.IsDistribute && adj.Implementation.IsGenerated; + } + } + else if (ctlAdj != null && ctlAdj.Implementation is SpecializationImplementation.Generated gen) + { + addControlled = InAdjoint && gen.Item.IsDistribute; + addAdjoint = InControlled && gen.Item.IsInvert; + } + + var props = new List(); + if (addAdjoint) props.Add(OpProperty.Adjointable); + if (addControlled) props.Add(OpProperty.Controllable); + var newSig = new ResolvedSignature( + CurrentCallable.Callable.Signature.TypeParameters, + argsType, + ResolvedType.New(ResolvedTypeKind.UnitType), + new CallableInformation(ResolvedCharacteristics.FromProperties(props), InferredCallableInformation.NoInformation)); + + var controlledSig = new ResolvedSignature( + newSig.TypeParameters, + ResolvedType.New(ResolvedTypeKind.NewTupleType(ImmutableArray.Create( + ResolvedType.New(ResolvedTypeKind.NewArrayType(ResolvedType.New(ResolvedTypeKind.Qubit))), + newSig.ArgumentType))), + newSig.ReturnType, + newSig.Information); + + var specializations = new List() { MakeSpec(QsSpecializationKind.QsBody, newSig, bodyImplementation) }; + + if (addAdjoint) + { + specializations.Add(MakeSpec( + QsSpecializationKind.QsAdjoint, + newSig, + SpecializationImplementation.NewGenerated(QsGeneratorDirective.Invert))); + } + + if (addControlled) + { + specializations.Add(MakeSpec( + QsSpecializationKind.QsControlled, + controlledSig, + SpecializationImplementation.NewGenerated(QsGeneratorDirective.Distribute))); + } + + if (addAdjoint && addControlled) + { + specializations.Add(MakeSpec( + QsSpecializationKind.QsControlledAdjoint, + controlledSig, + SpecializationImplementation.NewGenerated(QsGeneratorDirective.Distribute))); + } + + return (newSig, specializations); + } + + public (QsCallable, ResolvedType) GenerateOperation(QsScope contents) + { + var newName = UniqueVariableNames.PrependGuid(CurrentCallable.Callable.FullName); + + var knownVariables = contents.KnownSymbols.IsEmpty + ? ImmutableArray>>.Empty + : contents.KnownSymbols.Variables; + + var parameters = QsTuple>.NewQsTuple(knownVariables + .Select(var => QsTuple>.NewQsTupleItem(new LocalVariableDeclaration( + QsLocalSymbol.NewValidName(var.VariableName), + var.Type, + var.InferredInformation, + var.Position, + var.Range))) + .ToImmutableArray()); + + var paramTypes = ResolvedType.New(ResolvedTypeKind.UnitType); + if (knownVariables.Length == 1) + { + paramTypes = knownVariables.First().Type; + } + else if (knownVariables.Length > 1) + { + paramTypes = ResolvedType.New(ResolvedTypeKind.NewTupleType(knownVariables + .Select(var => var.Type) + .ToImmutableArray())); + } + + var (signature, specializations) = MakeSpecializations(newName, paramTypes, SpecializationImplementation.NewProvided(parameters, contents)); + + var controlCallable = new QsCallable( + QsCallableKind.Operation, + newName, + ImmutableArray.Empty, + CurrentCallable.Callable.SourceFile, + QsNullable.Null, + signature, + parameters, + specializations.ToImmutableArray(), + ImmutableArray.Empty, + QsComments.Empty); + + var updatedCallable = UpdateGeneratedOp.Apply(controlCallable, knownVariables, CurrentCallable.Callable.FullName, newName); + + return (updatedCallable, signature.ArgumentType); + } + } + + private HoistContents() : base(new TransformationState()) + { + this.Namespaces = new NamespaceTransformation(this); + this.StatementKinds = new StatementKindTransformation(this); + this.Expressions = new ExpressionTransformation(this); + this.ExpressionKinds = new ExpressionKindTransformation(this); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); + } + + private class NamespaceTransformation : NamespaceTransformation + { + public NamespaceTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override QsCallable OnCallableDeclaration(QsCallable c) + { + SharedState.CurrentCallable = new CallableDetails(c); + return base.OnCallableDeclaration(c); + } + + public override QsSpecialization OnBodySpecialization(QsSpecialization spec) + { + SharedState.InBody = true; + var rtrn = base.OnBodySpecialization(spec); + SharedState.InBody = false; + return rtrn; + } + + public override QsSpecialization OnAdjointSpecialization(QsSpecialization spec) + { + SharedState.InAdjoint = true; + var rtrn = base.OnAdjointSpecialization(spec); + SharedState.InAdjoint = false; + return rtrn; + } + + public override QsSpecialization OnControlledSpecialization(QsSpecialization spec) + { + SharedState.InControlled = true; + var rtrn = base.OnControlledSpecialization(spec); + SharedState.InControlled = false; + return rtrn; + } + + public override QsCallable OnFunction(QsCallable c) => c; // Prevent anything in functions from being hoisted + + public override QsNamespace OnNamespace(QsNamespace ns) + { + // Control operations list will be populated in the transform + SharedState.ControlOperations = new List(); + return base.OnNamespace(ns) + .WithElements(elems => elems.AddRange(SharedState.ControlOperations.Select(op => QsNamespaceElement.NewQsCallable(op)))); + } + } + + private class StatementKindTransformation : StatementKindTransformation + { + public StatementKindTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + private (QsCallable, QsStatement) HoistBody(QsScope body) + { + var (targetOp, originalArgumentType) = SharedState.GenerateOperation(body); + var targetOpType = ResolvedType.New(ResolvedTypeKind.NewOperation( + Tuple.Create( + originalArgumentType, + ResolvedType.New(ResolvedTypeKind.UnitType)), + targetOp.Signature.Information)); + + var targetTypeArgTypes = SharedState.CurrentCallable.TypeParamTypes; + var targetOpId = new TypedExpression + ( + ExpressionKind.NewIdentifier(Identifier.NewGlobalCallable(targetOp.FullName), targetTypeArgTypes), + targetTypeArgTypes.IsNull + ? TypeArgsResolution.Empty + : targetTypeArgTypes.Item + .Select(type => Tuple.Create(targetOp.FullName, ((ResolvedTypeKind.TypeParameter)type.Resolution).Item.TypeName, type)) + .ToImmutableArray(), + targetOpType, + new InferredExpressionInformation(false, false), + QsNullable>.Null + ); + + var knownSymbols = body.KnownSymbols.Variables; + + TypedExpression targetArgs = null; + if (knownSymbols.Any()) + { + targetArgs = Utils.CreateValueTupleExpression(knownSymbols.Select(var => Utils.CreateIdentifierExpression( + Identifier.NewLocalVariable(var.VariableName), + TypeArgsResolution.Empty, + var.Type)) + .ToArray()); + } + else + { + targetArgs = new TypedExpression + ( + ExpressionKind.UnitValue, + TypeArgsResolution.Empty, + ResolvedType.New(ResolvedTypeKind.UnitType), + new InferredExpressionInformation(false, false), + QsNullable>.Null + ); + } + + var call = new TypedExpression + ( + ExpressionKind.NewCallLikeExpression(targetOpId, targetArgs), + targetTypeArgTypes.IsNull + ? TypeArgsResolution.Empty + : targetTypeArgTypes.Item + .Select(type => Tuple.Create(SharedState.CurrentCallable.Callable.FullName, ((ResolvedTypeKind.TypeParameter)type.Resolution).Item.TypeName, type)) + .ToImmutableArray(), + ResolvedType.New(ResolvedTypeKind.UnitType), + new InferredExpressionInformation(false, true), + QsNullable>.Null + ); + + return (targetOp, new QsStatement( + QsStatementKind.NewQsExpressionStatement(call), + LocalDeclarations.Empty, + QsNullable.Null, + QsComments.Empty)); + } + + // ToDo: This logic should be externalized at some point to make the Hoisting more general + private bool IsScopeSingleCall(QsScope contents) + { + if (contents.Statements.Length != 1) return false; + + return contents.Statements[0].Statement is QsStatementKind.QsExpressionStatement expr + && expr.Item.Expression is ExpressionKind.CallLikeExpression call + && !TypedExpression.IsPartialApplication(expr.Item.Expression) + && call.Item1.Expression is ExpressionKind.Identifier; + } + + public override QsStatementKind OnConjugation(QsConjugation stm) + { + var superInWithinBlock = SharedState.InWithinBlock; + SharedState.InWithinBlock = true; + var (_, outer) = this.OnPositionedBlock(QsNullable.Null, stm.OuterTransformation); + SharedState.InWithinBlock = superInWithinBlock; + + var (_, inner) = this.OnPositionedBlock(QsNullable.Null, stm.InnerTransformation); + + return QsStatementKind.NewQsConjugation(new QsConjugation(outer, inner)); + } + + public override QsStatementKind OnReturnStatement(TypedExpression ex) + { + SharedState.IsValidScope = false; + return base.OnReturnStatement(ex); + } + + public override QsStatementKind OnValueUpdate(QsValueUpdate stm) + { + // If lhs contains an identifier found in the scope's known variables (variables from the super-scope), the scope is not valid + var lhs = this.Expressions.OnTypedExpression(stm.Lhs); + + if (SharedState.ContainsHoistParamRef) + { + SharedState.IsValidScope = false; + } + + var rhs = this.Expressions.OnTypedExpression(stm.Rhs); + return QsStatementKind.NewQsValueUpdate(new QsValueUpdate(lhs, rhs)); + } + + public override QsStatementKind OnConditionalStatement(QsConditionalStatement stm) + { + var contextValidScope = SharedState.IsValidScope; + var contextHoistParams = SharedState.CurrentHoistParams; + + var isHoistValid = true; + + var newConditionBlocks = new List>(); + var generatedOperations = new List(); + foreach (var condBlock in stm.ConditionalBlocks) + { + SharedState.IsValidScope = true; + SharedState.CurrentHoistParams = condBlock.Item2.Body.KnownSymbols.IsEmpty + ? ImmutableArray>>.Empty + : condBlock.Item2.Body.KnownSymbols.Variables; + + var (expr, block) = this.OnPositionedBlock(QsNullable.NewValue(condBlock.Item1), condBlock.Item2); + + // ToDo: Reduce the number of unnecessary generated operations by generalizing + // the condition logic for the conversion and using that condition here + //var (isExprCond, _, _) = IsConditionedOnResultLiteralExpression(expr.Item); + + if (block.Body.Statements.Length > 0 /*&& isExprCond*/ && SharedState.IsValidScope && !IsScopeSingleCall(block.Body)) // if sub-scope is valid, hoist content + { + // Hoist the scope to its own operation + var (callable, call) = HoistBody(block.Body); + block = new QsPositionedBlock( + new QsScope(ImmutableArray.Create(call), block.Body.KnownSymbols), + block.Location, + block.Comments); + newConditionBlocks.Add(Tuple.Create(expr.Item, block)); + generatedOperations.Add(callable); + } + else + { + isHoistValid = false; + break; + } + } + + var newDefault = QsNullable.Null; + if (isHoistValid && stm.Default.IsValue) + { + SharedState.IsValidScope = true; + SharedState.CurrentHoistParams = stm.Default.Item.Body.KnownSymbols.IsEmpty + ? ImmutableArray>>.Empty + : stm.Default.Item.Body.KnownSymbols.Variables; + + var (_, block) = this.OnPositionedBlock(QsNullable.Null, stm.Default.Item); + if (block.Body.Statements.Length > 0 && SharedState.IsValidScope && !IsScopeSingleCall(block.Body)) // if sub-scope is valid, hoist content + { + // Hoist the scope to its own operation + var (callable, call) = HoistBody(block.Body); + block = new QsPositionedBlock( + new QsScope(ImmutableArray.Create(call), block.Body.KnownSymbols), + block.Location, + block.Comments); + newDefault = QsNullable.NewValue(block); + generatedOperations.Add(callable); + } + else + { + isHoistValid = false; + } + } + + if (isHoistValid) + { + SharedState.ControlOperations.AddRange(generatedOperations); + } + + SharedState.CurrentHoistParams = contextHoistParams; + SharedState.IsValidScope = contextValidScope; + + return isHoistValid + ? QsStatementKind.NewQsConditionalStatement( + new QsConditionalStatement(newConditionBlocks.ToImmutableArray(), newDefault)) + : QsStatementKind.NewQsConditionalStatement( + new QsConditionalStatement(stm.ConditionalBlocks, stm.Default)); + } + + public override QsStatementKind OnStatementKind(QsStatementKind kind) + { + SharedState.ContainsHoistParamRef = false; // Every statement kind starts off false + return base.OnStatementKind(kind); + } + } + + private class ExpressionTransformation : ExpressionTransformation + { + public ExpressionTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override TypedExpression OnTypedExpression(TypedExpression ex) + { + var contextContainsHoistParamRef = SharedState.ContainsHoistParamRef; + SharedState.ContainsHoistParamRef = false; + var rtrn = base.OnTypedExpression(ex); + + // If the sub context contains a reference, then the super context contains a reference, + // otherwise return the super context to its original value + if (!SharedState.ContainsHoistParamRef) + { + SharedState.ContainsHoistParamRef = contextContainsHoistParamRef; + } + + return rtrn; + } + } + + private class ExpressionKindTransformation : ExpressionKindTransformation + { + public ExpressionKindTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override ExpressionKind OnIdentifier(Identifier sym, QsNullable> tArgs) + { + if (sym is Identifier.LocalVariable local && + SharedState.CurrentHoistParams.Any(param => param.VariableName.Equals(local.Item))) + { + SharedState.ContainsHoistParamRef = true; + } + return base.OnIdentifier(sym, tArgs); + } + } + } + + /// + /// Transformation that updates the contents of newly generated operations by: + /// 1. Rerouting the origins of type parameter references to the new operation + /// 2. Changes the IsMutable info on variable that used to be mutable, but are now immutable params to the operation + /// + private class UpdateGeneratedOp : SyntaxTreeTransformation + { + public static QsCallable Apply(QsCallable qsCallable, ImmutableArray>> parameters, QsQualifiedName oldName, QsQualifiedName newName) + { + var filter = new UpdateGeneratedOp(parameters, oldName, newName); + + return filter.Namespaces.OnCallableDeclaration(qsCallable); + } + + public class TransformationState + { + public bool IsRecursiveIdentifier = false; + public readonly ImmutableArray>> Parameters; + public readonly QsQualifiedName OldName; + public readonly QsQualifiedName NewName; + + public TransformationState(ImmutableArray>> parameters, QsQualifiedName oldName, QsQualifiedName newName) + { + Parameters = parameters; + OldName = oldName; + NewName = newName; + } + } + + private UpdateGeneratedOp(ImmutableArray>> parameters, QsQualifiedName oldName, QsQualifiedName newName) + : base(new TransformationState(parameters, oldName, newName)) + { + this.Expressions = new ExpressionTransformation(this); + this.ExpressionKinds = new ExpressionKindTransformation(this); + this.Types = new TypeTransformation(this); + } + + private class ExpressionTransformation : ExpressionTransformation + { + public ExpressionTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override ImmutableDictionary>, ResolvedType> OnTypeParamResolutions(ImmutableDictionary>, ResolvedType> typeParams) + { + // Prevent keys from having their names updated + return typeParams.ToImmutableDictionary(kvp => kvp.Key, kvp => this.Types.OnType(kvp.Value)); + } + + public override TypedExpression OnTypedExpression(TypedExpression ex) + { + // Checks if expression is mutable identifier that is in parameter list + if (ex.InferredInformation.IsMutable && + ex.Expression is ExpressionKind.Identifier id && + id.Item1 is Identifier.LocalVariable variable && + SharedState.Parameters.Any(x => x.VariableName.Equals(variable))) + { + // Set the mutability to false + ex = new TypedExpression( + ex.Expression, + ex.TypeArguments, + ex.ResolvedType, + new InferredExpressionInformation(false, ex.InferredInformation.HasLocalQuantumDependency), + ex.Range); + } + + // Prevent IsRecursiveIdentifier from propagating beyond the typed expression it is referring to + var isRecursiveIdentifier = SharedState.IsRecursiveIdentifier; + var rtrn = base.OnTypedExpression(ex); + SharedState.IsRecursiveIdentifier = isRecursiveIdentifier; + return rtrn; + } + } + + private class ExpressionKindTransformation : ExpressionKindTransformation + { + public ExpressionKindTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override ExpressionKind OnIdentifier(Identifier sym, QsNullable> tArgs) + { + var rtrn = base.OnIdentifier(sym, tArgs); + + // Then check if this is a recursive identifier + // In this context, that is a call back to the original callable from the newly generated operation + if (sym is Identifier.GlobalCallable callable && SharedState.OldName.Equals(callable.Item)) + { + // Setting this flag will prevent the rerouting logic from processing the resolved type of the recursive identifier expression. + // This is necessary because we don't want any type parameters from the original callable from being rerouted to the new generated + // operation's type parameters in the definition of the identifier. + SharedState.IsRecursiveIdentifier = true; + } + return rtrn; + } + } + + private class TypeTransformation : TypeTransformation + { + public TypeTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override ResolvedTypeKind OnTypeParameter(QsTypeParameter tp) + { + // Reroute a type parameter's origin to the newly generated operation + if (!SharedState.IsRecursiveIdentifier && SharedState.OldName.Equals(tp.Origin)) + { + tp = new QsTypeParameter(SharedState.NewName, tp.TypeName, tp.Range); + } + + return base.OnTypeParameter(tp); + } + } + } + } +} diff --git a/src/QsCompiler/Transformations/ClassicallyControlledTransformation.cs b/src/QsCompiler/Transformations/ClassicallyControlledTransformation.cs deleted file mode 100644 index 3068bdd70e..0000000000 --- a/src/QsCompiler/Transformations/ClassicallyControlledTransformation.cs +++ /dev/null @@ -1,1094 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Linq; -using Microsoft.Quantum.QsCompiler.DataTypes; -using Microsoft.Quantum.QsCompiler.SyntaxTokens; -using Microsoft.Quantum.QsCompiler.SyntaxTree; - - -namespace Microsoft.Quantum.QsCompiler.Transformations.ClassicallyControlledTransformation -{ - using ExpressionKind = QsExpressionKind; - using ResolvedTypeKind = QsTypeKind; - using TypeArgsResolution = ImmutableArray, ResolvedType>>; - - // This transformation works in two passes. - // 1st Pass: Hoist the contents of conditional statements into separate operations, where possible. - // 2nd Pass: On the way down the tree, reshape conditional statements to replace Elif's and - // top level OR and AND conditions with equivalent nested if-else statements. One the way back up - // the tree, convert conditional statements into ApplyIf calls, where possible. - // This relies on anything having type parameters must be a global callable. - public class ClassicallyControlledTransformation - { - public static QsCompilation Apply(QsCompilation compilation) - { - compilation = HoistTransformation.Apply(compilation); - - var filter = new ClassicallyControlledSyntax(compilation); - return new QsCompilation(compilation.Namespaces.Select(ns => filter.Transform(ns)).ToImmutableArray(), compilation.EntryPoints); - } - - private static TypedExpression CreateIdentifierExpression(Identifier id, - TypeArgsResolution typeArgsMapping, ResolvedType resolvedType) => - new TypedExpression - ( - ExpressionKind.NewIdentifier( - id, - typeArgsMapping.Any() - ? QsNullable>.NewValue(typeArgsMapping - .Select(argMapping => argMapping.Item3) // This should preserve the order of the type args - .ToImmutableArray()) - : QsNullable>.Null), - typeArgsMapping, - resolvedType, - new InferredExpressionInformation(false, false), - QsNullable>.Null - ); - - private static TypedExpression CreateValueTupleExpression(params TypedExpression[] expressions) => - new TypedExpression - ( - ExpressionKind.NewValueTuple(expressions.ToImmutableArray()), - TypeArgsResolution.Empty, - ResolvedType.New(ResolvedTypeKind.NewTupleType(expressions.Select(expr => expr.ResolvedType).ToImmutableArray())), - new InferredExpressionInformation(false, false), - QsNullable>.Null - ); - - private static (bool, QsResult, TypedExpression) IsConditionedOnResultLiteralExpression(TypedExpression expression) - { - if (expression.Expression is ExpressionKind.EQ eq) - { - if (eq.Item1.Expression is ExpressionKind.ResultLiteral exp1) - { - return (true, exp1.Item, eq.Item2); - } - else if (eq.Item2.Expression is ExpressionKind.ResultLiteral exp2) - { - return (true, exp2.Item, eq.Item1); - } - } - - return (false, null, null); - } - - private static (bool, QsResult, TypedExpression, QsScope, QsScope) IsConditionedOnResultLiteralStatement(QsStatement statement) - { - if (statement.Statement is QsStatementKind.QsConditionalStatement cond) - { - if (cond.Item.ConditionalBlocks.Length == 1 && (cond.Item.ConditionalBlocks[0].Item1.Expression is ExpressionKind.EQ expression)) - { - var scope = cond.Item.ConditionalBlocks[0].Item2.Body; - var defaultScope = cond.Item.Default.ValueOr(null)?.Body; - - var (success, literal, expr) = IsConditionedOnResultLiteralExpression(cond.Item.ConditionalBlocks[0].Item1); - - if (success) - { - return (true, literal, expr, scope, defaultScope); - } - } - } - - return (false, null, null, null, null); - } - - private ClassicallyControlledTransformation() { } - - private class ClassicallyControlledSyntax : SyntaxTreeTransformation - { - public ClassicallyControlledSyntax(QsCompilation compilation, ClassicallyControlledScope scope = null) : base(scope ?? new ClassicallyControlledScope(compilation)) { } - - public override QsCallable onFunction(QsCallable c) => c; // Prevent anything in functions from being considered - } - - private class ClassicallyControlledScope : ScopeTransformation - { - private QsCompilation _Compilation; - - public ClassicallyControlledScope(QsCompilation compilation, NoExpressionTransformations expr = null) : base(expr ?? new NoExpressionTransformations()) - { - _Compilation = compilation; - } - - private TypeArgsResolution GetCombinedType(TypeArgsResolution outer, TypeArgsResolution inner) - { - var outerDict = outer.ToDictionary(x => (x.Item1, x.Item2), x => x.Item3); - return inner.Select(innerRes => - { - if (innerRes.Item3.Resolution is ResolvedTypeKind.TypeParameter typeParam && - outerDict.TryGetValue((typeParam.Item.Origin, typeParam.Item.TypeName), out var outerRes)) - { - outerDict.Remove((typeParam.Item.Origin, typeParam.Item.TypeName)); - return Tuple.Create(innerRes.Item1, innerRes.Item2, outerRes); - } - else - { - return innerRes; - } - }).Concat(outerDict.Select(x => Tuple.Create(x.Key.Item1, x.Key.Item2, x.Value))).ToImmutableArray(); - } - - private (bool, TypedExpression, TypedExpression) IsValidScope(QsScope scope) - { - // if the scope has exactly one statement in it and that statement is a call like expression statement - if (scope != null - && scope.Statements.Length == 1 - && scope.Statements[0].Statement is QsStatementKind.QsExpressionStatement expr - && expr.Item.ResolvedType.Resolution.IsUnitType - && expr.Item.Expression is ExpressionKind.CallLikeExpression call - && !TypedExpression.IsPartialApplication(expr.Item.Expression) - && call.Item1.Expression is ExpressionKind.Identifier) - { - // We are dissolving the application of arguments here, so the call's type argument - // resolutions have to be moved to the 'identifier' sub expression. - - var callTypeArguments = expr.Item.TypeArguments; - var idTypeArguments = call.Item1.TypeArguments; - var combinedTypeArguments = GetCombinedType(callTypeArguments, idTypeArguments); - - // This relies on anything having type parameters must be a global callable. - var newExpr1 = call.Item1; - if (combinedTypeArguments.Any() - && newExpr1.Expression is ExpressionKind.Identifier id - && id.Item1 is Identifier.GlobalCallable global) - { - var globalCallable = _Compilation.Namespaces - .Where(ns => ns.Name.Equals(global.Item.Namespace)) - .Callables() - .FirstOrDefault(c => c.FullName.Name.Equals(global.Item.Name)); - - QsCompilerError.Verify(globalCallable != null, $"Could not find the global reference {global.Item.Namespace.Value + "." + global.Item.Name.Value}"); - - var callableTypeParameters = globalCallable.Signature.TypeParameters - .Select(x => x as QsLocalSymbol.ValidName); - - QsCompilerError.Verify(callableTypeParameters.All(x => x != null), $"Invalid type parameter names."); - - newExpr1 = new TypedExpression( - ExpressionKind.NewIdentifier( - id.Item1, - QsNullable>.NewValue( - callableTypeParameters - .Select(x => combinedTypeArguments.First(y => y.Item2.Equals(x.Item)).Item3).ToImmutableArray())), - combinedTypeArguments, - call.Item1.ResolvedType, - call.Item1.InferredInformation, - call.Item1.Range); - } - - return (true, newExpr1, call.Item2); - } - - return (false, null, null); - } - - private TypedExpression CreateApplyIfCall(TypedExpression id, TypedExpression args, TypeArgsResolution typeRes) => - new TypedExpression - ( - ExpressionKind.NewCallLikeExpression(id, args), - typeRes, - ResolvedType.New(ResolvedTypeKind.UnitType), - new InferredExpressionInformation(false, true), - QsNullable>.Null - ); - - private QsStatement CreateApplyIfStatement(QsStatement statement, QsResult result, TypedExpression conditionExpression, QsScope conditionScope, QsScope defaultScope) - { - var controlCall = GetApplyIfExpression(result, conditionExpression, conditionScope, defaultScope); - - if (controlCall != null) - { - return new QsStatement( - QsStatementKind.NewQsExpressionStatement(controlCall), - statement.SymbolDeclarations, - QsNullable.Null, - statement.Comments); - } - else - { - // ToDo: add diagnostic message here - return statement; // If the blocks can't be converted, return the original - } - } - - private static ResolvedType GetApplyIfResolvedType(IEnumerable props, ResolvedType argumentType) - { - var characteristics = new CallableInformation( - ResolvedCharacteristics.FromProperties(props), - InferredCallableInformation.NoInformation); - - return ResolvedType.New(ResolvedTypeKind.NewOperation( - Tuple.Create(argumentType, ResolvedType.New(ResolvedTypeKind.UnitType)), - characteristics)); - } - - private TypedExpression GetApplyIfExpression(QsResult result, TypedExpression conditionExpression, QsScope conditionScope, QsScope defaultScope) - { - var (isCondValid, condId, condArgs) = IsValidScope(conditionScope); - var (isDefaultValid, defaultId, defaultArgs) = IsValidScope(defaultScope); - - BuiltIn controlOpInfo; - TypedExpression controlArgs; - ImmutableArray targetArgs; - - var props = ImmutableHashSet.Empty; - - if (isCondValid) - { - // Get characteristic properties from global id - if (condId.ResolvedType.Resolution is ResolvedTypeKind.Operation op) - { - props = op.Item2.Characteristics.GetProperties(); - } - - (bool adj, bool ctl) = (props.Contains(OpProperty.Adjointable), props.Contains(OpProperty.Controllable)); - - if (isDefaultValid) - { - if (adj && ctl) - { - controlOpInfo = BuiltIn.ApplyIfElseRCA; - } - else if (adj) - { - controlOpInfo = BuiltIn.ApplyIfElseRA; - } - else if (ctl) - { - controlOpInfo = BuiltIn.ApplyIfElseRC; - } - else - { - controlOpInfo = BuiltIn.ApplyIfElseR; - } - - var (zeroOpArg, oneOpArg) = (result == QsResult.Zero) - ? (CreateValueTupleExpression(condId, condArgs), CreateValueTupleExpression(defaultId, defaultArgs)) - : (CreateValueTupleExpression(defaultId, defaultArgs), CreateValueTupleExpression(condId, condArgs)); - - controlArgs = CreateValueTupleExpression(conditionExpression, zeroOpArg, oneOpArg); - - targetArgs = ImmutableArray.Create(condArgs.ResolvedType, defaultArgs.ResolvedType); - } - else if (defaultScope == null) - { - if (adj && ctl) - { - controlOpInfo = (result == QsResult.Zero) - ? BuiltIn.ApplyIfZeroCA - : BuiltIn.ApplyIfOneCA; - } - else if (adj) - { - controlOpInfo = (result == QsResult.Zero) - ? BuiltIn.ApplyIfZeroA - : BuiltIn.ApplyIfOneA; - } - else if (ctl) - { - controlOpInfo = (result == QsResult.Zero) - ? BuiltIn.ApplyIfZeroC - : BuiltIn.ApplyIfOneC; - } - else - { - controlOpInfo = (result == QsResult.Zero) - ? BuiltIn.ApplyIfZero - : BuiltIn.ApplyIfOne; - } - - controlArgs = CreateValueTupleExpression( - conditionExpression, - CreateValueTupleExpression(condId, condArgs)); - - targetArgs = ImmutableArray.Create(condArgs.ResolvedType); - } - else - { - return null; // ToDo: Diagnostic message - default body exists, but is not valid - } - - } - else - { - return null; // ToDo: Diagnostic message - cond body not valid - } - - // Build the surrounding apply-if call - var controlOpId = CreateIdentifierExpression( - Identifier.NewGlobalCallable(new QsQualifiedName(controlOpInfo.Namespace, controlOpInfo.Name)), - targetArgs - .Zip(controlOpInfo.TypeParameters, (type, param) => Tuple.Create(new QsQualifiedName(controlOpInfo.Namespace, controlOpInfo.Name), param, type)) - .ToImmutableArray(), - GetApplyIfResolvedType(props, controlArgs.ResolvedType)); - - // Creates identity resolutions for the call expression - var opTypeArgResolutions = targetArgs - .SelectMany(x => - x.Resolution is ResolvedTypeKind.TupleType tup - ? tup.Item - : ImmutableArray.Create(x)) - .Where(x => x.Resolution.IsTypeParameter) - .Select(x => (x.Resolution as ResolvedTypeKind.TypeParameter).Item) - .GroupBy(x => (x.Origin, x.TypeName)) - .Select(group => - { - var typeParam = group.First(); - return Tuple.Create(typeParam.Origin, typeParam.TypeName, ResolvedType.New(ResolvedTypeKind.NewTypeParameter(typeParam))); - }) - .ToImmutableArray(); - - return CreateApplyIfCall(controlOpId, controlArgs, opTypeArgResolutions); - } - - private (bool, QsConditionalStatement) ProcessElif(QsConditionalStatement cond) - { - if (cond.ConditionalBlocks.Length < 2) return (false, cond); - - var subCond = new QsConditionalStatement(cond.ConditionalBlocks.RemoveAt(0), cond.Default); - var secondCondBlock = cond.ConditionalBlocks[1].Item2; - - var subIfStatment = new QsStatement - ( - QsStatementKind.NewQsConditionalStatement(subCond), - LocalDeclarations.Empty, - secondCondBlock.Location, - secondCondBlock.Comments - ); - - var newDefault = QsNullable.NewValue(new QsPositionedBlock( - new QsScope(ImmutableArray.Create(subIfStatment), secondCondBlock.Body.KnownSymbols), - secondCondBlock.Location, - QsComments.Empty)); - - return (true, new QsConditionalStatement(ImmutableArray.Create(cond.ConditionalBlocks[0]), newDefault)); - } - - private (bool, QsConditionalStatement) ProcessOR(QsConditionalStatement cond) - { - // This method expects elif blocks to have been abstracted out - if (cond.ConditionalBlocks.Length != 1) return (false, cond); - - var (condition, block) = cond.ConditionalBlocks[0]; - - if (condition.Expression is ExpressionKind.OR orCond) - { - var subCond = new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(orCond.Item2, block)), cond.Default); - var subIfStatment = new QsStatement - ( - QsStatementKind.NewQsConditionalStatement(subCond), - LocalDeclarations.Empty, - block.Location, - QsComments.Empty - ); - var newDefault = QsNullable.NewValue(new QsPositionedBlock( - new QsScope(ImmutableArray.Create(subIfStatment), block.Body.KnownSymbols), - block.Location, - QsComments.Empty)); - - return (true, new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(orCond.Item1, block)), newDefault)); - } - else - { - return (false, cond); - } - } - - private (bool, QsConditionalStatement) ProcessAND(QsConditionalStatement cond) - { - // This method expects elif blocks to have been abstracted out - if (cond.ConditionalBlocks.Length != 1) return (false, cond); - - var (condition, block) = cond.ConditionalBlocks[0]; - - if (condition.Expression is ExpressionKind.AND andCond) - { - var subCond = new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(andCond.Item2, block)), cond.Default); - var subIfStatment = new QsStatement - ( - QsStatementKind.NewQsConditionalStatement(subCond), - LocalDeclarations.Empty, - block.Location, - QsComments.Empty - ); - var newBlock = new QsPositionedBlock( - new QsScope(ImmutableArray.Create(subIfStatment), block.Body.KnownSymbols), - block.Location, - QsComments.Empty); - - return (true, new QsConditionalStatement(ImmutableArray.Create(Tuple.Create(andCond.Item1, newBlock)), cond.Default)); - } - else - { - return (false, cond); - } - } - - private QsStatement ReshapeConditional(QsStatement statement) - { - if (statement.Statement is QsStatementKind.QsConditionalStatement cond) - { - var stm = cond.Item; - (_, stm) = ProcessElif(stm); - bool wasOrProcessed, wasAndProcessed; - do - { - (wasOrProcessed, stm) = ProcessOR(stm); - (wasAndProcessed, stm) = ProcessAND(stm); - } while (wasOrProcessed || wasAndProcessed); - - return new QsStatement - ( - QsStatementKind.NewQsConditionalStatement(stm), - statement.SymbolDeclarations, - statement.Location, - statement.Comments - ); - } - return statement; - } - - public override QsScope Transform(QsScope scope) - { - var parentSymbols = this.onLocalDeclarations(scope.KnownSymbols); - var statements = new List(); - - foreach (var statement in scope.Statements) - { - if (statement.Statement is QsStatementKind.QsConditionalStatement) - { - var stm = ReshapeConditional(statement); - stm = this.onStatement(stm); - - var (isCondition, result, conditionExpression, conditionScope, defaultScope) = IsConditionedOnResultLiteralStatement(stm); - - if (isCondition) - { - statements.Add(CreateApplyIfStatement(stm, result, conditionExpression, conditionScope, defaultScope)); - } - else - { - statements.Add(stm); - } - } - else - { - statements.Add(this.onStatement(statement)); - } - } - - return new QsScope(statements.ToImmutableArray(), parentSymbols); - } - } - - // Transformation that updates the contents of newly generated operations by: - // 1. Rerouting the origins of type parameter references to the new operation - // 2. Changes the IsMutable info on variable that used to be mutable, but are now immutable params to the operation - private class UpdateGeneratedOpTransformation - { - private bool _IsRecursiveIdentifier = false; - private ImmutableArray>> _Parameters; - private QsQualifiedName _OldName; - private QsQualifiedName _NewName; - - public static QsCallable Apply(QsCallable qsCallable, ImmutableArray>> parameters, QsQualifiedName oldName, QsQualifiedName newName) - { - var filter = new SyntaxTreeTransformation>( - new ScopeTransformation( - new UpdateGeneratedOpExpression( - new UpdateGeneratedOpTransformation(parameters, oldName, newName)))); - - return filter.onCallableImplementation(qsCallable); - } - - private UpdateGeneratedOpTransformation(ImmutableArray>> parameters, QsQualifiedName oldName, QsQualifiedName newName) - { - _Parameters = parameters; - _OldName = oldName; - _NewName = newName; - } - - private class UpdateGeneratedOpExpression : ExpressionTransformation - { - private UpdateGeneratedOpTransformation _super; - - public UpdateGeneratedOpExpression(UpdateGeneratedOpTransformation super) : - base(expr => new UpdateGeneratedOpExpressionKind(super, expr as UpdateGeneratedOpExpression), - expr => new UpdateGeneratedOpExpressionType(super, expr as UpdateGeneratedOpExpression)) - { _super = super; } - - public override ImmutableDictionary>, ResolvedType> onTypeParamResolutions(ImmutableDictionary>, ResolvedType> typeParams) - { - // Prevent keys from having their names updated - return typeParams.ToImmutableDictionary(kvp => kvp.Key, kvp => this.Type.Transform(kvp.Value)); - } - - public override TypedExpression Transform(TypedExpression ex) - { - // Checks if expression is mutable identifier that is in parameter list - if (ex.InferredInformation.IsMutable && - ex.Expression is ExpressionKind.Identifier id && - id.Item1 is Identifier.LocalVariable variable && - _super._Parameters.Any(x => x.VariableName.Equals(variable))) - { - // Set the mutability to false - ex = new TypedExpression( - ex.Expression, - ex.TypeArguments, - ex.ResolvedType, - new InferredExpressionInformation(false, ex.InferredInformation.HasLocalQuantumDependency), - ex.Range); - } - - // Prevent _IsRecursiveIdentifier from propagating beyond the typed expression it is referring to - var isRecursiveIdentifier = _super._IsRecursiveIdentifier; - var rtrn = base.Transform(ex); - _super._IsRecursiveIdentifier = isRecursiveIdentifier; - return rtrn; - } - } - - private class UpdateGeneratedOpExpressionKind : ExpressionKindTransformation - { - private UpdateGeneratedOpTransformation _super; - - public UpdateGeneratedOpExpressionKind(UpdateGeneratedOpTransformation super, UpdateGeneratedOpExpression expr) : base(expr) { _super = super; } - - public override ExpressionKind onIdentifier(Identifier sym, QsNullable> tArgs) - { - var rtrn = base.onIdentifier(sym, tArgs); - - // Then check if this is a recursive identifier - // In this context, that is a call back to the original callable from the newly generated operation - if (sym is Identifier.GlobalCallable callable && _super._OldName.Equals(callable.Item)) - { - // Setting this flag will prevent the rerouting logic from processing the resolved type of the recursive identifier expression. - // This is necessary because we don't want any type parameters from the original callable from being rerouted to the new generated - // operation's type parameters in the definition of the identifier. - _super._IsRecursiveIdentifier = true; - } - return rtrn; - } - } - - private class UpdateGeneratedOpExpressionType : ExpressionTypeTransformation - { - private UpdateGeneratedOpTransformation _super; - - public UpdateGeneratedOpExpressionType(UpdateGeneratedOpTransformation super, UpdateGeneratedOpExpression expr) : base(expr) { _super = super; } - - public override ResolvedTypeKind onTypeParameter(QsTypeParameter tp) - { - // Reroute a type parameter's origin to the newly generated operation - if (!_super._IsRecursiveIdentifier && _super._OldName.Equals(tp.Origin)) - { - tp = new QsTypeParameter(_super._NewName, tp.TypeName, tp.Range); - } - - return base.onTypeParameter(tp); - } - } - } - - // Transformation handling the first pass task of hoisting of the contents of conditional statements. - // If blocks are first validated to see if they can safely be hoisted into their own operation. - // Validation requirements are that there are no return statements and that there are no set statements - // on mutables declared outside the block. Setting mutables declared inside the block is valid. - // If the block is valid, and there is more than one statement in the block, a new operation with the - // block's contents is generated, having all the same type parameters as the calling context - // and all known variables at the start of the block become parameters to the new operation. - // The contents of the conditional block are then replaced with a call to the new operation with all - // the type parameters and known variables being forwarded to the new operation as arguments. - private class HoistTransformation - { - private bool _IsValidScope = true; - private List _ControlOperations; - private ImmutableArray>> _CurrentHoistParams = - ImmutableArray>>.Empty; - private bool _ContainsHoistParamRef = false; - - private class CallableDetails - { - public QsCallable Callable; - public QsSpecialization Adjoint; - public QsSpecialization Controlled; - public QsSpecialization ControlledAdjoint; - public QsNullable> TypeParamTypes; - - public CallableDetails(QsCallable callable) - { - Callable = callable; - Adjoint = callable.Specializations.FirstOrDefault(spec => spec.Kind == QsSpecializationKind.QsAdjoint); - Controlled = callable.Specializations.FirstOrDefault(spec => spec.Kind == QsSpecializationKind.QsControlled); - ControlledAdjoint = callable.Specializations.FirstOrDefault(spec => spec.Kind == QsSpecializationKind.QsControlledAdjoint); - TypeParamTypes = callable.Signature.TypeParameters.Any(param => param.IsValidName) - ? QsNullable>.NewValue(callable.Signature.TypeParameters - .Where(param => param.IsValidName) - .Select(param => - ResolvedType.New(ResolvedTypeKind.NewTypeParameter(new QsTypeParameter( - callable.FullName, - ((QsLocalSymbol.ValidName)param).Item, - QsNullable>.Null - )))) - .ToImmutableArray()) - : QsNullable>.Null; - } - } - - private CallableDetails _CurrentCallable = null; - private bool _InBody = false; - private bool _InAdjoint = false; - private bool _InControlled = false; - - private bool _InWithinBlock = false; - - public static QsCompilation Apply(QsCompilation compilation) - { - var filter = new HoistSyntax(new HoistTransformation()); - - return new QsCompilation(compilation.Namespaces.Select(ns => filter.Transform(ns)).ToImmutableArray(), compilation.EntryPoints); - } - - private (ResolvedSignature, IEnumerable) MakeSpecializations(QsQualifiedName callableName, ResolvedType argsType, SpecializationImplementation bodyImplementation) - { - QsSpecialization MakeSpec(QsSpecializationKind kind, ResolvedSignature signature, SpecializationImplementation impl) => - new QsSpecialization( - kind, - callableName, - ImmutableArray.Empty, - _CurrentCallable.Callable.SourceFile, - QsNullable.Null, - QsNullable>.Null, - signature, - impl, - ImmutableArray.Empty, - QsComments.Empty); - - var adj = _CurrentCallable.Adjoint; - var ctl = _CurrentCallable.Controlled; - var ctlAdj = _CurrentCallable.ControlledAdjoint; - - bool addAdjoint = false; - bool addControlled = false; - - if (_InWithinBlock) - { - addAdjoint = true; - addControlled = false; - } - else if (_InBody) - { - if (adj != null && adj.Implementation is SpecializationImplementation.Generated adjGen) addAdjoint = adjGen.Item.IsInvert; - if (ctl != null && ctl.Implementation is SpecializationImplementation.Generated ctlGen) addControlled = ctlGen.Item.IsDistribute; - if (ctlAdj != null && ctlAdj.Implementation is SpecializationImplementation.Generated ctlAdjGen) - { - addAdjoint = addAdjoint || ctlAdjGen.Item.IsInvert && ctl.Implementation.IsGenerated; - addControlled = addControlled || ctlAdjGen.Item.IsDistribute && adj.Implementation.IsGenerated; - } - } - else if (ctlAdj != null && ctlAdj.Implementation is SpecializationImplementation.Generated gen) - { - addControlled = _InAdjoint && gen.Item.IsDistribute; - addAdjoint = _InControlled && gen.Item.IsInvert; - } - - var props = new List(); - if (addAdjoint) props.Add(OpProperty.Adjointable); - if (addControlled) props.Add(OpProperty.Controllable); - var newSig = new ResolvedSignature( - _CurrentCallable.Callable.Signature.TypeParameters, - argsType, - ResolvedType.New(ResolvedTypeKind.UnitType), - new CallableInformation(ResolvedCharacteristics.FromProperties(props), InferredCallableInformation.NoInformation)); - - var controlledSig = new ResolvedSignature( - newSig.TypeParameters, - ResolvedType.New(ResolvedTypeKind.NewTupleType(ImmutableArray.Create( - ResolvedType.New(ResolvedTypeKind.NewArrayType(ResolvedType.New(ResolvedTypeKind.Qubit))), - newSig.ArgumentType))), - newSig.ReturnType, - newSig.Information); - - var specializations = new List() { MakeSpec(QsSpecializationKind.QsBody, newSig, bodyImplementation) }; - - if (addAdjoint) - { - specializations.Add(MakeSpec( - QsSpecializationKind.QsAdjoint, - newSig, - SpecializationImplementation.NewGenerated(QsGeneratorDirective.Invert))); - } - - if (addControlled) - { - specializations.Add(MakeSpec( - QsSpecializationKind.QsControlled, - controlledSig, - SpecializationImplementation.NewGenerated(QsGeneratorDirective.Distribute))); - } - - if (addAdjoint && addControlled) - { - specializations.Add(MakeSpec( - QsSpecializationKind.QsControlledAdjoint, - controlledSig, - SpecializationImplementation.NewGenerated(QsGeneratorDirective.Distribute))); - } - - return (newSig, specializations); - } - - private (QsCallable, ResolvedType) GenerateOperation(QsScope contents) - { - var newName = new QsQualifiedName( - _CurrentCallable.Callable.FullName.Namespace, - NonNullable.New("_" + Guid.NewGuid().ToString("N") + "_" + _CurrentCallable.Callable.FullName.Name.Value)); - - var knownVariables = contents.KnownSymbols.IsEmpty - ? ImmutableArray>>.Empty - : contents.KnownSymbols.Variables; - - var parameters = QsTuple>.NewQsTuple(knownVariables - .Select(var => QsTuple>.NewQsTupleItem(new LocalVariableDeclaration( - QsLocalSymbol.NewValidName(var.VariableName), - var.Type, - var.InferredInformation, - var.Position, - var.Range))) - .ToImmutableArray()); - - var paramTypes = ResolvedType.New(ResolvedTypeKind.UnitType); - if (knownVariables.Length == 1) - { - paramTypes = knownVariables.First().Type; - } - else if (knownVariables.Length > 1) - { - paramTypes = ResolvedType.New(ResolvedTypeKind.NewTupleType(knownVariables - .Select(var => var.Type) - .ToImmutableArray())); - } - - var (signature, specializations) = MakeSpecializations(newName, paramTypes, SpecializationImplementation.NewProvided(parameters, contents)); - - var controlCallable = new QsCallable( - QsCallableKind.Operation, - newName, - ImmutableArray.Empty, - _CurrentCallable.Callable.SourceFile, - QsNullable.Null, - signature, - parameters, - specializations.ToImmutableArray(), - ImmutableArray.Empty, - QsComments.Empty); - - var updatedCallable = UpdateGeneratedOpTransformation.Apply(controlCallable, knownVariables, _CurrentCallable.Callable.FullName, newName); - - return (updatedCallable, signature.ArgumentType); - } - - private HoistTransformation() { } - - private class HoistSyntax : SyntaxTreeTransformation> - { - private HoistTransformation _super; - - public HoistSyntax(HoistTransformation super, ScopeTransformation scope = null) : - base(scope ?? new ScopeTransformation( - scopeTransform => new HoistStatementKind(super, scopeTransform), - new HoistExpression(super))) - { _super = super; } - - public override QsCallable onCallableImplementation(QsCallable c) - { - _super._CurrentCallable = new CallableDetails(c); - return base.onCallableImplementation(c); - } - - public override QsSpecialization onBodySpecialization(QsSpecialization spec) - { - _super._InBody = true; - var rtrn = base.onBodySpecialization(spec); - _super._InBody = false; - return rtrn; - } - - public override QsSpecialization onAdjointSpecialization(QsSpecialization spec) - { - _super._InAdjoint = true; - var rtrn = base.onAdjointSpecialization(spec); - _super._InAdjoint = false; - return rtrn; - } - - public override QsSpecialization onControlledSpecialization(QsSpecialization spec) - { - _super._InControlled = true; - var rtrn = base.onControlledSpecialization(spec); - _super._InControlled = false; - return rtrn; - } - - public override QsCallable onFunction(QsCallable c) => c; // Prevent anything in functions from being hoisted - - public override QsNamespace Transform(QsNamespace ns) - { - // Control operations list will be populated in the transform - _super._ControlOperations = new List(); - return base.Transform(ns) - .WithElements(elems => elems.AddRange(_super._ControlOperations.Select(op => QsNamespaceElement.NewQsCallable(op)))); - } - } - - private class HoistStatementKind : StatementKindTransformation> - { - private HoistTransformation _super; - - public HoistStatementKind(HoistTransformation super, ScopeTransformation scope) : base(scope) { _super = super; } - - private (QsCallable, QsStatement) HoistIfContents(QsScope contents) - { - var (targetOp, originalArgumentType) = _super.GenerateOperation(contents); - var targetOpType = ResolvedType.New(ResolvedTypeKind.NewOperation( - Tuple.Create( - originalArgumentType, - ResolvedType.New(ResolvedTypeKind.UnitType)), - targetOp.Signature.Information)); - - var targetTypeArgTypes = _super._CurrentCallable.TypeParamTypes; - var targetOpId = new TypedExpression - ( - ExpressionKind.NewIdentifier(Identifier.NewGlobalCallable(targetOp.FullName), targetTypeArgTypes), - targetTypeArgTypes.IsNull - ? TypeArgsResolution.Empty - : targetTypeArgTypes.Item - .Select(type => Tuple.Create(targetOp.FullName, ((ResolvedTypeKind.TypeParameter)type.Resolution).Item.TypeName, type)) - .ToImmutableArray(), - targetOpType, - new InferredExpressionInformation(false, false), - QsNullable>.Null - ); - - var knownSymbols = contents.KnownSymbols.Variables; - - TypedExpression targetArgs = null; - if (knownSymbols.Any()) - { - targetArgs = CreateValueTupleExpression(knownSymbols.Select(var => CreateIdentifierExpression( - Identifier.NewLocalVariable(var.VariableName), - TypeArgsResolution.Empty, - var.Type)) - .ToArray()); - } - else - { - targetArgs = new TypedExpression - ( - ExpressionKind.UnitValue, - TypeArgsResolution.Empty, - ResolvedType.New(ResolvedTypeKind.UnitType), - new InferredExpressionInformation(false, false), - QsNullable>.Null - ); - } - - var call = new TypedExpression - ( - ExpressionKind.NewCallLikeExpression(targetOpId, targetArgs), - targetTypeArgTypes.IsNull - ? TypeArgsResolution.Empty - : targetTypeArgTypes.Item - .Select(type => Tuple.Create(_super._CurrentCallable.Callable.FullName, ((ResolvedTypeKind.TypeParameter)type.Resolution).Item.TypeName, type)) - .ToImmutableArray(), - ResolvedType.New(ResolvedTypeKind.UnitType), - new InferredExpressionInformation(false, true), - QsNullable>.Null - ); - - return (targetOp, new QsStatement( - QsStatementKind.NewQsExpressionStatement(call), - LocalDeclarations.Empty, - QsNullable.Null, - QsComments.Empty)); - } - - private bool IsScopeSingleCall(QsScope contents) - { - if (contents.Statements.Length != 1) return false; - - return contents.Statements[0].Statement is QsStatementKind.QsExpressionStatement expr - && expr.Item.Expression is ExpressionKind.CallLikeExpression call - && !TypedExpression.IsPartialApplication(expr.Item.Expression) - && call.Item1.Expression is ExpressionKind.Identifier; - } - - public override QsStatementKind onConjugation(QsConjugation stm) - { - var superInWithinBlock = _super._InWithinBlock; - _super._InWithinBlock = true; - var (_, outer) = this.onPositionedBlock(QsNullable.Null, stm.OuterTransformation); - _super._InWithinBlock = superInWithinBlock; - - var (_, inner) = this.onPositionedBlock(QsNullable.Null, stm.InnerTransformation); - - return QsStatementKind.NewQsConjugation(new QsConjugation(outer, inner)); - } - - public override QsStatementKind onReturnStatement(TypedExpression ex) - { - _super._IsValidScope = false; - return base.onReturnStatement(ex); - } - - public override QsStatementKind onValueUpdate(QsValueUpdate stm) - { - // If lhs contains an identifier found in the scope's known variables (variables from the super-scope), the scope is not valid - var lhs = this.ExpressionTransformation(stm.Lhs); - - if (_super._ContainsHoistParamRef) - { - _super._IsValidScope = false; - } - - var rhs = this.ExpressionTransformation(stm.Rhs); - return QsStatementKind.NewQsValueUpdate(new QsValueUpdate(lhs, rhs)); - } - - public override QsStatementKind onConditionalStatement(QsConditionalStatement stm) - { - var contextValidScope = _super._IsValidScope; - var contextHoistParams = _super._CurrentHoistParams; - - var isHoistValid = true; - - var newConditionBlocks = new List>(); - var generatedOperations = new List(); - foreach (var condBlock in stm.ConditionalBlocks) - { - _super._IsValidScope = true; - _super._CurrentHoistParams = condBlock.Item2.Body.KnownSymbols.IsEmpty - ? ImmutableArray>>.Empty - : condBlock.Item2.Body.KnownSymbols.Variables; - - var (expr, block) = this.onPositionedBlock(QsNullable.NewValue(condBlock.Item1), condBlock.Item2); - - // ToDo: Reduce the number of unnecessary generated operations by generalizing - // the condition logic for the conversion and using that condition here - //var (isExprCond, _, _) = IsConditionedOnResultLiteralExpression(expr.Item); - - if (block.Body.Statements.Length > 0 /*&& isExprCond*/ && _super._IsValidScope && !IsScopeSingleCall(block.Body)) // if sub-scope is valid, hoist content - { - // Hoist the scope to its own operation - var (callable, call) = HoistIfContents(block.Body); - block = new QsPositionedBlock( - new QsScope(ImmutableArray.Create(call), block.Body.KnownSymbols), - block.Location, - block.Comments); - newConditionBlocks.Add(Tuple.Create(expr.Item,block)); - generatedOperations.Add(callable); - } - else - { - isHoistValid = false; - break; - } - } - - var newDefault = QsNullable.Null; - if (isHoistValid && stm.Default.IsValue) - { - _super._IsValidScope = true; - _super._CurrentHoistParams = stm.Default.Item.Body.KnownSymbols.IsEmpty - ? ImmutableArray>>.Empty - : stm.Default.Item.Body.KnownSymbols.Variables; - - var (_, block) = this.onPositionedBlock(QsNullable.Null, stm.Default.Item); - if (block.Body.Statements.Length > 0 && _super._IsValidScope && !IsScopeSingleCall(block.Body)) // if sub-scope is valid, hoist content - { - // Hoist the scope to its own operation - var (callable, call) = HoistIfContents(block.Body); - block = new QsPositionedBlock( - new QsScope(ImmutableArray.Create(call), block.Body.KnownSymbols), - block.Location, - block.Comments); - newDefault = QsNullable.NewValue(block); - generatedOperations.Add(callable); - } - else - { - isHoistValid = false; - } - } - - if (isHoistValid) - { - _super._ControlOperations.AddRange(generatedOperations); - } - - _super._CurrentHoistParams = contextHoistParams; - _super._IsValidScope = contextValidScope; - - return isHoistValid - ? QsStatementKind.NewQsConditionalStatement( - new QsConditionalStatement(newConditionBlocks.ToImmutableArray(), newDefault)) - : QsStatementKind.NewQsConditionalStatement( - new QsConditionalStatement(stm.ConditionalBlocks, stm.Default)); - } - - public override QsStatementKind Transform(QsStatementKind kind) - { - _super._ContainsHoistParamRef = false; // Every statement kind starts off false - return base.Transform(kind); - } - } - - private class HoistExpression : ExpressionTransformation - { - private HoistTransformation _super; - - public HoistExpression(HoistTransformation super) : - base(expr => new HoistExpressionKind(super, expr as HoistExpression)) - { _super = super; } - - public override TypedExpression Transform(TypedExpression ex) - { - var contextContainsHoistParamRef = _super._ContainsHoistParamRef; - _super._ContainsHoistParamRef = false; - var rtrn = base.Transform(ex); - - // If the sub context contains a reference, then the super context contains a reference, - // otherwise return the super context to its original value - if (!_super._ContainsHoistParamRef) - { - _super._ContainsHoistParamRef = contextContainsHoistParamRef; - } - - return rtrn; - } - } - - private class HoistExpressionKind : ExpressionKindTransformation - { - private HoistTransformation _super; - - public HoistExpressionKind(HoistTransformation super, HoistExpression expr) : base(expr) { _super = super; } - - public override ExpressionKind onIdentifier(Identifier sym, QsNullable> tArgs) - { - if (sym is Identifier.LocalVariable local && - _super._CurrentHoistParams.Any(param => param.VariableName.Equals(local.Item))) - { - _super._ContainsHoistParamRef = true; - } - return base.onIdentifier(sym, tArgs); - } - } - } - } -} diff --git a/src/QsCompiler/Transformations/ClassicallyControlledUtils.cs b/src/QsCompiler/Transformations/ClassicallyControlledUtils.cs new file mode 100644 index 0000000000..bc030db7a7 --- /dev/null +++ b/src/QsCompiler/Transformations/ClassicallyControlledUtils.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.Quantum.QsCompiler.DataTypes; +using Microsoft.Quantum.QsCompiler.SyntaxTokens; +using Microsoft.Quantum.QsCompiler.SyntaxTree; + + +namespace Microsoft.Quantum.QsCompiler.Transformations.ClassicallyControlled +{ + using ExpressionKind = QsExpressionKind; + using ResolvedTypeKind = QsTypeKind; + using TypeArgsResolution = ImmutableArray, ResolvedType>>; + + /// + /// These tools are specific to the classically-controlled transformation and are not intended for wider use in their current state. + /// They rely on the specific context in which they are invoked during that transformation and are not general purpuse tools. + /// + internal static class Utils + { + internal static TypedExpression CreateIdentifierExpression(Identifier id, + TypeArgsResolution typeArgsMapping, ResolvedType resolvedType) => + new TypedExpression + ( + ExpressionKind.NewIdentifier( + id, + typeArgsMapping.Any() + ? QsNullable>.NewValue(typeArgsMapping + .Select(argMapping => argMapping.Item3) // This should preserve the order of the type args + .ToImmutableArray()) + : QsNullable>.Null), + typeArgsMapping, + resolvedType, + new InferredExpressionInformation(false, false), + QsNullable>.Null + ); + + internal static TypedExpression CreateValueTupleExpression(params TypedExpression[] expressions) => + new TypedExpression + ( + ExpressionKind.NewValueTuple(expressions.ToImmutableArray()), + TypeArgsResolution.Empty, + ResolvedType.New(ResolvedTypeKind.NewTupleType(expressions.Select(expr => expr.ResolvedType).ToImmutableArray())), + new InferredExpressionInformation(false, false), + QsNullable>.Null + ); + + internal static TypedExpression CreateCallLikeExpression(TypedExpression id, TypedExpression args, TypeArgsResolution typeRes) => + new TypedExpression + ( + ExpressionKind.NewCallLikeExpression(id, args), + typeRes, + ResolvedType.New(ResolvedTypeKind.UnitType), + new InferredExpressionInformation(false, true), + QsNullable>.Null + ); + + internal static ResolvedType GetOperationType(IEnumerable props, ResolvedType argumentType) + { + var characteristics = new CallableInformation( + ResolvedCharacteristics.FromProperties(props), + InferredCallableInformation.NoInformation); + + return ResolvedType.New(ResolvedTypeKind.NewOperation( + Tuple.Create(argumentType, ResolvedType.New(ResolvedTypeKind.UnitType)), + characteristics)); + } + + internal static TypeArgsResolution GetCombinedTypeResolution(TypeArgsResolution outer, TypeArgsResolution inner) + { + var outerDict = outer.ToDictionary(x => (x.Item1, x.Item2), x => x.Item3); + return inner.Select(innerRes => + { + if (innerRes.Item3.Resolution is ResolvedTypeKind.TypeParameter typeParam && + outerDict.TryGetValue((typeParam.Item.Origin, typeParam.Item.TypeName), out var outerRes)) + { + outerDict.Remove((typeParam.Item.Origin, typeParam.Item.TypeName)); + return Tuple.Create(innerRes.Item1, innerRes.Item2, outerRes); + } + else + { + return innerRes; + } + }) + .Concat(outerDict.Select(x => Tuple.Create(x.Key.Item1, x.Key.Item2, x.Value))).ToImmutableArray(); + } + } +} diff --git a/src/QsCompiler/Transformations/CodeOutput.cs b/src/QsCompiler/Transformations/CodeOutput.cs deleted file mode 100644 index c82cdfba06..0000000000 --- a/src/QsCompiler/Transformations/CodeOutput.cs +++ /dev/null @@ -1,1311 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Globalization; -using System.Linq; -using System.Numerics; -using System.Text.RegularExpressions; -using Microsoft.Quantum.QsCompiler.DataTypes; -using Microsoft.Quantum.QsCompiler.SyntaxTokens; -using Microsoft.Quantum.QsCompiler.SyntaxTree; -using Microsoft.Quantum.QsCompiler.TextProcessing; -using Microsoft.Quantum.QsCompiler.Transformations.BasicTransformations; - - -namespace Microsoft.Quantum.QsCompiler.Transformations.QsCodeOutput -{ - using QsTypeKind = QsTypeKind; - using QsExpressionKind = QsExpressionKind; - - - // Q# type to Code - - /// - /// Class used to generate Q# code for Q# types. - /// Adds an Output string property to ExpressionTypeTransformation, - /// that upon calling Transform on a Q# type is set to the Q# code corresponding to that type. - /// - public class ExpressionTypeToQs : - ExpressionTypeTransformation - { - public string Output; - public const string InvalidType = "__UnknownType__"; - public const string InvalidSet = "__UnknownSet__"; - - public Action beforeInvalidType; - public Action beforeInvalidSet; - - public string Apply(ResolvedType t) - { - this.Transform(t); - return this.Output; - } - - public ExpressionTypeToQs(ExpressionToQs expression) : - base(expression) - { } - - - public override QsTypeKind onArrayType(ResolvedType b) - { - this.Output = $"{this.Apply(b)}[]"; - return QsTypeKind.NewArrayType(b); - } - - public override QsTypeKind onBool() - { - this.Output = Keywords.qsBool.id; - return QsTypeKind.Bool; - } - - public override QsTypeKind onDouble() - { - this.Output = Keywords.qsDouble.id; - return QsTypeKind.Double; - } - - public override QsTypeKind onFunction(ResolvedType it, ResolvedType ot) - { - this.Output = $"({this.Apply(it)} -> {this.Apply(ot)})"; - return QsTypeKind.NewFunction(it, ot); - } - - public override QsTypeKind onInt() - { - this.Output = Keywords.qsInt.id; - return QsTypeKind.Int; - } - - public override QsTypeKind onBigInt() - { - this.Output = Keywords.qsBigInt.id; - return QsTypeKind.BigInt; - } - - public override QsTypeKind onInvalidType() - { - this.beforeInvalidType?.Invoke(); - this.Output = InvalidType; - return QsTypeKind.InvalidType; - } - - public override QsTypeKind onMissingType() - { - this.Output = "_"; // needs to be underscore, since this is valid as type argument specifier - return QsTypeKind.MissingType; - } - - public override ResolvedCharacteristics onCharacteristicsExpression(ResolvedCharacteristics fs) - { - int CurrentPrecedence = 0; - string SetPrecedenceAndReturn(int prec, string str) - { - CurrentPrecedence = prec; - return str; - } - - string Recur(int prec, ResolvedCharacteristics ex) - { - var output = SetAnnotation(ex); - return prec < CurrentPrecedence || CurrentPrecedence == int.MaxValue ? output : $"({output})"; - } - - string BinaryOperator(Keywords.QsOperator op, ResolvedCharacteristics lhs, ResolvedCharacteristics rhs) => - SetPrecedenceAndReturn(op.prec, $"{Recur(op.prec, lhs)} {op.op} {Recur(op.prec, rhs)}"); - - string SetAnnotation(ResolvedCharacteristics characteristics) - { - if (characteristics.Expression is CharacteristicsKind.SimpleSet set) - { - string setName = null; - if (set.Item.IsAdjointable) setName = Keywords.qsAdjSet.id; - else if (set.Item.IsControllable) setName = Keywords.qsCtlSet.id; - else throw new NotImplementedException("unknown set name"); - return SetPrecedenceAndReturn(int.MaxValue, setName); - } - else if (characteristics.Expression is CharacteristicsKind.Union u) - { return BinaryOperator(Keywords.qsSetUnion, u.Item1, u.Item2); } - else if (characteristics.Expression is CharacteristicsKind.Intersection i) - { return BinaryOperator(Keywords.qsSetIntersection, i.Item1, i.Item2); } - else if (characteristics.Expression.IsInvalidSetExpr) - { - this.beforeInvalidSet?.Invoke(); - return SetPrecedenceAndReturn(int.MaxValue, InvalidSet); - } - else throw new NotImplementedException("unknown set expression"); - } - - this.Output = fs.Expression.IsEmptySet ? null : SetAnnotation(fs); - return fs; - } - - public override QsTypeKind onOperation(Tuple sign, CallableInformation info) - { - info = base.onCallableInformation(info); - var characteristics = String.IsNullOrWhiteSpace(this.Output) ? "" : $" {Keywords.qsCharacteristics.id} {this.Output}"; - this.Output = $"({this.Apply(sign.Item1)} => {this.Apply(sign.Item2)}{characteristics})"; - return QsTypeKind.NewOperation(sign, info); - } - - public override QsTypeKind onPauli() - { - this.Output = Keywords.qsPauli.id; - return QsTypeKind.Pauli; - } - - public override QsTypeKind onQubit() - { - this.Output = Keywords.qsQubit.id; - return QsTypeKind.Qubit; - } - - public override QsTypeKind onRange() - { - this.Output = Keywords.qsRange.id; - return QsTypeKind.Range; - } - - public override QsTypeKind onResult() - { - this.Output = Keywords.qsResult.id; - return QsTypeKind.Result; - } - - public override QsTypeKind onString() - { - this.Output = Keywords.qsString.id; - return QsTypeKind.String; - } - - public override QsTypeKind onTupleType(ImmutableArray ts) - { - this.Output = $"({String.Join(", ", ts.Select(this.Apply))})"; - return QsTypeKind.NewTupleType(ts); - } - - public override QsTypeKind onTypeParameter(QsTypeParameter tp) - { - this.Output = $"'{tp.TypeName.Value}"; - return QsTypeKind.NewTypeParameter(tp); - } - - public override QsTypeKind onUnitType() - { - this.Output = Keywords.qsUnit.id; - return QsTypeKind.UnitType; - } - - public override QsTypeKind onUserDefinedType(UserDefinedType udt) - { - var isInCurrentNamespace = udt.Namespace.Value == this._Expression.Context.CurrentNamespace; - var isInOpenNamespace = this._Expression.Context.OpenedNamespaces.Contains(udt.Namespace) && !this._Expression.Context.SymbolsInCurrentNamespace.Contains(udt.Name); - var hasShortName = this._Expression.Context.NamespaceShortNames.TryGetValue(udt.Namespace, out var shortName); - this.Output = isInCurrentNamespace || (isInOpenNamespace && !this._Expression.Context.AmbiguousNames.Contains(udt.Name)) - ? udt.Name.Value - : $"{(hasShortName ? shortName.Value : udt.Namespace.Value)}.{udt.Name.Value}"; - return QsTypeKind.NewUserDefinedType(udt); - } - } - - - /// - /// Class used to generate Q# code for Q# expressions. - /// Upon calling Transform, the Output property is set to the Q# code corresponding to an expression of the given kind. - /// - public class ExpressionKindToQs : - ExpressionKindTransformation - { - public string Output; - - public const string InvalidIdentifier = "__UnknownId__"; - public const string InvalidExpression = "__InvalidEx__"; - - public Action beforeInvalidIdentifier; - public Action beforeInvalidExpression; - - /// - /// allows to omit unnecessary parentheses - /// - private int CurrentPrecedence = 0; - - public string Apply(QsExpressionKind k) - { - this.Transform(k); - return this.Output; - } - - public ExpressionKindToQs(ExpressionToQs expression) : - base(expression) - { } - - - private string Type(ResolvedType t) => - this._Expression._Type.Apply(t); - - private string Recur(int prec, TypedExpression ex) - { - this._Expression.Transform(ex); - return prec < this.CurrentPrecedence || this.CurrentPrecedence == int.MaxValue // need to cover the case where prec = currentPrec = MaxValue - ? this.Output - : $"({this.Output})"; - } - - private void UnaryOperator(Keywords.QsOperator op, TypedExpression ex) - { - this.Output = Keywords.ReservedKeywords.Contains(op.op) - ? $"{op.op} {this.Recur(op.prec, ex)}" - : $"{op.op}{this.Recur(op.prec, ex)}"; - this.CurrentPrecedence = op.prec; - } - - private void BinaryOperator(Keywords.QsOperator op, TypedExpression lhs, TypedExpression rhs) - { - this.Output = $"{this.Recur(op.prec, lhs)} {op.op} {this.Recur(op.prec, rhs)}"; - this.CurrentPrecedence = op.prec; - } - - private void TernaryOperator(Keywords.QsOperator op, TypedExpression fst, TypedExpression snd, TypedExpression trd) - { - this.Output = $"{this.Recur(op.prec, fst)} {op.op} {this.Recur(op.prec, snd)} {op.cont} {this.Recur(op.prec, trd)}"; - this.CurrentPrecedence = op.prec; - } - - private QsExpressionKind CallLike(TypedExpression method, TypedExpression arg) - { - var prec = Keywords.qsCallCombinator.prec; - var argStr = arg.Expression.IsValueTuple || arg.Expression.IsUnitValue ? this.Recur(int.MinValue, arg) : $"({this.Recur(int.MinValue, arg)})"; - this.Output = $"{this.Recur(prec, method)}{argStr}"; - this.CurrentPrecedence = prec; - return QsExpressionKind.NewCallLikeExpression(method, arg); - } - - - public override QsExpressionKind onIdentifier(Identifier sym, QsNullable> tArgs) - { - if (sym is Identifier.LocalVariable loc) - { this.Output = loc.Item.Value; } - else if (sym.IsInvalidIdentifier) - { - this.beforeInvalidIdentifier?.Invoke(); - this.Output = InvalidIdentifier; - } - else if (sym is Identifier.GlobalCallable global) - { - var isInCurrentNamespace = global.Item.Namespace.Value == this._Expression.Context.CurrentNamespace; - var isInOpenNamespace = this._Expression.Context.OpenedNamespaces.Contains(global.Item.Namespace) && !this._Expression.Context.SymbolsInCurrentNamespace.Contains(global.Item.Name); - var hasShortName = this._Expression.Context.NamespaceShortNames.TryGetValue(global.Item.Namespace, out var shortName); - this.Output = isInCurrentNamespace || (isInOpenNamespace && !this._Expression.Context.AmbiguousNames.Contains(global.Item.Name)) - ? global.Item.Name.Value - : $"{(hasShortName ? shortName.Value : global.Item.Namespace.Value)}.{global.Item.Name.Value}"; - } - else throw new NotImplementedException("unknown identifier kind"); - - if (tArgs.IsValue) - { this.Output = $"{this.Output}<{ String.Join(", ", tArgs.Item.Select(this.Type))}>"; } - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewIdentifier(sym, tArgs); - } - - public override QsExpressionKind onOperationCall(TypedExpression method, TypedExpression arg) => - this.CallLike(method, arg); - - public override QsExpressionKind onFunctionCall(TypedExpression method, TypedExpression arg) => - this.CallLike(method, arg); - - public override QsExpressionKind onPartialApplication(TypedExpression method, TypedExpression arg) => - this.CallLike(method, arg); - - public override QsExpressionKind onAdjointApplication(TypedExpression ex) - { - var op = Keywords.qsAdjointModifier; - this.Output = $"{op.op} {this.Recur(op.prec, ex)}"; - this.CurrentPrecedence = op.prec; - return QsExpressionKind.NewAdjointApplication(ex); - } - - public override QsExpressionKind onControlledApplication(TypedExpression ex) - { - var op = Keywords.qsControlledModifier; - this.Output = $"{op.op} {this.Recur(op.prec, ex)}"; - this.CurrentPrecedence = op.prec; - return QsExpressionKind.NewControlledApplication(ex); - } - - public override QsExpressionKind onUnwrapApplication(TypedExpression ex) - { - var op = Keywords.qsUnwrapModifier; - this.Output = $"{this.Recur(op.prec, ex)}{op.op}"; - this.CurrentPrecedence = op.prec; - return QsExpressionKind.NewUnwrapApplication(ex); - } - - public override QsExpressionKind onUnitValue() - { - this.Output = "()"; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.UnitValue; - } - - public override QsExpressionKind onMissingExpression() - { - this.Output = "_"; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.MissingExpr; - } - - public override QsExpressionKind onInvalidExpression() - { - this.beforeInvalidExpression?.Invoke(); - this.Output = InvalidExpression; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.InvalidExpr; - } - - public override QsExpressionKind onValueTuple(ImmutableArray vs) - { - this.Output = $"({String.Join(", ", vs.Select(v => this.Recur(int.MinValue, v)))})"; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewValueTuple(vs); - } - - public override QsExpressionKind onValueArray(ImmutableArray vs) - { - this.Output = $"[{String.Join(", ", vs.Select(v => this.Recur(int.MinValue, v)))}]"; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewValueArray(vs); - } - - public override QsExpressionKind onNewArray(ResolvedType bt, TypedExpression idx) - { - this.Output = $"{Keywords.arrayDecl.id} {this.Type(bt)}[{this.Recur(int.MinValue, idx)}]"; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewNewArray(bt, idx); - } - - public override QsExpressionKind onArrayItem(TypedExpression arr, TypedExpression idx) - { - var prec = Keywords.qsArrayAccessCombinator.prec; - this.Output = $"{this.Recur(prec,arr)}[{this.Recur(int.MinValue, idx)}]"; // Todo: generate contextual open range expression when appropriate - this.CurrentPrecedence = prec; - return QsExpressionKind.NewArrayItem(arr, idx); - } - - public override QsExpressionKind onNamedItem(TypedExpression ex, Identifier acc) - { - this.onIdentifier(acc, QsNullable>.Null); - var (op, itemName) = (Keywords.qsNamedItemCombinator, this.Output); - this.Output = $"{this.Recur(op.prec,ex)}{op.op}{itemName}"; - return base.onNamedItem(ex, acc); - } - - public override QsExpressionKind onIntLiteral(long i) - { - this.Output = i.ToString(CultureInfo.InvariantCulture); - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewIntLiteral(i); - } - - public override QsExpressionKind onBigIntLiteral(BigInteger b) - { - this.Output = b.ToString("R", CultureInfo.InvariantCulture) + "L"; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewBigIntLiteral(b); - } - - public override QsExpressionKind onDoubleLiteral(double d) - { - this.Output = d.ToString("R", CultureInfo.InvariantCulture); - if ((int)d == d) this.Output = $"{this.Output}.0"; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewDoubleLiteral(d); - } - - public override QsExpressionKind onBoolLiteral(bool b) - { - if (b) this.Output = Keywords.qsTrue.id; - else this.Output = Keywords.qsFalse.id; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewBoolLiteral(b); - } - - private static readonly Regex InterpolationArg = new Regex(@"(? replace) - { - var itemNr = 0; - string ReplaceMatch(Match m) => replace?.Invoke(itemNr++); - return InterpolationArg.Replace(text, ReplaceMatch); - } - - public override QsExpressionKind onStringLiteral(NonNullable s, ImmutableArray exs) - { - string InterpolatedArg(int index) => $"{{{this.Recur(int.MinValue, exs[index])}}}"; - this.Output = exs.Length == 0 ? $"\"{s.Value}\"" : $"$\"{ReplaceInterpolatedArgs(s.Value, InterpolatedArg)}\""; - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewStringLiteral(s, exs); - } - - public override QsExpressionKind onRangeLiteral(TypedExpression lhs, TypedExpression rhs) - { - var op = Keywords.qsRangeOp; - var lhsStr = lhs.Expression.IsRangeLiteral ? this.Recur(int.MinValue, lhs) : this.Recur(op.prec, lhs); - this.Output = $"{lhsStr} {op.op} {this.Recur(op.prec, rhs)}"; - this.CurrentPrecedence = op.prec; - return QsExpressionKind.NewRangeLiteral(lhs, rhs); - } - - public override QsExpressionKind onResultLiteral(QsResult r) - { - if (r.IsZero) this.Output = Keywords.qsZero.id; - else if (r.IsOne) this.Output = Keywords.qsOne.id; - else throw new NotImplementedException("unknown Result literal"); - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewResultLiteral(r); - } - - public override QsExpressionKind onPauliLiteral(QsPauli p) - { - if (p.IsPauliI) this.Output = Keywords.qsPauliI.id; - else if (p.IsPauliX) this.Output = Keywords.qsPauliX.id; - else if (p.IsPauliY) this.Output = Keywords.qsPauliY.id; - else if (p.IsPauliZ) this.Output = Keywords.qsPauliZ.id; - else throw new NotImplementedException("unknown Pauli literal"); - this.CurrentPrecedence = int.MaxValue; - return QsExpressionKind.NewPauliLiteral(p); - } - - - public override QsExpressionKind onCopyAndUpdateExpression(TypedExpression lhs, TypedExpression acc, TypedExpression rhs) - { - TernaryOperator(Keywords.qsCopyAndUpdateOp, lhs, acc, rhs); - return QsExpressionKind.NewCopyAndUpdate(lhs, acc, rhs); - } - - public override QsExpressionKind onConditionalExpression(TypedExpression cond, TypedExpression ifTrue, TypedExpression ifFalse) - { - TernaryOperator(Keywords.qsConditionalOp, cond, ifTrue, ifFalse); - return QsExpressionKind.NewCONDITIONAL(cond, ifTrue, ifFalse); - } - - public override QsExpressionKind onAddition(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsADDop, lhs, rhs); - return QsExpressionKind.NewADD(lhs, rhs); - } - - public override QsExpressionKind onBitwiseAnd(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsBANDop, lhs, rhs); - return QsExpressionKind.NewBAND(lhs, rhs); - } - - public override QsExpressionKind onBitwiseExclusiveOr(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsBXORop, lhs, rhs); - return QsExpressionKind.NewBXOR(lhs, rhs); - } - - public override QsExpressionKind onBitwiseOr(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsBORop, lhs, rhs); - return QsExpressionKind.NewBOR(lhs, rhs); - } - - public override QsExpressionKind onDivision(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsDIVop, lhs, rhs); - return QsExpressionKind.NewDIV(lhs, rhs); - } - - public override QsExpressionKind onEquality(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsEQop, lhs, rhs); - return QsExpressionKind.NewEQ(lhs, rhs); - } - - public override QsExpressionKind onExponentiate(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsPOWop, lhs, rhs); - return QsExpressionKind.NewPOW(lhs, rhs); - } - - public override QsExpressionKind onGreaterThan(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsGTop, lhs, rhs); - return QsExpressionKind.NewGT(lhs, rhs); - } - - public override QsExpressionKind onGreaterThanOrEqual(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsGTEop, lhs, rhs); - return QsExpressionKind.NewGTE(lhs, rhs); - } - - public override QsExpressionKind onInequality(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsNEQop, lhs, rhs); - return QsExpressionKind.NewNEQ(lhs, rhs); - } - - public override QsExpressionKind onLeftShift(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsLSHIFTop, lhs, rhs); - return QsExpressionKind.NewLSHIFT(lhs, rhs); - } - - public override QsExpressionKind onLessThan(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsLTop, lhs, rhs); - return QsExpressionKind.NewLT(lhs, rhs); - } - - public override QsExpressionKind onLessThanOrEqual(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsLTEop, lhs, rhs); - return QsExpressionKind.NewLTE(lhs, rhs); - } - - public override QsExpressionKind onLogicalAnd(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsANDop, lhs, rhs); - return QsExpressionKind.NewAND(lhs, rhs); - } - - public override QsExpressionKind onLogicalOr(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsORop, lhs, rhs); - return QsExpressionKind.NewOR(lhs, rhs); - } - - public override QsExpressionKind onModulo(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsMODop, lhs, rhs); - return QsExpressionKind.NewMOD(lhs, rhs); - } - - public override QsExpressionKind onMultiplication(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsMULop, lhs, rhs); - return QsExpressionKind.NewMUL(lhs, rhs); - } - - public override QsExpressionKind onRightShift(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsRSHIFTop, lhs, rhs); - return QsExpressionKind.NewRSHIFT(lhs, rhs); - } - - public override QsExpressionKind onSubtraction(TypedExpression lhs, TypedExpression rhs) - { - BinaryOperator(Keywords.qsSUBop, lhs, rhs); - return QsExpressionKind.NewSUB(lhs, rhs); - } - - public override QsExpressionKind onNegative(TypedExpression ex) - { - UnaryOperator(Keywords.qsNEGop, ex); - return QsExpressionKind.NewNEG(ex); - } - - public override QsExpressionKind onLogicalNot(TypedExpression ex) - { - UnaryOperator(Keywords.qsNOTop, ex); - return QsExpressionKind.NewNOT(ex); - } - - public override QsExpressionKind onBitwiseNot(TypedExpression ex) - { - UnaryOperator(Keywords.qsBNOTop, ex); - return QsExpressionKind.NewBNOT(ex); - } - } - - - /// - /// used to pass contextual information for expression transformations - /// - public class TransformationContext - { - public string CurrentNamespace; - public ImmutableHashSet> OpenedNamespaces; - public ImmutableDictionary, NonNullable> NamespaceShortNames; // mapping namespace names to their short names - public ImmutableHashSet> SymbolsInCurrentNamespace; - public ImmutableHashSet> AmbiguousNames; - - public TransformationContext() - { - this.CurrentNamespace = null; - this.OpenedNamespaces = ImmutableHashSet>.Empty; - this.NamespaceShortNames = ImmutableDictionary, NonNullable>.Empty; - this.SymbolsInCurrentNamespace = ImmutableHashSet>.Empty; - this.AmbiguousNames = ImmutableHashSet>.Empty; - } - } - - - /// - /// Class used to generate Q# code for Q# expressions. - /// Upon calling Transform, the Output property is set to the Q# code corresponding to the given expression. - /// - public class ExpressionToQs : - ExpressionTransformation - { - internal readonly TransformationContext Context; - - public ExpressionToQs(TransformationContext context = null) : - base(e => new ExpressionKindToQs(e as ExpressionToQs), e => new ExpressionTypeToQs(e as ExpressionToQs)) => - this.Context = context ?? new TransformationContext(); - } - - - /// - /// Class used to generate Q# code for Q# statements. - /// Upon calling Transform, the _Output property of the scope transformation given on initialization - /// is set to the Q# code corresponding to a statement of the given kind. - /// - public class StatementKindToQs : - StatementKindTransformation - { - private int currentIndendation = 0; - - public const string InvalidSymbol = "__InvalidName__"; - public const string InvalidInitializer = "__InvalidInitializer__"; - - public Action beforeInvalidSymbol; - public Action beforeInvalidInitializer; - - internal StatementKindToQs(ScopeToQs scope) : - base(scope) - { } - - private void AddToOutput(string line) - { - for (var i = 0; i < currentIndendation; ++i) line = $" {line}"; - this._Scope._Output.Add(line); - } - - private void AddComments(IEnumerable comments) - { - foreach (var comment in comments) - { this.AddToOutput(String.IsNullOrWhiteSpace(comment) ? "" : $"//{comment}"); } - } - - private bool PrecededByCode => - SyntaxTreeToQs.PrecededByCode(this._Scope._Output); - - private void AddStatement(string stm) - { - var comments = this._Scope.CurrentComments; - var precededByBlockStatement = SyntaxTreeToQs.PrecededByBlock(this._Scope._Output); - - if (precededByBlockStatement || (PrecededByCode && comments.OpeningComments.Length != 0)) this.AddToOutput(""); - this.AddComments(comments.OpeningComments); - this.AddToOutput($"{stm};"); - this.AddComments(comments.ClosingComments); - if (comments.ClosingComments.Length != 0) this.AddToOutput(""); - } - - private void AddBlockStatement(string intro, QsScope statements, bool withWhiteSpace = true) - { - var comments = this._Scope.CurrentComments; - if (PrecededByCode && withWhiteSpace) this.AddToOutput(""); - this.AddComments(comments.OpeningComments); - this.AddToOutput($"{intro} {"{"}"); - ++currentIndendation; - this._Scope.Transform(statements); - this.AddComments(comments.ClosingComments); - --currentIndendation; - this.AddToOutput("}"); - } - - private string Expression(TypedExpression ex) => - this._Scope._Expression._Kind.Apply(ex.Expression); - - private string SymbolTuple(SymbolTuple sym) - { - if (sym.IsDiscardedItem) return "_"; - else if (sym is SymbolTuple.VariableName name) return name.Item.Value; - else if (sym is SymbolTuple.VariableNameTuple tuple) return $"({String.Join(", ", tuple.Item.Select(SymbolTuple))})"; - else if (sym.IsInvalidItem) - { - this.beforeInvalidSymbol?.Invoke(); - return InvalidSymbol; - } - else throw new NotImplementedException("unknown item in symbol tuple"); - } - - private string InitializerTuple(ResolvedInitializer init) - { - if (init.Resolution.IsSingleQubitAllocation) return $"{Keywords.qsQubit.id}()"; - else if (init.Resolution is QsInitializerKind.QubitRegisterAllocation reg) - { return $"{Keywords.qsQubit.id}[{Expression(reg.Item)}]"; } - else if (init.Resolution is QsInitializerKind.QubitTupleAllocation tuple) - { return $"({String.Join(", ", tuple.Item.Select(InitializerTuple))})"; } - else if (init.Resolution.IsInvalidInitializer) - { - this.beforeInvalidInitializer?.Invoke(); - return InvalidInitializer; - } - else throw new NotImplementedException("unknown qubit initializer"); - } - - - private QsStatementKind QubitScope(QsQubitScope stm) - { - var symbols = SymbolTuple(stm.Binding.Lhs); - var initializers = InitializerTuple(stm.Binding.Rhs); - string header; - if (stm.Kind.IsBorrow) header = Keywords.qsBorrowing.id; - else if (stm.Kind.IsAllocate) header = Keywords.qsUsing.id; - else throw new NotImplementedException("unknown qubit scope"); - - var intro = $"{header} ({symbols} = {initializers})"; - this.AddBlockStatement(intro, stm.Body); - return QsStatementKind.NewQsQubitScope(stm); - } - - public override QsStatementKind onAllocateQubits(QsQubitScope stm) => - this.QubitScope(stm); - - public override QsStatementKind onBorrowQubits(QsQubitScope stm) => - this.QubitScope(stm); - - public override QsStatementKind onForStatement(QsForStatement stm) - { - var symbols = SymbolTuple(stm.LoopItem.Item1); - var intro = $"{Keywords.qsFor.id} ({symbols} {Keywords.qsRangeIter.id} {Expression(stm.IterationValues)})"; - this.AddBlockStatement(intro, stm.Body); - return QsStatementKind.NewQsForStatement(stm); - } - - public override QsStatementKind onWhileStatement(QsWhileStatement stm) - { - var intro = $"{Keywords.qsWhile.id} ({Expression(stm.Condition)})"; - this.AddBlockStatement(intro, stm.Body); - return QsStatementKind.NewQsWhileStatement(stm); - } - - public override QsStatementKind onRepeatStatement(QsRepeatStatement stm) - { - this._Scope.CurrentComments = stm.RepeatBlock.Comments; - this.AddBlockStatement(Keywords.qsRepeat.id, stm.RepeatBlock.Body); - this._Scope.CurrentComments = stm.FixupBlock.Comments; - this.AddToOutput($"{Keywords.qsUntil.id} ({Expression(stm.SuccessCondition)})"); - this.AddBlockStatement(Keywords.qsRUSfixup.id, stm.FixupBlock.Body, false); - return QsStatementKind.NewQsRepeatStatement(stm); - } - - public override QsStatementKind onConditionalStatement(QsConditionalStatement stm) - { - var header = Keywords.qsIf.id; - if (PrecededByCode) this.AddToOutput(""); - foreach (var clause in stm.ConditionalBlocks) - { - this._Scope.CurrentComments = clause.Item2.Comments; - var intro = $"{header} ({Expression(clause.Item1)})"; - this.AddBlockStatement(intro, clause.Item2.Body, false); - header = Keywords.qsElif.id; - } - if (stm.Default.IsValue) - { - this._Scope.CurrentComments = stm.Default.Item.Comments; - this.AddBlockStatement(Keywords.qsElse.id, stm.Default.Item.Body, false); - } - return QsStatementKind.NewQsConditionalStatement(stm); - } - - public override QsStatementKind onConjugation(QsConjugation stm) - { - this._Scope.CurrentComments = stm.OuterTransformation.Comments; - this.AddBlockStatement(Keywords.qsWithin.id, stm.OuterTransformation.Body, true); - this._Scope.CurrentComments = stm.InnerTransformation.Comments; - this.AddBlockStatement(Keywords.qsApply.id, stm.InnerTransformation.Body, false); - return QsStatementKind.NewQsConjugation(stm); - } - - - public override QsStatementKind onExpressionStatement(TypedExpression ex) - { - this.AddStatement(Expression(ex)); - return QsStatementKind.NewQsExpressionStatement(ex); - } - - public override QsStatementKind onFailStatement(TypedExpression ex) - { - this.AddStatement($"{Keywords.qsFail.id} {Expression(ex)}"); - return QsStatementKind.NewQsFailStatement(ex); - } - - public override QsStatementKind onReturnStatement(TypedExpression ex) - { - this.AddStatement($"{Keywords.qsReturn.id} {Expression(ex)}"); - return QsStatementKind.NewQsReturnStatement(ex); - } - - public override QsStatementKind onVariableDeclaration(QsBinding stm) - { - string header; - if (stm.Kind.IsImmutableBinding) header = Keywords.qsImmutableBinding.id; - else if (stm.Kind.IsMutableBinding) header = Keywords.qsMutableBinding.id; - else throw new NotImplementedException("unknown binding kind"); - - this.AddStatement($"{header} {SymbolTuple(stm.Lhs)} = {Expression(stm.Rhs)}"); - return QsStatementKind.NewQsVariableDeclaration(stm); - } - - public override QsStatementKind onValueUpdate(QsValueUpdate stm) - { - this.AddStatement($"{Keywords.qsValueUpdate.id} {Expression(stm.Lhs)} = {Expression(stm.Rhs)}"); - return QsStatementKind.NewQsValueUpdate(stm); - } - } - - - /// - /// Class used to generate Q# code for Q# statements. - /// Upon calling Transform, the Output property is set to the Q# code corresponding to the given statement block. - /// - public class ScopeToQs : - ScopeTransformation - { - internal readonly List _Output; - public string Output => String.Join(Environment.NewLine, this._Output); - internal QsComments CurrentComments; - - public ScopeToQs(TransformationContext context = null) : - base(s => new StatementKindToQs(s as ScopeToQs), new ExpressionToQs(context)) - { - this.CurrentComments = QsComments.Empty; - this._Output = new List(); - } - - public override QsStatement onStatement(QsStatement stm) - { - this.CurrentComments = stm.Comments; - return base.onStatement(stm); - } - } - - - /// - /// Class used to generate Q# code for compiled Q# namespaces. - /// Upon calling Transform, the Output property is set to the Q# code corresponding to the given namespace. - /// - public class SyntaxTreeToQs : - SyntaxTreeTransformation - { - private QsComments CurrentComments; - private int currentIndendation = 0; - private string currentSpec; - private int nrSpecialzations; - - private readonly List _Output; - public string Output => String.Join(Environment.NewLine, this._Output); - - internal static bool PrecededByCode(IEnumerable output) => - output == null ? false : output.Any() && !String.IsNullOrWhiteSpace(output.Last().Replace("{", "")); - - internal static bool PrecededByBlock(IEnumerable output) => - output == null ? false : output.Any() && output.Last().Trim() == "}"; - - public const string ExternalImplementation = "__external__"; - public const string InvalidFunctorGenerator = "__UnknownGenerator__"; - - public Action beforeExternalImplementation; - public Action beforeInvalidFunctorGenerator; - - private void SetAllInvalid(Action action) - { - this.beforeExternalImplementation = action; - this._Scope._StatementKind.beforeInvalidInitializer = action; - this._Scope._StatementKind.beforeInvalidSymbol = action; - this._Scope._Expression._Kind.beforeInvalidIdentifier = action; - this._Scope._Expression._Kind.beforeInvalidExpression = action; - this._Scope._Expression._Type.beforeInvalidType = action; - this._Scope._Expression._Type.beforeInvalidSet = action; - } - - private void SetAllInvalid(SyntaxTreeToQs other) - { - this.beforeExternalImplementation = other.beforeExternalImplementation; - this._Scope._StatementKind.beforeInvalidInitializer = other._Scope._StatementKind.beforeInvalidInitializer; - this._Scope._StatementKind.beforeInvalidSymbol = other._Scope._StatementKind.beforeInvalidSymbol; - this._Scope._Expression._Kind.beforeInvalidIdentifier = other._Scope._Expression._Kind.beforeInvalidIdentifier; - this._Scope._Expression._Kind.beforeInvalidExpression = other._Scope._Expression._Kind.beforeInvalidExpression; - this._Scope._Expression._Type.beforeInvalidType = other._Scope._Expression._Type.beforeInvalidType; - this._Scope._Expression._Type.beforeInvalidSet = other._Scope._Expression._Type.beforeInvalidSet; - } - - /// - /// For each file in the given parameter array of open directives, - /// generates a dictionary that maps (the name of) each partial namespace contained in the file - /// to a string containing the formatted Q# code for the part of the namespace. - /// Qualified or unqualified names for types and identifiers are generated based on the given namespace and open directives. - /// Throws an ArgumentNullException if the given namespace is null. - /// -> IMPORTANT: The given namespace is expected to contain *all* elements in that namespace for the *entire* compilation unit! - /// - public static bool Apply(out List, string>> generatedCode, - IEnumerable namespaces, - params (NonNullable, ImmutableDictionary, ImmutableArray<(NonNullable, string)>>)[] openDirectives) - { - if (namespaces == null) throw new ArgumentNullException(nameof(namespaces)); - - generatedCode = new List, string>>(); - var symbolsInNS = namespaces.ToImmutableDictionary(ns => ns.Name, ns => ns.Elements - .Select(element => (element is QsNamespaceElement.QsCallable c) ? c.Item.FullName.Name.Value : null) - .Where(name => name != null).Select(name => NonNullable.New(name)).ToImmutableHashSet()); - - var success = true; - foreach (var (sourceFile, imports) in openDirectives) - { - var nsInFile = new Dictionary, string>(); - foreach (var ns in namespaces) - { - var tree = FilterBySourceFile.Apply(ns, sourceFile); - if (!tree.Elements.Any()) continue; - - // determine all symbols that occur in multiple open namespaces - var ambiguousSymbols = symbolsInNS.Where(entry => imports[ns.Name].Contains((entry.Key, null))) - .SelectMany(entry => entry.Value) - .GroupBy(name => name) - .Where(group => group.Count() > 1) - .Select(group => group.Key).ToImmutableHashSet(); - - var openedNS = imports[ns.Name].Where(o => o.Item2 == null).Select(o => o.Item1).ToImmutableHashSet(); - var nsShortNames = imports[ns.Name].Where(o => o.Item2 != null).ToImmutableDictionary(o => o.Item1, o => NonNullable.New(o.Item2)); - var context = new TransformationContext - { - CurrentNamespace = ns.Name.Value, - OpenedNamespaces = openedNS, - NamespaceShortNames = nsShortNames, - SymbolsInCurrentNamespace = symbolsInNS[ns.Name], - AmbiguousNames = ambiguousSymbols - }; - - var generator = new SyntaxTreeToQs(new ScopeToQs(context)); - var totNrInvalid = 0; - generator.SetAllInvalid(() => ++totNrInvalid); - - var docComments = ns.Documentation[sourceFile]; - generator.AddDocumentation(docComments.Count() == 1 ? docComments.Single() : ImmutableArray.Empty); // let's drop the doc if it is ambiguous - - generator.AddToOutput($"{Keywords.namespaceDeclHeader.id} {ns.Name.Value}"); - generator.AddBlock(() => - { - var explicitImports = openedNS.Where(opened => !BuiltIn.NamespacesToAutoOpen.Contains(opened)); - if (explicitImports.Any() || nsShortNames.Any()) generator.AddToOutput(""); - foreach (var nsName in explicitImports.OrderBy(name => name)) - { generator.AddDirective($"{Keywords.importDirectiveHeader.id} {nsName.Value}"); } - foreach (var kv in nsShortNames.OrderBy(pair => pair.Key)) - { generator.AddDirective($"{Keywords.importDirectiveHeader.id} {kv.Key.Value} {Keywords.importedAs.id} {kv.Value.Value}"); } - if (explicitImports.Any() || nsShortNames.Any()) generator.AddToOutput(""); - generator.ProcessNamespaceElements(tree.Elements); - }); - if (totNrInvalid > 0) success = false; - nsInFile.Add(ns.Name, generator.Output); - } - generatedCode.Add(nsInFile.ToImmutableDictionary()); - } - return success; - } - - public SyntaxTreeToQs(ScopeToQs scope = null) : - base(scope ?? new ScopeToQs()) - { - this.CurrentComments = QsComments.Empty; - this._Output = new List(); - } - - private void AddToOutput(string line) - { - for (var i = 0; i < currentIndendation; ++i) line = $" {line}"; - this._Output.Add(line); - } - - private void AddComments(IEnumerable comments) - { - foreach (var comment in comments) - { this.AddToOutput(String.IsNullOrWhiteSpace(comment) ? "" : $"//{comment}"); } - } - - private void AddDirective(string str) => - this.AddToOutput($"{str};"); - - private void AddBlock(Action processBlock) - { - var comments = this.CurrentComments; - var opening = "{"; - if (!this._Output.Any()) this.AddToOutput(opening); - else this._Output[this._Output.Count - 1] += $" {opening}"; - ++currentIndendation; - processBlock(); - this.AddComments(comments.ClosingComments); - --currentIndendation; - this.AddToOutput("}"); - } - - private string Type(ResolvedType t) => - this._Scope._Expression._Type.Apply(t); - - private static string SymbolName(QsLocalSymbol sym, Action onInvalidName) - { - if (sym is QsLocalSymbol.ValidName n) return n.Item.Value; - else if (sym.IsInvalidName) - { - onInvalidName?.Invoke(); - return StatementKindToQs.InvalidSymbol; - } - else throw new NotImplementedException("unknown case for local symbol"); - } - - private static string TypeParameters(ResolvedSignature sign, Action onInvalidName) - { - if (sign.TypeParameters.IsEmpty) return String.Empty; - return $"<{String.Join(", ", sign.TypeParameters.Select(tp => $"'{SyntaxTreeToQs.SymbolName(tp, onInvalidName)}"))}>"; - } - - private static string ArgumentTuple(QsTuple arg, - Func getItemNameAndType, Func typeTransformation, bool symbolsOnly = false) - { - if (arg is QsTuple.QsTuple t) - { return $"({String.Join(", ", t.Item.Select(a => ArgumentTuple(a, getItemNameAndType, typeTransformation, symbolsOnly)))})"; } - else if (arg is QsTuple.QsTupleItem i) - { - var (itemName, itemType) = getItemNameAndType(i.Item); - return itemName == null - ? $"{(symbolsOnly ? "_" : $"{typeTransformation(itemType)}")}" - : $"{itemName}{(symbolsOnly ? "" : $" : {typeTransformation(itemType)}")}"; - } - else throw new NotImplementedException("unknown case for argument tuple item"); - } - - public static string ArgumentTuple(QsTuple> arg, - Func typeTransformation, Action onInvalidName = null, bool symbolsOnly = false) => - ArgumentTuple(arg, item => (SymbolName(item.VariableName, onInvalidName), item.Type), typeTransformation, symbolsOnly); - - public static string DeclarationSignature(QsCallable c, Func typeTransformation, Action onInvalidName = null) - { - var argTuple = SyntaxTreeToQs.ArgumentTuple(c.ArgumentTuple, typeTransformation, onInvalidName); - return $"{c.FullName.Name.Value}{TypeParameters(c.Signature, onInvalidName)} {argTuple} : {typeTransformation(c.Signature.ReturnType)}"; - } - - private void AddDocumentation(ImmutableArray doc) - { - foreach (var line in doc) - { this.AddToOutput($"///{line}"); } - } - - private void ProcessNamespaceElements(IEnumerable elements) - { - var types = elements.Where(e => e.IsQsCustomType); - var callables = elements.Where(e => e.IsQsCallable); - - foreach (var t in types) - { this.dispatchNamespaceElement(t); } - if (types.Any()) this.AddToOutput(""); - - foreach (var c in callables) - { this.dispatchNamespaceElement(c); } - } - - - public override Tuple>, QsScope> onProvidedImplementation - (QsTuple> argTuple, QsScope body) - { - var functorArg = "(...)"; - if (this.currentSpec == Keywords.ctrlDeclHeader.id || this.currentSpec == Keywords.ctrlAdjDeclHeader.id) - { - var ctlQubitsName = SyntaxGenerator.ControlledFunctorArgument(argTuple); - if (ctlQubitsName != null) functorArg = $"({ctlQubitsName}, ...)"; - } - else if (this.currentSpec != Keywords.bodyDeclHeader.id && this.currentSpec != Keywords.adjDeclHeader.id) - { throw new NotImplementedException("the current specialization could not be determined"); } - - void ProcessContent() - { - this._Scope._Output.Clear(); - this._Scope.Transform(body); - foreach (var line in this._Scope._Output) - { this.AddToOutput(line); } - } - if (this.nrSpecialzations != 1) // todo: needs to be adapted once we support type specializations - { - this.AddToOutput($"{this.currentSpec} {functorArg}"); - this.AddBlock(ProcessContent); - } - else - { - var comments = this.CurrentComments; - ProcessContent(); - this.AddComments(comments.ClosingComments); - } - return new Tuple>, QsScope>(argTuple, body); - } - - public override void onInvalidGeneratorDirective() - { - this.beforeInvalidFunctorGenerator?.Invoke(); - this.AddDirective($"{this.currentSpec} {InvalidFunctorGenerator}"); - } - - public override void onDistributeDirective() => - this.AddDirective($"{this.currentSpec} {Keywords.distributeFunctorGenDirective.id}"); - - public override void onInvertDirective() => - this.AddDirective($"{this.currentSpec} {Keywords.invertFunctorGenDirective.id}"); - - public override void onSelfInverseDirective() => - this.AddDirective($"{this.currentSpec} {Keywords.selfFunctorGenDirective.id}"); - - public override void onIntrinsicImplementation() => - this.AddDirective($"{this.currentSpec} {Keywords.intrinsicFunctorGenDirective.id}"); - - public override void onExternalImplementation() - { - this.beforeExternalImplementation?.Invoke(); - this.AddDirective($"{this.currentSpec} {ExternalImplementation}"); - } - - public override QsSpecialization beforeSpecialization(QsSpecialization spec) - { - var precededByCode = PrecededByCode(this._Output); - var precededByBlock = PrecededByBlock(this._Output); - if (precededByCode && (precededByBlock || spec.Implementation.IsProvided || spec.Documentation.Any())) this.AddToOutput(""); - this.CurrentComments = spec.Comments; - this.AddComments(spec.Comments.OpeningComments); - if (spec.Comments.OpeningComments.Any() && spec.Documentation.Any()) this.AddToOutput(""); - this.AddDocumentation(spec.Documentation); - return spec; - } - - public override QsSpecialization onBodySpecialization(QsSpecialization spec) - { - this.currentSpec = Keywords.bodyDeclHeader.id; - return base.onBodySpecialization(spec); - } - - public override QsSpecialization onAdjointSpecialization(QsSpecialization spec) - { - this.currentSpec = Keywords.adjDeclHeader.id; - return base.onAdjointSpecialization(spec); - } - - public override QsSpecialization onControlledSpecialization(QsSpecialization spec) - { - this.currentSpec = Keywords.ctrlDeclHeader.id; - return base.onControlledSpecialization(spec); - } - - public override QsSpecialization onControlledAdjointSpecialization(QsSpecialization spec) - { - this.currentSpec = Keywords.ctrlAdjDeclHeader.id; - return base.onControlledAdjointSpecialization(spec); - } - - private QsCallable onCallable(QsCallable c, string declHeader) - { - if (!c.Kind.IsTypeConstructor) - { - this.AddToOutput(""); - this.CurrentComments = c.Comments; - this.AddComments(c.Comments.OpeningComments); - if (c.Comments.OpeningComments.Any() && c.Documentation.Any()) this.AddToOutput(""); - this.AddDocumentation(c.Documentation); - foreach (var attribute in c.Attributes) - { this.onAttribute(attribute); } - } - - var signature = SyntaxTreeToQs.DeclarationSignature(c, this.Type, this._Scope._StatementKind.beforeInvalidSymbol); - this._Scope._Expression._Type.onCharacteristicsExpression(c.Signature.Information.Characteristics); - var characteristics = this._Scope._Expression._Type.Output; - - var userDefinedSpecs = c.Specializations.Where(spec => spec.Implementation.IsProvided); - var specBundles = SpecializationBundleProperties.Bundle(spec => spec.TypeArguments, spec => spec.Kind, userDefinedSpecs); - bool NeedsToBeExplicit (QsSpecialization s) - { - if (s.Kind.IsQsBody) return true; - else if (s.Implementation is SpecializationImplementation.Generated gen) - { - if (gen.Item.IsSelfInverse) return s.Kind.IsQsAdjoint; - if (s.Kind.IsQsControlled || s.Kind.IsQsAdjoint) return false; - - var relevantUserDefinedSpecs = specBundles.TryGetValue(SpecializationBundleProperties.BundleId(s.TypeArguments), out var dict) - ? dict // there may be no user defined implementations for a certain set of type arguments, in which case there is no such entry in the dictionary - : ImmutableDictionary.Empty; - var userDefAdj = relevantUserDefinedSpecs.ContainsKey(QsSpecializationKind.QsAdjoint); - var userDefCtl = relevantUserDefinedSpecs.ContainsKey(QsSpecializationKind.QsControlled); - if (gen.Item.IsInvert) return userDefAdj || !userDefCtl; - if (gen.Item.IsDistribute) return userDefCtl && !userDefAdj; - return false; - } - else return !s.Implementation.IsIntrinsic; - } - c = c.WithSpecializations(specs => specs.Where(NeedsToBeExplicit).ToImmutableArray()); - this.nrSpecialzations = c.Specializations.Length; - - this.AddToOutput($"{declHeader} {signature}"); - if (!String.IsNullOrWhiteSpace(characteristics)) this.AddToOutput($"{Keywords.qsCharacteristics.id} {characteristics}"); - this.AddBlock(() => c.Specializations.Select(dispatchSpecialization).ToImmutableArray()); - this.AddToOutput(""); - return c; - } - - public override QsCallable onFunction(QsCallable c) => - this.onCallable(c, Keywords.fctDeclHeader.id); - - public override QsCallable onOperation(QsCallable c) => - this.onCallable(c, Keywords.opDeclHeader.id); - - public override QsCallable onTypeConstructor(QsCallable c) => c; // no code for these - public override QsCustomType onType(QsCustomType t) - { - this.AddToOutput(""); - this.CurrentComments = t.Comments; // no need to deal with closing comments (can't exist), but need to make sure CurrentComments is up to date - this.AddComments(t.Comments.OpeningComments); - if (t.Comments.OpeningComments.Any() && t.Documentation.Any()) this.AddToOutput(""); - this.AddDocumentation(t.Documentation); - foreach (var attribute in t.Attributes) - { this.onAttribute(attribute); } - - (string, ResolvedType) GetItemNameAndType (QsTypeItem item) - { - if (item is QsTypeItem.Named named) return (named.Item.VariableName.Value, named.Item.Type); - else if (item is QsTypeItem.Anonymous type) return (null, type.Item); - else throw new NotImplementedException("unknown case for type item"); - } - var udtTuple = ArgumentTuple(t.TypeItems, GetItemNameAndType, this.Type); - this.AddDirective($"{Keywords.typeDeclHeader.id} {t.FullName.Name.Value} = {udtTuple}"); - return t; - } - - public override QsDeclarationAttribute onAttribute(QsDeclarationAttribute att) - { - // do *not* set CurrentComments! - this._Scope._Expression.Transform(att.Argument); - var arg = this._Scope._Expression._Kind.Output; - var argStr = att.Argument.Expression.IsValueTuple || att.Argument.Expression.IsUnitValue ? arg : $"({arg})"; - var id = att.TypeId.IsValue - ? Identifier.NewGlobalCallable(new QsQualifiedName(att.TypeId.Item.Namespace, att.TypeId.Item.Name)) - : Identifier.InvalidIdentifier; - this._Scope._Expression._Kind.onIdentifier(id, QsNullable>.Null); - this.AddComments(att.Comments.OpeningComments); - this.AddToOutput($"@ {this._Scope._Expression._Kind.Output}{argStr}"); - return att; - } - - public override QsNamespace Transform(QsNamespace ns) - { - var scope = new ScopeToQs(new TransformationContext { CurrentNamespace = ns.Name.Value }); - var generator = new SyntaxTreeToQs(scope); - generator.SetAllInvalid(this); - - generator.AddToOutput($"{Keywords.namespaceDeclHeader.id} {ns.Name.Value}"); - generator.AddBlock(() => generator.ProcessNamespaceElements(ns.Elements)); - this._Output.AddRange(generator._Output); - return ns; - } - } -} - diff --git a/src/QsCompiler/Transformations/CodeTransformations.cs b/src/QsCompiler/Transformations/CodeTransformations.cs index b159b539a3..f4f5ccd0af 100644 --- a/src/QsCompiler/Transformations/CodeTransformations.cs +++ b/src/QsCompiler/Transformations/CodeTransformations.cs @@ -28,9 +28,9 @@ public static QsScope GenerateAdjoint(this QsScope scope) { // Since we are pulling purely classical statements up, we are potentially changing the order of declarations. // We therefore need to generate unique variable names before reordering the statements. - scope = new UniqueVariableNames().Transform(scope); + scope = new UniqueVariableNames().Statements.OnScope(scope); scope = ApplyFunctorToOperationCalls.ApplyAdjoint(scope); - scope = new ReverseOrderOfOperationCalls().Transform(scope); + scope = new ReverseOrderOfOperationCalls().Statements.OnScope(scope); return StripPositionInfo.Apply(scope); } @@ -58,9 +58,9 @@ public static bool InlineConjugations(this QsCompilation compilation, out QsComp { if (compilation == null) throw new ArgumentNullException(nameof(compilation)); var inline = new InlineConjugations(onException); - var namespaces = compilation.Namespaces.Select(inline.Transform).ToImmutableArray(); + var namespaces = compilation.Namespaces.Select(inline.Namespaces.OnNamespace).ToImmutableArray(); inlined = new QsCompilation(namespaces, compilation.EntryPoints); - return inline.Success; + return inline.SharedState.Success; } /// diff --git a/src/QsCompiler/Transformations/Conjugations.cs b/src/QsCompiler/Transformations/Conjugations.cs index 5f118b8950..0b00da1ea6 100644 --- a/src/QsCompiler/Transformations/Conjugations.cs +++ b/src/QsCompiler/Transformations/Conjugations.cs @@ -5,6 +5,7 @@ using System.Collections.Immutable; using Microsoft.Quantum.QsCompiler.SyntaxTokens; using Microsoft.Quantum.QsCompiler.SyntaxTree; +using Microsoft.Quantum.QsCompiler.Transformations.Core; using Microsoft.Quantum.QsCompiler.Transformations.SearchAndReplace; @@ -18,69 +19,90 @@ namespace Microsoft.Quantum.QsCompiler.Transformations.Conjugations /// In particular, it is only guaranteed to be valid if operation calls only occur within expression statements, and /// throws an InvalidOperationException if the outer block contains while-loops. /// - public class InlineConjugations - : SyntaxTreeTransformation + public class InlineConjugations + : SyntaxTreeTransformation { - public bool Success { get; private set; } - private readonly Action OnException; + public class TransformationState + { + public bool Success { get; internal set; } + internal readonly Action OnException; + + internal Func ResolveNames = + new UniqueVariableNames().Statements.OnScope; + + public void Reset() => + this.ResolveNames = new UniqueVariableNames().Statements.OnScope; + + public TransformationState(Action onException = null) + { + this.Success = true; + this.OnException = onException; + } + } + public InlineConjugations(Action onException = null) - : base(new InlineConjugationStatements()) - { - this.Success = true; - this.OnException = onException; + : base(new TransformationState(onException)) + { + this.Namespaces = new NamespaceTransformation(this); + this.Statements = new StatementTransformation(this); + this.Expressions = new ExpressionTransformation(this, TransformationOptions.Disabled); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); } - public override Tuple>, QsScope> onProvidedImplementation - (QsTuple> argTuple, QsScope body) + + // helper classes + + private class StatementTransformation + : StatementTransformation { - this._Scope.Reset(); - try { body = this._Scope.Transform(body); } - catch (Exception ex) + public StatementTransformation(SyntaxTreeTransformation parent) + : base(parent) { } + + + public override QsScope OnScope(QsScope scope) { - this.OnException?.Invoke(ex); - this.Success = false; + var statements = ImmutableArray.CreateBuilder(); + foreach (var statement in scope.Statements) + { + if (statement.Statement is QsStatementKind.QsConjugation conj) + { + // since we are eliminating scopes, + // we need to make sure that the variables defined within the inlined scopes do not clash with other defined variables. + var outer = this.SharedState.ResolveNames(this.OnScope(conj.Item.OuterTransformation.Body)); + var inner = this.SharedState.ResolveNames(this.OnScope(conj.Item.InnerTransformation.Body)); + var adjOuter = outer.GenerateAdjoint(); // will add a unique name wrapper + + statements.AddRange(outer.Statements); + statements.AddRange(inner.Statements); + statements.AddRange(adjOuter.Statements); + } + else statements.Add(this.OnStatement(statement)); + } + return new QsScope(statements.ToImmutableArray(), scope.KnownSymbols); } - return new Tuple>, QsScope>(argTuple, body); } - } - /// - /// Scope transformation that inlines all conjugations, thus eliminating them from a given scope. - /// The generation of the adjoint for the outer block is subject to the same limitation as any adjoint auto-generation. - /// In particular, it is only guaranteed to be valid if operation calls only occur within expression statements, and - /// throws an InvalidOperationException if the outer block contains while-loops. - /// - public class InlineConjugationStatements - : ScopeTransformation, NoExpressionTransformations> - { - private Func ResolveNames; - internal void Reset() => this.ResolveNames = new UniqueVariableNames().Transform; - public InlineConjugationStatements() - : base(s => new StatementKindTransformation(s as InlineConjugationStatements), new NoExpressionTransformations()) => - this.ResolveNames = new UniqueVariableNames().Transform; - - public override QsScope Transform(QsScope scope) + private class NamespaceTransformation + : NamespaceTransformation { - var statements = ImmutableArray.CreateBuilder(); - foreach (var statement in scope.Statements) + public NamespaceTransformation(SyntaxTreeTransformation parent) + : base(parent) { } + + + public override Tuple>, QsScope> OnProvidedImplementation + (QsTuple> argTuple, QsScope body) { - if (statement.Statement is QsStatementKind.QsConjugation conj) + this.SharedState.Reset(); + try { body = this.Transformation.Statements.OnScope(body); } + catch (Exception ex) { - // since we are eliminating scopes, - // we need to make sure that the variables defined within the inlined scopes do not clash with other defined variables. - var outer = ResolveNames(this.Transform(conj.Item.OuterTransformation.Body)); - var inner = ResolveNames(this.Transform(conj.Item.InnerTransformation.Body)); - var adjOuter = outer.GenerateAdjoint(); // will add a unique name wrapper - - statements.AddRange(outer.Statements); - statements.AddRange(inner.Statements); - statements.AddRange(adjOuter.Statements); + this.SharedState.OnException?.Invoke(ex); + this.SharedState.Success = false; } - else statements.Add(this.onStatement(statement)); + return new Tuple>, QsScope>(argTuple, body); } - return new QsScope(statements.ToImmutableArray(), scope.KnownSymbols); } } } diff --git a/src/QsCompiler/Transformations/FunctorGeneration.cs b/src/QsCompiler/Transformations/FunctorGeneration.cs index e4be60cd5c..f47871d31a 100644 --- a/src/QsCompiler/Transformations/FunctorGeneration.cs +++ b/src/QsCompiler/Transformations/FunctorGeneration.cs @@ -10,6 +10,7 @@ using Microsoft.Quantum.QsCompiler.SyntaxTokens; using Microsoft.Quantum.QsCompiler.SyntaxTree; using Microsoft.Quantum.QsCompiler.Transformations.BasicTransformations; +using Microsoft.Quantum.QsCompiler.Transformations.Core; namespace Microsoft.Quantum.QsCompiler.Transformations.FunctorGeneration @@ -19,131 +20,158 @@ namespace Microsoft.Quantum.QsCompiler.Transformations.FunctorGeneration /// with a call to the operation after application of the functor given on initialization. /// The default values used for auto-generation will be used for the additional functor arguments. ///
- public class ApplyFunctorToOperationCalls : - ScopeTransformation< - ApplyFunctorToOperationCalls.IgnoreOuterBlockInConjugations, - ExpressionTransformation > + public class ApplyFunctorToOperationCalls + : SyntaxTreeTransformation { - public ApplyFunctorToOperationCalls(QsFunctor functor) : base( - s => new IgnoreOuterBlockInConjugations(s as ApplyFunctorToOperationCalls), - new ExpressionTransformation(e => new ApplyToExpressionKind(e, functor))) - { } + public class TransformationsState + { + public readonly QsFunctor FunctorToApply; + + public TransformationsState(QsFunctor functor) => + this.FunctorToApply = functor ?? throw new ArgumentNullException(nameof(functor)); + } + + + public ApplyFunctorToOperationCalls(QsFunctor functor) + : base(new TransformationsState(functor)) + { + this.StatementKinds = new IgnoreOuterBlockInConjugations(this); + this.ExpressionKinds = new ExpressionKindTransformation(this); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); + } + + + // static methods for convenience private static readonly TypedExpression ControlQubits = SyntaxGenerator.ImmutableQubitArrayWithName(NonNullable.New(InternalUse.ControlQubitsName)); public static readonly Func ApplyAdjoint = - new ApplyFunctorToOperationCalls(QsFunctor.Adjoint).Transform; + new ApplyFunctorToOperationCalls(QsFunctor.Adjoint).Statements.OnScope; public static readonly Func ApplyControlled = - new ApplyFunctorToOperationCalls(QsFunctor.Controlled).Transform; + new ApplyFunctorToOperationCalls(QsFunctor.Controlled).Statements.OnScope; - // helper class - - /// - /// Ignores outer blocks of conjugations, transforming only the inner one. - /// - public class IgnoreOuterBlockInConjugations : - StatementKindTransformation - where S : Core.ScopeTransformation - { - public IgnoreOuterBlockInConjugations(S scope) - : base(scope) - { } - - public override QsStatementKind onConjugation(QsConjugation stm) - { - var inner = stm.InnerTransformation; - var innerLoc = this._Scope.onLocation(inner.Location); - var transformedInner = new QsPositionedBlock(this._Scope.Transform(inner.Body), innerLoc, inner.Comments); - return QsStatementKind.NewQsConjugation(new QsConjugation(stm.OuterTransformation, transformedInner)); - } - } + // helper classes /// /// Replaces each operation call with a call to the operation after application of the given functor. /// The default values used for auto-generation will be used for the additional functor arguments. /// - public class ApplyToExpressionKind : - ExpressionKindTransformation + public class ExpressionKindTransformation + : ExpressionKindTransformation { - public readonly QsFunctor FunctorToApply; - public ApplyToExpressionKind(Core.ExpressionTransformation expression, QsFunctor functor) : - base(expression) => - this.FunctorToApply = functor ?? throw new ArgumentNullException(nameof(functor)); + public ExpressionKindTransformation(SyntaxTreeTransformation parent) + : base(parent) { } + + public ExpressionKindTransformation(QsFunctor functor) + : base(new TransformationsState(functor)) { } - public override QsExpressionKind onOperationCall(TypedExpression method, TypedExpression arg) + + public override QsExpressionKind OnOperationCall(TypedExpression method, TypedExpression arg) { - if (this.FunctorToApply.IsControlled) + if (this.SharedState.FunctorToApply.IsControlled) { method = SyntaxGenerator.ControlledOperation(method); arg = SyntaxGenerator.ArgumentWithControlQubits(arg, ControlQubits); } - else if (this.FunctorToApply.IsAdjoint) + else if (this.SharedState.FunctorToApply.IsAdjoint) { method = SyntaxGenerator.AdjointOperation(method); } else throw new NotImplementedException("unsupported functor"); - return base.onOperationCall(method, arg); + return base.OnOperationCall(method, arg); } } } + /// + /// Ensures that the outer block of conjugations is ignored during transformation. + /// + public class IgnoreOuterBlockInConjugations + : StatementKindTransformation + { + public IgnoreOuterBlockInConjugations(SyntaxTreeTransformation parent, TransformationOptions options = null) + : base(parent, options ?? TransformationOptions.Default) { } + + public IgnoreOuterBlockInConjugations(T sharedInternalState, TransformationOptions options = null) + : base(sharedInternalState, options ?? TransformationOptions.Default) { } + + + public override QsStatementKind OnConjugation(QsConjugation stm) + { + var inner = stm.InnerTransformation; + var innerLoc = this.Transformation.Statements.OnLocation(inner.Location); + var transformedInner = new QsPositionedBlock(this.Transformation.Statements.OnScope(inner.Body), innerLoc, inner.Comments); + return QsStatementKind.NewQsConjugation(new QsConjugation(stm.OuterTransformation, transformedInner)); + } + } + + /// /// Scope transformation that reverses the order of execution for operation calls within a given scope, - /// unless these calls occur within the outer block of a conjugation. Outer tranformations of conjugations are left unchanged. + /// unless these calls occur within the outer block of a conjugation. Outer transformations of conjugations are left unchanged. /// Note that the transformed scope is only guaranteed to be valid if operation calls only occur within expression statements! /// Otherwise the transformation will succeed, but the generated scope is not necessarily valid. /// Throws an InvalidOperationException if the scope to transform contains while-loops. /// - internal class ReverseOrderOfOperationCalls : - SelectByAllContainedExpressions + internal class ReverseOrderOfOperationCalls + : SelectByAllContainedExpressions { - public ReverseOrderOfOperationCalls() : - base(ex => !ex.InferredInformation.HasLocalQuantumDependency, false, s => new ReverseLoops(s as ReverseOrderOfOperationCalls)) // no need to evaluate subexpressions - { } + public ReverseOrderOfOperationCalls() + : base(ex => !ex.InferredInformation.HasLocalQuantumDependency, false) // no need to evaluate subexpressions + { + // Do *not* disable transformations; the base class takes care of that! + this.StatementKinds = new ReverseLoops(this); + this.Statements = new StatementTransformation(this); + } + - protected override SelectByFoldingOverExpressions GetSubSelector() => - new ReverseOrderOfOperationCalls(); + // helper classes - public override QsScope Transform(QsScope scope) + private class StatementTransformation + : StatementTransformation { - var topStatements = ImmutableArray.CreateBuilder(); - var bottomStatements = new List(); - foreach (var statement in scope.Statements) + public StatementTransformation(ReverseOrderOfOperationCalls parent) + : base(state => new ReverseOrderOfOperationCalls(), parent) { } + + public override QsScope OnScope(QsScope scope) { - var transformed = this.onStatement(statement); - if (this.SubSelector.SatisfiesCondition) topStatements.Add(statement); - else bottomStatements.Add(transformed); + var topStatements = ImmutableArray.CreateBuilder(); + var bottomStatements = new List(); + foreach (var statement in scope.Statements) + { + var transformed = this.OnStatement(statement); + if (this.SubSelector.SharedState.SatisfiesCondition) topStatements.Add(statement); + else bottomStatements.Add(transformed); + } + bottomStatements.Reverse(); + return new QsScope(topStatements.Concat(bottomStatements).ToImmutableArray(), scope.KnownSymbols); } - bottomStatements.Reverse(); - return new QsScope(topStatements.Concat(bottomStatements).ToImmutableArray(), scope.KnownSymbols); } - - // helper class - /// - /// Helper class for the scope transformation that reverses the order of all operation calls - /// unless these calls occur within the outer block of a conjugation. Outer tranformations of conjugations are left unchanged. + /// Helper class to reverse the order of all operation calls + /// unless these calls occur within the outer block of a conjugation. + /// Outer transformations of conjugations are left unchanged. /// Throws an InvalidOperationException upon while-loops. /// - internal class ReverseLoops : - ApplyFunctorToOperationCalls.IgnoreOuterBlockInConjugations + private class ReverseLoops + : IgnoreOuterBlockInConjugations { - internal ReverseLoops(ReverseOrderOfOperationCalls scope) : - base(scope) { } + internal ReverseLoops(ReverseOrderOfOperationCalls parent) + : base(parent) { } - public override QsStatementKind onForStatement(QsForStatement stm) + public override QsStatementKind OnForStatement(QsForStatement stm) { var reversedIterable = SyntaxGenerator.ReverseIterable(stm.IterationValues); stm = new QsForStatement(stm.LoopItem, reversedIterable, stm.Body); - return base.onForStatement(stm); + return base.OnForStatement(stm); } - public override QsStatementKind onWhileStatement(QsWhileStatement stm) => + public override QsStatementKind OnWhileStatement(QsWhileStatement stm) => throw new InvalidOperationException("cannot reverse while-loops"); } } diff --git a/src/QsCompiler/Transformations/IntrinsicResolutionTransformation.cs b/src/QsCompiler/Transformations/IntrinsicResolution.cs similarity index 95% rename from src/QsCompiler/Transformations/IntrinsicResolutionTransformation.cs rename to src/QsCompiler/Transformations/IntrinsicResolution.cs index 1d667cedd2..685ac83d13 100644 --- a/src/QsCompiler/Transformations/IntrinsicResolutionTransformation.cs +++ b/src/QsCompiler/Transformations/IntrinsicResolution.cs @@ -9,9 +9,9 @@ using Microsoft.Quantum.QsCompiler.SyntaxTree; -namespace Microsoft.Quantum.QsCompiler.Transformations.IntrinsicResolutionTransformation +namespace Microsoft.Quantum.QsCompiler.Transformations.IntrinsicResolution { - public class IntrinsicResolutionTransformation + public static class ReplaceWithTargetIntrinsics { /// /// Merge the environment-specific syntax tree with the target tree. The resulting tree will @@ -58,7 +58,7 @@ private static IEnumerable MergeElements(IEnumerable>, ResolvedType>, Identifier>; using ImmutableConcretion = ImmutableDictionary>, ResolvedType>; - public static class MonomorphizationTransformation + public static class Monomorphize { - public static QsCompilation Apply(QsCompilation compilation) => ResolveGenericsSyntax.Apply(compilation); - private struct Request { public QsQualifiedName originalName; @@ -35,154 +35,174 @@ private struct Response public QsCallable concreteCallable; } - #region ResolveGenerics - - private class ResolveGenericsSyntax : - SyntaxTreeTransformation + public static QsCompilation Apply(QsCompilation compilation) { - ImmutableDictionary, IEnumerable> NamespaceCallables; + if (compilation == null || compilation.Namespaces.Contains(null)) throw new ArgumentNullException(nameof(compilation)); - public static QsCompilation Apply(QsCompilation compilation) - { - if (compilation == null || compilation.Namespaces.Contains(null)) throw new ArgumentNullException(nameof(compilation)); + var globals = compilation.Namespaces.GlobalCallableResolutions(); - var globals = compilation.Namespaces.GlobalCallableResolutions(); + var entryPoints = compilation.EntryPoints + .Select(call => new Request + { + originalName = call, + typeResolutions = ImmutableConcretion.Empty, + concreteName = call + }); - var entryPoints = compilation.EntryPoints - .Select(call => new Request - { - originalName = call, - typeResolutions = ImmutableConcretion.Empty, - concreteName = call - }); + var requests = new Stack(entryPoints); + var responses = new List(); + + while (requests.Any()) + { + Request currentRequest = requests.Pop(); - var requests = new Stack(entryPoints); - var responses = new List(); + // If there is a call to an unknown callable, throw exception + if (!globals.TryGetValue(currentRequest.originalName, out QsCallable originalGlobal)) + throw new ArgumentException($"Couldn't find definition for callable: {currentRequest.originalName.ToString()}"); - while (requests.Any()) + var currentResponse = new Response { - Request currentRequest = requests.Pop(); + originalName = currentRequest.originalName, + typeResolutions = currentRequest.typeResolutions, + concreteCallable = originalGlobal.WithFullName(name => currentRequest.concreteName) + }; - // If there is a call to an unknown callable, throw exception - if (!globals.TryGetValue(currentRequest.originalName, out QsCallable originalGlobal)) - throw new ArgumentException($"Couldn't find definition for callable: {currentRequest.originalName.Namespace.Value + "." + currentRequest.originalName.Name.Value}"); + GetConcreteIdentifierFunc getConcreteIdentifier = (globalCallable, types) => + GetConcreteIdentifier(currentResponse, requests, responses, globalCallable, types); - var currentResponse = new Response - { - originalName = currentRequest.originalName, - typeResolutions = currentRequest.typeResolutions, - concreteCallable = originalGlobal.WithFullName(name => currentRequest.concreteName) - }; + // Rewrite implementation + currentResponse = ReplaceTypeParamImplementations.Apply(currentResponse); - GetConcreteIdentifierFunc getConcreteIdentifier = (globalCallable, types) => - GetConcreteIdentifier(currentResponse, requests, responses, globalCallable, types); + // Rewrite calls + currentResponse = ReplaceTypeParamCalls.Apply(currentResponse, getConcreteIdentifier); - // Rewrite implementation - currentResponse = ReplaceTypeParamImplementationsSyntax.Apply(currentResponse); + responses.Add(currentResponse); + } - // Rewrite calls - currentResponse = ReplaceTypeParamCallsSyntax.Apply(currentResponse, getConcreteIdentifier); + return ResolveGenerics.Apply(compilation, responses); + } - responses.Add(currentResponse); - } + private static Identifier GetConcreteIdentifier( + Response currentResponse, + Stack requests, + List responses, + Identifier.GlobalCallable globalCallable, + ImmutableConcretion types) + { + QsQualifiedName concreteName = globalCallable.Item; - var filter = new ResolveGenericsSyntax(responses - .GroupBy(res => res.concreteCallable.FullName.Namespace) - .ToImmutableDictionary(group => group.Key, group => group.Select(res => res.concreteCallable))); - return new QsCompilation(compilation.Namespaces.Select(ns => filter.Transform(ns)).ToImmutableArray(), compilation.EntryPoints); + var typesHashSet = ImmutableHashSet>, ResolvedType>>.Empty; + if (types != null && !types.IsEmpty) + { + typesHashSet = types.ToImmutableHashSet(); } - /// - /// Constructor for the ResolveGenericsSyntax class. Its transform function replaces global callables in the namespace. - /// - /// Maps namespace names to an enumerable of all global callables in that namespace. - public ResolveGenericsSyntax(ImmutableDictionary, IEnumerable> namespaceCallables) : base(new NoScopeTransformations()) + string name = null; + + // Check for recursive call + if (currentResponse.originalName.Equals(globalCallable.Item) && + typesHashSet.SetEquals(currentResponse.typeResolutions)) { - NamespaceCallables = namespaceCallables; + name = currentResponse.concreteCallable.FullName.Name.Value; } - public override QsNamespace Transform(QsNamespace ns) + // Search requests for identifier + if (name == null) { - NamespaceCallables.TryGetValue(ns.Name, out IEnumerable concretesInNs); - - // Removes unused or generic callables from the namespace - // Adds in the used concrete callables - return ns.WithElements(elems => elems - .Where(elem => elem is QsNamespaceElement.QsCustomType) - .Concat(concretesInNs?.Select(call => QsNamespaceElement.NewQsCallable(call)) ?? Enumerable.Empty()) - .ToImmutableArray()); + name = requests + .Where(req => + req.originalName.Equals(globalCallable.Item) && + typesHashSet.SetEquals(req.typeResolutions)) + .Select(req => req.concreteName.Name.Value) + .FirstOrDefault(); } - private static Identifier GetConcreteIdentifier( - Response currentResponse, - Stack requests, - List responses, - Identifier.GlobalCallable globalCallable, - ImmutableConcretion types) + // Search responses for identifier + if (name == null) { - QsQualifiedName concreteName = globalCallable.Item; + name = responses + .Where(res => + res.originalName.Equals(globalCallable.Item) && + typesHashSet.SetEquals(res.typeResolutions)) + .Select(res => res.concreteCallable.FullName.Name.Value) + .FirstOrDefault(); + } - var typesHashSet = ImmutableHashSet>, ResolvedType>>.Empty; - if (types != null && !types.IsEmpty) + // If identifier can't be found, make a new request + if (name == null) + { + // If this is not a generic, do not change the name + if (!typesHashSet.IsEmpty) { - typesHashSet = types.ToImmutableHashSet(); + // Create new name + concreteName = UniqueVariableNames.PrependGuid(globalCallable.Item); } - string name = null; - - // Check for recursive call - if (currentResponse.originalName.Equals(globalCallable.Item) && - typesHashSet.SetEquals(currentResponse.typeResolutions)) + requests.Push(new Request() { - name = currentResponse.concreteCallable.FullName.Name.Value; - } + originalName = globalCallable.Item, + typeResolutions = types, + concreteName = concreteName + }); + } + else // If the identifier was found, update with the name + { + concreteName = new QsQualifiedName(globalCallable.Item.Namespace, NonNullable.New(name)); + } - // Search requests for identifier - if (name == null) - { - name = requests - .Where(req => - req.originalName.Equals(globalCallable.Item) && - typesHashSet.SetEquals(req.typeResolutions)) - .Select(req => req.concreteName.Name.Value) - .FirstOrDefault(); - } + return Identifier.NewGlobalCallable(concreteName); + } + + #region ResolveGenerics - // Search responses for identifier - if (name == null) + private class ResolveGenerics : SyntaxTreeTransformation + { + public static QsCompilation Apply(QsCompilation compilation, List responses) + { + var filter = new ResolveGenerics(responses + .GroupBy(res => res.concreteCallable.FullName.Namespace) + .ToImmutableDictionary(group => group.Key, group => group.Select(res => res.concreteCallable))); + + return new QsCompilation(compilation.Namespaces.Select(ns => filter.Namespaces.OnNamespace(ns)).ToImmutableArray(), compilation.EntryPoints); + } + + public class TransformationState + { + public readonly ImmutableDictionary, IEnumerable> NamespaceCallables; + + public TransformationState(ImmutableDictionary, IEnumerable> namespaceCallables) { - name = responses - .Where(res => - res.originalName.Equals(globalCallable.Item) && - typesHashSet.SetEquals(res.typeResolutions)) - .Select(res => res.concreteCallable.FullName.Name.Value) - .FirstOrDefault(); + this.NamespaceCallables = namespaceCallables; } + } - // If identifier can't be found, make a new request - if (name == null) - { - // If this is not a generic, do not change the name - if (!typesHashSet.IsEmpty) - { - // Create new name - name = "_" + Guid.NewGuid().ToString("N") + "_" + globalCallable.Item.Name.Value; - concreteName = new QsQualifiedName(globalCallable.Item.Namespace, NonNullable.New(name)); - } + /// + /// Constructor for the ResolveGenericsSyntax class. Its transform function replaces global callables in the namespace. + /// + /// Maps namespace names to an enumerable of all global callables in that namespace. + private ResolveGenerics(ImmutableDictionary, IEnumerable> namespaceCallables) : base(new TransformationState(namespaceCallables)) + { + this.Namespaces = new NamespaceTransformation(this); + this.Statements = new StatementTransformation(this, TransformationOptions.Disabled); + this.Expressions = new ExpressionTransformation(this, TransformationOptions.Disabled); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); + } - requests.Push(new Request() - { - originalName = globalCallable.Item, - typeResolutions = types, - concreteName = concreteName - }); - } - else // If the identifier was found, update with the name + private class NamespaceTransformation : NamespaceTransformation + { + public NamespaceTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override QsNamespace OnNamespace(QsNamespace ns) { - concreteName = new QsQualifiedName(globalCallable.Item.Namespace, NonNullable.New(name)); + SharedState.NamespaceCallables.TryGetValue(ns.Name, out IEnumerable concretesInNs); + + // Removes unused or generic callables from the namespace + // Adds in the used concrete callables + return ns.WithElements(elems => elems + .Where(elem => !(elem is QsNamespaceElement.QsCallable)) + .Concat(concretesInNs?.Select(call => QsNamespaceElement.NewQsCallable(call)) ?? Enumerable.Empty()) + .ToImmutableArray()); } - - return Identifier.NewGlobalCallable(concreteName); } } @@ -190,63 +210,70 @@ private static Identifier GetConcreteIdentifier( #region RewriteImplementations - private class ReplaceTypeParamImplementationsSyntax : - SyntaxTreeTransformation>> + private class ReplaceTypeParamImplementations : + SyntaxTreeTransformation { public static Response Apply(Response current) { // Nothing to change if the current callable is already concrete if (current.typeResolutions == ImmutableConcretion.Empty) return current; - var filter = new ReplaceTypeParamImplementationsSyntax(current.typeResolutions); + var filter = new ReplaceTypeParamImplementations(current.typeResolutions); // Create a new response with the transformed callable return new Response { originalName = current.originalName, typeResolutions = current.typeResolutions, - concreteCallable = filter.onCallableImplementation(current.concreteCallable) + concreteCallable = filter.Namespaces.OnCallableDeclaration(current.concreteCallable) }; } - public ReplaceTypeParamImplementationsSyntax(ImmutableConcretion typeParams) : base( - new ScopeTransformation>( - new ExpressionTransformation( - ex => new ExpressionKindTransformation>(ex as ExpressionTransformation), - ex => new ReplaceTypeParamImplementationsExpressionType(typeParams, ex as ExpressionTransformation) - ))) - { } - - public override ResolvedSignature onSignature(ResolvedSignature s) + public class TransformationState { - // Remove the type parameters from the signature - s = new ResolvedSignature( - ImmutableArray.Empty, - s.ArgumentType, - s.ReturnType, - s.Information - ); - return base.onSignature(s); + public readonly ImmutableConcretion TypeParams; + + public TransformationState(ImmutableConcretion typeParams) + { + this.TypeParams = typeParams; + } } - } - private class ReplaceTypeParamImplementationsExpressionType : - ExpressionTypeTransformation> - { - ImmutableConcretion TypeParams; + private ReplaceTypeParamImplementations(ImmutableConcretion typeParams) : base(new TransformationState(typeParams)) + { + this.Namespaces = new NamespaceTransformation(this); + this.Types = new TypeTransformation(this); + } - public ReplaceTypeParamImplementationsExpressionType(ImmutableConcretion typeParams, ExpressionTransformation expr) : base(expr) + private class NamespaceTransformation : NamespaceTransformation { - TypeParams = typeParams; + public NamespaceTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override ResolvedSignature OnSignature(ResolvedSignature s) + { + // Remove the type parameters from the signature + s = new ResolvedSignature( + ImmutableArray.Empty, + s.ArgumentType, + s.ReturnType, + s.Information + ); + return base.OnSignature(s); + } } - public override QsTypeKind onTypeParameter(QsTypeParameter tp) + private class TypeTransformation : TypeTransformation { - if (TypeParams.TryGetValue(Tuple.Create(tp.Origin, tp.TypeName), out var typeParam)) + public TypeTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override QsTypeKind OnTypeParameter(QsTypeParameter tp) { - return typeParam.Resolution; + if (SharedState.TypeParams.TryGetValue(Tuple.Create(tp.Origin, tp.TypeName), out var typeParam)) + { + return typeParam.Resolution; + } + return QsTypeKind.NewTypeParameter(tp); } - return QsTypeKind.NewTypeParameter(tp); } } @@ -254,123 +281,115 @@ public override QsTypeKind> + private class ReplaceTypeParamCalls : + SyntaxTreeTransformation { public static Response Apply(Response current, GetConcreteIdentifierFunc getConcreteIdentifier) { - var filter = new ReplaceTypeParamCallsSyntax(new ScopeTransformation( - new ReplaceTypeParamCallsExpression(new Concretion(), getConcreteIdentifier))); + var filter = new ReplaceTypeParamCalls(getConcreteIdentifier); // Create a new response with the transformed callable return new Response { originalName = current.originalName, typeResolutions = current.typeResolutions, - concreteCallable = filter.onCallableImplementation(current.concreteCallable) + concreteCallable = filter.Namespaces.OnCallableDeclaration(current.concreteCallable) }; } - public ReplaceTypeParamCallsSyntax(ScopeTransformation scope) : base(scope) { } - - } - - private class ReplaceTypeParamCallsExpression : - ExpressionTransformation - { - private readonly Concretion CurrentParamTypes; - - public ReplaceTypeParamCallsExpression(Concretion currentParamTypes, GetConcreteIdentifierFunc getConcreteIdentifier) : - base(ex => new ReplaceTypeParamCallsExpressionKind(ex as ReplaceTypeParamCallsExpression, currentParamTypes, getConcreteIdentifier), - ex => new ReplaceTypeParamCallsExpressionType(ex as ReplaceTypeParamCallsExpression, currentParamTypes)) + public class TransformationState { - CurrentParamTypes = currentParamTypes; + public readonly Concretion CurrentParamTypes = new Concretion(); + public readonly GetConcreteIdentifierFunc GetConcreteIdentifier; + + public TransformationState(GetConcreteIdentifierFunc getConcreteIdentifier) + { + this.GetConcreteIdentifier = getConcreteIdentifier; + } } - public override TypedExpression Transform(TypedExpression ex) - { - var range = this.onRangeInformation(ex.Range); - var typeParamResolutions = this.onTypeParamResolutions(ex.TypeParameterResolutions) - .Select(kv => new Tuple, ResolvedType>(kv.Key.Item1, kv.Key.Item2, kv.Value)) - .ToImmutableArray(); - var exType = this.Type.Transform(ex.ResolvedType); - var inferredInfo = this.onExpressionInformation(ex.InferredInformation); - // Change the order so that Kind is transformed last. - // This matters because the onTypeParamResolutions method builds up type param mappings in - // the CurrentParamTypes dictionary that are then used, and removed from the - // dictionary, in the next global callable identifier found under the Kind transformations. - var kind = this.Kind.Transform(ex.Expression); - return new TypedExpression(kind, typeParamResolutions, exType, inferredInfo, range); + private ReplaceTypeParamCalls(GetConcreteIdentifierFunc getConcreteIdentifier) : base(new TransformationState(getConcreteIdentifier)) + { + this.Expressions = new ExpressionTransformation(this); + this.ExpressionKinds = new ExpressionKindTransformation(this); + this.Types = new TypeTransformation(this); } - public override ImmutableConcretion onTypeParamResolutions(ImmutableConcretion typeParams) + private class ExpressionTransformation : ExpressionTransformation { - // Merge the type params into the current dictionary - foreach (var kvp in typeParams) + public ExpressionTransformation(SyntaxTreeTransformation parent) : base(parent) { } + + public override TypedExpression OnTypedExpression(TypedExpression ex) { - CurrentParamTypes.Add(kvp.Key, kvp.Value); + var range = this.OnRangeInformation(ex.Range); + var typeParamResolutions = this.OnTypeParamResolutions(ex.TypeParameterResolutions) + .Select(kv => new Tuple, ResolvedType>(kv.Key.Item1, kv.Key.Item2, kv.Value)) + .ToImmutableArray(); + var exType = this.Types.OnType(ex.ResolvedType); + var inferredInfo = this.OnExpressionInformation(ex.InferredInformation); + // Change the order so that Kind is transformed last. + // This matters because the onTypeParamResolutions method builds up type param mappings in + // the CurrentParamTypes dictionary that are then used, and removed from the + // dictionary, in the next global callable identifier found under the Kind transformations. + var kind = this.ExpressionKinds.OnExpressionKind(ex.Expression); + return new TypedExpression(kind, typeParamResolutions, exType, inferredInfo, range); } - return ImmutableConcretion.Empty; - } - } - - private class ReplaceTypeParamCallsExpressionKind : ExpressionKindTransformation - { - private readonly GetConcreteIdentifierFunc GetConcreteIdentifier; - private Concretion CurrentParamTypes; + public override ImmutableConcretion OnTypeParamResolutions(ImmutableConcretion typeParams) + { + // Merge the type params into the current dictionary + foreach (var kvp in typeParams) + { + SharedState.CurrentParamTypes.Add(kvp.Key, kvp.Value); + } - public ReplaceTypeParamCallsExpressionKind(ReplaceTypeParamCallsExpression expr, - Concretion currentParamTypes, - GetConcreteIdentifierFunc getConcreteIdentifier) : base(expr) - { - GetConcreteIdentifier = getConcreteIdentifier; - CurrentParamTypes = currentParamTypes; + return ImmutableConcretion.Empty; + } } - public override QsExpressionKind onIdentifier(Identifier sym, QsNullable> tArgs) + private class ExpressionKindTransformation : ExpressionKindTransformation { - if (sym is Identifier.GlobalCallable global) - { - ImmutableConcretion applicableParams = CurrentParamTypes - .Where(kvp => kvp.Key.Item1.Equals(global.Item)) - .ToImmutableDictionary(kvp => kvp.Key, kvp => kvp.Value); - - // Create a new identifier - sym = GetConcreteIdentifier(global, applicableParams); - tArgs = QsNullable>.Null; + public ExpressionKindTransformation(SyntaxTreeTransformation parent) : base(parent) { } - // Remove Type Params used from the CurrentParamTypes - foreach (var key in applicableParams.Keys) + public override QsExpressionKind OnIdentifier(Identifier sym, QsNullable> tArgs) + { + if (sym is Identifier.GlobalCallable global) { - CurrentParamTypes.Remove(key); + ImmutableConcretion applicableParams = SharedState.CurrentParamTypes + .Where(kvp => kvp.Key.Item1.Equals(global.Item)) + .ToImmutableDictionary(kvp => kvp.Key, kvp => kvp.Value); + + // Create a new identifier + sym = SharedState.GetConcreteIdentifier(global, applicableParams); + tArgs = QsNullable>.Null; + + // Remove Type Params used from the CurrentParamTypes + foreach (var key in applicableParams.Keys) + { + SharedState.CurrentParamTypes.Remove(key); + } + } + else if (sym is Identifier.LocalVariable && tArgs.IsValue && tArgs.Item.Any()) + { + throw new ArgumentException($"Local variables cannot have type arguments."); } - } - else if (sym is Identifier.LocalVariable && tArgs.IsValue && tArgs.Item.Any()) - { - throw new ArgumentException($"Local variables cannot have type arguments."); - } - return base.onIdentifier(sym, tArgs); + return base.OnIdentifier(sym, tArgs); + } } - } - - private class ReplaceTypeParamCallsExpressionType : ExpressionTypeTransformation - { - private Concretion CurrentParamTypes; - public ReplaceTypeParamCallsExpressionType(ReplaceTypeParamCallsExpression expr, Concretion currentParamTypes) : base(expr) + private class TypeTransformation : TypeTransformation { - CurrentParamTypes = currentParamTypes; - } + public TypeTransformation(SyntaxTreeTransformation parent) : base(parent) { } - public override QsTypeKind onTypeParameter(QsTypeParameter tp) - { - if (CurrentParamTypes.TryGetValue(Tuple.Create(tp.Origin, tp.TypeName), out var typeParam)) + public override QsTypeKind OnTypeParameter(QsTypeParameter tp) { - return typeParam.Resolution; + if (SharedState.CurrentParamTypes.TryGetValue(Tuple.Create(tp.Origin, tp.TypeName), out var typeParam)) + { + return typeParam.Resolution; + } + return QsTypeKind.NewTypeParameter(tp); } - return QsTypeKind.NewTypeParameter(tp); } } diff --git a/src/QsCompiler/Transformations/MonomorphizationValidation.cs b/src/QsCompiler/Transformations/MonomorphizationValidation.cs index cb84f6c8a3..02819a440b 100644 --- a/src/QsCompiler/Transformations/MonomorphizationValidation.cs +++ b/src/QsCompiler/Transformations/MonomorphizationValidation.cs @@ -7,47 +7,52 @@ using Microsoft.Quantum.QsCompiler.DataTypes; using Microsoft.Quantum.QsCompiler.SyntaxTokens; using Microsoft.Quantum.QsCompiler.SyntaxTree; +using Microsoft.Quantum.QsCompiler.Transformations.Core; -namespace Microsoft.Quantum.QsCompiler.Transformations.MonomorphizationValidation +namespace Microsoft.Quantum.QsCompiler.Transformations.Monomorphization.Validation { - public class MonomorphizationValidationTransformation + public class ValidateMonomorphization : SyntaxTreeTransformation { public static void Apply(QsCompilation compilation) { - var filter = new MonomorphizationValidationSyntax(); + var filter = new ValidateMonomorphization(); foreach (var ns in compilation.Namespaces) { - filter.Transform(ns); + filter.Namespaces.OnNamespace(ns); } } - private class MonomorphizationValidationSyntax : SyntaxTreeTransformation> + public class TransformationState { } + + internal ValidateMonomorphization() : base(new TransformationState()) + { + this.Namespaces = new NamespaceTransformation(this); + this.Expressions = new ExpressionTransformation(this); + this.Types = new TypeTransformation(this); + } + + private class NamespaceTransformation : NamespaceTransformation { - public MonomorphizationValidationSyntax(ScopeTransformation scope = null) : - base(scope ?? new ScopeTransformation(new MonomorphizationValidationExpression())) { } + public NamespaceTransformation(SyntaxTreeTransformation parent) : base(parent) { } - public override ResolvedSignature onSignature(ResolvedSignature s) + public override ResolvedSignature OnSignature(ResolvedSignature s) { if (s.TypeParameters.Any()) { throw new Exception("Signatures cannot contains type parameters"); } - return base.onSignature(s); + return base.OnSignature(s); } } - private class MonomorphizationValidationExpression : - ExpressionTransformation + private class ExpressionTransformation : ExpressionTransformation { - public MonomorphizationValidationExpression() : - base(expr => new ExpressionKindTransformation(expr as MonomorphizationValidationExpression), - expr => new MonomorphizationValidationExpressionType(expr as MonomorphizationValidationExpression)) - { } + public ExpressionTransformation(SyntaxTreeTransformation parent) : base(parent) { } - public override ImmutableDictionary>, ResolvedType> onTypeParamResolutions(ImmutableDictionary>, ResolvedType> typeParams) + public override ImmutableDictionary>, ResolvedType> OnTypeParamResolutions(ImmutableDictionary>, ResolvedType> typeParams) { if (typeParams.Any()) { @@ -58,11 +63,11 @@ public override ImmutableDictionary>, } } - private class MonomorphizationValidationExpressionType : ExpressionTypeTransformation + private class TypeTransformation : TypeTransformation { - public MonomorphizationValidationExpressionType(MonomorphizationValidationExpression expr) : base(expr) { } + public TypeTransformation(SyntaxTreeTransformation parent) : base(parent) { } - public override QsTypeKind onTypeParameter(QsTypeParameter tp) + public override QsTypeKind OnTypeParameter(QsTypeParameter tp) { throw new Exception("Type Parameter types must be resolved"); } diff --git a/src/QsCompiler/Transformations/QsharpCodeOutput.cs b/src/QsCompiler/Transformations/QsharpCodeOutput.cs new file mode 100644 index 0000000000..28785776d8 --- /dev/null +++ b/src/QsCompiler/Transformations/QsharpCodeOutput.cs @@ -0,0 +1,1384 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Globalization; +using System.Linq; +using System.Numerics; +using System.Text.RegularExpressions; +using Microsoft.Quantum.QsCompiler.DataTypes; +using Microsoft.Quantum.QsCompiler.SyntaxTokens; +using Microsoft.Quantum.QsCompiler.SyntaxTree; +using Microsoft.Quantum.QsCompiler.TextProcessing; +using Microsoft.Quantum.QsCompiler.Transformations.BasicTransformations; +using Microsoft.Quantum.QsCompiler.Transformations.Core; + + +namespace Microsoft.Quantum.QsCompiler.Transformations.QsCodeOutput +{ + using QsTypeKind = QsTypeKind; + using QsExpressionKind = QsExpressionKind; + + + /// + /// Class used to represent contextual information for expression transformations. + /// + public class TransformationContext + { + public string CurrentNamespace; + public ImmutableHashSet> OpenedNamespaces; + public ImmutableDictionary, NonNullable> NamespaceShortNames; // mapping namespace names to their short names + public ImmutableHashSet> SymbolsInCurrentNamespace; + public ImmutableHashSet> AmbiguousNames; + + public TransformationContext() + { + this.CurrentNamespace = null; + this.OpenedNamespaces = ImmutableHashSet>.Empty; + this.NamespaceShortNames = ImmutableDictionary, NonNullable>.Empty; + this.SymbolsInCurrentNamespace = ImmutableHashSet>.Empty; + this.AmbiguousNames = ImmutableHashSet>.Empty; + } + } + + + /// + /// Class used to generate Q# code for compiled Q# namespaces. + /// Upon calling Transform, the Output property is set to the Q# code corresponding to the given namespace. + /// + public class SyntaxTreeToQsharp + : SyntaxTreeTransformation + { + public const string InvalidType = "__UnknownType__"; + public const string InvalidSet = "__UnknownSet__"; + public const string InvalidIdentifier = "__UnknownId__"; + public const string InvalidExpression = "__InvalidEx__"; + public const string InvalidSymbol = "__InvalidName__"; + public const string InvalidInitializer = "__InvalidInitializer__"; + public const string ExternalImplementation = "__external__"; + public const string InvalidFunctorGenerator = "__UnknownGenerator__"; + + public class TransformationState + { + public Action BeforeInvalidType = null; + public Action BeforeInvalidSet = null; + public Action BeforeInvalidIdentifier = null; + public Action BeforeInvalidExpression = null; + public Action BeforeInvalidSymbol = null; + public Action BeforeInvalidInitializer = null; + public Action BeforeExternalImplementation = null; + public Action BeforeInvalidFunctorGenerator = null; + + internal string TypeOutputHandle = null; + internal string ExpressionOutputHandle = null; + internal readonly List StatementOutputHandle = new List(); + internal readonly List NamespaceOutputHandle = new List(); + + internal QsComments StatementComments = QsComments.Empty; + internal TransformationContext Context; + internal IEnumerable NamespaceDocumentation = null; + + public TransformationState(TransformationContext context = null) => + this.Context = context ?? new TransformationContext(); + + internal static bool PrecededByCode(IEnumerable output) => + output == null ? false : output.Any() && !String.IsNullOrWhiteSpace(output.Last().Replace("{", "")); + + internal static bool PrecededByBlock(IEnumerable output) => + output == null ? false : output.Any() && output.Last().Trim() == "}"; + + internal void InvokeOnInvalid(Action action) + { + this.BeforeExternalImplementation = action; + this.BeforeInvalidInitializer = action; + this.BeforeInvalidSymbol = action; + this.BeforeInvalidIdentifier = action; + this.BeforeInvalidExpression = action; + this.BeforeInvalidType = action; + this.BeforeInvalidSet = action; + } + } + + + public SyntaxTreeToQsharp(TransformationContext context = null) + : base(new TransformationState(context), TransformationOptions.NoRebuild) + { + this.Types = new TypeTransformation(this); + this.ExpressionKinds = new ExpressionKindTransformation(this); + this.StatementKinds = new StatementKindTransformation(this); + this.Statements = new StatementTransformation(this); + this.Namespaces = new NamespaceTransformation(this); + } + + + // public methods for convenience + + public static SyntaxTreeToQsharp Default = + new SyntaxTreeToQsharp(); + + public string ToCode(ResolvedType t) + { + this.Types.OnType(t); + return this.SharedState.TypeOutputHandle; + } + + public string ToCode(QsExpressionKind k) + { + this.ExpressionKinds.OnExpressionKind(k); + return this.SharedState.ExpressionOutputHandle; + } + + public string ToCode(TypedExpression ex) => + this.ToCode(ex.Expression); + + public string ToCode(QsStatementKind stmKind) + { + var nrPreexistingLines = this.SharedState.StatementOutputHandle.Count; + this.StatementKinds.OnStatementKind(stmKind); + return String.Join(Environment.NewLine, this.SharedState.StatementOutputHandle.Skip(nrPreexistingLines)); + } + + public string ToCode(QsStatement stm) => + this.ToCode(stm); + + public string ToCode(QsNamespace ns) + { + var nrPreexistingLines = this.SharedState.NamespaceOutputHandle.Count; + this.Namespaces.OnNamespace(ns); + return String.Join(Environment.NewLine, this.SharedState.NamespaceOutputHandle.Skip(nrPreexistingLines)); + } + + public static string CharacteristicsExpression(ResolvedCharacteristics characteristics) => + TypeTransformation.CharacteristicsExpression(characteristics); + + public static string ArgumentTuple(QsTuple> arg, + Func typeTransformation, Action onInvalidName = null, bool symbolsOnly = false) => + NamespaceTransformation.ArgumentTuple(arg, item => (NamespaceTransformation.SymbolName(item.VariableName, onInvalidName), item.Type), typeTransformation, symbolsOnly); + + public static string DeclarationSignature(QsCallable c, Func typeTransformation, Action onInvalidName = null) + { + var argTuple = ArgumentTuple(c.ArgumentTuple, typeTransformation, onInvalidName); + return $"{c.FullName.Name.Value}{NamespaceTransformation.TypeParameters(c.Signature, onInvalidName)} {argTuple} : {typeTransformation(c.Signature.ReturnType)}"; + } + + + /// + /// For each file in the given parameter array of open directives, + /// generates a dictionary that maps (the name of) each partial namespace contained in the file + /// to a string containing the formatted Q# code for the part of the namespace. + /// Qualified or unqualified names for types and identifiers are generated based on the given namespace and open directives. + /// Throws an ArgumentNullException if the given namespace is null. + /// -> IMPORTANT: The given namespace is expected to contain *all* elements in that namespace for the *entire* compilation unit! + /// + public static bool Apply(out List, string>> generatedCode, + IEnumerable namespaces, + params (NonNullable, ImmutableDictionary, ImmutableArray<(NonNullable, string)>>)[] openDirectives) + { + if (namespaces == null) throw new ArgumentNullException(nameof(namespaces)); + + generatedCode = new List, string>>(); + var symbolsInNS = namespaces.ToImmutableDictionary(ns => ns.Name, ns => ns.Elements + .Select(element => (element is QsNamespaceElement.QsCallable c) ? c.Item.FullName.Name.Value : null) + .Where(name => name != null).Select(name => NonNullable.New(name)).ToImmutableHashSet()); + + var success = true; + foreach (var (sourceFile, imports) in openDirectives) + { + var nsInFile = new Dictionary, string>(); + foreach (var ns in namespaces) + { + var tree = FilterBySourceFile.Apply(ns, sourceFile); + if (!tree.Elements.Any()) continue; + + // determine all symbols that occur in multiple open namespaces + var ambiguousSymbols = symbolsInNS.Where(entry => imports[ns.Name].Contains((entry.Key, null))) + .SelectMany(entry => entry.Value) + .GroupBy(name => name) + .Where(group => group.Count() > 1) + .Select(group => group.Key).ToImmutableHashSet(); + + var openedNS = imports[ns.Name].Where(o => o.Item2 == null).Select(o => o.Item1).ToImmutableHashSet(); + var nsShortNames = imports[ns.Name].Where(o => o.Item2 != null).ToImmutableDictionary(o => o.Item1, o => NonNullable.New(o.Item2)); + var context = new TransformationContext + { + CurrentNamespace = ns.Name.Value, + OpenedNamespaces = openedNS, + NamespaceShortNames = nsShortNames, + SymbolsInCurrentNamespace = symbolsInNS[ns.Name], + AmbiguousNames = ambiguousSymbols + }; + + var totNrInvalid = 0; + var docComments = ns.Documentation[sourceFile]; + var generator = new SyntaxTreeToQsharp(context); + generator.SharedState.InvokeOnInvalid(() => ++totNrInvalid); + generator.SharedState.NamespaceDocumentation = docComments.Count() == 1 ? docComments.Single() : ImmutableArray.Empty; // let's drop the doc if it is ambiguous + generator.Namespaces.OnNamespace(tree); + + if (totNrInvalid > 0) success = false; + nsInFile.Add(ns.Name, String.Join(Environment.NewLine, generator.SharedState.NamespaceOutputHandle)); + } + generatedCode.Add(nsInFile.ToImmutableDictionary()); + } + return success; + } + + + // helper classes + + /// + /// Class used to generate Q# code for Q# types. + /// Adds an Output string property to ExpressionTypeTransformation, + /// that upon calling Transform on a Q# type is set to the Q# code corresponding to that type. + /// + public class TypeTransformation + : TypeTransformation + { + private readonly Func TypeToQs; + + protected string Output // the sole purpose of this is a shorter name ... + { + get => this.SharedState.TypeOutputHandle; + set => SharedState.TypeOutputHandle = value; + } + + public TypeTransformation(SyntaxTreeToQsharp parent) + : base(parent, TransformationOptions.NoRebuild) => + this.TypeToQs = parent.ToCode; + + public TypeTransformation() + : base(new TransformationState(), TransformationOptions.NoRebuild) => + this.TypeToQs = t => + { + this.Transformation.Types.OnType(t); + return this.SharedState.TypeOutputHandle; + }; + + + // internal static methods + + internal static string CharacteristicsExpression(ResolvedCharacteristics characteristics, Action onInvalidSet = null) + { + int CurrentPrecedence = 0; + string SetPrecedenceAndReturn(int prec, string str) + { + CurrentPrecedence = prec; + return str; + } + + string Recur(int prec, ResolvedCharacteristics ex) + { + var output = SetAnnotation(ex); + return prec < CurrentPrecedence || CurrentPrecedence == int.MaxValue ? output : $"({output})"; + } + + string BinaryOperator(Keywords.QsOperator op, ResolvedCharacteristics lhs, ResolvedCharacteristics rhs) => + SetPrecedenceAndReturn(op.prec, $"{Recur(op.prec, lhs)} {op.op} {Recur(op.prec, rhs)}"); + + string SetAnnotation(ResolvedCharacteristics charEx) + { + if (charEx.Expression is CharacteristicsKind.SimpleSet set) + { + string setName = null; + if (set.Item.IsAdjointable) setName = Keywords.qsAdjSet.id; + else if (set.Item.IsControllable) setName = Keywords.qsCtlSet.id; + else throw new NotImplementedException("unknown set name"); + return SetPrecedenceAndReturn(int.MaxValue, setName); + } + else if (charEx.Expression is CharacteristicsKind.Union u) + { return BinaryOperator(Keywords.qsSetUnion, u.Item1, u.Item2); } + else if (charEx.Expression is CharacteristicsKind.Intersection i) + { return BinaryOperator(Keywords.qsSetIntersection, i.Item1, i.Item2); } + else if (charEx.Expression.IsInvalidSetExpr) + { + onInvalidSet?.Invoke(); + return SetPrecedenceAndReturn(int.MaxValue, InvalidSet); + } + else throw new NotImplementedException("unknown set expression"); + } + + return characteristics.Expression.IsEmptySet ? null : SetAnnotation(characteristics); + } + + + // overrides + + public override QsTypeKind OnArrayType(ResolvedType b) + { + this.Output = $"{this.TypeToQs(b)}[]"; + return QsTypeKind.NewArrayType(b); + } + + public override QsTypeKind OnBool() + { + this.Output = Keywords.qsBool.id; + return QsTypeKind.Bool; + } + + public override QsTypeKind OnDouble() + { + this.Output = Keywords.qsDouble.id; + return QsTypeKind.Double; + } + + public override QsTypeKind OnFunction(ResolvedType it, ResolvedType ot) + { + this.Output = $"({this.TypeToQs(it)} -> {this.TypeToQs(ot)})"; + return QsTypeKind.NewFunction(it, ot); + } + + public override QsTypeKind OnInt() + { + this.Output = Keywords.qsInt.id; + return QsTypeKind.Int; + } + + public override QsTypeKind OnBigInt() + { + this.Output = Keywords.qsBigInt.id; + return QsTypeKind.BigInt; + } + + public override QsTypeKind OnInvalidType() + { + this.SharedState.BeforeInvalidType?.Invoke(); + this.Output = InvalidType; + return QsTypeKind.InvalidType; + } + + public override QsTypeKind OnMissingType() + { + this.Output = "_"; // needs to be underscore, since this is valid as type argument specifier + return QsTypeKind.MissingType; + } + + public override QsTypeKind OnPauli() + { + this.Output = Keywords.qsPauli.id; + return QsTypeKind.Pauli; + } + + public override QsTypeKind OnQubit() + { + this.Output = Keywords.qsQubit.id; + return QsTypeKind.Qubit; + } + + public override QsTypeKind OnRange() + { + this.Output = Keywords.qsRange.id; + return QsTypeKind.Range; + } + + public override QsTypeKind OnResult() + { + this.Output = Keywords.qsResult.id; + return QsTypeKind.Result; + } + + public override QsTypeKind OnString() + { + this.Output = Keywords.qsString.id; + return QsTypeKind.String; + } + + public override QsTypeKind OnTupleType(ImmutableArray ts) + { + this.Output = $"({String.Join(", ", ts.Select(this.TypeToQs))})"; + return QsTypeKind.NewTupleType(ts); + } + + public override QsTypeKind OnTypeParameter(QsTypeParameter tp) + { + this.Output = $"'{tp.TypeName.Value}"; + return QsTypeKind.NewTypeParameter(tp); + } + + public override QsTypeKind OnUnitType() + { + this.Output = Keywords.qsUnit.id; + return QsTypeKind.UnitType; + } + + public override QsTypeKind OnUserDefinedType(UserDefinedType udt) + { + var isInCurrentNamespace = udt.Namespace.Value == this.SharedState.Context.CurrentNamespace; + var isInOpenNamespace = this.SharedState.Context.OpenedNamespaces.Contains(udt.Namespace) && !this.SharedState.Context.SymbolsInCurrentNamespace.Contains(udt.Name); + var hasShortName = this.SharedState.Context.NamespaceShortNames.TryGetValue(udt.Namespace, out var shortName); + this.Output = isInCurrentNamespace || (isInOpenNamespace && !this.SharedState.Context.AmbiguousNames.Contains(udt.Name)) + ? udt.Name.Value + : $"{(hasShortName ? shortName.Value : udt.Namespace.Value)}.{udt.Name.Value}"; + return QsTypeKind.NewUserDefinedType(udt); + } + + public override ResolvedCharacteristics OnCharacteristicsExpression(ResolvedCharacteristics set) + { + this.Output = CharacteristicsExpression(set, onInvalidSet: this.SharedState.BeforeInvalidSet); + return set; + } + + public override QsTypeKind OnOperation(Tuple sign, CallableInformation info) + { + info = this.OnCallableInformation(info); + var characteristics = String.IsNullOrWhiteSpace(this.Output) ? "" : $" {Keywords.qsCharacteristics.id} {this.Output}"; + this.Output = $"({this.TypeToQs(sign.Item1)} => {this.TypeToQs(sign.Item2)}{characteristics})"; + return QsTypeKind.NewOperation(sign, info); + } + } + + + /// + /// Class used to generate Q# code for Q# expressions. + /// Upon calling Transform, the Output property is set to the Q# code corresponding to an expression of the given kind. + /// + public class ExpressionKindTransformation + : ExpressionKindTransformation + { + // allows to omit unnecessary parentheses + private int CurrentPrecedence = 0; + + // used to replace interpolated pieces in string literals + private static readonly Regex InterpolationArg = + new Regex(@"(? TypeToQs; + + protected string Output // the sole purpose of this is a shorter name ... + { + get => this.SharedState.ExpressionOutputHandle; + set => SharedState.ExpressionOutputHandle = value; + } + + public ExpressionKindTransformation(SyntaxTreeToQsharp parent) + : base(parent, TransformationOptions.NoRebuild) => + this.TypeToQs = parent.ToCode; + + + // private helper functions + + private static string ReplaceInterpolatedArgs(string text, Func replace) + { + var itemNr = 0; + string ReplaceMatch(Match m) => replace?.Invoke(itemNr++); + return InterpolationArg.Replace(text, ReplaceMatch); + } + + private string Recur(int prec, TypedExpression ex) + { + this.Transformation.Expressions.OnTypedExpression(ex); + return prec < this.CurrentPrecedence || this.CurrentPrecedence == int.MaxValue // need to cover the case where prec = currentPrec = MaxValue + ? this.Output + : $"({this.Output})"; + } + + private void UnaryOperator(Keywords.QsOperator op, TypedExpression ex) + { + this.Output = Keywords.ReservedKeywords.Contains(op.op) + ? $"{op.op} {this.Recur(op.prec, ex)}" + : $"{op.op}{this.Recur(op.prec, ex)}"; + this.CurrentPrecedence = op.prec; + } + + private void BinaryOperator(Keywords.QsOperator op, TypedExpression lhs, TypedExpression rhs) + { + this.Output = $"{this.Recur(op.prec, lhs)} {op.op} {this.Recur(op.prec, rhs)}"; + this.CurrentPrecedence = op.prec; + } + + private void TernaryOperator(Keywords.QsOperator op, TypedExpression fst, TypedExpression snd, TypedExpression trd) + { + this.Output = $"{this.Recur(op.prec, fst)} {op.op} {this.Recur(op.prec, snd)} {op.cont} {this.Recur(op.prec, trd)}"; + this.CurrentPrecedence = op.prec; + } + + private QsExpressionKind CallLike(TypedExpression method, TypedExpression arg) + { + var prec = Keywords.qsCallCombinator.prec; + var argStr = arg.Expression.IsValueTuple || arg.Expression.IsUnitValue ? this.Recur(int.MinValue, arg) : $"({this.Recur(int.MinValue, arg)})"; + this.Output = $"{this.Recur(prec, method)}{argStr}"; + this.CurrentPrecedence = prec; + return QsExpressionKind.NewCallLikeExpression(method, arg); + } + + + // overrides + + public override QsExpressionKind OnCopyAndUpdateExpression(TypedExpression lhs, TypedExpression acc, TypedExpression rhs) + { + TernaryOperator(Keywords.qsCopyAndUpdateOp, lhs, acc, rhs); + return QsExpressionKind.NewCopyAndUpdate(lhs, acc, rhs); + } + + public override QsExpressionKind OnConditionalExpression(TypedExpression cond, TypedExpression ifTrue, TypedExpression ifFalse) + { + TernaryOperator(Keywords.qsConditionalOp, cond, ifTrue, ifFalse); + return QsExpressionKind.NewCONDITIONAL(cond, ifTrue, ifFalse); + } + + public override QsExpressionKind OnAddition(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsADDop, lhs, rhs); + return QsExpressionKind.NewADD(lhs, rhs); + } + + public override QsExpressionKind OnBitwiseAnd(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsBANDop, lhs, rhs); + return QsExpressionKind.NewBAND(lhs, rhs); + } + + public override QsExpressionKind OnBitwiseExclusiveOr(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsBXORop, lhs, rhs); + return QsExpressionKind.NewBXOR(lhs, rhs); + } + + public override QsExpressionKind OnBitwiseOr(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsBORop, lhs, rhs); + return QsExpressionKind.NewBOR(lhs, rhs); + } + + public override QsExpressionKind OnDivision(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsDIVop, lhs, rhs); + return QsExpressionKind.NewDIV(lhs, rhs); + } + + public override QsExpressionKind OnEquality(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsEQop, lhs, rhs); + return QsExpressionKind.NewEQ(lhs, rhs); + } + + public override QsExpressionKind OnExponentiate(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsPOWop, lhs, rhs); + return QsExpressionKind.NewPOW(lhs, rhs); + } + + public override QsExpressionKind OnGreaterThan(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsGTop, lhs, rhs); + return QsExpressionKind.NewGT(lhs, rhs); + } + + public override QsExpressionKind OnGreaterThanOrEqual(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsGTEop, lhs, rhs); + return QsExpressionKind.NewGTE(lhs, rhs); + } + + public override QsExpressionKind OnInequality(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsNEQop, lhs, rhs); + return QsExpressionKind.NewNEQ(lhs, rhs); + } + + public override QsExpressionKind OnLeftShift(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsLSHIFTop, lhs, rhs); + return QsExpressionKind.NewLSHIFT(lhs, rhs); + } + + public override QsExpressionKind OnLessThan(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsLTop, lhs, rhs); + return QsExpressionKind.NewLT(lhs, rhs); + } + + public override QsExpressionKind OnLessThanOrEqual(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsLTEop, lhs, rhs); + return QsExpressionKind.NewLTE(lhs, rhs); + } + + public override QsExpressionKind OnLogicalAnd(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsANDop, lhs, rhs); + return QsExpressionKind.NewAND(lhs, rhs); + } + + public override QsExpressionKind OnLogicalOr(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsORop, lhs, rhs); + return QsExpressionKind.NewOR(lhs, rhs); + } + + public override QsExpressionKind OnModulo(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsMODop, lhs, rhs); + return QsExpressionKind.NewMOD(lhs, rhs); + } + + public override QsExpressionKind OnMultiplication(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsMULop, lhs, rhs); + return QsExpressionKind.NewMUL(lhs, rhs); + } + + public override QsExpressionKind OnRightShift(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsRSHIFTop, lhs, rhs); + return QsExpressionKind.NewRSHIFT(lhs, rhs); + } + + public override QsExpressionKind OnSubtraction(TypedExpression lhs, TypedExpression rhs) + { + BinaryOperator(Keywords.qsSUBop, lhs, rhs); + return QsExpressionKind.NewSUB(lhs, rhs); + } + + public override QsExpressionKind OnNegative(TypedExpression ex) + { + UnaryOperator(Keywords.qsNEGop, ex); + return QsExpressionKind.NewNEG(ex); + } + + public override QsExpressionKind OnLogicalNot(TypedExpression ex) + { + UnaryOperator(Keywords.qsNOTop, ex); + return QsExpressionKind.NewNOT(ex); + } + + public override QsExpressionKind OnBitwiseNot(TypedExpression ex) + { + UnaryOperator(Keywords.qsBNOTop, ex); + return QsExpressionKind.NewBNOT(ex); + } + + public override QsExpressionKind OnOperationCall(TypedExpression method, TypedExpression arg) + { + return this.CallLike(method, arg); + } + + public override QsExpressionKind OnFunctionCall(TypedExpression method, TypedExpression arg) + { + return this.CallLike(method, arg); + } + + public override QsExpressionKind OnPartialApplication(TypedExpression method, TypedExpression arg) + { + return this.CallLike(method, arg); + } + + public override QsExpressionKind OnAdjointApplication(TypedExpression ex) + { + var op = Keywords.qsAdjointModifier; + this.Output = $"{op.op} {this.Recur(op.prec, ex)}"; + this.CurrentPrecedence = op.prec; + return QsExpressionKind.NewAdjointApplication(ex); + } + + public override QsExpressionKind OnControlledApplication(TypedExpression ex) + { + var op = Keywords.qsControlledModifier; + this.Output = $"{op.op} {this.Recur(op.prec, ex)}"; + this.CurrentPrecedence = op.prec; + return QsExpressionKind.NewControlledApplication(ex); + } + + public override QsExpressionKind OnUnwrapApplication(TypedExpression ex) + { + var op = Keywords.qsUnwrapModifier; + this.Output = $"{this.Recur(op.prec, ex)}{op.op}"; + this.CurrentPrecedence = op.prec; + return QsExpressionKind.NewUnwrapApplication(ex); + } + + public override QsExpressionKind OnUnitValue() + { + this.Output = "()"; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.UnitValue; + } + + public override QsExpressionKind OnMissingExpression() + { + this.Output = "_"; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.MissingExpr; + } + + public override QsExpressionKind OnInvalidExpression() + { + this.SharedState.BeforeInvalidExpression?.Invoke(); + this.Output = InvalidExpression; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.InvalidExpr; + } + + public override QsExpressionKind OnValueTuple(ImmutableArray vs) + { + this.Output = $"({String.Join(", ", vs.Select(v => this.Recur(int.MinValue, v)))})"; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewValueTuple(vs); + } + + public override QsExpressionKind OnValueArray(ImmutableArray vs) + { + this.Output = $"[{String.Join(", ", vs.Select(v => this.Recur(int.MinValue, v)))}]"; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewValueArray(vs); + } + + public override QsExpressionKind OnNewArray(ResolvedType bt, TypedExpression idx) + { + this.Output = $"{Keywords.arrayDecl.id} {this.TypeToQs(bt)}[{this.Recur(int.MinValue, idx)}]"; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewNewArray(bt, idx); + } + + public override QsExpressionKind OnArrayItem(TypedExpression arr, TypedExpression idx) + { + var prec = Keywords.qsArrayAccessCombinator.prec; + this.Output = $"{this.Recur(prec, arr)}[{this.Recur(int.MinValue, idx)}]"; // Todo: generate contextual open range expression when appropriate + this.CurrentPrecedence = prec; + return QsExpressionKind.NewArrayItem(arr, idx); + } + + public override QsExpressionKind OnNamedItem(TypedExpression ex, Identifier acc) + { + this.OnIdentifier(acc, QsNullable>.Null); + var (op, itemName) = (Keywords.qsNamedItemCombinator, this.Output); + this.Output = $"{this.Recur(op.prec, ex)}{op.op}{itemName}"; + return QsExpressionKind.NewNamedItem(ex, acc); + } + + public override QsExpressionKind OnIntLiteral(long i) + { + this.Output = i.ToString(CultureInfo.InvariantCulture); + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewIntLiteral(i); + } + + public override QsExpressionKind OnBigIntLiteral(BigInteger b) + { + this.Output = b.ToString("R", CultureInfo.InvariantCulture) + "L"; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewBigIntLiteral(b); + } + + public override QsExpressionKind OnDoubleLiteral(double d) + { + this.Output = d.ToString("R", CultureInfo.InvariantCulture); + if ((int)d == d) this.Output = $"{this.Output}.0"; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewDoubleLiteral(d); + } + + public override QsExpressionKind OnBoolLiteral(bool b) + { + if (b) this.Output = Keywords.qsTrue.id; + else this.Output = Keywords.qsFalse.id; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewBoolLiteral(b); + } + + public override QsExpressionKind OnStringLiteral(NonNullable s, ImmutableArray exs) + { + string InterpolatedArg(int index) => $"{{{this.Recur(int.MinValue, exs[index])}}}"; + this.Output = exs.Length == 0 ? $"\"{s.Value}\"" : $"$\"{ReplaceInterpolatedArgs(s.Value, InterpolatedArg)}\""; + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewStringLiteral(s, exs); + } + + public override QsExpressionKind OnRangeLiteral(TypedExpression lhs, TypedExpression rhs) + { + var op = Keywords.qsRangeOp; + var lhsStr = lhs.Expression.IsRangeLiteral ? this.Recur(int.MinValue, lhs) : this.Recur(op.prec, lhs); + this.Output = $"{lhsStr} {op.op} {this.Recur(op.prec, rhs)}"; + this.CurrentPrecedence = op.prec; + return QsExpressionKind.NewRangeLiteral(lhs, rhs); + } + + public override QsExpressionKind OnResultLiteral(QsResult r) + { + if (r.IsZero) this.Output = Keywords.qsZero.id; + else if (r.IsOne) this.Output = Keywords.qsOne.id; + else throw new NotImplementedException("unknown Result literal"); + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewResultLiteral(r); + } + + public override QsExpressionKind OnPauliLiteral(QsPauli p) + { + if (p.IsPauliI) this.Output = Keywords.qsPauliI.id; + else if (p.IsPauliX) this.Output = Keywords.qsPauliX.id; + else if (p.IsPauliY) this.Output = Keywords.qsPauliY.id; + else if (p.IsPauliZ) this.Output = Keywords.qsPauliZ.id; + else throw new NotImplementedException("unknown Pauli literal"); + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewPauliLiteral(p); + } + + public override QsExpressionKind OnIdentifier(Identifier sym, QsNullable> tArgs) + { + if (sym is Identifier.LocalVariable loc) + { this.Output = loc.Item.Value; } + else if (sym.IsInvalidIdentifier) + { + this.SharedState.BeforeInvalidIdentifier?.Invoke(); + this.Output = InvalidIdentifier; + } + else if (sym is Identifier.GlobalCallable global) + { + var isInCurrentNamespace = global.Item.Namespace.Value == this.SharedState.Context.CurrentNamespace; + var isInOpenNamespace = this.SharedState.Context.OpenedNamespaces.Contains(global.Item.Namespace) && !this.SharedState.Context.SymbolsInCurrentNamespace.Contains(global.Item.Name); + var hasShortName = this.SharedState.Context.NamespaceShortNames.TryGetValue(global.Item.Namespace, out var shortName); + this.Output = isInCurrentNamespace || (isInOpenNamespace && !this.SharedState.Context.AmbiguousNames.Contains(global.Item.Name)) + ? global.Item.Name.Value + : $"{(hasShortName ? shortName.Value : global.Item.Namespace.Value)}.{global.Item.Name.Value}"; + } + else throw new NotImplementedException("unknown identifier kind"); + + if (tArgs.IsValue) + { + this.Output = $"{this.Output}<{ String.Join(", ", tArgs.Item.Select(this.TypeToQs))}>"; + } + + this.CurrentPrecedence = int.MaxValue; + return QsExpressionKind.NewIdentifier(sym, tArgs); + } + } + + + /// + /// Class used to generate Q# code for Q# statements. + /// Upon calling Transform, the _Output property of the scope transformation given on initialization + /// is set to the Q# code corresponding to a statement of the given kind. + /// + public class StatementKindTransformation + : StatementKindTransformation + { + private int CurrentIndendation = 0; + + private readonly Func ExpressionToQs; + + private bool PrecededByCode => + TransformationState.PrecededByCode(this.SharedState.StatementOutputHandle); + + private bool PrecededByBlock => + TransformationState.PrecededByBlock(this.SharedState.StatementOutputHandle); + + public StatementKindTransformation(SyntaxTreeToQsharp parent) + : base(parent, TransformationOptions.NoRebuild) => + this.ExpressionToQs = parent.ToCode; + + + // private helper functions + + private void AddToOutput(string line) + { + for (var i = 0; i < this.CurrentIndendation; ++i) line = $" {line}"; + this.SharedState.StatementOutputHandle.Add(line); + } + + private void AddComments(IEnumerable comments) + { + foreach (var comment in comments) + { this.AddToOutput(String.IsNullOrWhiteSpace(comment) ? "" : $"//{comment}"); } + } + + private void AddStatement(string stm) + { + var comments = this.SharedState.StatementComments; + if (this.PrecededByBlock || (this.PrecededByCode && comments.OpeningComments.Length != 0)) this.AddToOutput(""); + this.AddComments(comments.OpeningComments); + this.AddToOutput($"{stm};"); + this.AddComments(comments.ClosingComments); + if (comments.ClosingComments.Length != 0) this.AddToOutput(""); + } + + private void AddBlockStatement(string intro, QsScope statements, bool withWhiteSpace = true) + { + var comments = this.SharedState.StatementComments; + if (this.PrecededByCode && withWhiteSpace) this.AddToOutput(""); + this.AddComments(comments.OpeningComments); + this.AddToOutput($"{intro} {"{"}"); + ++this.CurrentIndendation; + this.Transformation.Statements.OnScope(statements); + this.AddComments(comments.ClosingComments); + --this.CurrentIndendation; + this.AddToOutput("}"); + } + + private string SymbolTuple(SymbolTuple sym) + { + if (sym.IsDiscardedItem) return "_"; + else if (sym is SymbolTuple.VariableName name) return name.Item.Value; + else if (sym is SymbolTuple.VariableNameTuple tuple) return $"({String.Join(", ", tuple.Item.Select(SymbolTuple))})"; + else if (sym.IsInvalidItem) + { + this.SharedState.BeforeInvalidSymbol?.Invoke(); + return InvalidSymbol; + } + else throw new NotImplementedException("unknown item in symbol tuple"); + } + + private string InitializerTuple(ResolvedInitializer init) + { + if (init.Resolution.IsSingleQubitAllocation) return $"{Keywords.qsQubit.id}()"; + else if (init.Resolution is QsInitializerKind.QubitRegisterAllocation reg) + { return $"{Keywords.qsQubit.id}[{this.ExpressionToQs(reg.Item)}]"; } + else if (init.Resolution is QsInitializerKind.QubitTupleAllocation tuple) + { return $"({String.Join(", ", tuple.Item.Select(InitializerTuple))})"; } + else if (init.Resolution.IsInvalidInitializer) + { + this.SharedState.BeforeInvalidInitializer?.Invoke(); + return InvalidInitializer; + } + else throw new NotImplementedException("unknown qubit initializer"); + } + + + // overrides + + public override QsStatementKind OnQubitScope(QsQubitScope stm) + { + var symbols = this.SymbolTuple(stm.Binding.Lhs); + var initializers = this.InitializerTuple(stm.Binding.Rhs); + string header; + if (stm.Kind.IsBorrow) header = Keywords.qsBorrowing.id; + else if (stm.Kind.IsAllocate) header = Keywords.qsUsing.id; + else throw new NotImplementedException("unknown qubit scope"); + + var intro = $"{header} ({symbols} = {initializers})"; + this.AddBlockStatement(intro, stm.Body); + return QsStatementKind.NewQsQubitScope(stm); + } + + public override QsStatementKind OnForStatement(QsForStatement stm) + { + var symbols = this.SymbolTuple(stm.LoopItem.Item1); + var intro = $"{Keywords.qsFor.id} ({symbols} {Keywords.qsRangeIter.id} {this.ExpressionToQs(stm.IterationValues)})"; + this.AddBlockStatement(intro, stm.Body); + return QsStatementKind.NewQsForStatement(stm); + } + + public override QsStatementKind OnWhileStatement(QsWhileStatement stm) + { + var intro = $"{Keywords.qsWhile.id} ({this.ExpressionToQs(stm.Condition)})"; + this.AddBlockStatement(intro, stm.Body); + return QsStatementKind.NewQsWhileStatement(stm); + } + + public override QsStatementKind OnRepeatStatement(QsRepeatStatement stm) + { + this.SharedState.StatementComments = stm.RepeatBlock.Comments; + this.AddBlockStatement(Keywords.qsRepeat.id, stm.RepeatBlock.Body); + this.SharedState.StatementComments = stm.FixupBlock.Comments; + this.AddToOutput($"{Keywords.qsUntil.id} ({this.ExpressionToQs(stm.SuccessCondition)})"); + this.AddBlockStatement(Keywords.qsRUSfixup.id, stm.FixupBlock.Body, false); + return QsStatementKind.NewQsRepeatStatement(stm); + } + + public override QsStatementKind OnConditionalStatement(QsConditionalStatement stm) + { + var header = Keywords.qsIf.id; + if (this.PrecededByCode) this.AddToOutput(""); + foreach (var clause in stm.ConditionalBlocks) + { + this.SharedState.StatementComments = clause.Item2.Comments; + var intro = $"{header} ({this.ExpressionToQs(clause.Item1)})"; + this.AddBlockStatement(intro, clause.Item2.Body, false); + header = Keywords.qsElif.id; + } + if (stm.Default.IsValue) + { + this.SharedState.StatementComments = stm.Default.Item.Comments; + this.AddBlockStatement(Keywords.qsElse.id, stm.Default.Item.Body, false); + } + return QsStatementKind.NewQsConditionalStatement(stm); + } + + public override QsStatementKind OnConjugation(QsConjugation stm) + { + this.SharedState.StatementComments = stm.OuterTransformation.Comments; + this.AddBlockStatement(Keywords.qsWithin.id, stm.OuterTransformation.Body, true); + this.SharedState.StatementComments = stm.InnerTransformation.Comments; + this.AddBlockStatement(Keywords.qsApply.id, stm.InnerTransformation.Body, false); + return QsStatementKind.NewQsConjugation(stm); + } + + + public override QsStatementKind OnExpressionStatement(TypedExpression ex) + { + this.AddStatement(this.ExpressionToQs(ex)); + return QsStatementKind.NewQsExpressionStatement(ex); + } + + public override QsStatementKind OnFailStatement(TypedExpression ex) + { + this.AddStatement($"{Keywords.qsFail.id} {this.ExpressionToQs(ex)}"); + return QsStatementKind.NewQsFailStatement(ex); + } + + public override QsStatementKind OnReturnStatement(TypedExpression ex) + { + this.AddStatement($"{Keywords.qsReturn.id} {this.ExpressionToQs(ex)}"); + return QsStatementKind.NewQsReturnStatement(ex); + } + + public override QsStatementKind OnVariableDeclaration(QsBinding stm) + { + string header; + if (stm.Kind.IsImmutableBinding) header = Keywords.qsImmutableBinding.id; + else if (stm.Kind.IsMutableBinding) header = Keywords.qsMutableBinding.id; + else throw new NotImplementedException("unknown binding kind"); + + this.AddStatement($"{header} {this.SymbolTuple(stm.Lhs)} = {this.ExpressionToQs(stm.Rhs)}"); + return QsStatementKind.NewQsVariableDeclaration(stm); + } + + public override QsStatementKind OnValueUpdate(QsValueUpdate stm) + { + this.AddStatement($"{Keywords.qsValueUpdate.id} {this.ExpressionToQs(stm.Lhs)} = {this.ExpressionToQs(stm.Rhs)}"); + return QsStatementKind.NewQsValueUpdate(stm); + } + } + + + /// + /// Class used to generate Q# code for Q# statements. + /// Upon calling Transform, the Output property is set to the Q# code corresponding to the given statement block. + /// + public class StatementTransformation + : StatementTransformation + { + public StatementTransformation(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } + + + // overrides + + public override QsStatement OnStatement(QsStatement stm) + { + this.SharedState.StatementComments = stm.Comments; + return base.OnStatement(stm); + } + } + + + public class NamespaceTransformation + : NamespaceTransformation + { + private int CurrentIndendation = 0; + private string CurrentSpecialization = null; + private int NrSpecialzations = 0; + + private QsComments DeclarationComments = QsComments.Empty; + + private readonly Func TypeToQs; + + private List Output => // the sole purpose of this is a shorter name ... + this.SharedState.NamespaceOutputHandle; + + public NamespaceTransformation(SyntaxTreeToQsharp parent) + : base(parent, TransformationOptions.NoRebuild) => + this.TypeToQs = parent.ToCode; + + + // private helper functions + + private void AddToOutput(string line) + { + for (var i = 0; i < this.CurrentIndendation; ++i) line = $" {line}"; + this.Output.Add(line); + } + + private void AddComments(IEnumerable comments) + { + foreach (var comment in comments) + { this.AddToOutput(String.IsNullOrWhiteSpace(comment) ? "" : $"//{comment}"); } + } + + private void AddDirective(string str) + { + this.AddToOutput($"{str};"); + } + + private void AddDocumentation(IEnumerable doc) + { + if (doc == null) return; + foreach (var line in doc) + { this.AddToOutput($"///{line}"); } + } + + private void AddBlock(Action processBlock) + { + var comments = this.DeclarationComments; + var opening = "{"; + if (!this.Output.Any()) this.AddToOutput(opening); + else this.Output[this.Output.Count - 1] += $" {opening}"; + ++this.CurrentIndendation; + processBlock(); + this.AddComments(comments.ClosingComments); + --this.CurrentIndendation; + this.AddToOutput("}"); + } + + private void ProcessNamespaceElements(IEnumerable elements) + { + var types = elements.Where(e => e.IsQsCustomType); + var callables = elements.Where(e => e.IsQsCallable); + + foreach (var t in types) + { this.OnNamespaceElement(t); } + if (types.Any()) this.AddToOutput(""); + + foreach (var c in callables) + { this.OnNamespaceElement(c); } + } + + + // internal static methods + + internal static string SymbolName(QsLocalSymbol sym, Action onInvalidName) + { + if (sym is QsLocalSymbol.ValidName n) return n.Item.Value; + else if (sym.IsInvalidName) + { + onInvalidName?.Invoke(); + return InvalidSymbol; + } + else throw new NotImplementedException("unknown case for local symbol"); + } + + internal static string TypeParameters(ResolvedSignature sign, Action onInvalidName) + { + if (sign.TypeParameters.IsEmpty) return String.Empty; + return $"<{String.Join(", ", sign.TypeParameters.Select(tp => $"'{SymbolName(tp, onInvalidName)}"))}>"; + } + + internal static string ArgumentTuple(QsTuple arg, + Func getItemNameAndType, Func typeTransformation, bool symbolsOnly = false) + { + if (arg is QsTuple.QsTuple t) + { return $"({String.Join(", ", t.Item.Select(a => ArgumentTuple(a, getItemNameAndType, typeTransformation, symbolsOnly)))})"; } + else if (arg is QsTuple.QsTupleItem i) + { + var (itemName, itemType) = getItemNameAndType(i.Item); + return itemName == null + ? $"{(symbolsOnly ? "_" : $"{typeTransformation(itemType)}")}" + : $"{itemName}{(symbolsOnly ? "" : $" : {typeTransformation(itemType)}")}"; + } + else throw new NotImplementedException("unknown case for argument tuple item"); + } + + + // overrides + + public override Tuple>, QsScope> OnProvidedImplementation + (QsTuple> argTuple, QsScope body) + { + var functorArg = "(...)"; + if (this.CurrentSpecialization == Keywords.ctrlDeclHeader.id || this.CurrentSpecialization == Keywords.ctrlAdjDeclHeader.id) + { + var ctlQubitsName = SyntaxGenerator.ControlledFunctorArgument(argTuple); + if (ctlQubitsName != null) functorArg = $"({ctlQubitsName}, ...)"; + } + else if (this.CurrentSpecialization != Keywords.bodyDeclHeader.id && this.CurrentSpecialization != Keywords.adjDeclHeader.id) + { throw new NotImplementedException("the current specialization could not be determined"); } + + void ProcessContent() + { + this.SharedState.StatementOutputHandle.Clear(); + this.Transformation.Statements.OnScope(body); + foreach (var line in this.SharedState.StatementOutputHandle) + { this.AddToOutput(line); } + } + if (this.NrSpecialzations != 1) // todo: needs to be adapted once we support type specializations + { + this.AddToOutput($"{this.CurrentSpecialization} {functorArg}"); + this.AddBlock(ProcessContent); + } + else + { + var comments = this.DeclarationComments; + ProcessContent(); + this.AddComments(comments.ClosingComments); + } + return new Tuple>, QsScope>(argTuple, body); + } + + public override void OnInvalidGeneratorDirective() + { + this.SharedState.BeforeInvalidFunctorGenerator?.Invoke(); + this.AddDirective($"{this.CurrentSpecialization} {InvalidFunctorGenerator}"); + } + + public override void OnDistributeDirective() => + this.AddDirective($"{this.CurrentSpecialization} {Keywords.distributeFunctorGenDirective.id}"); + + public override void OnInvertDirective() => + this.AddDirective($"{this.CurrentSpecialization} {Keywords.invertFunctorGenDirective.id}"); + + public override void OnSelfInverseDirective() => + this.AddDirective($"{this.CurrentSpecialization} {Keywords.selfFunctorGenDirective.id}"); + + public override void OnIntrinsicImplementation() => + this.AddDirective($"{this.CurrentSpecialization} {Keywords.intrinsicFunctorGenDirective.id}"); + + public override void OnExternalImplementation() + { + this.SharedState.BeforeExternalImplementation?.Invoke(); + this.AddDirective($"{this.CurrentSpecialization} {ExternalImplementation}"); + } + + public override QsSpecialization OnBodySpecialization(QsSpecialization spec) + { + this.CurrentSpecialization = Keywords.bodyDeclHeader.id; + return base.OnBodySpecialization(spec); + } + + public override QsSpecialization OnAdjointSpecialization(QsSpecialization spec) + { + this.CurrentSpecialization = Keywords.adjDeclHeader.id; + return base.OnAdjointSpecialization(spec); + } + + public override QsSpecialization OnControlledSpecialization(QsSpecialization spec) + { + this.CurrentSpecialization = Keywords.ctrlDeclHeader.id; + return base.OnControlledSpecialization(spec); + } + + public override QsSpecialization OnControlledAdjointSpecialization(QsSpecialization spec) + { + this.CurrentSpecialization = Keywords.ctrlAdjDeclHeader.id; + return base.OnControlledAdjointSpecialization(spec); + } + + public override QsSpecialization OnSpecializationDeclaration(QsSpecialization spec) + { + var precededByCode = TransformationState.PrecededByCode(this.Output); + var precededByBlock = TransformationState.PrecededByBlock(this.Output); + if (precededByCode && (precededByBlock || spec.Implementation.IsProvided || spec.Documentation.Any())) this.AddToOutput(""); + this.DeclarationComments = spec.Comments; + this.AddComments(spec.Comments.OpeningComments); + if (spec.Comments.OpeningComments.Any() && spec.Documentation.Any()) this.AddToOutput(""); + this.AddDocumentation(spec.Documentation); + return base.OnSpecializationDeclaration(spec); + } + + public override QsCallable OnCallableDeclaration(QsCallable c) + { + if (c.Kind.IsTypeConstructor) return c; // no code for these + + this.AddToOutput(""); + this.DeclarationComments = c.Comments; + this.AddComments(c.Comments.OpeningComments); + if (c.Comments.OpeningComments.Any() && c.Documentation.Any()) this.AddToOutput(""); + this.AddDocumentation(c.Documentation); + foreach (var attribute in c.Attributes) + { this.OnAttribute(attribute); } + + var signature = DeclarationSignature(c, this.TypeToQs, this.SharedState.BeforeInvalidSymbol); + this.Transformation.Types.OnCharacteristicsExpression(c.Signature.Information.Characteristics); + var characteristics = this.SharedState.TypeOutputHandle; + + var userDefinedSpecs = c.Specializations.Where(spec => spec.Implementation.IsProvided); + var specBundles = SpecializationBundleProperties.Bundle(spec => spec.TypeArguments, spec => spec.Kind, userDefinedSpecs); + bool NeedsToBeExplicit(QsSpecialization s) + { + if (s.Kind.IsQsBody) return true; + else if (s.Implementation is SpecializationImplementation.Generated gen) + { + if (gen.Item.IsSelfInverse) return s.Kind.IsQsAdjoint; + if (s.Kind.IsQsControlled || s.Kind.IsQsAdjoint) return false; + + var relevantUserDefinedSpecs = specBundles.TryGetValue(SpecializationBundleProperties.BundleId(s.TypeArguments), out var dict) + ? dict // there may be no user defined implementations for a certain set of type arguments, in which case there is no such entry in the dictionary + : ImmutableDictionary.Empty; + var userDefAdj = relevantUserDefinedSpecs.ContainsKey(QsSpecializationKind.QsAdjoint); + var userDefCtl = relevantUserDefinedSpecs.ContainsKey(QsSpecializationKind.QsControlled); + if (gen.Item.IsInvert) return userDefAdj || !userDefCtl; + if (gen.Item.IsDistribute) return userDefCtl && !userDefAdj; + return false; + } + else return !s.Implementation.IsIntrinsic; + } + c = c.WithSpecializations(specs => specs.Where(NeedsToBeExplicit).ToImmutableArray()); + this.NrSpecialzations = c.Specializations.Length; + + var declHeader = + c.Kind.IsOperation ? Keywords.opDeclHeader.id : + c.Kind.IsFunction ? Keywords.fctDeclHeader.id : + throw new NotImplementedException("unknown callable kind"); + + this.AddToOutput($"{declHeader} {signature}"); + if (!String.IsNullOrWhiteSpace(characteristics)) this.AddToOutput($"{Keywords.qsCharacteristics.id} {characteristics}"); + this.AddBlock(() => c.Specializations.Select(this.OnSpecializationDeclaration).ToImmutableArray()); + this.AddToOutput(""); + return c; + } + + public override QsCustomType OnTypeDeclaration(QsCustomType t) + { + this.AddToOutput(""); + this.DeclarationComments = t.Comments; // no need to deal with closing comments (can't exist), but need to make sure DeclarationComments is up to date + this.AddComments(t.Comments.OpeningComments); + if (t.Comments.OpeningComments.Any() && t.Documentation.Any()) this.AddToOutput(""); + this.AddDocumentation(t.Documentation); + foreach (var attribute in t.Attributes) + { this.OnAttribute(attribute); } + + (string, ResolvedType) GetItemNameAndType(QsTypeItem item) + { + if (item is QsTypeItem.Named named) return (named.Item.VariableName.Value, named.Item.Type); + else if (item is QsTypeItem.Anonymous type) return (null, type.Item); + else throw new NotImplementedException("unknown case for type item"); + } + var udtTuple = ArgumentTuple(t.TypeItems, GetItemNameAndType, this.TypeToQs); + this.AddDirective($"{Keywords.typeDeclHeader.id} {t.FullName.Name.Value} = {udtTuple}"); + return t; + } + + public override QsDeclarationAttribute OnAttribute(QsDeclarationAttribute att) + { + // do *not* set DeclarationComments! + this.Transformation.Expressions.OnTypedExpression(att.Argument); + var arg = this.SharedState.ExpressionOutputHandle; + var argStr = att.Argument.Expression.IsValueTuple || att.Argument.Expression.IsUnitValue ? arg : $"({arg})"; + var id = att.TypeId.IsValue + ? Identifier.NewGlobalCallable(new QsQualifiedName(att.TypeId.Item.Namespace, att.TypeId.Item.Name)) + : Identifier.InvalidIdentifier; + this.Transformation.ExpressionKinds.OnIdentifier(id, QsNullable>.Null); + this.AddComments(att.Comments.OpeningComments); + this.AddToOutput($"@ {this.SharedState.ExpressionOutputHandle}{argStr}"); + return att; + } + + public override QsNamespace OnNamespace(QsNamespace ns) + { + if (this.SharedState.Context.CurrentNamespace != ns.Name.Value) + { + this.SharedState.Context = + new TransformationContext { CurrentNamespace = ns.Name.Value }; + this.SharedState.NamespaceDocumentation = null; + } + + this.AddDocumentation(this.SharedState.NamespaceDocumentation); + this.AddToOutput($"{Keywords.namespaceDeclHeader.id} {ns.Name.Value}"); + this.AddBlock(() => + { + var context = this.SharedState.Context; + var explicitImports = context.OpenedNamespaces.Where(opened => !BuiltIn.NamespacesToAutoOpen.Contains(opened)); + if (explicitImports.Any() || context.NamespaceShortNames.Any()) this.AddToOutput(""); + foreach (var nsName in explicitImports.OrderBy(name => name)) + { this.AddDirective($"{Keywords.importDirectiveHeader.id} {nsName.Value}"); } + foreach (var kv in context.NamespaceShortNames.OrderBy(pair => pair.Key)) + { this.AddDirective($"{Keywords.importDirectiveHeader.id} {kv.Key.Value} {Keywords.importedAs.id} {kv.Value.Value}"); } + if (explicitImports.Any() || context.NamespaceShortNames.Any()) this.AddToOutput(""); + this.ProcessNamespaceElements(ns.Elements); + }); + + return ns; + } + } + } +} + diff --git a/src/QsCompiler/Transformations/SearchAndReplace.cs b/src/QsCompiler/Transformations/SearchAndReplace.cs index bbfb22670f..e12ed5dd81 100644 --- a/src/QsCompiler/Transformations/SearchAndReplace.cs +++ b/src/QsCompiler/Transformations/SearchAndReplace.cs @@ -9,6 +9,8 @@ using Microsoft.Quantum.QsCompiler.DataTypes; using Microsoft.Quantum.QsCompiler.SyntaxTokens; using Microsoft.Quantum.QsCompiler.SyntaxTree; +using Microsoft.Quantum.QsCompiler.Transformations.Core; +using Microsoft.Quantum.QsCompiler.Transformations.QsCodeOutput; namespace Microsoft.Quantum.QsCompiler.Transformations.SearchAndReplace @@ -22,10 +24,10 @@ namespace Microsoft.Quantum.QsCompiler.Transformations.SearchAndReplace /// /// Class that allows to walk the syntax tree and find all locations where a certain identifier occurs. - /// If a set of source file names is given on initialization, the search is limited to callables and specializations in those files. + /// If a set of source file names is given on initialization, the search is limited to callables and specializations in those files. /// - public class IdentifierReferences : - SyntaxTreeTransformation + public class IdentifierReferences + : SyntaxTreeTransformation { public class Location : IEquatable { @@ -43,7 +45,7 @@ public class Location : IEquatable /// public readonly Tuple SymbolRange; - public Location(NonNullable source, Tuple declOffset, QsLocation stmLoc, Tuple range) + public Location(NonNullable source, Tuple declOffset, QsLocation stmLoc, Tuple range) { this.SourceFile = source; this.DeclarationOffset = declOffset ?? throw new ArgumentNullException(nameof(declOffset)); @@ -72,192 +74,204 @@ public override int GetHashCode() } } - public QsQualifiedName IdentifierName; - public Tuple, QsLocation> DeclarationLocation { get; private set; } - public IEnumerable Locations => this._Scope.Locations; - private readonly IImmutableSet> RelevantSourseFiles; - private bool IsRelevant(NonNullable source) => - this.RelevantSourseFiles?.Contains(source) ?? true; - - public IdentifierReferences(QsQualifiedName idName, QsLocation defaultOffset, IImmutableSet> limitToSourceFiles = null) : - base(new IdentifierLocation(idName, defaultOffset)) + /// + /// Class used to track the internal state for a transformation that finds all locations where a certain identifier occurs. + /// If no source file is specified prior to transformation, its name is set to the empty string. + /// The DeclarationOffset needs to be set prior to transformation, and in particular after defining a source file. + /// If no defaultOffset is specified upon initialization then only the locations of occurrences within statements are logged. + /// + public class TransformationState { - this.IdentifierName = idName ?? throw new ArgumentNullException(nameof(idName)); - this.RelevantSourseFiles = limitToSourceFiles; - } + public Tuple, QsLocation> DeclarationLocation { get; internal set; } + public ImmutableList Locations { get; private set; } - public override QsCustomType onType(QsCustomType t) - { - if (!this.IsRelevant(t.SourceFile) || t.Location.IsNull) return t; - if (t.FullName.Equals(this.IdentifierName)) - { this.DeclarationLocation = new Tuple, QsLocation>(t.SourceFile, t.Location.Item); } - return base.onType(t); - } + /// + /// Whenever DeclarationOffset is set, the current statement offset is set to this default value. + /// + private readonly QsLocation DefaultOffset = null; + private readonly IImmutableSet> RelevantSourseFiles = null; - public override QsCallable onCallableImplementation(QsCallable c) - { - if (!this.IsRelevant(c.SourceFile) || c.Location.IsNull) return c; - if (c.FullName.Equals(this.IdentifierName)) - { this.DeclarationLocation = new Tuple, QsLocation>(c.SourceFile, c.Location.Item); } - return base.onCallableImplementation(c); - } + internal bool IsRelevant(NonNullable source) => + this.RelevantSourseFiles?.Contains(source) ?? true; - public override QsDeclarationAttribute onAttribute(QsDeclarationAttribute att) - { - var declRoot = this._Scope.DeclarationOffset; - this._Scope.DeclarationOffset = att.Offset; - if (att.TypeId.IsValue) this._Scope._Expression._Type.onUserDefinedType(att.TypeId.Item); - this._Scope._Expression.Transform(att.Argument); - this._Scope.DeclarationOffset = declRoot; - return att; - } - public override QsSpecialization onSpecializationImplementation(QsSpecialization spec) => - this.IsRelevant(spec.SourceFile) ? base.onSpecializationImplementation(spec) : spec; + internal TransformationState(Func trackId, + QsLocation defaultOffset = null, IImmutableSet> limitToSourceFiles = null) + { + this.TrackIdentifier = trackId ?? throw new ArgumentNullException(nameof(trackId)); + this.RelevantSourseFiles = limitToSourceFiles; + this.Locations = ImmutableList.Empty; + this.DefaultOffset = defaultOffset; + } - public override QsNullable onLocation(QsNullable l) - { - this._Scope.DeclarationOffset = l.IsValue? l.Item.Offset : null; - return l; + private NonNullable CurrentSourceFile = NonNullable.New(""); + private Tuple RootOffset = null; + internal QsLocation CurrentLocation = null; + internal readonly Func TrackIdentifier; + + public Tuple DeclarationOffset + { + internal get => this.RootOffset; + set + { + this.RootOffset = value ?? throw new ArgumentNullException(nameof(value), "declaration offset cannot be null"); + this.CurrentLocation = this.DefaultOffset; + } + } + + public NonNullable Source + { + internal get => this.CurrentSourceFile; + set + { + this.CurrentSourceFile = value; + this.RootOffset = null; + this.CurrentLocation = null; + } + } + + internal void LogIdentifierLocation(Identifier id, QsRangeInfo range) + { + if (this.TrackIdentifier(id) && this.CurrentLocation?.Offset != null && range.IsValue) + { + var idLoc = new Location(this.Source, this.RootOffset, this.CurrentLocation, range.Item); + this.Locations = this.Locations.Add(idLoc); + } + } + + internal void LogIdentifierLocation(TypedExpression ex) + { + if (ex.Expression is QsExpressionKind.Identifier id) + { this.LogIdentifierLocation(id.Item1, ex.Range); } + } } - public override NonNullable onSourceFile(NonNullable f) + public IdentifierReferences(TransformationState state) + : base(state, TransformationOptions.NoRebuild) + { + this.Types = new TypeTransformation(this); + this.Expressions = new TypedExpressionWalker(this.SharedState.LogIdentifierLocation, this); + this.Statements = new StatementTransformation(this); + this.Namespaces = new NamespaceTransformation(this); + } + + public IdentifierReferences(NonNullable idName, QsLocation defaultOffset, IImmutableSet> limitToSourceFiles = null) + : this(new TransformationState(id => id is Identifier.LocalVariable varName && varName.Item.Value == idName.Value, defaultOffset, limitToSourceFiles)) { } + + public IdentifierReferences(QsQualifiedName idName, QsLocation defaultOffset, IImmutableSet> limitToSourceFiles = null) + : this(new TransformationState(id => id is Identifier.GlobalCallable cName && cName.Item.Equals(idName), defaultOffset, limitToSourceFiles)) { - this._Scope.Source = f; - return base.onSourceFile(f); + if (idName == null) throw new ArgumentNullException(nameof(idName)); } // static methods for convenience + public static IEnumerable Find(NonNullable idName, QsScope scope, + NonNullable sourceFile, Tuple rootLoc) + { + var finder = new IdentifierReferences(idName, null, ImmutableHashSet.Create(sourceFile)); + finder.SharedState.Source = sourceFile; + finder.SharedState.DeclarationOffset = rootLoc; // will throw if null + finder.Statements.OnScope(scope ?? throw new ArgumentNullException(nameof(scope))); + return finder.SharedState.Locations; + } + public static IEnumerable Find(QsQualifiedName idName, QsNamespace ns, QsLocation defaultOffset, out Tuple, QsLocation> declarationLocation, IImmutableSet> limitToSourceFiles = null) { var finder = new IdentifierReferences(idName, defaultOffset, limitToSourceFiles); - finder.Transform(ns ?? throw new ArgumentNullException(nameof(ns))); - declarationLocation = finder.DeclarationLocation; - return finder.Locations; + finder.Namespaces.OnNamespace(ns ?? throw new ArgumentNullException(nameof(ns))); + declarationLocation = finder.SharedState.DeclarationLocation; + return finder.SharedState.Locations; } - } - /// - /// Class that allows to walk a scope and find all locations where a certain identifier occurs within an expression. - /// If no source file is specified prior to transformation, its name is set to the empty string. - /// The DeclarationOffset needs to be set prior to transformation, and in particular after defining a source file. - /// If no DefaultOffset is set on initialization then only the locations of occurrences within statements are logged. - /// - public class IdentifierLocation : - ScopeTransformation> - { - public class TypeLocation : - Core.ExpressionTypeTransformation - { - private readonly QsCodeOutput.ExpressionTypeToQs CodeOutput = new QsCodeOutput.ExpressionToQs()._Type; - internal Action OnIdentifier; - public TypeLocation(Action onIdentifier = null) : - base(true) => - this.OnIdentifier = onIdentifier; + // helper classes - public override QsTypeKind onUserDefinedType(UserDefinedType udt) + private class TypeTransformation + : TypeTransformation + { + public TypeTransformation(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } + + public override QsTypeKind OnUserDefinedType(UserDefinedType udt) { - this.OnIdentifier?.Invoke(Identifier.NewGlobalCallable(new QsQualifiedName(udt.Namespace, udt.Name)), udt.Range); + var id = Identifier.NewGlobalCallable(new QsQualifiedName(udt.Namespace, udt.Name)); + this.SharedState.LogIdentifierLocation(id, udt.Range); return QsTypeKind.NewUserDefinedType(udt); } - public override QsTypeKind onTypeParameter(QsTypeParameter tp) + public override QsTypeKind OnTypeParameter(QsTypeParameter tp) { - this.CodeOutput.onTypeParameter(tp); - var tpName = NonNullable.New(this.CodeOutput.Output ?? ""); - this.OnIdentifier?.Invoke(Identifier.NewLocalVariable(tpName), tp.Range); - return QsTypeKind.NewTypeParameter(tp); + var resT = ResolvedType.New(QsTypeKind.NewTypeParameter(tp)); + var id = Identifier.NewLocalVariable(NonNullable.New(SyntaxTreeToQsharp.Default.ToCode(resT) ?? "")); + this.SharedState.LogIdentifierLocation(id, tp.Range); + return resT.Resolution; } } - - private IdentifierLocation(Func trackId, QsLocation defaultOffset) : - base(null, new OnTypedExpression(null, _ => new TypeLocation(), recur: true)) + private class StatementTransformation + : StatementTransformation { - this.TrackIdentifier = trackId ?? throw new ArgumentNullException(nameof(trackId)); - this.Locations = ImmutableList.Empty; - this.DefaultOffset = defaultOffset; - this._Expression.OnExpression = this.OnExpression; - this._Expression._Type.OnIdentifier = this.LogIdentifierLocation; - } - - public IdentifierLocation(NonNullable idName, QsLocation defaultOffset = null) : - this(id => id is Identifier.LocalVariable varName && varName.Item.Value == idName.Value, defaultOffset) - { } + public StatementTransformation(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } - public IdentifierLocation(QsQualifiedName idName, QsLocation defaultOffset = null) : - this(id => id is Identifier.GlobalCallable cName && cName.Item.Equals(idName), defaultOffset) - { } - - private NonNullable SourceFile; - private Tuple RootOffset; - private QsLocation CurrentLocation; - private readonly Func TrackIdentifier; - - public Tuple DeclarationOffset - { - internal get => this.RootOffset; - set + public override QsNullable OnLocation(QsNullable loc) { - this.RootOffset = value ?? throw new ArgumentNullException(nameof(value), "declaration offset cannot be null"); - this.CurrentLocation = this.DefaultOffset; + this.SharedState.CurrentLocation = loc.IsValue ? loc.Item : null; + return loc; } } - public NonNullable Source + private class NamespaceTransformation + : NamespaceTransformation { - internal get => this.SourceFile; - set - { - this.SourceFile = value; - this.RootOffset = null; - this.CurrentLocation = null; - } - } - /// - /// Whenever DeclarationOffset is set, the current statement offset is set to this default value. - /// - public readonly QsLocation DefaultOffset; - public ImmutableList Locations { get; private set; } + public NamespaceTransformation(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } - public override QsNullable onLocation(QsNullable loc) - { - this.CurrentLocation = loc.IsValue ? loc.Item : null; - return base.onLocation(loc); - } + public override QsCustomType OnTypeDeclaration(QsCustomType t) + { + if (!this.SharedState.IsRelevant(t.SourceFile) || t.Location.IsNull) return t; + if (this.SharedState.TrackIdentifier(Identifier.NewGlobalCallable(t.FullName))) + { this.SharedState.DeclarationLocation = new Tuple, QsLocation>(t.SourceFile, t.Location.Item); } + return base.OnTypeDeclaration(t); + } - private void LogIdentifierLocation(Identifier id, QsRangeInfo range) - { - if (this.TrackIdentifier(id) && this.CurrentLocation?.Offset != null && range.IsValue) + public override QsCallable OnCallableDeclaration(QsCallable c) { - var idLoc = new IdentifierReferences.Location(this.SourceFile, this.RootOffset, this.CurrentLocation, range.Item); - this.Locations = this.Locations.Add(idLoc); + if (!this.SharedState.IsRelevant(c.SourceFile) || c.Location.IsNull) return c; + if (this.SharedState.TrackIdentifier(Identifier.NewGlobalCallable(c.FullName))) + { this.SharedState.DeclarationLocation = new Tuple, QsLocation>(c.SourceFile, c.Location.Item); } + return base.OnCallableDeclaration(c); } - } - private void OnExpression(TypedExpression ex) - { - if (ex.Expression is QsExpressionKind.Identifier id) - { this.LogIdentifierLocation(id.Item1, ex.Range); } - } + public override QsDeclarationAttribute OnAttribute(QsDeclarationAttribute att) + { + var declRoot = this.SharedState.DeclarationOffset; + this.SharedState.DeclarationOffset = att.Offset; + if (att.TypeId.IsValue) this.Transformation.Types.OnUserDefinedType(att.TypeId.Item); + this.Transformation.Expressions.OnTypedExpression(att.Argument); + this.SharedState.DeclarationOffset = declRoot; + return att; + } + public override QsSpecialization OnSpecializationDeclaration(QsSpecialization spec) => + this.SharedState.IsRelevant(spec.SourceFile) ? base.OnSpecializationDeclaration(spec) : spec; - // static methods for convenience + public override QsNullable OnLocation(QsNullable loc) + { + this.SharedState.DeclarationOffset = loc.IsValue ? loc.Item.Offset : null; + return loc; + } - public static IEnumerable Find(NonNullable idName, QsScope scope, - NonNullable sourceFile, Tuple rootLoc) - { - var finder = new IdentifierLocation(idName, null); - finder.SourceFile = sourceFile; - finder.RootOffset = rootLoc ?? throw new ArgumentNullException(nameof(rootLoc)); - finder.Transform(scope ?? throw new ArgumentNullException(nameof(scope))); - return finder.Locations; + public override NonNullable OnSourceFile(NonNullable source) + { + this.SharedState.Source = source; + return source; + } } } @@ -265,68 +279,82 @@ private void OnExpression(TypedExpression ex) // routines for finding all symbols/identifiers /// - /// Generates a look-up for all used local variables and their location in any of the transformed scopes, - /// as well as one for all local variables reassigned in any of the transformed scopes and their locations. - /// Note that the location information is relative to the root node, i.e. the start position of the containing specialization declaration. + /// Generates a look-up for all used local variables and their location in any of the transformed scopes, + /// as well as one for all local variables reassigned in any of the transformed scopes and their locations. + /// Note that the location information is relative to the root node, i.e. the start position of the containing specialization declaration. /// - public class AccumulateIdentifiers : - ScopeTransformation> + public class AccumulateIdentifiers + : SyntaxTreeTransformation { - private QsLocation StatementLocation; - private Func UpdatedExpression; + public class TransformationState + { + internal QsLocation StatementLocation = null; + internal Func UpdatedExpression; - private List<(NonNullable, QsLocation)> UpdatedLocals; - private List<(NonNullable, QsLocation)> UsedLocals; + private readonly List<(NonNullable, QsLocation)> UpdatedLocals = new List<(NonNullable, QsLocation)>(); + private readonly List<(NonNullable, QsLocation)> UsedLocals = new List<(NonNullable, QsLocation)>(); - public ILookup, QsLocation> ReassignedVariables => - this.UpdatedLocals.ToLookup(var => var.Item1, var => var.Item2); + internal TransformationState() => + this.UpdatedExpression = new TypedExpressionWalker(this.UpdatedLocal, this).OnTypedExpression; - public ILookup, QsLocation> UsedLocalVariables => - this.UsedLocals.ToLookup(var => var.Item1, var => var.Item2); + public ILookup, QsLocation> ReassignedVariables => + this.UpdatedLocals.ToLookup(var => var.Item1, var => var.Item2); + public ILookup, QsLocation> UsedLocalVariables => + this.UsedLocals.ToLookup(var => var.Item1, var => var.Item2); - public AccumulateIdentifiers() : - base( - scope => new VariableReassignments(scope as AccumulateIdentifiers), - new OnTypedExpression(recur: true)) - { - this.UpdatedLocals = new List<(NonNullable, QsLocation)>(); - this.UsedLocals = new List<(NonNullable, QsLocation)>(); - this._Expression.OnExpression = this.onLocal(this.UsedLocals); - this.UpdatedExpression = new OnTypedExpression(this.onLocal(this.UpdatedLocals), recur: true).Transform; - } - private Action onLocal(List<(NonNullable, QsLocation)> accumulate) => (TypedExpression ex) => - { - if (ex.Expression is QsExpressionKind.Identifier id && - id.Item1 is Identifier.LocalVariable var) + private Action Add(List<(NonNullable, QsLocation)> accumulate) => (TypedExpression ex) => { - var range = ex.Range.IsValue ? ex.Range.Item : this.StatementLocation.Range; - accumulate.Add((var.Item, new QsLocation(this.StatementLocation.Offset, range))); - } - }; + if (ex.Expression is QsExpressionKind.Identifier id && + id.Item1 is Identifier.LocalVariable var) + { + var range = ex.Range.IsValue ? ex.Range.Item : this.StatementLocation.Range; + accumulate.Add((var.Item, new QsLocation(this.StatementLocation.Offset, range))); + } + }; + + internal Action UsedLocal => Add(this.UsedLocals); + internal Action UpdatedLocal => Add(this.UpdatedLocals); + } + - public override QsStatement onStatement(QsStatement stm) + public AccumulateIdentifiers() + : base(new TransformationState(), TransformationOptions.NoRebuild) { - this.StatementLocation = stm.Location.IsNull ? null : stm.Location.Item; - this.StatementKind.Transform(stm.Statement); - return stm; + this.Statements = new StatementTransformation(this); + this.StatementKinds = new StatementKindTransformation(this); + this.Expressions = new TypedExpressionWalker(this.SharedState.UsedLocal, this); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); } - // helper class + // helper classes + + private class StatementTransformation + : StatementTransformation + { + public StatementTransformation(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } + + public override QsStatement OnStatement(QsStatement stm) + { + this.SharedState.StatementLocation = stm.Location.IsNull ? null : stm.Location.Item; + this.StatementKinds.OnStatementKind(stm.Statement); + return stm; + } + } - public class VariableReassignments : - StatementKindTransformation + private class StatementKindTransformation + : StatementKindTransformation { - public VariableReassignments(AccumulateIdentifiers scope) - : base(scope) - { } + public StatementKindTransformation(SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) { } - public override QsStatementKind onValueUpdate(QsValueUpdate stm) + public override QsStatementKind OnValueUpdate(QsValueUpdate stm) { - this._Scope.UpdatedExpression(stm.Lhs); - this.ExpressionTransformation(stm.Rhs); + this.SharedState.UpdatedExpression(stm.Lhs); + this.Expressions.OnTypedExpression(stm.Rhs); return QsStatementKind.NewQsValueUpdate(stm); } } @@ -337,79 +365,84 @@ public override QsStatementKind onValueUpdate(QsValueUpdate stm) /// /// Upon transformation, assigns each defined variable a unique name, independent on the scope, and replaces all references to it accordingly. - /// The original variable name can be recovered by using the static method StripUniqueName. - /// This class is *not* threadsafe. + /// The original variable name can be recovered by using the static method StripUniqueName. + /// This class is *not* threadsafe. /// public class UniqueVariableNames - : ScopeTransformation> + : SyntaxTreeTransformation { - private const string Prefix = "qsVar"; - private const string OrigVarName = "origVarName"; - private static Regex WrappedVarName = new Regex($"^__{Prefix}[0-9]*__(?<{OrigVarName}>.*)__$"); + public class TransformationState + { + private int VariableNr = 0; + private Dictionary, NonNullable> UniqueNames = + new Dictionary, NonNullable>(); - private int VariableNr; - private Dictionary, NonNullable> UniqueNames; + internal QsExpressionKind AdaptIdentifier(Identifier sym, QsNullable> tArgs) => + sym is Identifier.LocalVariable varName && this.UniqueNames.TryGetValue(varName.Item, out var unique) + ? QsExpressionKind.NewIdentifier(Identifier.NewLocalVariable(unique), tArgs) + : QsExpressionKind.NewIdentifier(sym, tArgs); - /// - /// Will overwrite the dictionary entry mapping a variable name to the corresponding unique name if the key already exists. - /// - internal NonNullable GenerateUniqueName(NonNullable varName) - { - var unique = NonNullable.New($"__{Prefix}{this.VariableNr++}__{varName.Value}__"); - this.UniqueNames[varName] = unique; - return unique; + /// + /// Will overwrite the dictionary entry mapping a variable name to the corresponding unique name if the key already exists. + /// + internal NonNullable GenerateUniqueName(NonNullable varName) + { + var unique = NonNullable.New($"__{Prefix}{this.VariableNr++}__{varName.Value}__"); + this.UniqueNames[varName] = unique; + return unique; + } } - public NonNullable StripUniqueName(NonNullable uniqueName) + + private const string Prefix = "qsVar"; + private const string OrigVarName = "origVarName"; + private static readonly Regex WrappedVarName = new Regex($"^__{Prefix}[0-9]*__(?<{OrigVarName}>.*)__$"); + + public UniqueVariableNames() + : base(new TransformationState()) { - var matched = WrappedVarName.Match(uniqueName.Value).Groups[OrigVarName]; - return matched.Success ? NonNullable.New(matched.Value) : uniqueName; + this.StatementKinds = new StatementKindTransformation(this); + this.ExpressionKinds = new ExpressionKindTransformation(this); + this.Types = new TypeTransformation(this, TransformationOptions.Disabled); } - private QsExpressionKind AdaptIdentifier(Identifier sym, QsNullable> tArgs) => - sym is Identifier.LocalVariable varName && this.UniqueNames.TryGetValue(varName.Item, out var unique) - ? QsExpressionKind.NewIdentifier(Identifier.NewLocalVariable(unique), tArgs) - : QsExpressionKind.NewIdentifier(sym, tArgs); - public UniqueVariableNames(int initVarNr = 0) : - base(s => new ReplaceDeclarations(s as UniqueVariableNames), - new ExpressionTransformation(e => - new ReplaceIdentifiers(e as ExpressionTransformation))) + // static methods for convenience + + internal static QsQualifiedName PrependGuid(QsQualifiedName original) => + new QsQualifiedName(original.Namespace, NonNullable.New("_" + Guid.NewGuid().ToString("N") + "_" + original.Name.Value)); + + public static NonNullable StripUniqueName(NonNullable uniqueName) { - this._Expression._Kind.ReplaceId = this.AdaptIdentifier; - this.VariableNr = initVarNr; - this.UniqueNames = new Dictionary, NonNullable>(); + var matched = WrappedVarName.Match(uniqueName.Value).Groups[OrigVarName]; + return matched.Success ? NonNullable.New(matched.Value) : uniqueName; } // helper classes - public class ReplaceDeclarations - : StatementKindTransformation + private class StatementKindTransformation + : StatementKindTransformation { - public ReplaceDeclarations(UniqueVariableNames scope) - : base(scope) { } + public StatementKindTransformation(SyntaxTreeTransformation parent) + : base(parent) { } - public override SymbolTuple onSymbolTuple(SymbolTuple syms) => + public override SymbolTuple OnSymbolTuple(SymbolTuple syms) => syms is SymbolTuple.VariableNameTuple tuple - ? SymbolTuple.NewVariableNameTuple(tuple.Item.Select(this.onSymbolTuple).ToImmutableArray()) + ? SymbolTuple.NewVariableNameTuple(tuple.Item.Select(this.OnSymbolTuple).ToImmutableArray()) : syms is SymbolTuple.VariableName varName - ? SymbolTuple.NewVariableName(this._Scope.GenerateUniqueName(varName.Item)) + ? SymbolTuple.NewVariableName(this.SharedState.GenerateUniqueName(varName.Item)) : syms; } - public class ReplaceIdentifiers - : ExpressionKindTransformation> + private class ExpressionKindTransformation + : ExpressionKindTransformation { - internal Func>, QsExpressionKind> ReplaceId; + public ExpressionKindTransformation(SyntaxTreeTransformation parent) + : base(parent) { } - public ReplaceIdentifiers(ExpressionTransformation expression, - Func>, QsExpressionKind> replaceId = null) - : base(expression) => - this.ReplaceId = replaceId ?? ((sym, tArgs) => QsExpressionKind.NewIdentifier(sym, tArgs)); - - public override QsExpressionKind onIdentifier(Identifier sym, QsNullable> tArgs) => - this.ReplaceId(sym, tArgs); + public override QsExpressionKind OnIdentifier(Identifier sym, QsNullable> tArgs) => + this.SharedState.AdaptIdentifier(sym, tArgs); } } @@ -417,28 +450,26 @@ public override QsExpressionKind onIdentifier(Identifier sym, QsNullable - /// Recursively applies the specified action OnExpression to each identifier expression upon transformation. - /// Does nothing upon transformation if no action is specified. + /// Upon transformation, applies the specified action to each expression and subexpression. + /// The action to apply is specified upon construction, and will be applied before recurring into subexpressions. + /// The transformation merely walks expressions and rebuilding is disabled. ///
- public class OnTypedExpression : - ExpressionTransformation>, T> - where T : Core.ExpressionTypeTransformation + public class TypedExpressionWalker + : ExpressionTransformation { - private readonly bool recur; + public TypedExpressionWalker(Action onExpression, SyntaxTreeTransformation parent) + : base(parent, TransformationOptions.NoRebuild) => + this.OnExpression = onExpression ?? throw new ArgumentNullException(nameof(onExpression)); - public OnTypedExpression(Action onExpression = null, Func, T> typeTransformation = null, bool recur = false) : - base(e => new ExpressionKindTransformation>(e as OnTypedExpression), - e => typeTransformation?.Invoke(e as OnTypedExpression)) - { - this.OnExpression = onExpression; - this.recur = recur; - } + public TypedExpressionWalker(Action onExpression, T internalState = default) + : base(internalState, TransformationOptions.NoRebuild) => + this.OnExpression = onExpression ?? throw new ArgumentNullException(nameof(onExpression)); - public Action OnExpression; - public override TypedExpression Transform(TypedExpression ex) + public readonly Action OnExpression; + public override TypedExpression OnTypedExpression(TypedExpression ex) { - this.OnExpression?.Invoke(ex); - return this.recur ? base.Transform(ex) : ex; + this.OnExpression(ex); + return base.OnTypedExpression(ex); } } } diff --git a/src/QsCompiler/Transformations/TransformationDefinitions.cs b/src/QsCompiler/Transformations/TransformationDefinitions.cs deleted file mode 100644 index fb8a850a57..0000000000 --- a/src/QsCompiler/Transformations/TransformationDefinitions.cs +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using Microsoft.Quantum.QsCompiler.DataTypes; -using Microsoft.Quantum.QsCompiler.SyntaxTree; - - -namespace Microsoft.Quantum.QsCompiler.Transformations -{ - // syntax tree transformations - - public class SyntaxTreeTransformation : - Core.SyntaxTreeTransformation - where S : Core.ScopeTransformation - { - public readonly S _Scope; - public override Core.ScopeTransformation Scope => this._Scope ?? base.Scope; - - public SyntaxTreeTransformation(S scope) : - base() => - this._Scope = scope; - } - - - // scope transformations - - /// - /// Base class for all StatementKindTransformations. - /// - public class StatementKindTransformation : - Core.StatementKindTransformation - where S : Core.ScopeTransformation - { - public readonly S _Scope; - - public StatementKindTransformation(S scope) : - base(true) => - this._Scope = scope ?? throw new ArgumentNullException(nameof(scope)); - - public override QsScope ScopeTransformation(QsScope value) => - this._Scope.Transform(value); - - public override TypedExpression ExpressionTransformation(TypedExpression value) => - this._Scope.Expression.Transform(value); - - public override ResolvedType TypeTransformation(ResolvedType value) => - this._Scope.Expression.Type.Transform(value); - - public override QsNullable LocationTransformation(QsNullable value) => - this._Scope.onLocation(value); - } - - /// - /// Base class for all ScopeTransformations. - /// - public class ScopeTransformation : - Core.ScopeTransformation - where K : Core.StatementKindTransformation - where E : Core.ExpressionTransformation - { - public readonly K _StatementKind; - private readonly Core.StatementKindTransformation DefaultStatementKind; - public override Core.StatementKindTransformation StatementKind => _StatementKind ?? DefaultStatementKind; - - public readonly E _Expression; - private readonly Core.ExpressionTransformation DefaultExpression; - public override Core.ExpressionTransformation Expression => _Expression ?? DefaultExpression; - - public ScopeTransformation(Func, K> statementKind, E expression) : - base(expression != null) // default kind transformations are enabled only if there are expression transformations - { - this.DefaultStatementKind = base.StatementKind; - this._StatementKind = statementKind?.Invoke(this); - - this.DefaultExpression = new NoExpressionTransformations(); // disable by default - this._Expression = expression; - } - } - - /// - /// Given an expression transformation, Transform applies the given transformation to all expressions in a scope. - /// - public class ScopeTransformation : - ScopeTransformation - where E : Core.ExpressionTransformation - { - public ScopeTransformation(E expression) : - base(null, expression) { } - } - - /// - /// Does not do any transformations, and can be use as no-op if a ScopeTransformation is required as argument. - /// - public class NoScopeTransformations : - ScopeTransformation - { - public NoScopeTransformations() : - base(null) { } - - public override QsScope Transform(QsScope scope) => scope; - } - - - // expression transformations - - /// - /// Base class for all ExpressionTypeTransformations. - /// - public class ExpressionTypeTransformation : - Core.ExpressionTypeTransformation - where E : Core.ExpressionTransformation - { - public readonly E _Expression; - - public ExpressionTypeTransformation(E expression) : - base(true) => - this._Expression = expression ?? throw new ArgumentNullException(nameof(expression)); - } - - /// - /// Base class for all ExpressionKindTransformations. - /// - public class ExpressionKindTransformation : - Core.ExpressionKindTransformation - where E : Core.ExpressionTransformation - { - public readonly E _Expression; - - public ExpressionKindTransformation(E expression) : - base(true) => - this._Expression = expression ?? throw new ArgumentNullException(nameof(expression)); - - public override TypedExpression ExpressionTransformation(TypedExpression value) => - this._Expression.Transform(value); - - public override ResolvedType TypeTransformation(ResolvedType value) => - this._Expression.Type.Transform(value); - } - - /// - /// Base class for all ExpressionTransformations. - /// - public class ExpressionTransformation : - Core.ExpressionTransformation - where K : Core.ExpressionKindTransformation - where T : Core.ExpressionTypeTransformation - { - public readonly K _Kind; - private readonly Core.ExpressionKindTransformation DefaultKind; - public override Core.ExpressionKindTransformation Kind => _Kind ?? DefaultKind; - - public readonly T _Type; - private readonly Core.ExpressionTypeTransformation DefaultType; - public override Core.ExpressionTypeTransformation Type => _Type ?? DefaultType; - - public ExpressionTransformation(Func, K> kind, Func, T> type) : - base(false) // disable transformations by default - { - this.DefaultKind = base.Kind; - this._Kind = kind?.Invoke(this); - - this.DefaultType = new Core.ExpressionTypeTransformation(false); // disabled by default - this._Type = type?.Invoke(this); - } - } - - /// - /// Given an expression kind transformation, Transform applies the given transformation to the Kind of every expression. - /// - public class ExpressionTransformation : - ExpressionTransformation - where K : Core.ExpressionKindTransformation - { - public ExpressionTransformation(Func, K> kind) : - base(kind, null) { } - } - - /// - /// ExpressionTransformation where expression kind transformations are set to their default - - /// i.e. subexpressions are walked, but no transformation is done on the kind itself. - /// - public class DefaultExpressionTransformation : - ExpressionTransformation> - { - public DefaultExpressionTransformation() : - base(e => new ExpressionKindTransformation(e as DefaultExpressionTransformation)) - { } - } - - /// - /// Disables all expression transformations, and can be use as no-op if an ExpressionTransformation is required as argument. - /// - public class NoExpressionTransformations : - ExpressionTransformation - { - public NoExpressionTransformations() : - base(null) { } - - public override TypedExpression Transform(TypedExpression ex) => ex; - } -} - diff --git a/src/QsCompiler/Transformations/WalkerDefinitions.cs b/src/QsCompiler/Transformations/WalkerDefinitions.cs deleted file mode 100644 index 90c4eb1c65..0000000000 --- a/src/QsCompiler/Transformations/WalkerDefinitions.cs +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using Microsoft.Quantum.QsCompiler.DataTypes; -using Microsoft.Quantum.QsCompiler.SyntaxTree; - - -namespace Microsoft.Quantum.QsCompiler.Transformations -{ - // syntax tree walkers - - public class SyntaxTreeWalker : - Core.SyntaxTreeWalker - where S : Core.ScopeWalker - { - public readonly S _Scope; - public override Core.ScopeWalker Scope => this._Scope ?? base.Scope; - - public SyntaxTreeWalker(S scope) : - base() => - this._Scope = scope; - } - - - // scope walkers - - /// - /// Base class for all StatementKindWalkers. - /// - public class StatementKindWalker : - Core.StatementKindWalker - where S : Core.ScopeWalker - { - public readonly S _Scope; - - public StatementKindWalker(S scope) : - base(true) => - this._Scope = scope ?? throw new ArgumentNullException(nameof(scope)); - - public override void ScopeWalker(QsScope value) => - this._Scope.Walk(value); - - public override void ExpressionWalker(TypedExpression value) => - this._Scope.Expression.Walk(value); - - public override void TypeWalker(ResolvedType value) => - this._Scope.Expression.Type.Walk(value); - - public override void LocationWalker(QsNullable value) => - this._Scope.onLocation(value); - } - - /// - /// Base class for all ScopeWalkers. - /// - public class ScopeWalker : - Core.ScopeWalker - where K : Core.StatementKindWalker - where E : Core.ExpressionWalker - { - public readonly K _StatementKind; - private readonly Core.StatementKindWalker DefaultStatementKind; - public override Core.StatementKindWalker StatementKind => _StatementKind ?? DefaultStatementKind; - - public readonly E _Expression; - private readonly Core.ExpressionWalker DefaultExpression; - public override Core.ExpressionWalker Expression => _Expression ?? DefaultExpression; - - public ScopeWalker(Func, K> statementKind, E expression) : - base(expression != null) // default kind Walkers are enabled only if there are expression Walkers - { - this.DefaultStatementKind = base.StatementKind; - this._StatementKind = statementKind?.Invoke(this); - - this.DefaultExpression = new NoExpressionWalkers(); // disable by default - this._Expression = expression; - } - } - - /// - /// Given an expression Walker, Walk applies the given Walker to all expressions in a scope. - /// - public class ScopeWalker : - ScopeWalker - where E : Core.ExpressionWalker - { - public ScopeWalker(E expression) : - base(null, expression) - { } - } - - /// - /// Does not do any Walkers, and can be use as no-op if a ScopeWalker is required as argument. - /// - public class NoScopeWalkers : - ScopeWalker - { - public NoScopeWalkers() : - base(null) - { } - - public override void Walk(QsScope scope) {} - } - - - // expression Walkers - - /// - /// Base class for all ExpressionTypeWalkers. - /// - public class ExpressionTypeWalker : - Core.ExpressionTypeWalker - where E : Core.ExpressionWalker - { - public readonly E _Expression; - - public ExpressionTypeWalker(E expression) : - base(true) => - this._Expression = expression ?? throw new ArgumentNullException(nameof(expression)); - } - - /// - /// Base class for all ExpressionKindWalkers. - /// - public class ExpressionKindWalker : - Core.ExpressionKindWalker - where E : Core.ExpressionWalker - { - public readonly E _Expression; - - public ExpressionKindWalker(E expression) : - base(true) => - this._Expression = expression ?? throw new ArgumentNullException(nameof(expression)); - - public override void ExpressionWalker(TypedExpression value) => - this._Expression.Walk(value); - - public override void TypeWalker(ResolvedType value) => - this._Expression.Type.Walk(value); - } - - /// - /// Base class for all ExpressionWalkers. - /// - public class ExpressionWalker : - Core.ExpressionWalker - where K : Core.ExpressionKindWalker - where T : Core.ExpressionTypeWalker - { - public readonly K _Kind; - private readonly Core.ExpressionKindWalker DefaultKind; - public override Core.ExpressionKindWalker Kind => _Kind ?? DefaultKind; - - public readonly T _Type; - private readonly Core.ExpressionTypeWalker DefaultType; - public override Core.ExpressionTypeWalker Type => _Type ?? DefaultType; - - public ExpressionWalker(Func, K> kind, Func, T> type) : - base(false) // disable Walkers by default - { - this.DefaultKind = base.Kind; - this._Kind = kind?.Invoke(this); - - this.DefaultType = new Core.ExpressionTypeWalker(false); // disabled by default - this._Type = type?.Invoke(this); - } - } - - /// - /// Given an expression kind Walker, Walk applies the given Walker to the Kind of every expression. - /// - public class ExpressionWalker : - ExpressionWalker - where K : Core.ExpressionKindWalker - { - public ExpressionWalker(Func, K> kind) : - base(kind, null) - { } - } - - /// - /// ExpressionWalker where expression kind Walkers are set to their default - - /// i.e. subexpressions are walked, but no Walker is done on the kind itself. - /// - public class DefaultExpressionWalker : - ExpressionWalker> - { - public DefaultExpressionWalker() : - base(e => new ExpressionKindWalker(e as DefaultExpressionWalker)) - { } - } - - /// - /// Disables all expression Walkers, and can be use as no-op if an ExpressionWalker is required as argument. - /// - public class NoExpressionWalkers : - ExpressionWalker - { - public NoExpressionWalkers() : - base(null) - { } - - public override void Walk(TypedExpression ex) { } - } -} -