diff --git a/src/QsCompiler/CompilationManager/CompilationUnit.cs b/src/QsCompiler/CompilationManager/CompilationUnit.cs index 31c3fbe7b5..40daa00a71 100644 --- a/src/QsCompiler/CompilationManager/CompilationUnit.cs +++ b/src/QsCompiler/CompilationManager/CompilationUnit.cs @@ -44,11 +44,16 @@ internal Headers(string source, ?? ImmutableArray<(SpecializationDeclarationHeader, SpecializationImplementation)>.Empty; } - internal Headers(NonNullable source, IEnumerable syntaxTree) : this ( + /// + /// Initializes a set of reference headers based on the given syntax tree loaded from the specified source. + /// The source is expected to be the path to the dll from which the syntax has been loaded. + /// Returns an empty set of headers if the given syntax tree is null. + /// + public Headers(NonNullable source, IEnumerable syntaxTree) : this ( source.Value, - syntaxTree.Callables().Where(c => c.SourceFile.Value.EndsWith(".qs")).Select(CallableDeclarationHeader.New), - syntaxTree.Specializations().Where(c => c.SourceFile.Value.EndsWith(".qs")).Select(s => (SpecializationDeclarationHeader.New(s), s.Implementation)), - syntaxTree.Types().Where(c => c.SourceFile.Value.EndsWith(".qs")).Select(TypeDeclarationHeader.New)) + syntaxTree?.Callables().Where(c => c.SourceFile.Value.EndsWith(".qs")).Select(CallableDeclarationHeader.New), + syntaxTree?.Specializations().Where(c => c.SourceFile.Value.EndsWith(".qs")).Select(s => (SpecializationDeclarationHeader.New(s), s.Implementation)), + syntaxTree?.Types().Where(c => c.SourceFile.Value.EndsWith(".qs")).Select(TypeDeclarationHeader.New)) { } internal Headers(NonNullable source, IEnumerable<(string, string)> attributes) : this( diff --git a/src/QsCompiler/CompilationManager/DiagnosticTools.cs b/src/QsCompiler/CompilationManager/DiagnosticTools.cs index 1ca1a8cb3b..3144a3a7ec 100644 --- a/src/QsCompiler/CompilationManager/DiagnosticTools.cs +++ b/src/QsCompiler/CompilationManager/DiagnosticTools.cs @@ -19,7 +19,7 @@ public static class DiagnosticTools /// Returns the line and character of the given position as tuple without verifying them. /// Throws an ArgumentNullException if the given position is null. /// - internal static Tuple AsTuple(Position position) => + public static Tuple AsTuple(Position position) => position != null ? new Tuple(position.Line, position.Character) : throw new ArgumentNullException(nameof(position)); @@ -28,7 +28,7 @@ internal static Tuple AsTuple(Position position) => /// Returns a Position with the line and character given as tuple (inverse function for AsTuple). /// Throws an ArgumentNullException if the given tuple is null. /// - internal static Position AsPosition(Tuple position) => + public static Position AsPosition(Tuple position) => position != null ? new Position(position.Item1, position.Item2) : throw new ArgumentNullException(nameof(position)); diff --git a/src/QsCompiler/Compiler/PluginInterface.cs b/src/QsCompiler/Compiler/PluginInterface.cs index 41e256fb42..1b6bf81430 100644 --- a/src/QsCompiler/Compiler/PluginInterface.cs +++ b/src/QsCompiler/Compiler/PluginInterface.cs @@ -4,7 +4,9 @@ using System; using System.Collections.Generic; using Microsoft.CodeAnalysis; +using Microsoft.Quantum.QsCompiler.CompilationBuilder; using Microsoft.Quantum.QsCompiler.SyntaxTree; +using VS = Microsoft.VisualStudio.LanguageServer.Protocol; namespace Microsoft.Quantum.QsCompiler @@ -82,6 +84,27 @@ public struct Diagnostic /// The position is null if the diagnostic is not caused by a piece of source code. /// public Tuple End { get; set; } + + /// + /// Initializes a new diagnostic. + /// If a diagnostic generated by the Q# compiler is given as argument, the values are initialized accordingly. + /// + public static Diagnostic Create(VS.Diagnostic d = null, Stage stage = Stage.Unknown) => + d == null ? new Diagnostic() : new Diagnostic + { + Severity = d.Severity switch + { + VS.DiagnosticSeverity.Error => DiagnosticSeverity.Error, + VS.DiagnosticSeverity.Warning => DiagnosticSeverity.Warning, + VS.DiagnosticSeverity.Information => DiagnosticSeverity.Info, + _ => DiagnosticSeverity.Hidden + }, + Message = d.Message, + Source = d.Source, + Stage = stage, + Start = d.Range?.Start == null ? null : DiagnosticTools.AsTuple(d.Range.Start), + End = d.Range?.End == null ? null : DiagnosticTools.AsTuple(d.Range.End) + }; } /// diff --git a/src/QsCompiler/Core/SymbolResolution.fs b/src/QsCompiler/Core/SymbolResolution.fs index 5b93329d5f..dd31d39e86 100644 --- a/src/QsCompiler/Core/SymbolResolution.fs +++ b/src/QsCompiler/Core/SymbolResolution.fs @@ -433,6 +433,7 @@ module SymbolResolution = /// or if the resolved argument type does not match the expected argument type. /// The TypeId in the resolved attribute is set to Null if the unresolved Id is not a valid identifier /// or if the correct attribute cannot be determined, and is set to the corresponding type identifier otherwise. + /// Throws an ArgumentException if a tuple-valued attribute argument does not contain at least one item. let internal ResolveAttribute getAttribute (attribute : AttributeAnnotation) = let asTypedExression range (exKind, exType) = { Expression = exKind @@ -460,7 +461,9 @@ module SymbolResolution = | StringLiteral (s, exs) -> if exs.Length <> 0 then invalidExpr ex.Range, [| ex.Range |> diagnostic ErrorCode.InterpolatedStringInAttribute |] else (StringLiteral (s, ImmutableArray.Empty), String) |> asTypedExression ex.Range, [||] + | ValueTuple vs when vs.Length = 1 -> ArgExression (vs.First()) | ValueTuple vs -> + if vs.Length = 0 then ArgumentException "tuple valued attribute argument requires at least one tuple item" |> raise let innerExs, errs = aggregateInner vs let types = (innerExs |> Seq.map (fun ex -> ex.ResolvedType)).ToImmutableArray() (ValueTuple innerExs, TupleType types) |> asTypedExression ex.Range, errs diff --git a/src/QsCompiler/Core/SyntaxGenerator.fs b/src/QsCompiler/Core/SyntaxGenerator.fs index 72da5c2dad..c33e86a772 100644 --- a/src/QsCompiler/Core/SyntaxGenerator.fs +++ b/src/QsCompiler/Core/SyntaxGenerator.fs @@ -63,8 +63,13 @@ module SyntaxGenerator = /// setting the quantum dependency to the given value and assuming no type parameter resolutions. /// Sets the range information for the built expression to Null. let private AutoGeneratedExpression kind exTypeKind qDep = - let noInferredInfo = InferredExpressionInformation.New (false, quantumDep = qDep) - TypedExpression.New (kind, ImmutableDictionary.Empty, exTypeKind |> ResolvedType.New, noInferredInfo, QsRangeInfo.Null) + let inferredInfo = InferredExpressionInformation.New (false, quantumDep = qDep) + TypedExpression.New (kind, ImmutableDictionary.Empty, exTypeKind |> ResolvedType.New, inferredInfo, QsRangeInfo.Null) + + /// Creates a typed expression that represents an invalid expression of invalid type. + /// Sets the range information for the built expression to Null. + let InvalidExpression = + AutoGeneratedExpression InvalidExpr QsTypeKind.InvalidType false /// Creates a typed expression that corresponds to a Unit value. /// Sets the range information for the built expression to Null. @@ -97,6 +102,14 @@ module SyntaxGenerator = let RangeLiteral (lhs, rhs) = AutoGeneratedExpression (RangeLiteral (lhs, rhs)) QsTypeKind.Range false + /// Creates a typed expression that corresponds to a value tuple with the given items. + /// Sets the range information for the built expression to Null. + /// Does *not* strip positional information from the given items; the responsibility to do so is with the caller. + let TupleLiteral (items : TypedExpression seq) = + let qdep = items |> Seq.exists (fun item -> item.InferredInformation.HasLocalQuantumDependency) + let tupleType = items |> Seq.map (fun item -> item.ResolvedType) |> ImmutableArray.CreateRange |> TupleType + AutoGeneratedExpression (ValueTuple (items.ToImmutableArray())) tupleType qdep + // utils to for building typed expressions and iterable inversions @@ -250,12 +263,7 @@ module SyntaxGenerator = | _ -> true if ctlQs.ResolvedType.Resolution <> QubitArray.Resolution && not (ctlQs.ResolvedType |> isInvalid) then new ArgumentException "expression for the control qubits is valid but not of type Qubit[]" |> raise - let buildControlledArgument orig = - let kind = QsExpressionKind.ValueTuple ([ctlQs; orig].ToImmutableArray()) - let quantumDep = orig.InferredInformation.HasLocalQuantumDependency || ctlQs.InferredInformation.HasLocalQuantumDependency - let exInfo = InferredExpressionInformation.New (isMutable = false, quantumDep = quantumDep) - TypedExpression.New (kind, orig.TypeParameterResolutions, AddControlQubits orig.ResolvedType, exInfo, QsRangeInfo.Null) - buildControlledArgument arg + TupleLiteral [ctlQs; arg] /// Returns the name of the control qubits /// if the given argument tuple is consistent with the argument tuple of a controlled specialization. diff --git a/src/QsCompiler/DataStructures/SyntaxTree.fs b/src/QsCompiler/DataStructures/SyntaxTree.fs index 3d6e8b8fa1..b528867f4d 100644 --- a/src/QsCompiler/DataStructures/SyntaxTree.fs +++ b/src/QsCompiler/DataStructures/SyntaxTree.fs @@ -660,6 +660,7 @@ type QsSpecialization = { } with member this.AddAttribute att = {this with Attributes = this.Attributes.Add att} + member this.AddAttributes (att : _ seq) = {this with Attributes = this.Attributes.AddRange att} member this.WithImplementation impl = {this with Implementation = impl} member this.WithParent (getName : Func<_,_>) = {this with Parent = getName.Invoke(this.Parent)} @@ -699,6 +700,7 @@ type QsCallable = { } with member this.AddAttribute att = {this with Attributes = this.Attributes.Add att} + member this.AddAttributes (att : _ seq) = {this with Attributes = this.Attributes.AddRange att} member this.WithSpecializations (getSpecs : Func<_,_>) = {this with Specializations = getSpecs.Invoke(this.Specializations)} member this.WithFullName (getName : Func<_,_>) = {this with FullName = getName.Invoke(this.FullName)} @@ -740,6 +742,7 @@ type QsCustomType = { } with member this.AddAttribute att = {this with Attributes = this.Attributes.Add att} + member this.AddAttributes (att : _ seq) = {this with Attributes = this.Attributes.AddRange att} member this.WithFullName (getName : Func<_,_>) = {this with FullName = getName.Invoke(this.FullName)} diff --git a/src/QsCompiler/Tests.Compiler/TestCases/AttributeGeneration.qs b/src/QsCompiler/Tests.Compiler/TestCases/AttributeGeneration.qs new file mode 100644 index 0000000000..f6d3fc480e --- /dev/null +++ b/src/QsCompiler/Tests.Compiler/TestCases/AttributeGeneration.qs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Quantum.Testing.AttributeGeneration { + + open Microsoft.Quantum.Arrays; + + function DefaultArray<'A>(size : Int) : 'A[] { + mutable arr = new 'A[size]; + for (i in IndexRange(arr)) { + set arr w/= i <- Default<'A>(); + } + return arr; + } + + operation CallDefaultArray<'A>(size : Int) : 'A[] { + return DefaultArray(size); + } + +} + diff --git a/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj b/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj index 94e620420f..d5f9528eff 100644 --- a/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj +++ b/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj @@ -105,6 +105,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/src/QsCompiler/Tests.Compiler/TransformationTests.fs b/src/QsCompiler/Tests.Compiler/TransformationTests.fs index 2a376638b5..237589a610 100644 --- a/src/QsCompiler/Tests.Compiler/TransformationTests.fs +++ b/src/QsCompiler/Tests.Compiler/TransformationTests.fs @@ -11,6 +11,7 @@ open Microsoft.Quantum.QsCompiler open Microsoft.Quantum.QsCompiler.CompilationBuilder open Microsoft.Quantum.QsCompiler.DataTypes open Microsoft.Quantum.QsCompiler.SyntaxTree +open Microsoft.Quantum.QsCompiler.Transformations open Microsoft.Quantum.QsCompiler.Transformations.Core open Microsoft.Quantum.QsCompiler.Transformations.QsCodeOutput open Xunit @@ -18,6 +19,36 @@ open Xunit // utils for testing syntax tree transformations and the corresponding infrastructure +type private GlobalDeclarations(parent : CheckDeclarations) = + inherit NamespaceTransformation(parent, TransformationOptions.NoRebuild) + + override this.OnCallableDeclaration c = + parent.CheckCallableDeclaration c + + override this.OnTypeDeclaration t = + parent.CheckTypeDeclaration t + + override this.OnSpecializationDeclaration s = + parent.CheckSpecializationDeclaration s + +and private CheckDeclarations private (_internal_, onTypeDecl, onCallableDecl, onSpecDecl) = + inherit SyntaxTreeTransformation() + + member internal this.CheckTypeDeclaration = onTypeDecl + member internal this.CheckCallableDeclaration = onCallableDecl + member internal this.CheckSpecializationDeclaration = onSpecDecl + + new (?onTypeDecl, ?onCallableDecl, ?onSpecDecl) as this = + let onTypeDecl = defaultArg onTypeDecl id + let onCallableDecl = defaultArg onCallableDecl id + let onSpecDecl = defaultArg onSpecDecl id + CheckDeclarations("_internal_", onTypeDecl, onCallableDecl, onSpecDecl) then + this.Types <- new TypeTransformation(this, TransformationOptions.Disabled) + this.Expressions <- new ExpressionTransformation(this, TransformationOptions.Disabled) + this.Statements <- new StatementTransformation(this, TransformationOptions.Disabled) + this.Namespaces <- new GlobalDeclarations(this) + + type private Counter () = member val callsCount = 0 with get, set member val opsCount = 0 with get, set @@ -78,16 +109,16 @@ let private buildSyntaxTree code = compilationUnit.AddOrUpdateSourceFileAsync file |> ignore // spawns a task that modifies the current compilation let mutable syntaxTree = compilationUnit.Build().BuiltCompilation // will wait for any current tasks to finish CodeGeneration.GenerateFunctorSpecializations(syntaxTree, &syntaxTree) |> ignore - syntaxTree.Namespaces + syntaxTree //////////////////////////////// tests ////////////////////////////////// [] let ``basic walk`` () = - let tree = Path.Combine(Path.GetFullPath ".", "TestCases", "Transformation.qs") |> File.ReadAllText |> buildSyntaxTree + let compilation = Path.Combine(Path.GetFullPath ".", "TestCases", "Transformation.qs") |> File.ReadAllText |> buildSyntaxTree let walker = new SyntaxCounter(TransformationOptions.NoRebuild) - tree |> Seq.iter (walker.Namespaces.OnNamespace >> ignore) + compilation.Namespaces |> Seq.iter (walker.Namespaces.OnNamespace >> ignore) Assert.Equal (4, walker.Counter.udtCount) Assert.Equal (1, walker.Counter.funCount) @@ -96,11 +127,12 @@ let ``basic walk`` () = Assert.Equal (6, walker.Counter.ifsCount) Assert.Equal (20, walker.Counter.callsCount) + [] let ``basic transformation`` () = - let tree = Path.Combine(Path.GetFullPath ".", "TestCases", "Transformation.qs") |> File.ReadAllText |> buildSyntaxTree + let compilation = Path.Combine(Path.GetFullPath ".", "TestCases", "Transformation.qs") |> File.ReadAllText |> buildSyntaxTree let walker = new SyntaxCounter() - tree |> Seq.iter (walker.Namespaces.OnNamespace >> ignore) + compilation.Namespaces |> Seq.iter (walker.Namespaces.OnNamespace >> ignore) Assert.Equal (4, walker.Counter.udtCount) Assert.Equal (1, walker.Counter.funCount) @@ -109,14 +141,60 @@ let ``basic transformation`` () = Assert.Equal (6, walker.Counter.ifsCount) Assert.Equal (20, walker.Counter.callsCount) + +[] +let ``attaching attributes to callables`` () = + let WithinNamespace nsName (c : QsNamespaceElement) = c.GetFullName().Namespace.Value = nsName + let attGenNs = "Microsoft.Quantum.Testing.AttributeGeneration" + let predicate = QsCallable >> WithinNamespace attGenNs + let sources = [ + Path.Combine(Path.GetFullPath ".", "TestCases", "LinkingTests", "Core.qs") + Path.Combine(Path.GetFullPath ".", "TestCases", "AttributeGeneration.qs") + ] + let compilation = sources |> Seq.map File.ReadAllText |> String.Concat |> buildSyntaxTree + let testAttribute = AttributeUtils.BuildAttribute(BuiltIn.Test.FullName, AttributeUtils.StringArgument "QuantumSimulator") + + let checkSpec (spec : QsSpecialization) = Assert.Empty spec.Attributes; spec + let checkType (customType : QsCustomType) = + if customType |> QsCustomType |> WithinNamespace attGenNs then Assert.Empty customType.Attributes; + customType + let checkCallable limitedToNs nrAtts (callable : QsCallable) = + if limitedToNs = null || callable |> QsCallable |> WithinNamespace limitedToNs then + Assert.Equal(nrAtts, callable.Attributes.Length) + for att in callable.Attributes do + Assert.Equal(testAttribute, att) + else Assert.Empty callable.Attributes + callable + + let transformed = AttributeUtils.AddToCallables(compilation, testAttribute, predicate) + let checker = new CheckDeclarations(checkType, checkCallable attGenNs 1, checkSpec) + checker.Apply transformed |> ignore + + let transformed = AttributeUtils.AddToCallables(compilation, testAttribute, null) + let checker = new CheckDeclarations(checkType, checkCallable null 1, checkSpec) + checker.Apply transformed |> ignore + + let transformed = AttributeUtils.AddToCallables(compilation, testAttribute) + let checker = new CheckDeclarations(checkType, checkCallable null 1, checkSpec) + checker.Apply transformed |> ignore + + let transformed = AttributeUtils.AddToCallables(compilation, struct (testAttribute, new Func<_,_>(predicate)), struct(testAttribute, new Func<_,_>(predicate))) + let checker = new CheckDeclarations(checkType, checkCallable attGenNs 2, checkSpec) + checker.Apply transformed |> ignore + + let transformed = AttributeUtils.AddToCallables(compilation, testAttribute, testAttribute) + let checker = new CheckDeclarations(checkType, checkCallable null 2, checkSpec) + checker.Apply transformed |> ignore + + [] let ``generation of open statements`` () = - let tree = buildSyntaxTree @" + let compilation = buildSyntaxTree @" namespace Microsoft.Quantum.Testing { operation emptyOperation () : Unit {} }" - let ns = tree |> Seq.head + let ns = compilation.Namespaces |> Seq.head let source = ns.Elements.Single() |> function | QsCallable callable -> callable.SourceFile | QsCustomType t -> t.SourceFile @@ -133,7 +211,7 @@ let ``generation of open statements`` () = let imports = ImmutableDictionary.Empty.Add(ns.Name, openDirectives) let codeOutput = ref null - SyntaxTreeToQsharp.Apply (codeOutput, tree, struct (source, imports)) |> Assert.True + SyntaxTreeToQsharp.Apply (codeOutput, compilation.Namespaces, struct (source, imports)) |> Assert.True let lines = Utils.SplitLines (codeOutput.Value.Single().[ns.Name]) Assert.Equal(13, lines.Count()) diff --git a/src/QsCompiler/TextProcessor/QsFragmentParsing.fs b/src/QsCompiler/TextProcessor/QsFragmentParsing.fs index c4241ab761..05de2fe1eb 100644 --- a/src/QsCompiler/TextProcessor/QsFragmentParsing.fs +++ b/src/QsCompiler/TextProcessor/QsFragmentParsing.fs @@ -205,7 +205,7 @@ let private namespaceDeclaration = buildFragment namespaceDeclHeader.parse (expectedNamespaceName eof) invalid (fun _ -> NamespaceDeclaration) eof /// Uses buildFragment to parse a Q# DeclarationAttribute as QsFragment. -let private attributeDeclaration = +let private attributeAnnotation = let invalid = DeclarationAttribute (invalidSymbol, unknownExpr) let attributeId = multiSegmentSymbol ErrorCode.InvalidIdentifierName |>> asQualifiedSymbol let expectedArgs = @@ -500,6 +500,6 @@ let private expressionStatement = let internal codeFragment = let validFragment = choice (fragments |> List.map snd) - <|> attributeDeclaration + <|> attributeAnnotation <|> expressionStatement // the expressionStatement needs to be last attempt validFragment <|> buildInvalidFragment (preturn ()) diff --git a/src/QsCompiler/Transformations/Attributes.cs b/src/QsCompiler/Transformations/Attributes.cs new file mode 100644 index 0000000000..9f907afb9a --- /dev/null +++ b/src/QsCompiler/Transformations/Attributes.cs @@ -0,0 +1,146 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.Quantum.QsCompiler.DataTypes; +using Microsoft.Quantum.QsCompiler.SyntaxTree; + + +namespace Microsoft.Quantum.QsCompiler.Transformations +{ + using CallablePredicate = Func; + using AttributeId = QsNullable; + using QsRangeInfo = QsNullable>; + + + /// + /// Contains tools for building and adding attributes to an existing Q# compilation. + /// + public static class AttributeUtils + { + private static AttributeId BuildId(QsQualifiedName name) => + name != null + ? AttributeId.NewValue(new UserDefinedType(name.Namespace, name.Name, QsRangeInfo.Null)) + : AttributeId.Null; + + // public static methods + + /// + /// Returns a Q# attribute with the given name and argument that can be attached to a declaration. + /// The attribute id is set to Null if the given name is null. + /// The attribute argument is set to an invalid expression if the given argument is null. + /// + public static QsDeclarationAttribute BuildAttribute(QsQualifiedName name, TypedExpression arg) => + new QsDeclarationAttribute(BuildId(name), arg ?? SyntaxGenerator.InvalidExpression, null, QsComments.Empty); + + /// + /// Builds a string literal with the given content that can be used as argument to a Q# attribute. + /// The value of the string literal is set to the empty string if the given content is null. + /// + public static TypedExpression StringArgument(string content) => + SyntaxGenerator.StringLiteral(NonNullable.New(content ?? ""), ImmutableArray.Empty); + + /// + /// Builds an attribute argument with the given string valued tuple items. + /// If a given string is null, the value of the corresponding item is set to the empty string. + /// If no items are given, a suitable argument of type unit is returned. + /// + public static TypedExpression StringArguments(params string[] items) => + items == null || items.Length == 0 ? SyntaxGenerator.UnitValue : + items.Length == 1 ? StringArgument(items.Single()) : + SyntaxGenerator.TupleLiteral(items.Select(StringArgument)); + + /// + /// Adds the given attribute to all callables in the given compilation that satisfy the given predicate + /// - if the predicate is specified and not null. + /// Throws an ArgumentNullException if the given attribute or compilation is null. + /// + public static QsCompilation AddToCallables(QsCompilation compilation, QsDeclarationAttribute attribute, CallablePredicate predicate = null) => + new AddAttributes(new[] { (attribute, predicate) }).Apply(compilation); + + /// + /// Adds the given attribute(s) to all callables in the given compilation that satisfy the given predicate + /// - if the predicate is specified and not null. + /// Throws an ArgumentNullException if one of the given attributes or the compilation is null. + /// + public static QsCompilation AddToCallables(QsCompilation compilation, params (QsDeclarationAttribute, CallablePredicate)[] attributes) => + new AddAttributes(attributes).Apply(compilation); + + /// + /// Adds the given attribute(s) to all callables in the given compilation. + /// Throws an ArgumentNullException if one of the given attributes or the compilation is null. + /// + public static QsCompilation AddToCallables(QsCompilation compilation, params QsDeclarationAttribute[] attributes) => + new AddAttributes(attributes.Select(att => (att, (CallablePredicate)null))).Apply(compilation); + + /// + /// Adds the given attribute to all callables in the given namespace that satisfy the given predicate + /// - if the predicate is specified and not null. + /// Throws an ArgumentNullException if the given attribute or namespace is null. + /// + public static QsNamespace AddToCallables(QsNamespace ns, QsDeclarationAttribute attribute, CallablePredicate predicate = null) => + new AddAttributes(new[] { (attribute, predicate) }).Namespaces.OnNamespace(ns); + + /// + /// Adds the given attribute(s) to all callables in the given namespace that satisfy the given predicate + /// - if the predicate is specified and not null. + /// Throws an ArgumentNullException if one of the given attributes or the namespace is null. + /// + public static QsNamespace AddToCallables(QsNamespace ns, params (QsDeclarationAttribute, CallablePredicate)[] attributes) => + new AddAttributes(attributes).Namespaces.OnNamespace(ns); + + /// + /// Adds the given attribute(s) to all callables in the given namespace. + /// Throws an ArgumentNullException if one of the given attributes or the namespace is null. + /// + public static QsNamespace AddToCallables(QsNamespace ns, params QsDeclarationAttribute[] attributes) => + new AddAttributes(attributes.Select(att => (att, (CallablePredicate)null))).Namespaces.OnNamespace(ns); + + + // private transformation class(es) + + /// + /// Transformation to add attributes to an existing compilation. + /// + private class AddAttributes + : Core.SyntaxTreeTransformation + { + internal class TransformationState + { + internal readonly ImmutableArray<(QsDeclarationAttribute, Func)> AttributeSelection; + + /// Thrown when the given selection is null. + internal TransformationState(IEnumerable<(QsDeclarationAttribute, Func)> selections) => + this.AttributeSelection = selections?.ToImmutableArray() ?? throw new ArgumentNullException(nameof(selections)); + } + + internal AddAttributes(IEnumerable<(QsDeclarationAttribute, CallablePredicate)> attributes) + : base(new TransformationState(attributes?.Select(entry => (entry.Item1, entry.Item2 ?? (_ => true))))) + { + if (attributes == null || attributes.Any(entry => entry.Item1 == null)) throw new ArgumentNullException(nameof(attributes)); + this.Namespaces = new NamespaceTransformation(this); + this.Statements = new Core.StatementTransformation(this, Core.TransformationOptions.Disabled); + this.StatementKinds = new Core.StatementKindTransformation(this, Core.TransformationOptions.Disabled); + this.Expressions = new Core.ExpressionTransformation(this, Core.TransformationOptions.Disabled); + this.ExpressionKinds = new Core.ExpressionKindTransformation(this, Core.TransformationOptions.Disabled); + this.Types = new Core.TypeTransformation(this, Core.TransformationOptions.Disabled); + } + + + // helper classes + + private class NamespaceTransformation + : Core.NamespaceTransformation + { + public NamespaceTransformation(AddAttributes parent) + : base(parent) + { } + + public override QsCallable OnCallableDeclaration(QsCallable c) => + c.AddAttributes(SharedState.AttributeSelection + .Where(entry => entry.Item2(c)) + .Select(entry => entry.Item1)); + } + } + } +}