Skip to content
This repository was archived by the owner on Jan 12, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/QsCompiler/Compiler/RewriteSteps/Monomorphization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.Quantum.QsCompiler.BuiltInRewriteSteps
/// </summary>
internal class Monomorphization : IRewriteStep
{
private readonly bool keepAllIntrinsics;

public string Name => "Monomorphization";

public int Priority => RewriteStepPriorities.TypeParameterElimination;
Expand All @@ -28,16 +30,22 @@ internal class Monomorphization : IRewriteStep

public bool ImplementsPostconditionVerification => true;

public Monomorphization()

/// <summary>
/// Constructor for the Monomorphization Rewrite Step.
/// </summary>
/// <param name="keepAllIntrinsics">When true, intrinsics will not be removed as part of the rewrite step.</param>
public Monomorphization(bool keepAllIntrinsics = true)
{
this.keepAllIntrinsics = keepAllIntrinsics;
this.AssemblyConstants = new Dictionary<string, string?>();
}

public bool PreconditionVerification(QsCompilation compilation) => compilation.EntryPoints.Any();

public bool Transformation(QsCompilation compilation, out QsCompilation transformed)
{
transformed = Monomorphize.Apply(compilation);
transformed = Monomorphize.Apply(compilation, this.keepAllIntrinsics);
return true;
}

Expand Down
5 changes: 2 additions & 3 deletions src/QsCompiler/Tests.Compiler/CallGraphTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ type CallGraphTests(output: ITestOutputHelper) =

[<Fact>]
[<Trait("Category", "Populate Call Graph")>]
member this.``Concrete Graph Trims Specializations``() =
member this.``Concrete Graph Contains All Specializations``() =
let graph = PopulateCallGraphWithExe 10 |> ConcreteCallGraph

let makeNode name spec = MakeNode name spec []
Expand All @@ -343,8 +343,7 @@ type CallGraphTests(output: ITestOutputHelper) =
AssertInConcreteGraph graph BarAdj
AssertInConcreteGraph graph BarCtl
AssertInConcreteGraph graph BarCtlAdj

AssertNotInConcreteGraph graph Unused
AssertInConcreteGraph graph Unused

[<Fact(Skip = "Double reference resolution is not yet supported")>]
[<Trait("Category", "Populate Call Graph")>]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ namespace Microsoft.Quantum.Testing.PopulateCallGraph {

// =================================

// Concrete Graph Trims Specializations
// Concrete Graph Contains All Specializations
namespace Microsoft.Quantum.Testing.PopulateCallGraph {

@ EntryPoint()
Expand Down
8 changes: 8 additions & 0 deletions src/QsCompiler/Tests.Compiler/TestUtils/Signatures.fs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,18 @@ let public MonomorphizationSignatures =
GenericsNs, "GenericCallsSpecializations", [| "Double"; "String"; "Double" |], "Unit"
GenericsNs, "GenericCallsSpecializations", [| "String"; "Int"; "Unit" |], "Unit"

GenericsNs, "BasicGeneric", [| "Double"; "String" |], "Unit"
GenericsNs, "BasicGeneric", [| "String"; "Qubit[]" |], "Unit"
GenericsNs, "BasicGeneric", [| "String"; "Int" |], "Unit"
GenericsNs, "BasicGeneric", [| "Qubit[]"; "Qubit[]" |], "Unit"
GenericsNs, "BasicGeneric", [| "Qubit[]"; "Double" |], "Unit"
GenericsNs, "BasicGeneric", [| "Qubit[]"; "Unit" |], "Unit"
GenericsNs, "BasicGeneric", [| "String"; "Double" |], "Unit"
GenericsNs, "BasicGeneric", [| "Int"; "Unit" |], "Unit"

GenericsNs, "ArrayGeneric", [| "Qubit"; "Double" |], "Int"
GenericsNs, "ArrayGeneric", [| "Qubit"; "Qubit[]" |], "Int"
GenericsNs, "ArrayGeneric", [| "Qubit"; "Unit" |], "Int"
|])
// Test Case 4
(_DefaultTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,18 @@ private static class ConcreteCallGraphWalker
/// </summary>
public static void PopulateConcreteGraph(ConcreteGraphBuilder graph, QsCompilation compilation)
{
var globals = compilation.Namespaces.GlobalCallableResolutions();
var walker = new BuildGraph(graph);
var entryPointNodes = compilation.EntryPoints.Select(name => new ConcreteCallGraphNode(name, QsSpecializationKind.QsBody, TypeParameterResolutions.Empty));
var entryPointNodes = compilation.EntryPoints.SelectMany(name =>
GetSpecializationKinds(globals, name).Select(kind =>
new ConcreteCallGraphNode(name, kind, TypeParameterResolutions.Empty)));
foreach (var entryPoint in entryPointNodes)
{
// Make sure all the entry points are added to the graph
walker.SharedState.Graph.AddNode(entryPoint);
walker.SharedState.RequestStack.Push(entryPoint);
}

var globals = compilation.Namespaces.GlobalCallableResolutions();
while (walker.SharedState.RequestStack.TryPop(out var currentRequest))
{
// If there is a call to an unknown callable, throw exception
Expand Down Expand Up @@ -253,12 +255,21 @@ private void AddEdge(QsQualifiedName identifier, QsSpecializationKind kind, Type
throw new ArgumentException("AddEdge requires CurrentNode to be non-null.");
}

// Add an edge to the specific specialization kind referenced
var called = new ConcreteCallGraphNode(identifier, kind, typeRes);
var edge = new ConcreteCallGraphEdge(referenceRange);
this.Graph.AddDependency(this.CurrentNode, called, edge);
if (!this.RequestStack.Contains(called) && !this.ResolvedNodeSet.Contains(called))

// Add all the specializations of the referenced callable to the graph
var newNodes = this.GetSpecializationKinds(identifier)
.Select(specKind => new ConcreteCallGraphNode(identifier, specKind, typeRes));
foreach (var node in newNodes)
{
this.RequestStack.Push(called);
if (!this.RequestStack.Contains(node) && !this.ResolvedNodeSet.Contains(node))
{
this.Graph.AddNode(node);
this.RequestStack.Push(node);
}
}
}
}
Expand Down
29 changes: 18 additions & 11 deletions src/QsCompiler/Transformations/Monomorphization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ namespace Microsoft.Quantum.QsCompiler.Transformations.Monomorphization
/// are found from uses of the callables.
/// This transformation also removes all callables that are not used directly or
/// indirectly from any of the marked entry point.
/// Intrinsic callables are not monomorphized or removed from the compilation.
/// Intrinsic callables are, by default, not monomorphized or removed from the compilation, but
/// may optionally be removed if unused if the keepAllIntrinsics parameter is set to false.
/// There are also some built-in callables that are also exempt from
/// being removed from non-use, as they are needed for later rewrite steps.
/// </summary>
public static class Monomorphize
{
/// <summary>
/// Performs Monomorphization on the given compilation.
/// Performs Monomorphization on the given compilation. If the keepAllIntrinsics parameter
/// is set to true, then unused intrinsics will not be removed from the resulting compilation.
/// </summary>
public static QsCompilation Apply(QsCompilation compilation)
public static QsCompilation Apply(QsCompilation compilation, bool keepAllIntrinsics = true)
{
var globals = compilation.Namespaces.GlobalCallableResolutions();
var concretizations = new List<QsCallable>();
Expand Down Expand Up @@ -66,7 +68,9 @@ public static QsCompilation Apply(QsCompilation compilation)

// Generate the concrete version of the callable
var concrete = ReplaceTypeParamImplementations.Apply(originalGlobal, node.ParamResolutions, getAccessModifiers);
concretizations.Add(concrete.WithFullName(oldName => concreteName));
concretizations.Add(
concrete.WithFullName(oldName => concreteName)
.WithSpecializations(specs => specs.Select(spec => spec.WithParent(_ => concreteName)).ToImmutableArray()));
}
else
{
Expand All @@ -89,16 +93,16 @@ public static QsCompilation Apply(QsCompilation compilation)
final.Add(ReplaceTypeParamCalls.Apply(callable, getConcreteIdentifier, intrinsicCallableSet));
}

return ResolveGenerics.Apply(compilation, final, intrinsicCallableSet);
return ResolveGenerics.Apply(compilation, final, intrinsicCallableSet, keepAllIntrinsics);
}

#region ResolveGenerics

private class ResolveGenerics : SyntaxTreeTransformation<ResolveGenerics.TransformationState>
{
public static QsCompilation Apply(QsCompilation compilation, List<QsCallable> callables, ImmutableHashSet<QsQualifiedName> intrinsicCallableSet)
public static QsCompilation Apply(QsCompilation compilation, List<QsCallable> callables, ImmutableHashSet<QsQualifiedName> intrinsicCallableSet, bool keepAllIntrinsics)
{
var filter = new ResolveGenerics(callables.ToLookup(res => res.FullName.Namespace), intrinsicCallableSet);
var filter = new ResolveGenerics(callables.ToLookup(res => res.FullName.Namespace), intrinsicCallableSet, keepAllIntrinsics);

return filter.OnCompilation(compilation);
}
Expand All @@ -107,20 +111,22 @@ public class TransformationState
{
public readonly ILookup<string, QsCallable> NamespaceCallables;
public readonly ImmutableHashSet<QsQualifiedName> IntrinsicCallableSet;
public readonly bool KeepAllIntrinsics;

public TransformationState(ILookup<string, QsCallable> namespaceCallables, ImmutableHashSet<QsQualifiedName> intrinsicCallableSet)
public TransformationState(ILookup<string, QsCallable> namespaceCallables, ImmutableHashSet<QsQualifiedName> intrinsicCallableSet, bool keepAllIntrinsics)
{
this.NamespaceCallables = namespaceCallables;
this.IntrinsicCallableSet = intrinsicCallableSet;
this.KeepAllIntrinsics = keepAllIntrinsics;
}
}

/// <summary>
/// Constructor for the ResolveGenericsSyntax class. Its transform function replaces global callables in the namespace.
/// </summary>
/// <param name="namespaceCallables">Maps namespace names to an enumerable of all global callables in that namespace.</param>
private ResolveGenerics(ILookup<string, QsCallable> namespaceCallables, ImmutableHashSet<QsQualifiedName> intrinsicCallableSet)
: base(new TransformationState(namespaceCallables, intrinsicCallableSet))
private ResolveGenerics(ILookup<string, QsCallable> namespaceCallables, ImmutableHashSet<QsQualifiedName> intrinsicCallableSet, bool keepAllIntrinsics)
: base(new TransformationState(namespaceCallables, intrinsicCallableSet, keepAllIntrinsics))
{
this.Namespaces = new NamespaceTransformation(this);
this.Statements = new StatementTransformation<TransformationState>(this, TransformationOptions.Disabled);
Expand All @@ -138,7 +144,8 @@ private bool NamespaceElementFilter(QsNamespaceElement elem)
{
if (elem is QsNamespaceElement.QsCallable call)
{
return BuiltIn.RewriteStepDependencies.Contains(call.Item.FullName) || this.SharedState.IntrinsicCallableSet.Contains(call.Item.FullName);
return BuiltIn.RewriteStepDependencies.Contains(call.Item.FullName) ||
(this.SharedState.KeepAllIntrinsics && this.SharedState.IntrinsicCallableSet.Contains(call.Item.FullName));
}
else
{
Expand Down