diff --git a/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs b/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs index 1c4b387e16..7802e9ee4c 100644 --- a/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs +++ b/src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs @@ -14,6 +14,8 @@ namespace Microsoft.Quantum.QsCompiler.BuiltInRewriteSteps /// internal class Monomorphization : IRewriteStep { + private readonly bool keepAllIntrinsics; + public string Name => "Monomorphization"; public int Priority => RewriteStepPriorities.TypeParameterElimination; @@ -28,8 +30,14 @@ internal class Monomorphization : IRewriteStep public bool ImplementsPostconditionVerification => true; - public Monomorphization() + + /// + /// Constructor for the Monomorphization Rewrite Step. + /// + /// When true, intrinsics will not be removed as part of the rewrite step. + public Monomorphization(bool keepAllIntrinsics = true) { + this.keepAllIntrinsics = keepAllIntrinsics; this.AssemblyConstants = new Dictionary(); } @@ -37,7 +45,7 @@ public Monomorphization() public bool Transformation(QsCompilation compilation, out QsCompilation transformed) { - transformed = Monomorphize.Apply(compilation); + transformed = Monomorphize.Apply(compilation, this.keepAllIntrinsics); return true; } diff --git a/src/QsCompiler/Tests.Compiler/CallGraphTests.fs b/src/QsCompiler/Tests.Compiler/CallGraphTests.fs index 64a6ae96a6..dd260aa419 100644 --- a/src/QsCompiler/Tests.Compiler/CallGraphTests.fs +++ b/src/QsCompiler/Tests.Compiler/CallGraphTests.fs @@ -324,7 +324,7 @@ type CallGraphTests(output: ITestOutputHelper) = [] [] - member this.``Concrete Graph Trims Specializations``() = + member this.``Concrete Graph Contains All Specializations``() = let graph = PopulateCallGraphWithExe 10 |> ConcreteCallGraph let makeNode name spec = MakeNode name spec [] @@ -343,8 +343,7 @@ type CallGraphTests(output: ITestOutputHelper) = AssertInConcreteGraph graph BarAdj AssertInConcreteGraph graph BarCtl AssertInConcreteGraph graph BarCtlAdj - - AssertNotInConcreteGraph graph Unused + AssertInConcreteGraph graph Unused [] [] diff --git a/src/QsCompiler/Tests.Compiler/TestCases/PopulateCallGraph.qs b/src/QsCompiler/Tests.Compiler/TestCases/PopulateCallGraph.qs index e492499fb8..4dd3a55203 100644 --- a/src/QsCompiler/Tests.Compiler/TestCases/PopulateCallGraph.qs +++ b/src/QsCompiler/Tests.Compiler/TestCases/PopulateCallGraph.qs @@ -152,7 +152,7 @@ namespace Microsoft.Quantum.Testing.PopulateCallGraph { // ================================= -// Concrete Graph Trims Specializations +// Concrete Graph Contains All Specializations namespace Microsoft.Quantum.Testing.PopulateCallGraph { @ EntryPoint() diff --git a/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs b/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs index b77788381f..876e2cd53a 100644 --- a/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs +++ b/src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs @@ -167,10 +167,18 @@ let public MonomorphizationSignatures = GenericsNs, "GenericCallsSpecializations", [| "Double"; "String"; "Double" |], "Unit" GenericsNs, "GenericCallsSpecializations", [| "String"; "Int"; "Unit" |], "Unit" + GenericsNs, "BasicGeneric", [| "Double"; "String" |], "Unit" GenericsNs, "BasicGeneric", [| "String"; "Qubit[]" |], "Unit" GenericsNs, "BasicGeneric", [| "String"; "Int" |], "Unit" + GenericsNs, "BasicGeneric", [| "Qubit[]"; "Qubit[]" |], "Unit" + GenericsNs, "BasicGeneric", [| "Qubit[]"; "Double" |], "Unit" + GenericsNs, "BasicGeneric", [| "Qubit[]"; "Unit" |], "Unit" + GenericsNs, "BasicGeneric", [| "String"; "Double" |], "Unit" + GenericsNs, "BasicGeneric", [| "Int"; "Unit" |], "Unit" GenericsNs, "ArrayGeneric", [| "Qubit"; "Double" |], "Int" + GenericsNs, "ArrayGeneric", [| "Qubit"; "Qubit[]" |], "Int" + GenericsNs, "ArrayGeneric", [| "Qubit"; "Unit" |], "Int" |]) // Test Case 4 (_DefaultTypes, diff --git a/src/QsCompiler/Transformations/CallGraph/ConcreteCallGraphWalker.cs b/src/QsCompiler/Transformations/CallGraph/ConcreteCallGraphWalker.cs index c7f8727031..b0df74b2ec 100644 --- a/src/QsCompiler/Transformations/CallGraph/ConcreteCallGraphWalker.cs +++ b/src/QsCompiler/Transformations/CallGraph/ConcreteCallGraphWalker.cs @@ -39,8 +39,11 @@ private static class ConcreteCallGraphWalker /// public static void PopulateConcreteGraph(ConcreteGraphBuilder graph, QsCompilation compilation) { + var globals = compilation.Namespaces.GlobalCallableResolutions(); var walker = new BuildGraph(graph); - var entryPointNodes = compilation.EntryPoints.Select(name => new ConcreteCallGraphNode(name, QsSpecializationKind.QsBody, TypeParameterResolutions.Empty)); + var entryPointNodes = compilation.EntryPoints.SelectMany(name => + GetSpecializationKinds(globals, name).Select(kind => + new ConcreteCallGraphNode(name, kind, TypeParameterResolutions.Empty))); foreach (var entryPoint in entryPointNodes) { // Make sure all the entry points are added to the graph @@ -48,7 +51,6 @@ public static void PopulateConcreteGraph(ConcreteGraphBuilder graph, QsCompilati walker.SharedState.RequestStack.Push(entryPoint); } - var globals = compilation.Namespaces.GlobalCallableResolutions(); while (walker.SharedState.RequestStack.TryPop(out var currentRequest)) { // If there is a call to an unknown callable, throw exception @@ -253,12 +255,21 @@ private void AddEdge(QsQualifiedName identifier, QsSpecializationKind kind, Type throw new ArgumentException("AddEdge requires CurrentNode to be non-null."); } + // Add an edge to the specific specialization kind referenced var called = new ConcreteCallGraphNode(identifier, kind, typeRes); var edge = new ConcreteCallGraphEdge(referenceRange); this.Graph.AddDependency(this.CurrentNode, called, edge); - if (!this.RequestStack.Contains(called) && !this.ResolvedNodeSet.Contains(called)) + + // Add all the specializations of the referenced callable to the graph + var newNodes = this.GetSpecializationKinds(identifier) + .Select(specKind => new ConcreteCallGraphNode(identifier, specKind, typeRes)); + foreach (var node in newNodes) { - this.RequestStack.Push(called); + if (!this.RequestStack.Contains(node) && !this.ResolvedNodeSet.Contains(node)) + { + this.Graph.AddNode(node); + this.RequestStack.Push(node); + } } } } diff --git a/src/QsCompiler/Transformations/Monomorphization.cs b/src/QsCompiler/Transformations/Monomorphization.cs index 662086395e..df9b5284ab 100644 --- a/src/QsCompiler/Transformations/Monomorphization.cs +++ b/src/QsCompiler/Transformations/Monomorphization.cs @@ -25,16 +25,18 @@ namespace Microsoft.Quantum.QsCompiler.Transformations.Monomorphization /// are found from uses of the callables. /// This transformation also removes all callables that are not used directly or /// indirectly from any of the marked entry point. - /// Intrinsic callables are not monomorphized or removed from the compilation. + /// Intrinsic callables are, by default, not monomorphized or removed from the compilation, but + /// may optionally be removed if unused if the keepAllIntrinsics parameter is set to false. /// There are also some built-in callables that are also exempt from /// being removed from non-use, as they are needed for later rewrite steps. /// public static class Monomorphize { /// - /// Performs Monomorphization on the given compilation. + /// Performs Monomorphization on the given compilation. If the keepAllIntrinsics parameter + /// is set to true, then unused intrinsics will not be removed from the resulting compilation. /// - public static QsCompilation Apply(QsCompilation compilation) + public static QsCompilation Apply(QsCompilation compilation, bool keepAllIntrinsics = true) { var globals = compilation.Namespaces.GlobalCallableResolutions(); var concretizations = new List(); @@ -66,7 +68,9 @@ public static QsCompilation Apply(QsCompilation compilation) // Generate the concrete version of the callable var concrete = ReplaceTypeParamImplementations.Apply(originalGlobal, node.ParamResolutions, getAccessModifiers); - concretizations.Add(concrete.WithFullName(oldName => concreteName)); + concretizations.Add( + concrete.WithFullName(oldName => concreteName) + .WithSpecializations(specs => specs.Select(spec => spec.WithParent(_ => concreteName)).ToImmutableArray())); } else { @@ -89,16 +93,16 @@ public static QsCompilation Apply(QsCompilation compilation) final.Add(ReplaceTypeParamCalls.Apply(callable, getConcreteIdentifier, intrinsicCallableSet)); } - return ResolveGenerics.Apply(compilation, final, intrinsicCallableSet); + return ResolveGenerics.Apply(compilation, final, intrinsicCallableSet, keepAllIntrinsics); } #region ResolveGenerics private class ResolveGenerics : SyntaxTreeTransformation { - public static QsCompilation Apply(QsCompilation compilation, List callables, ImmutableHashSet intrinsicCallableSet) + public static QsCompilation Apply(QsCompilation compilation, List callables, ImmutableHashSet intrinsicCallableSet, bool keepAllIntrinsics) { - var filter = new ResolveGenerics(callables.ToLookup(res => res.FullName.Namespace), intrinsicCallableSet); + var filter = new ResolveGenerics(callables.ToLookup(res => res.FullName.Namespace), intrinsicCallableSet, keepAllIntrinsics); return filter.OnCompilation(compilation); } @@ -107,11 +111,13 @@ public class TransformationState { public readonly ILookup NamespaceCallables; public readonly ImmutableHashSet IntrinsicCallableSet; + public readonly bool KeepAllIntrinsics; - public TransformationState(ILookup namespaceCallables, ImmutableHashSet intrinsicCallableSet) + public TransformationState(ILookup namespaceCallables, ImmutableHashSet intrinsicCallableSet, bool keepAllIntrinsics) { this.NamespaceCallables = namespaceCallables; this.IntrinsicCallableSet = intrinsicCallableSet; + this.KeepAllIntrinsics = keepAllIntrinsics; } } @@ -119,8 +125,8 @@ public TransformationState(ILookup namespaceCallables, Immut /// Constructor for the ResolveGenericsSyntax class. Its transform function replaces global callables in the namespace. /// /// Maps namespace names to an enumerable of all global callables in that namespace. - private ResolveGenerics(ILookup namespaceCallables, ImmutableHashSet intrinsicCallableSet) - : base(new TransformationState(namespaceCallables, intrinsicCallableSet)) + private ResolveGenerics(ILookup namespaceCallables, ImmutableHashSet intrinsicCallableSet, bool keepAllIntrinsics) + : base(new TransformationState(namespaceCallables, intrinsicCallableSet, keepAllIntrinsics)) { this.Namespaces = new NamespaceTransformation(this); this.Statements = new StatementTransformation(this, TransformationOptions.Disabled); @@ -138,7 +144,8 @@ private bool NamespaceElementFilter(QsNamespaceElement elem) { if (elem is QsNamespaceElement.QsCallable call) { - return BuiltIn.RewriteStepDependencies.Contains(call.Item.FullName) || this.SharedState.IntrinsicCallableSet.Contains(call.Item.FullName); + return BuiltIn.RewriteStepDependencies.Contains(call.Item.FullName) || + (this.SharedState.KeepAllIntrinsics && this.SharedState.IntrinsicCallableSet.Contains(call.Item.FullName)); } else {