diff --git a/src/linker/Linker.Dataflow/MethodBodyScanner.cs b/src/linker/Linker.Dataflow/MethodBodyScanner.cs new file mode 100644 index 000000000000..b1efab6a60bd --- /dev/null +++ b/src/linker/Linker.Dataflow/MethodBodyScanner.cs @@ -0,0 +1,781 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +using Mono.Cecil; +using Mono.Cecil.Cil; +using Mono.Collections.Generic; + +namespace Mono.Linker.Dataflow +{ + /// + /// Tracks information about the contents of a stack slot + /// + class StackSlot + { + public ValueNode Value { get; set; } + + /// + /// True if the value is on the stack as a byref + /// + public bool IsByRef { get; set; } + + public StackSlot () + { + } + + public StackSlot (ValueNode value, bool isByRef = false) : this () + { + Value = value; + IsByRef = isByRef; + } + } + + abstract partial class MethodBodyScanner + { + internal ValueNode MethodReturnValue { private set; get; } + + protected virtual void WarnAboutInvalidILInMethod (MethodBody method, int ilOffset) + { + } + + private void CheckForInvalidStack (Stack stack, int depthRequired, MethodBody method, int ilOffset) + { + if (stack.Count < depthRequired) { + WarnAboutInvalidILInMethod (method, ilOffset); + while (stack.Count < depthRequired) + stack.Push (new StackSlot ()); // Push dummy values to avoid crashes. + // Analysis of this method will be incorrect. + } + } + + private void CheckForInvalidReturnStack (Stack stack, MethodBody method, int ilOffset) + { + int numExpectedValuesOnStack = method.Method.ReturnType.MetadataType == MetadataType.Void ? 0 : 1; + if (stack.Count != numExpectedValuesOnStack) { + WarnAboutInvalidILInMethod (method, ilOffset); + } + } + + private static void PushUnknown (Stack stack) + { + stack.Push (new StackSlot ()); + } + + private void PushUnknownAndWarnAboutInvalidIL (Stack stack, MethodBody methodBody, int offset, bool invalidateBody) + { + WarnAboutInvalidILInMethod (methodBody, offset); + PushUnknown (stack); + } + + private StackSlot PopUnknown (Stack stack, int count, MethodBody method, int ilOffset) + { + StackSlot topOfStack = null; + CheckForInvalidStack (stack, count, method, ilOffset); + + for (int i = 0; i < count; ++i) { + StackSlot slot = stack.Pop (); + if (i == 0) + topOfStack = slot; + } + return topOfStack; + } + + private static StackSlot MergeStackElement (StackSlot a, StackSlot b) + { + StackSlot mergedSlot; + if (b.Value == null) { + mergedSlot = a; + } else if (a.Value == null) { + mergedSlot = b; + } else { + mergedSlot = new StackSlot (MergePointValue.MergeValues (a.Value, b.Value)); + } + + return mergedSlot; + } + + // Merge stacks together. This may return the first stack, the stack length must be the same for the two stacks. + private Stack MergeStack (Stack a, Stack b, MethodBody method, int ilOffset) + { + if (a.Count != b.Count) { + // Force stacks to be of equal size to avoid crashes. + // Analysis of this method will be incorrect. + while (a.Count < b.Count) + a.Push (new StackSlot ()); + + while (b.Count < a.Count) + b.Push (new StackSlot ()); + } + + Stack newStack = new Stack (a.Count); + IEnumerator aEnum = a.GetEnumerator (); + IEnumerator bEnum = b.GetEnumerator (); + while (aEnum.MoveNext () && bEnum.MoveNext ()) { + newStack.Push (MergeStackElement (aEnum.Current, bEnum.Current)); + } + + // The new stack is reversed. Use the copy constructor to reverse it back + return new Stack (newStack); + } + + private static void ClearStack (ref Stack stack) + { + stack = null; + } + + private void NewKnownStack (Dictionary> knownStacks, int newOffset, Stack newStack, MethodBody method) + { + // No need to merge in empty stacks + if (newStack.Count == 0) { + return; + } + + if (knownStacks.ContainsKey (newOffset)) { + knownStacks [newOffset] = MergeStack (knownStacks [newOffset], newStack, method, newOffset); + } else { + knownStacks.Add (newOffset, new Stack (newStack.Reverse ())); + } + } + + private struct BasicBlockIterator + { + HashSet _methodBranchTargets; + int _currentBlockIndex; + bool _foundEndOfPrevBlock; + + public BasicBlockIterator (MethodBody methodBody) + { + _methodBranchTargets = methodBody.ComputeBranchTargets (); + _currentBlockIndex = -1; + _foundEndOfPrevBlock = true; + } + + public int CurrentBlockIndex { + get { + return _currentBlockIndex; + } + } + + public int MoveNext (Instruction op) + { + if (_foundEndOfPrevBlock || _methodBranchTargets.Contains (op.Offset)) { + _currentBlockIndex++; + _foundEndOfPrevBlock = false; + } + + if (op.OpCode.IsControlFlowInstruction()) { + _foundEndOfPrevBlock = true; + } + + return CurrentBlockIndex; + } + } + + public struct ValueBasicBlockPair + { + public ValueNode Value; + public int BasicBlockIndex; + } + + private void StoreMethodLocalValue ( + Dictionary valueCollection, + ValueNode valueToStore, + KeyType collectionKey, + int curBasicBlock) + { + ValueBasicBlockPair newValue = new ValueBasicBlockPair { BasicBlockIndex = curBasicBlock }; + + ValueBasicBlockPair existingValue; + if (valueCollection.TryGetValue (collectionKey, out existingValue) + && existingValue.BasicBlockIndex == curBasicBlock) { + // If the previous value was stored in the current basic block, then we can safely + // overwrite the previous value with the new one. + newValue.Value = valueToStore; + } else { + // If the previous value came from a previous basic block, then some other use of + // the local could see the previous value, so we must merge the new value with the + // old value. + newValue.Value = MergePointValue.MergeValues (existingValue.Value, valueToStore); + } + valueCollection [collectionKey] = newValue; + } + + public void Scan (MethodBody methodBody) + { + MethodDefinition thisMethod = methodBody.Method; + + Dictionary locals = new Dictionary (methodBody.Variables.Count); + + Dictionary> knownStacks = new Dictionary> (); + Stack currentStack = new Stack (methodBody.MaxStackSize); + + ScanExceptionInformation (knownStacks, methodBody); + + BasicBlockIterator blockIterator = new BasicBlockIterator (methodBody); + + MethodReturnValue = null; + foreach (Instruction operation in methodBody.Instructions) { + int curBasicBlock = blockIterator.MoveNext (operation); + + if (knownStacks.ContainsKey (operation.Offset)) { + if (currentStack == null) { + // The stack copy constructor reverses the stack + currentStack = new Stack (knownStacks [operation.Offset].Reverse ()); + } else { + currentStack = MergeStack (currentStack, knownStacks [operation.Offset], methodBody, operation.Offset); + } + } + + if (currentStack == null) { + currentStack = new Stack (methodBody.MaxStackSize); + } + + switch (operation.OpCode.Code) { + case Code.Add: + case Code.Add_Ovf: + case Code.Add_Ovf_Un: + case Code.And: + case Code.Div: + case Code.Div_Un: + case Code.Mul: + case Code.Mul_Ovf: + case Code.Mul_Ovf_Un: + case Code.Or: + case Code.Rem: + case Code.Rem_Un: + case Code.Sub: + case Code.Sub_Ovf: + case Code.Sub_Ovf_Un: + case Code.Xor: + case Code.Cgt: + case Code.Cgt_Un: + case Code.Clt: + case Code.Clt_Un: + case Code.Ldelem_I: + case Code.Ldelem_I1: + case Code.Ldelem_I2: + case Code.Ldelem_I4: + case Code.Ldelem_I8: + case Code.Ldelem_R4: + case Code.Ldelem_R8: + case Code.Ldelem_U1: + case Code.Ldelem_U2: + case Code.Ldelem_U4: + case Code.Shl: + case Code.Shr: + case Code.Shr_Un: + case Code.Ldelem_Any: + case Code.Ldelem_Ref: + case Code.Ldelema: + case Code.Ceq: + PopUnknown (currentStack, 2, methodBody, operation.Offset); + PushUnknown (currentStack); + break; + + case Code.Dup: + currentStack.Push (currentStack.Peek ()); + break; + + case Code.Ldnull: + currentStack.Push (new StackSlot (NullValue.Instance)); + break; + + + case Code.Ldc_I4_0: + case Code.Ldc_I4_1: + case Code.Ldc_I4_2: + case Code.Ldc_I4_3: + case Code.Ldc_I4_4: + case Code.Ldc_I4_5: + case Code.Ldc_I4_6: + case Code.Ldc_I4_7: + case Code.Ldc_I4_8: { + int value = operation.OpCode.Code - Code.Ldc_I4_0; + ConstIntValue civ = new ConstIntValue (value); + StackSlot slot = new StackSlot (civ); + currentStack.Push (slot); + } + break; + + case Code.Ldc_I4_M1: { + ConstIntValue civ = new ConstIntValue (-1); + StackSlot slot = new StackSlot (civ); + currentStack.Push (slot); + } + break; + + case Code.Ldc_I4: { + int value = (int)operation.Operand; + ConstIntValue civ = new ConstIntValue (value); + StackSlot slot = new StackSlot (civ); + currentStack.Push (slot); + } + break; + + case Code.Ldc_I4_S: { + int value = (sbyte)operation.Operand; + ConstIntValue civ = new ConstIntValue (value); + StackSlot slot = new StackSlot (civ); + currentStack.Push (slot); + } + break; + + case Code.Arglist: + case Code.Ldftn: + case Code.Sizeof: + case Code.Ldc_I8: + case Code.Ldc_R4: + case Code.Ldc_R8: + case Code.Ldsfld: + case Code.Ldsflda: + PushUnknown (currentStack); + break; + + case Code.Ldarg: + case Code.Ldarg_0: + case Code.Ldarg_1: + case Code.Ldarg_2: + case Code.Ldarg_3: + case Code.Ldarg_S: + case Code.Ldarga: + case Code.Ldarga_S: + ScanLdarg (operation, currentStack, thisMethod, methodBody); + break; + + case Code.Ldloc: + case Code.Ldloc_0: + case Code.Ldloc_1: + case Code.Ldloc_2: + case Code.Ldloc_3: + case Code.Ldloc_S: + case Code.Ldloca: + case Code.Ldloca_S: + ScanLdloc (operation, currentStack, thisMethod, methodBody, locals); + break; + + case Code.Ldstr: { + StackSlot slot = new StackSlot (new KnownStringValue ((string)operation.Operand)); + currentStack.Push (slot); + } + break; + + case Code.Ldtoken: + ScanLdtoken (operation, currentStack, thisMethod, methodBody); + break; + + case Code.Ldind_I: + case Code.Ldind_I1: + case Code.Ldind_I2: + case Code.Ldind_I4: + case Code.Ldind_I8: + case Code.Ldind_R4: + case Code.Ldind_R8: + case Code.Ldind_U1: + case Code.Ldind_U2: + case Code.Ldind_U4: + case Code.Ldlen: + case Code.Ldvirtftn: + case Code.Localloc: + case Code.Refanytype: + case Code.Refanyval: + case Code.Conv_I1: + case Code.Conv_I2: + case Code.Conv_I4: + case Code.Conv_Ovf_I1: + case Code.Conv_Ovf_I1_Un: + case Code.Conv_Ovf_I2: + case Code.Conv_Ovf_I2_Un: + case Code.Conv_Ovf_I4: + case Code.Conv_Ovf_I4_Un: + case Code.Conv_Ovf_U: + case Code.Conv_Ovf_U_Un: + case Code.Conv_Ovf_U1: + case Code.Conv_Ovf_U1_Un: + case Code.Conv_Ovf_U2: + case Code.Conv_Ovf_U2_Un: + case Code.Conv_Ovf_U4: + case Code.Conv_Ovf_U4_Un: + case Code.Conv_U1: + case Code.Conv_U2: + case Code.Conv_U4: + case Code.Conv_I8: + case Code.Conv_Ovf_I8: + case Code.Conv_Ovf_I8_Un: + case Code.Conv_Ovf_U8: + case Code.Conv_Ovf_U8_Un: + case Code.Conv_U8: + case Code.Conv_I: + case Code.Conv_Ovf_I: + case Code.Conv_Ovf_I_Un: + case Code.Conv_U: + case Code.Conv_R_Un: + case Code.Conv_R4: + case Code.Conv_R8: + case Code.Ldind_Ref: + case Code.Ldobj: + case Code.Mkrefany: + case Code.Unbox: + case Code.Unbox_Any: + case Code.Box: + case Code.Neg: + case Code.Not: + PopUnknown (currentStack, 1, methodBody, operation.Offset); + PushUnknown (currentStack); + break; + + case Code.Isinst: + case Code.Castclass: + // We can consider a NOP because the value doesn't change. + // It might change to NULL, but for the purposes of dataflow analysis + // it doesn't hurt much. + break; + + case Code.Ldfld: + case Code.Ldflda: + // TODO: model field loads + PopUnknown (currentStack, 1, methodBody, operation.Offset); + PushUnknown (currentStack); + break; + + case Code.Newarr: { + StackSlot count = PopUnknown (currentStack, 1, methodBody, operation.Offset); + currentStack.Push (new StackSlot (new ArrayValue (count.Value))); + } + break; + + case Code.Cpblk: + case Code.Initblk: + case Code.Stelem_I: + case Code.Stelem_I1: + case Code.Stelem_I2: + case Code.Stelem_I4: + case Code.Stelem_I8: + case Code.Stelem_R4: + case Code.Stelem_R8: + case Code.Stelem_Any: + case Code.Stelem_Ref: + PopUnknown (currentStack, 3, methodBody, operation.Offset); + break; + + case Code.Stfld: { + StackSlot valueToStoreSlot = PopUnknown (currentStack, 1, methodBody, operation.Offset); + StackSlot objectToStoreIntoSlot = PopUnknown (currentStack, 1, methodBody, operation.Offset); + // TODO: model field stores + } + break; + + case Code.Stsfld: + PopUnknown (currentStack, 1, methodBody, operation.Offset); + // TODO: model field stores + break; + + case Code.Cpobj: + case Code.Stind_I: + case Code.Stind_I1: + case Code.Stind_I2: + case Code.Stind_I4: + case Code.Stind_I8: + case Code.Stind_R4: + case Code.Stind_R8: + case Code.Stind_Ref: + case Code.Stobj: + PopUnknown (currentStack, 2, methodBody, operation.Offset); + break; + + case Code.Initobj: + case Code.Pop: + PopUnknown (currentStack, 1, methodBody, operation.Offset); + break; + + case Code.Starg: + case Code.Starg_S: + // TODO: might want to track this and ensure ldarg reports the stored value. + PopUnknown (currentStack, 1, methodBody, operation.Offset); + break; + + case Code.Stloc: + case Code.Stloc_S: + case Code.Stloc_0: + case Code.Stloc_1: + case Code.Stloc_2: + case Code.Stloc_3: + ScanStloc (operation, currentStack, methodBody, locals, curBasicBlock); + break; + + case Code.Constrained: + case Code.No: + case Code.Readonly: + case Code.Tail: + case Code.Unaligned: + case Code.Volatile: + break; + + case Code.Brfalse: + case Code.Brfalse_S: + case Code.Brtrue: + case Code.Brtrue_S: + PopUnknown (currentStack, 1, methodBody, operation.Offset); + NewKnownStack (knownStacks, ((Instruction)operation.Operand).Offset, currentStack, methodBody); + break; + + case Code.Calli: + // TODO: currently not emitted by any mainstream compilers but we should implement + break; + + case Code.Call: + case Code.Callvirt: + case Code.Newobj: + HandleCall (methodBody, operation, currentStack); + break; + + case Code.Jmp: + // Not generated by mainstream compilers + break; + + case Code.Br: + case Code.Br_S: + NewKnownStack (knownStacks, ((Instruction)operation.Operand).Offset, currentStack, methodBody); + ClearStack (ref currentStack); + break; + + case Code.Leave: + case Code.Leave_S: + PopUnknown (currentStack, currentStack.Count, methodBody, operation.Offset); + ClearStack (ref currentStack); + NewKnownStack (knownStacks, ((Instruction)operation.Operand).Offset, new Stack (methodBody.MaxStackSize), methodBody); + break; + + case Code.Endfilter: + case Code.Endfinally: + case Code.Rethrow: + case Code.Throw: + PopUnknown (currentStack, currentStack.Count, methodBody, operation.Offset); + ClearStack (ref currentStack); + break; + + case Code.Ret: + CheckForInvalidReturnStack (currentStack, methodBody, operation.Offset); + StackSlot retValue = PopUnknown (currentStack, currentStack.Count, methodBody, operation.Offset); + if (retValue != null) + MethodReturnValue = MergePointValue.MergeValues (MethodReturnValue, retValue.Value); + + ClearStack (ref currentStack); + break; + + case Code.Switch: { + PopUnknown (currentStack, 1, methodBody, operation.Offset); + Instruction [] targets = (Instruction [])operation.Operand; + foreach (Instruction target in targets) { + NewKnownStack (knownStacks, target.Offset, currentStack, methodBody); + } + break; + } + + case Code.Beq: + case Code.Beq_S: + case Code.Bne_Un: + case Code.Bne_Un_S: + case Code.Bge: + case Code.Bge_S: + case Code.Bge_Un: + case Code.Bge_Un_S: + case Code.Bgt: + case Code.Bgt_S: + case Code.Bgt_Un: + case Code.Bgt_Un_S: + case Code.Ble: + case Code.Ble_S: + case Code.Ble_Un: + case Code.Ble_Un_S: + case Code.Blt: + case Code.Blt_S: + case Code.Blt_Un: + case Code.Blt_Un_S: + PopUnknown (currentStack, 2, methodBody, operation.Offset); + NewKnownStack (knownStacks, ((Instruction)operation.Operand).Offset, currentStack, methodBody); + break; + } + } + } + + private void ScanExceptionInformation (Dictionary> knownStacks, MethodBody methodBody) + { + foreach (ExceptionHandler exceptionClause in methodBody.ExceptionHandlers) { + Stack catchStack = new Stack (1); + catchStack.Push (new StackSlot ()); + + if (exceptionClause.HandlerType == ExceptionHandlerType.Filter) { + NewKnownStack (knownStacks, exceptionClause.FilterStart.Offset, catchStack, methodBody); + NewKnownStack (knownStacks, exceptionClause.HandlerStart.Offset, catchStack, methodBody); + } + if (exceptionClause.HandlerType == ExceptionHandlerType.Catch) { + NewKnownStack (knownStacks, exceptionClause.HandlerStart.Offset, catchStack, methodBody); + } + } + } + + private void ScanLdarg (Instruction operation, Stack currentStack, MethodDefinition thisMethod, MethodBody methodBody) + { + int paramNum; + if (operation.OpCode.Code >= Code.Ldarg_0 && + operation.OpCode.Code <= Code.Ldarg_3) { + paramNum = operation.OpCode.Code - Code.Ldarg_0; + } else { + paramNum = ((ParameterDefinition)operation.Operand).Index; + if (!thisMethod.IsStatic) + paramNum += 1; + } + + // TODO: isbyref + StackSlot slot = new StackSlot (new MethodParameterValue (paramNum), isByRef: false); + currentStack.Push (slot); + } + + private void ScanLdloc ( + Instruction operation, + Stack currentStack, + MethodDefinition thisMethod, + MethodBody methodBody, + Dictionary locals) + { + VariableDefinition localDef = GetLocalDef (operation, methodBody.Variables); + if (localDef == null) { + PushUnknownAndWarnAboutInvalidIL (currentStack, methodBody, operation.Offset, true); + return; + } + + bool isByRef = (operation.OpCode.Code == Code.Ldloca || operation.OpCode.Code == Code.Ldloca_S); + + ValueBasicBlockPair localValue; + locals.TryGetValue (localDef, out localValue); + if (localValue.Value != null) { + ValueNode valueToPush = localValue.Value; + currentStack.Push (new StackSlot (valueToPush, isByRef)); + } else { + PushUnknown (currentStack); + } + } + + private void ScanLdtoken ( + Instruction operation, + Stack currentStack, + MethodDefinition thisMethod, + MethodBody methodBody) + { + if (operation.Operand is TypeReference typeReference) { + var resolvedReference = typeReference.Resolve(); + if (resolvedReference != null) + { + StackSlot slot = new StackSlot (new RuntimeTypeHandleValue (resolvedReference)); + currentStack.Push (slot); + return; + } + } + + PushUnknown (currentStack); + } + + private void ScanStloc ( + Instruction operation, + Stack currentStack, + MethodBody methodBody, + Dictionary locals, + int curBasicBlock) + { + StackSlot valueToStore = PopUnknown (currentStack, 1, methodBody, operation.Offset); + VariableDefinition localDef = GetLocalDef (operation, methodBody.Variables); + if (localDef == null) { + WarnAboutInvalidILInMethod (methodBody, operation.Offset); + return; + } + + StoreMethodLocalValue (locals, valueToStore.Value, localDef, curBasicBlock); + } + + private static VariableDefinition GetLocalDef (Instruction operation, Collection localVariables) + { + Code code = operation.OpCode.Code; + if (code >= Code.Ldloc_0 && code <= Code.Ldloc_3) + return localVariables [code - Code.Ldloc_0]; + if (code >= Code.Stloc_0 && code <= Code.Stloc_3) + return localVariables [code - Code.Stloc_0]; + + return (VariableDefinition)operation.Operand; + } + + private ValueNodeList PopCallArguments ( + Stack currentStack, + MethodReference methodCalled, + MethodBody containingMethodBody, + bool isNewObj, int ilOffset, + out ValueNode newObjValue) + { + newObjValue = null; + + int countToPop = 0; + if (!isNewObj && methodCalled.HasThis && !methodCalled.ExplicitThis) + countToPop++; + countToPop += methodCalled.Parameters.Count; + + ValueNodeList methodParams = new ValueNodeList (countToPop); + for (int iParam = 0; iParam < countToPop; ++iParam) { + StackSlot slot = PopUnknown (currentStack, 1, containingMethodBody, ilOffset); + methodParams.Add (slot.Value); + } + + if (isNewObj) { + newObjValue = UnknownValue.Instance; + methodParams.Add (newObjValue); + } + methodParams.Reverse (); + return methodParams; + } + + private void HandleCall ( + MethodBody callingMethodBody, + Instruction operation, + Stack currentStack) + { + MethodReference calledMethod = (MethodReference)operation.Operand; + + bool isNewObj = (operation.OpCode.Code == Code.Newobj); + + ValueNode newObjValue = null; + ValueNodeList methodParams = PopCallArguments (currentStack, calledMethod, callingMethodBody, isNewObj, + operation.Offset, out newObjValue); + + ValueNode methodReturnValue = null; + bool handledFunction = HandleCall ( + callingMethodBody, + calledMethod, + operation, + methodParams, + out methodReturnValue); + + // Handle the return value or newobj result + if (!handledFunction) { + if (isNewObj) { + if (newObjValue == null) + PushUnknown (currentStack); + else + methodReturnValue = newObjValue; + } else { + if (calledMethod.ReturnType.MetadataType != MetadataType.Void) { + methodReturnValue = UnknownValue.Instance; + } + } + } + + if (methodReturnValue != null) + currentStack.Push (new StackSlot (methodReturnValue)); + } + + public abstract bool HandleCall ( + MethodBody callingMethodBody, + MethodReference calledMethod, + Instruction operation, + ValueNodeList methodParams, + out ValueNode methodReturnValue); + } +} diff --git a/src/linker/Linker.Dataflow/ScannerExtensions.cs b/src/linker/Linker.Dataflow/ScannerExtensions.cs new file mode 100644 index 000000000000..d9529feeddf7 --- /dev/null +++ b/src/linker/Linker.Dataflow/ScannerExtensions.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; + +using Mono.Cecil.Cil; + +namespace Mono.Linker.Dataflow +{ + static class ScannerExtensions + { + public static bool IsControlFlowInstruction (in this OpCode opcode) + { + return opcode.FlowControl == FlowControl.Branch + || opcode.FlowControl == FlowControl.Cond_Branch + || (opcode.FlowControl == FlowControl.Return && opcode.Code != Code.Ret); + } + + public static HashSet ComputeBranchTargets (this MethodBody methodBody) + { + HashSet branchTargets = new HashSet (); + foreach (Instruction operation in methodBody.Instructions) { + if (!operation.OpCode.IsControlFlowInstruction ()) + continue; + Object value = operation.Operand; + if (value is Instruction inst) { + branchTargets.Add (inst.Offset); + } else if (value is Instruction [] instructions) { + foreach (Instruction switchLabel in instructions) { + branchTargets.Add (switchLabel.Offset); + } + } + } + foreach (ExceptionHandler einfo in methodBody.ExceptionHandlers) { + if (einfo.HandlerType == ExceptionHandlerType.Filter) { + branchTargets.Add (einfo.FilterStart.Offset); + } + branchTargets.Add (einfo.HandlerStart.Offset); + } + return branchTargets; + } + } + +} diff --git a/src/linker/Linker.Dataflow/ValueNode.cs b/src/linker/Linker.Dataflow/ValueNode.cs new file mode 100644 index 000000000000..a35f1d468098 --- /dev/null +++ b/src/linker/Linker.Dataflow/ValueNode.cs @@ -0,0 +1,1181 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; + +using TypeDefinition = Mono.Cecil.TypeDefinition; +using FieldDefinition = Mono.Cecil.FieldDefinition; + +namespace Mono.Linker.Dataflow +{ + public enum ValueNodeKind + { + Invalid, // in case the Kind field is not initialized properly + + Unknown, // unknown value, has StaticType from context + + Null, // known value + SystemType, // known value - TypeRepresented + RuntimeTypeHandle, // known value - TypeRepresented + KnownString, // known value - Contents + ConstInt, // known value - Int32 + + MethodParameter, // symbolic placeholder + + MergePoint, // structural, multiplexer - Values + GetTypeFromString, // structural, could be known value - KnownString + Array, // structural, could be known value - Array + + LoadField, // structural, could be known value - InstanceValue + } + + /// + /// A ValueNode represents a value in the IL dataflow analysis. It may not contain complete information as it is a + /// best-effort representation. Additionally, as the analysis is linear and does not account for control flow, any + /// given ValueNode may represent multiple values simultaneously. (This occurs, for example, at control flow join + /// points when both paths yield values on the IL stack or in a local.) + /// + public abstract class ValueNode : IEquatable + { + public ValueNode () + { +#if false // Helpful for debugging a cycle that has inadvertently crept into the graph + if (this.DetectCycle(new HashSet())) + { + throw new Exception("Found a cycle"); + } +#endif + } + + /// + /// The 'kind' of value node -- this represents the most-derived type and allows us to switch over and do + /// equality checks without the cost of casting. Intermediate non-leaf types in the ValueNode hierarchy should + /// be abstract. + /// + public ValueNodeKind Kind { get; protected set; } + + /// + /// Allows the enumeration of the direct children of this node. The ChildCollection struct returned here + /// supports 'foreach' without allocation. + /// + public ChildCollection Children { get { return new ChildCollection (this); } } + + /// + /// This property allows you to enumerate all 'unique values' represented by a given ValueNode. The basic idea + /// is that there will be no MergePointValues in the returned ValueNodes and all structural operations will be + /// applied so that each 'unique value' can be considered on its own without regard to the structure that led to + /// it. + /// + public UniqueValueCollection UniqueValues { + get { + return new UniqueValueCollection (this); + } + } + + /// + /// This protected method is how nodes implement the UniqueValues property. It is protected because it returns + /// an IEnumerable and we want to avoid allocating an enumerator for the exceedingly common case of there being + /// only one value in the enumeration. The UniqueValueCollection returned by the UniqueValues property handles + /// this detail. + /// + protected abstract IEnumerable EvaluateUniqueValues (); + + /// + /// RepresentsExactlyOneValue is used by the UniqueValues property to allow us to bypass allocating an + /// enumerator to return just one value. If a node returns 'true' from RepresentsExactlyOneValue, it must also + /// return that one value from GetSingleUniqueValue. If it always returns 'false', it doesn't need to implement + /// GetSingleUniqueValue. + /// + protected virtual bool RepresentsExactlyOneValue { get { return false; } } + + /// + /// GetSingleUniqueValue is called if, and only if, RepresentsExactlyOneValue returns true. It allows us to + /// bypass the allocation of an enumerator for the common case of returning exactly one value. + /// + protected virtual ValueNode GetSingleUniqueValue () + { + // Not implemented because RepresentsExactlyOneValue returns false and, therefore, this method should be + // unreachable. + throw new NotImplementedException (); + } + + protected abstract int NumChildren { get; } + protected abstract ValueNode ChildAt (int index); + + public abstract bool Equals (ValueNode other); + + public abstract override int GetHashCode (); + + /// + /// Each node type must implement this to stringize itself. The expectation is that it is implemented using + /// ValueNodeDump.ValueNodeToString(), passing any non-ValueNode properties of interest (e.g. + /// SystemTypeValue.TypeRepresented). Properties that are invariant on a particular node type + /// should be omitted for clarity. + /// + protected abstract string NodeToString (); + + public override string ToString () + { + return NodeToString (); + } + + public override bool Equals (object other) + { + if (!(other is ValueNode)) + return false; + + return this.Equals ((ValueNode)other); + } + + #region Specialized Collection Nested Types + /// + /// ChildCollection struct is used to wrap the operations on a node involving its children. In particular, the + /// struct implements a GetEnumerator method that is used to allow "foreach (ValueNode node in myNode.Children)" + /// without heap allocations. + /// + public struct ChildCollection : IEnumerable + { + /// + /// Enumerator for children of a ValueNode. Allows foreach(var child in node.Children) to work without + /// allocating a heap-based enumerator. + /// + public struct Enumerator : IEnumerator + { + int _index; + ValueNode _parent; + + public Enumerator (ValueNode parent) + { + _parent = parent; + _index = -1; + } + + public ValueNode Current { get { return (_parent != null) ? _parent.ChildAt (_index) : null; } } + + object System.Collections.IEnumerator.Current { get { return Current; } } + + public bool MoveNext () + { + _index++; + return (_parent != null) ? (_index < _parent.NumChildren) : false; + } + + public void Reset () + { + _index = -1; + } + + public void Dispose () + { + } + } + + ValueNode _parentNode; + + public ChildCollection (ValueNode parentNode) { _parentNode = parentNode; } + + // Used by C# 'foreach', when strongly typed, to avoid allocation. + public Enumerator GetEnumerator () + { + return new Enumerator (_parentNode); + } + + IEnumerator IEnumerable.GetEnumerator () + { + // note the boxing! + return (IEnumerator)new Enumerator (_parentNode); + } + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator () + { + // note the boxing! + return (System.Collections.IEnumerator)new Enumerator (_parentNode); + } + + public int Count { get { return (_parentNode != null) ? _parentNode.NumChildren : 0; } } + } + + /// + /// UniqueValueCollection is used to wrap calls to ValueNode.EvaluateUniqueValues. If a ValueNode represents + /// only one value, then foreach(ValueNode value in node.UniqueValues) will not allocate a heap-based enumerator. + /// + /// This is implented by having each ValueNode tell us whether or not it represents exactly one value or not. + /// If it does, we fetch it with ValueNode.GetSingleUniqueValue(), otherwise, we fall back to the usual heap- + /// based IEnumerable returned by ValueNode.EvaluateUniqueValues. + /// + public struct UniqueValueCollection : IEnumerable + { + IEnumerable _multiValueEnumerable; + ValueNode _treeNode; + + public UniqueValueCollection (ValueNode node) + { + if (node.RepresentsExactlyOneValue) { + _multiValueEnumerable = null; + _treeNode = node; + } else { + _multiValueEnumerable = node.EvaluateUniqueValues (); + _treeNode = null; + } + } + + public Enumerator GetEnumerator () + { + return new Enumerator (_treeNode, _multiValueEnumerable); + } + + IEnumerator IEnumerable.GetEnumerator () + { + if (_multiValueEnumerable != null) { + return _multiValueEnumerable.GetEnumerator (); + } + + // note the boxing! + return (IEnumerator)GetEnumerator (); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator () + { + if (_multiValueEnumerable != null) { + return _multiValueEnumerable.GetEnumerator (); + } + + // note the boxing! + return (System.Collections.IEnumerator)GetEnumerator (); + } + + + public struct Enumerator : IEnumerator + { + IEnumerator _multiValueEnumerator; + ValueNode _singleValueNode; + int _index; + + public Enumerator (ValueNode treeNode, IEnumerable mulitValueEnumerable) + { + _singleValueNode = (treeNode != null) ? treeNode.GetSingleUniqueValue () : null; + _multiValueEnumerator = (mulitValueEnumerable != null) ? mulitValueEnumerable.GetEnumerator () : null; + _index = -1; + } + + public void Reset () + { + if (_multiValueEnumerator != null) { + _multiValueEnumerator.Reset (); + return; + } + + _index = -1; + } + + public bool MoveNext () + { + if (_multiValueEnumerator != null) + return _multiValueEnumerator.MoveNext (); + + _index++; + return (_index == 0); + } + + public ValueNode Current { + get { + if (_multiValueEnumerator != null) + return _multiValueEnumerator.Current; + + if (_index == 0) + return _singleValueNode; + + throw new InvalidOperationException (); + } + } + + object System.Collections.IEnumerator.Current { get { return Current; } } + + public void Dispose () + { + } + } + } + #endregion + } + + /// + /// LeafValueNode represents a 'leaf' in the expression tree. In other words, the node has no ValueNode children. + /// It *may* still have non-ValueNode 'properties' that are interesting. This class serves, primarily, as a way to + /// collect up the very common implmentation of NumChildren/ChildAt for leaf nodes and the "represents exactly one + /// value" optimization. These things aren't on the ValueNode base class because, otherwise, new node types + /// deriving from ValueNode may 'forget' to implement these things. So this class allows them to remain abstract in + /// ValueNode while still having a common implementation for all the leaf nodes. + /// + public abstract class LeafValueNode : ValueNode + { + protected override int NumChildren { get { return 0; } } + protected override ValueNode ChildAt (int index) { throw new InvalidOperationException (); } + + protected override bool RepresentsExactlyOneValue { get { return true; } } + + protected override ValueNode GetSingleUniqueValue () { return this; } + + + protected override IEnumerable EvaluateUniqueValues () + { + // Leaf values should not represent more than one value. This method should be unreachable as long as + // RepresentsExactlyOneValue returns true. + throw new NotImplementedException (); + } + } + + // These are extension methods because we want to allow the use of them on null 'this' pointers. + internal static class ValueNodeExtensions + { + /// + /// Returns true if a ValueNode graph contains a cycle + /// + /// Node to evaluate + /// Set of nodes previously seen on the current arc. Callers may pass a non-empty set + /// to test whether adding that set to this node would create a cycle. Contents will be modified by the walk + /// and should not be used by the caller after returning + /// Optional. The set of all nodes encountered during a walk after DetectCycle returns + /// + public static bool DetectCycle (this ValueNode node, HashSet seenNodes, HashSet allNodesSeen) + { + if (node == null) + return false; + + if (seenNodes.Contains (node)) + return true; + + seenNodes.Add (node); + + if (allNodesSeen != null) { + allNodesSeen.Add (node); + } + + bool foundCycle = false; + switch (node.Kind) { + // + // Leaf nodes + // + case ValueNodeKind.Unknown: + case ValueNodeKind.Null: + case ValueNodeKind.SystemType: + case ValueNodeKind.RuntimeTypeHandle: + case ValueNodeKind.KnownString: + case ValueNodeKind.ConstInt: + case ValueNodeKind.MethodParameter: + break; + + // + // Nodes with children + // + case ValueNodeKind.MergePoint: + foreach (ValueNode val in ((MergePointValue)node).Values) { + if (val.DetectCycle (seenNodes, allNodesSeen)) { + foundCycle = true; + } + } + break; + + case ValueNodeKind.GetTypeFromString: + GetTypeFromStringValue gtfsv = (GetTypeFromStringValue)node; + foundCycle = gtfsv.AssemblyIdentity.DetectCycle (seenNodes, allNodesSeen); + foundCycle |= gtfsv.NameString.DetectCycle (seenNodes, allNodesSeen); + break; + + case ValueNodeKind.LoadField: + LoadFieldValue lfv = (LoadFieldValue)node; + foundCycle = lfv.InstanceValue.DetectCycle (seenNodes, allNodesSeen); + break; + + case ValueNodeKind.Array: + ArrayValue av = (ArrayValue)node; + foundCycle = av.Size.DetectCycle (seenNodes, allNodesSeen); + break; + + default: + throw new Exception (String.Format ("Unknown node kind: {0}", node.Kind)); + } + seenNodes.Remove (node); + + return foundCycle; + } + + public static ValueNode.UniqueValueCollection UniqueValues(this ValueNode node) + { + if (node == null) + return new ValueNode.UniqueValueCollection (UnknownValue.Instance); + + return node.UniqueValues; + } + + public static int? AsConstInt(this ValueNode node) + { + if (node is ConstIntValue constInt) + return constInt.Value; + return null; + } + } + + static internal class ValueNodeDump + { + internal static string ValueNodeToString (ValueNode node, params object [] args) + { + if (node == null) + return ""; + + StringBuilder sb = new StringBuilder (); + sb.Append (node.Kind.ToString ()); + sb.Append ("("); + if (args != null) { + for (int i = 0; i < args.Length; i++) { + if (i > 0) + sb.Append (","); + sb.Append (args [i] == null ? "" : args [i].ToString ()); + } + } + sb.Append (")"); + return sb.ToString (); + } + + static string GetIndent (int level) + { + StringBuilder sb = new StringBuilder (level * 2); + for (int i = 0; i < level; i++) + sb.Append (" "); + return sb.ToString (); + } + + public static void DumpTree (this ValueNode node, System.IO.TextWriter writer = null, int indentLevel = 0) + { + if (writer == null) + writer = Console.Out; + + writer.Write (GetIndent (indentLevel)); + if (node == null) { + writer.WriteLine (""); + return; + } + + writer.WriteLine (node); + foreach (ValueNode child in node.Children) { + child.DumpTree (writer, indentLevel + 1); + } + } + } + + /// + /// Represents an unknown value. + /// + class UnknownValue : LeafValueNode + { + private UnknownValue () + { + Kind = ValueNodeKind.Unknown; + } + + public static UnknownValue Instance { get; } = new UnknownValue (); + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + return true; + } + + public override int GetHashCode () + { + // All instances of UnknownValue are equivalent, so they all hash to the same hashcode. This one was + // chosen for no particular reason at all. + return 0x98052; + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this); + } + } + + class NullValue : LeafValueNode + { + private NullValue () + { + Kind = ValueNodeKind.Null; + } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + return true; + } + + public static NullValue Instance { get; } = new NullValue (); + + public override int GetHashCode () + { + // All instances of NullValue are equivalent, so they all hash to the same hashcode. This one was + // chosen for no particular reason at all. + return 0x90210; + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this); + } + } + + /// + /// This is a known System.Type value. TypeRepresented is the 'value' of the System.Type.. + /// + class SystemTypeValue : LeafValueNode + { + public SystemTypeValue (TypeDefinition typeRepresented) + { + Kind = ValueNodeKind.SystemType; + TypeRepresented = typeRepresented; + } + + public TypeDefinition TypeRepresented { get; private set; } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + return Equals(this.TypeRepresented, ((SystemTypeValue)other).TypeRepresented); + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, TypeRepresented); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, TypeRepresented); + } + } + + /// + /// This is the System.RuntimeTypeHandle equivalent to a node. + /// + class RuntimeTypeHandleValue : LeafValueNode + { + public RuntimeTypeHandleValue (TypeDefinition typeRepresented) + { + Kind = ValueNodeKind.RuntimeTypeHandle; + TypeRepresented = typeRepresented; + } + + public TypeDefinition TypeRepresented { get; private set; } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + return Equals(this.TypeRepresented, ((RuntimeTypeHandleValue)other).TypeRepresented); + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, TypeRepresented); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, TypeRepresented); + } + } + + /// + /// A known string - such as the result of a ldstr. + /// + class KnownStringValue : LeafValueNode + { + public KnownStringValue (string contents) + { + Kind = ValueNodeKind.KnownString; + Contents = contents; + } + + public string Contents { get; private set; } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + return this.Contents == ((KnownStringValue)other).Contents; + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, Contents); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, "\"" + Contents + "\""); + } + } + + /// + /// A value that came from a method parameter - such as the result of a ldarg. + /// + class MethodParameterValue : LeafValueNode + { + public MethodParameterValue (int parameterIndex) + { + Kind = ValueNodeKind.MethodParameter; + ParameterIndex = parameterIndex; + } + + public int ParameterIndex { get; } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + return this.ParameterIndex != ((MethodParameterValue)other).ParameterIndex; + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, ParameterIndex); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, ParameterIndex); + } + } + + /// + /// A merge point commonly occurs due to control flow in a method body. It represents a set of values + /// from different paths through the method. It is the reason for EvaluateUniqueValues, which essentially + /// provides an enumeration over all the concrete values represented by a given ValueNode after 'erasing' + /// the merge point nodes. + /// + class MergePointValue : ValueNode + { + private MergePointValue (ValueNode one, ValueNode two) + { + Kind = ValueNodeKind.MergePoint; + m_values = new ValueNodeHashSet (); + + if (one.Kind == ValueNodeKind.MergePoint) { + MergePointValue mpvOne = (MergePointValue)one; + foreach (ValueNode value in mpvOne.Values) + m_values.Add (value); + } else + m_values.Add (one); + + if (two.Kind == ValueNodeKind.MergePoint) { + MergePointValue mpvTwo = (MergePointValue)two; + foreach (ValueNode value in mpvTwo.Values) + m_values.Add (value); + } else + m_values.Add (two); + } + + public MergePointValue () + { + Kind = ValueNodeKind.MergePoint; + m_values = new ValueNodeHashSet (); + } + + public void AddValue (ValueNode node) + { + // we are mutating our state, so we must invalidate any cached knowledge + //InvalidateIsOpen (); + + if (node.Kind == ValueNodeKind.MergePoint) { + foreach (ValueNode value in ((MergePointValue)node).Values) + m_values.Add (value); + } else + m_values.Add (node); + +#if false + if (this.DetectCycle(new HashSet())) + { + throw new Exception("Found a cycle"); + } +#endif + } + + ValueNodeHashSet m_values; + + public ValueNodeHashSet Values { get { return m_values; } } + + protected override int NumChildren { get { return Values.Count; } } + protected override ValueNode ChildAt (int index) + { + if (index < NumChildren) + return Values.ElementAt (index); + throw new InvalidOperationException (); + } + + static public ValueNode MergeValues (ValueNode one, ValueNode two) + { + if (one == null) + return two; + else if (two == null) + return one; + else if (one.Equals (two)) + return one; + else + return new MergePointValue (one, two); + } + + protected override IEnumerable EvaluateUniqueValues () + { + foreach (ValueNode value in Values) { + foreach (ValueNode uniqueValue in value.UniqueValues) { + yield return uniqueValue; + } + } + } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + MergePointValue otherMpv = (MergePointValue)other; + if (this.Values.Count != otherMpv.Values.Count) + return false; + + foreach (ValueNode value in this.Values) { + if (!otherMpv.Values.Contains (value)) + return false; + } + return true; + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, Values); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this); + } + } + + delegate TypeDefinition TypeResolver (string assemblyString, string typeString); + + /// + /// The result of a Type.GetType. + /// AssemblyIdentity is the scope in which to resolve if the type name string is not assembly-qualified. + /// + class GetTypeFromStringValue : ValueNode + { + private readonly TypeResolver _resolver; + + public GetTypeFromStringValue (TypeResolver resolver, ValueNode assemblyIdentity, ValueNode nameString) + { + _resolver = resolver; + Kind = ValueNodeKind.GetTypeFromString; + AssemblyIdentity = assemblyIdentity; + NameString = nameString; + } + + public ValueNode AssemblyIdentity { get; private set; } + + public ValueNode NameString { get; private set; } + + protected override int NumChildren { get { return 2; } } + protected override ValueNode ChildAt (int index) + { + if (index == 0) return AssemblyIdentity; + if (index == 1) return NameString; + throw new InvalidOperationException (); + } + + protected override IEnumerable EvaluateUniqueValues () + { + HashSet names = null; + + foreach (ValueNode nameStringValue in NameString.UniqueValues) { + if (nameStringValue.Kind == ValueNodeKind.KnownString) { + if (names == null) { + names = new HashSet (); + } + + string typeName = ((KnownStringValue)nameStringValue).Contents; + names.Add (typeName); + } + } + + bool foundAtLeastOne = false; + + if (names != null) { + foreach (ValueNode assemblyValue in AssemblyIdentity.UniqueValues) { + if (assemblyValue.Kind == ValueNodeKind.KnownString) { + string assemblyName = ((KnownStringValue)assemblyValue).Contents; + + foreach (string name in names) { + TypeDefinition typeDefinition = _resolver (assemblyName, name); + if (typeDefinition != null) { + foundAtLeastOne = true; + yield return new SystemTypeValue (typeDefinition); + } + } + } + } + } + + if (!foundAtLeastOne) + yield return UnknownValue.Instance; + } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + GetTypeFromStringValue otherGtfs = (GetTypeFromStringValue)other; + + return this.AssemblyIdentity.Equals (otherGtfs.AssemblyIdentity) && + this.NameString.Equals (otherGtfs.NameString); + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, AssemblyIdentity, NameString); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, NameString); + } + } + + /// + /// A representation of a ldfld. Note that we don't have a representation of objects containing fields + /// so there isn't much that can be done with this node type yet. + /// + class LoadFieldValue : ValueNode + { + public LoadFieldValue (ValueNode instanceValue, FieldDefinition fieldToLoad) + { + Kind = ValueNodeKind.LoadField; + InstanceValue = instanceValue; + Field = fieldToLoad; + } + + public FieldDefinition Field { get; private set; } + + public ValueNode InstanceValue { get; private set; } + + protected override int NumChildren { get { return 1; } } + protected override ValueNode ChildAt (int index) + { + if (index == 0) return InstanceValue; + throw new InvalidOperationException (); + } + + protected override bool RepresentsExactlyOneValue { get { return true; } } + + protected override ValueNode GetSingleUniqueValue () { return UnknownValue.Instance; } + + protected override IEnumerable EvaluateUniqueValues () + { + // Not implemented because RepresentsExactlyOneValue returns false and, therefore, this method should be + // unreachable. + throw new NotImplementedException (); + } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + LoadFieldValue otherLfv = (LoadFieldValue)other; + if (!Equals (this.Field, otherLfv.Field)) + return false; + + return this.InstanceValue.Equals (otherLfv.InstanceValue); + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, Field, InstanceValue); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, Field); + } + } + + /// + /// Represents a ldc on an int32. + /// + class ConstIntValue : LeafValueNode + { + public ConstIntValue (int value) + { + Kind = ValueNodeKind.ConstInt; + Value = value; + } + + public int Value { get; private set; } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, Value); + } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + ConstIntValue otherCiv = (ConstIntValue)other; + return Value == otherCiv.Value; + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, Value); + } + } + + class ArrayValue : ValueNode + { + protected override int NumChildren => 1; + + /// + /// Constructs an array value of the given size + /// + public ArrayValue (ValueNode size) + { + Kind = ValueNodeKind.Array; + Size = size ?? UnknownValue.Instance; + } + + public ValueNode Size { get; } + + public override int GetHashCode () + { + return HashUtils.CalcHashCode (Kind, Size); + } + + public override bool Equals (ValueNode other) + { + if (other == null) + return false; + if (this.Kind != other.Kind) + return false; + + ArrayValue otherArr = (ArrayValue)other; + return Size.Equals (otherArr.Size); + } + + protected override string NodeToString () + { + return ValueNodeDump.ValueNodeToString (this, Size); + } + + protected override IEnumerable EvaluateUniqueValues () + { + foreach (var sizeConst in Size.UniqueValues) + yield return new ArrayValue (sizeConst); + } + + protected override ValueNode ChildAt (int index) + { + if (index == 0) return Size; + throw new InvalidOperationException (); + } + } + + #region ValueNode Collections + public class ValueNodeList : List + { + public ValueNodeList () + { + } + + public ValueNodeList (int capacity) + : base (capacity) + { + } + + public ValueNodeList (List other) + : base (other) + { + } + + public override int GetHashCode () + { + return HashUtils.CalcHashCodeEnumerable (this); + } + + public override bool Equals (object other) + { + ValueNodeList otherList = other as ValueNodeList; + if (otherList == null) + return false; + + if (otherList.Count != Count) + return false; + + for (int i = 0; i < Count; i++) { + if (!otherList [i].Equals (this [i])) + return false; + } + return true; + } + } + + class ValueNodeHashSet : HashSet + { + public override int GetHashCode () + { + return HashUtils.CalcHashCodeEnumerable (this); + } + + public override bool Equals (object other) + { + ValueNodeHashSet otherSet = other as ValueNodeHashSet; + if (otherSet == null) + return false; + + if (otherSet.Count != Count) + return false; + + IEnumerator thisEnumerator = this.GetEnumerator (); + IEnumerator otherEnumerator = otherSet.GetEnumerator (); + + for (int i = 0; i < Count; i++) { + thisEnumerator.MoveNext (); + otherEnumerator.MoveNext (); + if (!thisEnumerator.Current.Equals (otherEnumerator.Current)) + return false; + } + return true; + } + } + #endregion + + static class HashUtils + { + [MethodImpl (MethodImplOptions.AggressiveInlining)] + static int _rotl (this int value, int shift) + { + return (int)(((uint)value << shift) | ((uint)value >> (32 - shift))); + } + + public static int CalcHashCodeEnumerable (IEnumerable list) where T : class + { + int length = list.Count (); + + int hash1 = 0x449b3ad6; + int hash2 = (length << 3) + 0x55399219; + + int index = 0; + + T element1 = null; + T element2 = null; + + foreach (T element in list) { + if ((index++ & 1) == 0) { + element1 = element; + continue; + } + element2 = element; + + hash1 = (hash1 + _rotl (hash1, 5)) ^ element1.GetHashCode(); + hash2 = (hash2 + _rotl (hash2, 5)) ^ element2.GetHashCode(); + } + + // If we had an odd length, we haven't included the last element yet. + if ((length & 1) != 0) + hash1 = (hash1 + _rotl (hash1, 5)) ^ element1.GetHashCode(); + + hash1 += _rotl (hash1, 8); + hash2 += _rotl (hash2, 8); + + return hash1 ^ hash2; + } + + public static int CalcHashCode (ValueNodeKind kind, T1 obj1) + where T1 : class + { + return CalcHashCode (kind, obj1.GetHashCode()); + } + + public static int CalcHashCode (ValueNodeKind kind, int val1) + { + return CalcHashCode (kind.GetHashCode (), val1); + } + + public static int CalcHashCode (int hashCode1, int hashCode2) + { + int length = 2; + + int hash1 = 0x449b3ad6; + int hash2 = (length << 3) + 0x55399219; + + hash1 = (hash1 + _rotl (hash1, 5)) ^ hashCode1; + hash2 = (hash2 + _rotl (hash2, 5)) ^ hashCode2; + + hash1 += _rotl (hash1, 8); + hash2 += _rotl (hash2, 8); + + return hash1 ^ hash2; + } + + public static int CalcHashCode (ValueNodeKind kind, T1 obj1, T2 obj2) + where T1 : class + where T2 : class + + { + return CalcHashCode (kind, obj1.GetHashCode(), obj2.GetHashCode()); + } + + static int CalcHashCode (ValueNodeKind kind, int hashCode1, int hashCode2) + { + int length = 3; + + int hash1 = 0x449b3ad6; + int hash2 = (length << 3) + 0x55399219; + + hash1 = (hash1 + _rotl (hash1, 5)) ^ kind.GetHashCode (); + hash2 = (hash2 + _rotl (hash2, 5)) ^ hashCode1; + hash1 = (hash1 + _rotl (hash1, 5)) ^ hashCode2; + + hash1 += _rotl (hash1, 8); + hash2 += _rotl (hash2, 8); + + return hash1 ^ hash2; + } + } +} diff --git a/src/linker/Linker.Steps/MarkStep.cs b/src/linker/Linker.Steps/MarkStep.cs index 3112fed3df0d..4f34ee2cbd25 100644 --- a/src/linker/Linker.Steps/MarkStep.cs +++ b/src/linker/Linker.Steps/MarkStep.cs @@ -36,6 +36,7 @@ using Mono.Cecil; using Mono.Cecil.Cil; using Mono.Collections.Generic; +using Mono.Linker.Dataflow; namespace Mono.Linker.Steps { @@ -2366,6 +2367,9 @@ protected virtual void MarkReflectionLikeDependencies (MethodBody body) if (HasManuallyTrackedDependency (body)) return; + var scanner = new ReflectionMethodBodyScanner (this); + scanner.Scan (body); + var instructions = body.Instructions; ReflectionPatternDetector detector = new ReflectionPatternDetector (this, body.Method); @@ -2439,6 +2443,14 @@ public void AnalyzingPattern () #endif } + [Conditional ("DEBUG")] + public void RecordHandledPattern () + { +#if DEBUG + _patternReported = true; +#endif + } + public void RecordRecognizedPattern (T accessedItem, Action mark) where T : IMemberDefinition { @@ -3354,6 +3366,176 @@ public AttributeProviderPair (CustomAttribute attribute, ICustomAttributeProvide public CustomAttribute Attribute { get; private set; } public ICustomAttributeProvider Provider { get; private set; } } + + private class ReflectionMethodBodyScanner : Dataflow.MethodBodyScanner + { + private readonly MarkStep _markStep; + + public ReflectionMethodBodyScanner(MarkStep parent) + { + _markStep = parent; + } + + protected override void WarnAboutInvalidILInMethod (MethodBody method, int ilOffset) + { + // TODO: remove once we're ready to scan actual invalid IL + // Serves as a debug helper for now to make sure valid IL is not considered invalid. + throw new Exception (); + } + + public override bool HandleCall (MethodBody callingMethodBody, MethodReference calledMethod, Instruction operation, ValueNodeList methodParams, out ValueNode methodReturnValue) + { + var reflectionContext = new ReflectionPatternContext (_markStep._context, callingMethodBody.Method, calledMethod.Resolve (), operation.Offset); + + try { + + methodReturnValue = null; + + switch (calledMethod.Name) { + case "GetTypeFromHandle" when calledMethod.DeclaringType.Name == "Type": { + // Infrastructure piece to support "typeof(Foo)" + var typeHnd = methodParams[0] as RuntimeTypeHandleValue; + if (typeHnd != null) + methodReturnValue = new SystemTypeValue (typeHnd.TypeRepresented); + } + break; + + case "MakeGenericType" when calledMethod.DeclaringType.Name == "Type": { + // Don't care about the actual arguments, but we don't want to lose track of the type + // in case this is e.g. Activator.CreateInstance(typeof(Foo<>).MakeGenericType(...)); + methodReturnValue = methodParams [0]; + } + break; + + // + // static CreateInstance (System.Type type) + // static CreateInstance (System.Type type, bool nonPublic) + // static CreateInstance (System.Type type, params object?[]? args) + // static CreateInstance (System.Type type, object?[]? args, object?[]? activationAttributes) + // static CreateInstance (System.Type type, System.Reflection.BindingFlags bindingAttr, System.Reflection.Binder? binder, object?[]? args, System.Globalization.CultureInfo? culture) + // static CreateInstance (System.Type type, System.Reflection.BindingFlags bindingAttr, System.Reflection.Binder? binder, object?[]? args, System.Globalization.CultureInfo? culture, object?[]? activationAttributes) { throw null; } + // + case "CreateInstance" when !calledMethod.ContainsGenericParameter + && calledMethod.DeclaringType.Name == "Activator" + && calledMethod.Parameters.Count >= 1 + && calledMethod.Parameters[0].ParameterType.MetadataType != MetadataType.String: { + + var parameters = calledMethod.Parameters; + + reflectionContext.AnalyzingPattern (); + + int? ctorParameterCount = null; + BindingFlags bindingFlags = BindingFlags.Instance; + if (parameters.Count > 1) { + if (parameters [1].ParameterType.MetadataType == MetadataType.Boolean) { + // The overload that takes a "nonPublic" bool + bool nonPublic = true; + if (methodParams [1] is ConstIntValue constInt) { + nonPublic = constInt.Value != 0; + } + + if (nonPublic) + bindingFlags |= BindingFlags.NonPublic | BindingFlags.Public; + else + bindingFlags |= BindingFlags.Public; + ctorParameterCount = 0; + } else { + // Overload that has the parameters as the second or fourth argument + int argsParam = parameters.Count == 2 || parameters.Count == 3 ? 1 : 3; + + if (methodParams.Count > argsParam && + methodParams [argsParam] is ArrayValue arrayValue && + arrayValue.Size.AsConstInt () != null) { + ctorParameterCount = arrayValue.Size.AsConstInt (); + } + + if (parameters.Count > 3) { + if (methodParams [1].AsConstInt () != null) + bindingFlags |= (BindingFlags)methodParams [1].AsConstInt (); + else + bindingFlags |= BindingFlags.NonPublic | BindingFlags.Public; + } else { + bindingFlags |= BindingFlags.Public; + } + } + } + else { + // The overload with a single System.Type argument + ctorParameterCount = 0; + bindingFlags |= BindingFlags.Public; + } + + // Go over all types we've seen + foreach (var value in methodParams[0].UniqueValues ()) { + if (value is SystemTypeValue systemTypeValue) { + MarkMethodsFromReflectionCall (ref reflectionContext, systemTypeValue.TypeRepresented, ".ctor", bindingFlags, ctorParameterCount); + } else if (value == NullValue.Instance) { + // Nothing to report. This is likely just a value on some unreachable branch. + reflectionContext.RecordHandledPattern (); + } else if (value is MethodParameterValue methodParameterValue) { + // This is the case where the value comes from a method parameter. + // TODO: If the parameter is annotated, we're good. If it's not annotated, we shold warn. + reflectionContext.RecordUnrecognizedPattern ($"Activator call '{calledMethod.FullName}' inside '{callingMethodBody.Method.FullName}' was detected with 1st argument expression which cannot be analyzed"); + } else { + // Not known where the value is coming from + reflectionContext.RecordUnrecognizedPattern ($"Activator call '{calledMethod.FullName}' inside '{callingMethodBody.Method.FullName}' was detected with 1st argument expression which cannot be analyzed"); + } + } + } + break; + default: + return false; + } + } + finally { + reflectionContext.Dispose (); + } + + // If we get here, we handled this as an intrinsic. As a convenience, if the code above + // didn't set the return value (and the method has a return value), we will set it to be an + // unknown value with the return type of the method. + if (methodReturnValue == null) { + if (calledMethod.ReturnType.MetadataType != MetadataType.Void) { + methodReturnValue = UnknownValue.Instance; + } + } + + return true; + } + + void MarkMethodsFromReflectionCall (ref ReflectionPatternContext reflectionContext, TypeDefinition declaringType, string name, BindingFlags? bindingFlags, int? parametersCount = null) + { + bool foundMatch = false; + foreach (var method in declaringType.Methods) { + var mname = method.Name; + + if (mname != name) { + continue; + } + + if ((bindingFlags & (BindingFlags.Instance | BindingFlags.Static)) == BindingFlags.Static && !method.IsStatic) + continue; + + if ((bindingFlags & (BindingFlags.Instance | BindingFlags.Static)) == BindingFlags.Instance && method.IsStatic) + continue; + + if ((bindingFlags & (BindingFlags.Public | BindingFlags.NonPublic)) == BindingFlags.Public && !method.IsPublic) + continue; + + if ((bindingFlags & (BindingFlags.Public | BindingFlags.NonPublic)) == BindingFlags.NonPublic && method.IsPublic) + continue; + + if (parametersCount != null && parametersCount != method.Parameters.Count) + continue; + + foundMatch = true; + reflectionContext.RecordRecognizedPattern (method, () => _markStep.MarkIndirectlyCalledMethod (method)); + } + + if (!foundMatch) + reflectionContext.RecordUnrecognizedPattern ($"Reflection call '{reflectionContext.MethodCalled.FullName}' inside '{reflectionContext.MethodCalling.FullName}' could not resolve method `{name}` on type `{declaringType.FullName}`."); + } + } } // Make our own copy of the BindingFlags enum, so that we don't depend on System.Reflection.