diff --git a/linker/Linker.Steps/MarkStep.cs b/linker/Linker.Steps/MarkStep.cs index fb06bc70c205..a42c5facdbcd 100644 --- a/linker/Linker.Steps/MarkStep.cs +++ b/linker/Linker.Steps/MarkStep.cs @@ -266,7 +266,7 @@ void ProcessOverride (MethodDefinition method, MethodDefinition @base) // We don't need to mark overrides until it is possible that the type could be instantiated // Note : The base type is interface check should be removed once we have base type sweeping - if (!isInstantiated && @base.DeclaringType.IsInterface) + if (@base.DeclaringType.IsInterface && !isInstantiated && !IsInterfaceImplementationMarked (method.DeclaringType, @base.DeclaringType)) return; if (!isInstantiated && !@base.IsAbstract) @@ -276,6 +276,11 @@ void ProcessOverride (MethodDefinition method, MethodDefinition @base) ProcessVirtualMethod (method); } + bool IsInterfaceImplementationMarked (TypeDefinition type, TypeDefinition interfaceType) + { + return type.HasInterface (@interfaceType, out InterfaceImplementation implementation) && Annotations.IsMarked (implementation); + } + void MarkMarshalSpec (IMarshalInfoProvider spec) { if (!spec.HasMarshalInfo) @@ -2004,6 +2009,8 @@ protected virtual void MarkMethodBody (MethodBody body) foreach (Instruction instruction in body.Instructions) MarkInstruction (instruction); + MarkInterfacesNeededByBodyStack (body); + MarkThingsUsedViaReflection (body); PostMarkMethodBody (body); @@ -2011,6 +2018,19 @@ protected virtual void MarkMethodBody (MethodBody body) partial void PostMarkMethodBody (MethodBody body); + void MarkInterfacesNeededByBodyStack (MethodBody body) + { + // If a type could be on the stack in the body and an interface it implements could be on the stack on the body + // then we need to mark that interface implementation. When this occurs it is not safe to remove the interface implementation from the type + // even if the type is never instantiated + var implementations = MethodBodyScanner.GetReferencedInterfaces (_context.Annotations, body); + if (implementations == null) + return; + + foreach (var implementation in implementations) + MarkInterfaceImplementation (implementation); + } + protected virtual void MarkThingsUsedViaReflection (MethodBody body) { MarkSomethingUsedViaReflection ("GetConstructor", MarkConstructorsUsedViaReflection, body.Instructions); diff --git a/linker/Linker.Steps/TypeMapStep.cs b/linker/Linker.Steps/TypeMapStep.cs index 44e8823301d2..4189e65a879b 100644 --- a/linker/Linker.Steps/TypeMapStep.cs +++ b/linker/Linker.Steps/TypeMapStep.cs @@ -44,6 +44,7 @@ protected virtual void MapType (TypeDefinition type) { MapVirtualMethods (type); MapInterfaceMethodsInTypeHierarchy (type); + MapBaseTypeHierarchy (type); if (!type.HasNestedTypes) return; @@ -123,6 +124,30 @@ void MapOverrides (MethodDefinition method) } } + void MapBaseTypeHierarchy (TypeDefinition type) + { + if (!type.IsClass) + return; + + var bases = new List (); + var current = type.BaseType; + + while (current != null) { + var resolved = current.Resolve (); + if (resolved == null) + break; + + // Exclude Object. That's implied and adding it to the list will just lead to lots of extra unnecessary processing + if (resolved.BaseType == null) + break; + + bases.Add (resolved); + current = resolved.BaseType; + } + + Annotations.SetClassHierarchy (type, bases); + } + void AnnotateMethods (MethodDefinition @base, MethodDefinition @override) { Annotations.AddBaseMethod (@override, @base); diff --git a/linker/Linker/Annotations.cs b/linker/Linker/Annotations.cs index 80b90e0b5bc1..488ce9cc46f1 100644 --- a/linker/Linker/Annotations.cs +++ b/linker/Linker/Annotations.cs @@ -48,6 +48,7 @@ public partial class AnnotationStore { protected readonly Dictionary> override_methods = new Dictionary> (); protected readonly Dictionary> base_methods = new Dictionary> (); protected readonly Dictionary symbol_readers = new Dictionary (); + protected readonly Dictionary> class_type_base_hierarchy = new Dictionary> (); protected readonly Dictionary> custom_annotations = new Dictionary> (); protected readonly Dictionary> resources_to_remove = new Dictionary> (); @@ -355,5 +356,17 @@ public bool SetPreservedStaticCtor (TypeDefinition type) return marked_types_with_cctor.Add (type); } + public void SetClassHierarchy (TypeDefinition type, List bases) + { + class_type_base_hierarchy [type] = bases; + } + + public List GetClassHierarchy (TypeDefinition type) + { + if (class_type_base_hierarchy.TryGetValue (type, out List bases)) + return bases; + + return null; + } } } diff --git a/linker/Linker/MethodBodyScanner.cs b/linker/Linker/MethodBodyScanner.cs new file mode 100644 index 000000000000..f875ad16282a --- /dev/null +++ b/linker/Linker/MethodBodyScanner.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Mono.Cecil; +using Mono.Cecil.Cil; + +namespace Mono.Linker { + public static class MethodBodyScanner { + public static IEnumerable GetReferencedInterfaces (AnnotationStore annotations, MethodBody body) + { + var possibleStackTypes = AllPossibleStackTypes (body.Method); + if (possibleStackTypes.Count == 0) + return null; + + var interfaceTypes = possibleStackTypes.Where (t => t.IsInterface).ToArray (); + if (interfaceTypes.Length == 0) + return null; + + var interfaceImplementations = new HashSet (); + + // If a type could be on the stack in the body and an interface it implements could be on the stack on the body + // then we need to mark that interface implementation. When this occurs it is not safe to remove the interface implementation from the type + // even if the type is never instantiated + foreach (var type in possibleStackTypes) { + // We only sweep interfaces on classes so that's why we only care about classes + if (!type.IsClass) + continue; + + AddMatchingInterfaces (interfaceImplementations, type, interfaceTypes); + var bases = annotations.GetClassHierarchy (type); + foreach (var @base in bases) { + AddMatchingInterfaces (interfaceImplementations, @base, interfaceTypes); + } + } + + return interfaceImplementations; + } + + static HashSet AllPossibleStackTypes (MethodDefinition method) + { + if (!method.HasBody) + throw new ArgumentException(); + + var body = method.Body; + var types = new HashSet (); + + foreach (VariableDefinition var in body.Variables) + AddIfResolved (types, var.VariableType); + + foreach (ExceptionHandler eh in body.ExceptionHandlers) { + if (eh.HandlerType == ExceptionHandlerType.Catch) { + AddIfResolved (types, eh.CatchType); + } + } + + foreach (Instruction instruction in body.Instructions) { + if (instruction.Operand is FieldReference fieldReference) { + AddIfResolved (types, fieldReference.Resolve ()?.FieldType); + } else if (instruction.Operand is MethodReference methodReference) { + if (methodReference is GenericInstanceMethod genericInstanceMethod) + AddFromGenericInstance (types, genericInstanceMethod); + + if (methodReference.DeclaringType is GenericInstanceType genericInstanceType) + AddFromGenericInstance (types, genericInstanceType); + + var resolvedMethod = methodReference.Resolve (); + if (resolvedMethod != null) { + if (resolvedMethod.HasParameters) { + foreach (var param in resolvedMethod.Parameters) + AddIfResolved (types, param.ParameterType); + } + + AddFromGenericParameterProvider (types, resolvedMethod); + AddFromGenericParameterProvider (types, resolvedMethod.DeclaringType); + AddIfResolved (types, resolvedMethod.ReturnType); + } + } + } + + return types; + } + + static void AddMatchingInterfaces (HashSet results, TypeDefinition type, TypeDefinition [] interfaceTypes) + { + foreach (var interfaceType in interfaceTypes) { + if (type.HasInterface (interfaceType, out InterfaceImplementation implementation)) + results.Add (implementation); + } + } + + static void AddFromGenericInstance (HashSet set, IGenericInstance instance) + { + if (!instance.HasGenericArguments) + return; + + foreach (var genericArgument in instance.GenericArguments) + AddIfResolved (set, genericArgument); + } + + static void AddFromGenericParameterProvider (HashSet set, IGenericParameterProvider provider) + { + if (!provider.HasGenericParameters) + return; + + foreach (var genericParameter in provider.GenericParameters) { + foreach (var constraint in genericParameter.Constraints) + AddIfResolved (set, constraint); + } + } + + static void AddIfResolved (HashSet set, TypeReference item) + { + var resolved = item.Resolve (); + if (resolved == null) + return; + set.Add (resolved); + } + } +} \ No newline at end of file diff --git a/linker/Linker/TypeDefinitionExtensions.cs b/linker/Linker/TypeDefinitionExtensions.cs new file mode 100644 index 000000000000..f94e340c3e9e --- /dev/null +++ b/linker/Linker/TypeDefinitionExtensions.cs @@ -0,0 +1,21 @@ +using Mono.Cecil; + +namespace Mono.Linker { + public static class TypeDefinitionExtensions { + public static bool HasInterface (this TypeDefinition type, TypeDefinition interfaceType, out InterfaceImplementation implementation) + { + implementation = null; + if (!type.HasInterfaces) + return false; + + foreach (var iface in type.Interfaces) { + if (iface.InterfaceType.Resolve () == interfaceType) { + implementation = iface; + return true; + } + } + + return false; + } + } +} \ No newline at end of file diff --git a/linker/Mono.Linker.csproj b/linker/Mono.Linker.csproj index cd547ca51f7e..2b82f38387ef 100644 --- a/linker/Mono.Linker.csproj +++ b/linker/Mono.Linker.csproj @@ -1,4 +1,4 @@ - +