From a4e0947be8b11fe3197f30e1699ec824d1641f6d Mon Sep 17 00:00:00 2001 From: kerams Date: Sun, 26 Jul 2020 00:48:59 +0200 Subject: [PATCH] Optimize reflection of F# types, part 2 (#9784) * Compile PreComputeRecordConstructor, PreComputeRecordReader, PreComputeRecordFieldReader * Compile PreComputeUnionConstructor, PreComputeUnionTagReader, PreComputeUnionReader --- src/fsharp/FSharp.Core/reflect.fs | 134 ++++++++++++++++-- .../FSharpReflection.fs | 45 ++++++ 2 files changed, 166 insertions(+), 13 deletions(-) diff --git a/src/fsharp/FSharp.Core/reflect.fs b/src/fsharp/FSharp.Core/reflect.fs index 56bcc317d0c..aad98b162f2 100644 --- a/src/fsharp/FSharp.Core/reflect.fs +++ b/src/fsharp/FSharp.Core/reflect.fs @@ -64,6 +64,9 @@ module internal Impl = | null -> None | prop -> Some(fun (obj: obj) -> prop.GetValue (obj, instancePropertyFlags ||| bindingFlags, null, null, null)) + //----------------------------------------------------------------- + // EXPRESSION TREE COMPILATION + let compilePropGetterFunc (prop: PropertyInfo) = let param = Expression.Parameter (typeof, "param") @@ -77,6 +80,84 @@ module internal Impl = param) expr.Compile () + let compileRecordOrUnionCaseReaderFunc (typ, props: PropertyInfo[]) = + let param = Expression.Parameter (typeof, "param") + let typedParam = Expression.Variable typ + + let expr = + Expression.Lambda> ( + Expression.Block ( + [ typedParam ], + Expression.Assign (typedParam, Expression.Convert (param, typ)), + Expression.NewArrayInit (typeof, [ + for prop in props -> + Expression.Convert (Expression.Property (typedParam, prop), typeof) :> Expression + ]) + ), + param) + expr.Compile () + + let compileRecordConstructorFunc (ctorInfo: ConstructorInfo) = + let ctorParams = ctorInfo.GetParameters () + let paramArray = Expression.Parameter (typeof, "paramArray") + + let expr = + Expression.Lambda> ( + Expression.Convert ( + Expression.New ( + ctorInfo, + [ + for paramIndex in 0 .. ctorParams.Length - 1 do + let p = ctorParams.[paramIndex] + + Expression.Convert ( + Expression.ArrayAccess (paramArray, Expression.Constant paramIndex), + p.ParameterType + ) :> Expression + ] + ), + typeof), + paramArray + ) + expr.Compile () + + let compileUnionCaseConstructorFunc (methodInfo: MethodInfo) = + let methodParams = methodInfo.GetParameters () + let paramArray = Expression.Parameter (typeof, "param") + + let expr = + Expression.Lambda> ( + Expression.Convert ( + Expression.Call ( + methodInfo, + [ + for paramIndex in 0 .. methodParams.Length - 1 do + let p = methodParams.[paramIndex] + + Expression.Convert ( + Expression.ArrayAccess (paramArray, Expression.Constant paramIndex), + p.ParameterType + ) :> Expression + ] + ), + typeof), + paramArray + ) + expr.Compile () + + let compileUnionTagReaderFunc (info: Choice) = + let param = Expression.Parameter (typeof, "param") + let tag = + match info with + | Choice1Of2 info -> Expression.Call (info, Expression.Convert (param, info.DeclaringType)) :> Expression + | Choice2Of2 info -> Expression.Property (Expression.Convert (param, info.DeclaringType), info) :> _ + + let expr = + Expression.Lambda> ( + tag, + param) + expr.Compile () + //----------------------------------------------------------------- // ATTRIBUTE DECOMPILATION @@ -275,6 +356,12 @@ module internal Impl = let props = fieldsPropsOfUnionCase (typ, tag, bindingFlags) (fun (obj: obj) -> props |> Array.map (fun prop -> prop.GetValue (obj, bindingFlags, null, null, null))) + let getUnionCaseRecordReaderCompiled (typ: Type, tag: int, bindingFlags) = + let props = fieldsPropsOfUnionCase (typ, tag, bindingFlags) + let caseTyp = getUnionCaseTyp (typ, tag, bindingFlags) + let caseTyp = if isNull caseTyp then typ else caseTyp + compileRecordOrUnionCaseReaderFunc(caseTyp, props).Invoke + let getUnionTagReader (typ: Type, bindingFlags) : (obj -> int) = if isOptionType typ then (fun (obj: obj) -> match obj with null -> 0 | _ -> 1) @@ -286,9 +373,22 @@ module internal Impl = match getInstancePropertyReader (typ, "Tag", bindingFlags) with | Some reader -> (fun (obj: obj) -> reader obj :?> int) | None -> - (fun (obj: obj) -> - let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null) - m2b.Invoke(null, [|obj|]) :?> int) + let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null) + (fun (obj: obj) -> m2b.Invoke(null, [|obj|]) :?> int) + + let getUnionTagReaderCompiled (typ: Type, bindingFlags) : (obj -> int) = + if isOptionType typ then + (fun (obj: obj) -> match obj with null -> 0 | _ -> 1) + else + let tagMap = getUnionTypeTagNameMap (typ, bindingFlags) + if tagMap.Length <= 1 then + (fun (_obj: obj) -> 0) + else + match getInstancePropertyInfo (typ, "Tag", bindingFlags) with + | null -> + let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null) + compileUnionTagReaderFunc(Choice1Of2 m2b).Invoke + | info -> compileUnionTagReaderFunc(Choice2Of2 info).Invoke let getUnionTagMemberInfo (typ: Type, bindingFlags) = match getInstancePropertyInfo (typ, "Tag", bindingFlags) with @@ -314,6 +414,10 @@ module internal Impl = (fun args -> meth.Invoke(null, BindingFlags.Static ||| BindingFlags.InvokeMethod ||| bindingFlags, null, args, null)) + let getUnionCaseConstructorCompiled (typ: Type, tag: int, bindingFlags) = + let meth = getUnionCaseConstructorMethod (typ, tag, bindingFlags) + compileUnionCaseConstructorFunc(meth).Invoke + let checkUnionType (unionType, bindingFlags) = checkNonNull "unionType" unionType if not (isUnionType (unionType, bindingFlags)) then @@ -599,9 +703,9 @@ module internal Impl = let props = fieldPropsOfRecordType(typ, bindingFlags) (fun (obj: obj) -> props |> Array.map (fun prop -> prop.GetValue (obj, null))) - let getRecordReaderFromFuncs(typ: Type, bindingFlags) = - let props = fieldPropsOfRecordType(typ, bindingFlags) |> Array.map compilePropGetterFunc - (fun (obj: obj) -> props |> Array.map (fun prop -> prop.Invoke obj)) + let getRecordReaderCompiled(typ: Type, bindingFlags) = + let props = fieldPropsOfRecordType(typ, bindingFlags) + compileRecordOrUnionCaseReaderFunc(typ, props).Invoke let getRecordConstructorMethod(typ: Type, bindingFlags) = let props = fieldPropsOfRecordType(typ, bindingFlags) @@ -616,6 +720,10 @@ module internal Impl = (fun (args: obj[]) -> ctor.Invoke(BindingFlags.InvokeMethod ||| BindingFlags.Instance ||| bindingFlags, null, args, null)) + let getRecordConstructorCompiled(typ: Type, bindingFlags) = + let ctor = getRecordConstructorMethod(typ, bindingFlags) + compileRecordConstructorFunc(ctor).Invoke + /// EXCEPTION DECOMPILATION // Check the base type - if it is also an F# type then // for the moment we know it is a Discriminated Union @@ -817,19 +925,19 @@ type FSharpValue = invalidArg "record" (SR.GetString (SR.objIsNotARecord)) getRecordReader (typ, bindingFlags) record - static member PreComputeRecordFieldReader(info: PropertyInfo) = + static member PreComputeRecordFieldReader(info: PropertyInfo): obj -> obj = checkNonNull "info" info - (fun (obj: obj) -> info.GetValue (obj, null)) + compilePropGetterFunc(info).Invoke static member PreComputeRecordReader(recordType: Type, ?bindingFlags) : (obj -> obj[]) = let bindingFlags = defaultArg bindingFlags BindingFlags.Public checkRecordType ("recordType", recordType, bindingFlags) - getRecordReaderFromFuncs (recordType, bindingFlags) + getRecordReaderCompiled (recordType, bindingFlags) static member PreComputeRecordConstructor(recordType: Type, ?bindingFlags) = let bindingFlags = defaultArg bindingFlags BindingFlags.Public checkRecordType ("recordType", recordType, bindingFlags) - getRecordConstructor (recordType, bindingFlags) + getRecordConstructorCompiled (recordType, bindingFlags) static member PreComputeRecordConstructorInfo(recordType: Type, ?bindingFlags) = let bindingFlags = defaultArg bindingFlags BindingFlags.Public @@ -894,7 +1002,7 @@ type FSharpValue = static member PreComputeUnionConstructor (unionCase: UnionCaseInfo, ?bindingFlags) = let bindingFlags = defaultArg bindingFlags BindingFlags.Public checkNonNull "unionCase" unionCase - getUnionCaseConstructor (unionCase.DeclaringType, unionCase.Tag, bindingFlags) + getUnionCaseConstructorCompiled (unionCase.DeclaringType, unionCase.Tag, bindingFlags) static member PreComputeUnionConstructorInfo(unionCase: UnionCaseInfo, ?bindingFlags) = let bindingFlags = defaultArg bindingFlags BindingFlags.Public @@ -926,7 +1034,7 @@ type FSharpValue = checkNonNull "unionType" unionType let unionType = getTypeOfReprType (unionType, bindingFlags) checkUnionType (unionType, bindingFlags) - getUnionTagReader (unionType, bindingFlags) + getUnionTagReaderCompiled (unionType, bindingFlags) static member PreComputeUnionTagMemberInfo(unionType: Type, ?bindingFlags) = let bindingFlags = defaultArg bindingFlags BindingFlags.Public @@ -939,7 +1047,7 @@ type FSharpValue = let bindingFlags = defaultArg bindingFlags BindingFlags.Public checkNonNull "unionCase" unionCase let typ = unionCase.DeclaringType - getUnionCaseRecordReader (typ, unionCase.Tag, bindingFlags) + getUnionCaseRecordReaderCompiled (typ, unionCase.Tag, bindingFlags) static member GetExceptionFields (exn: obj, ?bindingFlags) = let bindingFlags = defaultArg bindingFlags BindingFlags.Public diff --git a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Reflection/FSharpReflection.fs b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Reflection/FSharpReflection.fs index 6bd49f8db4b..a68724f97f2 100644 --- a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Reflection/FSharpReflection.fs +++ b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Reflection/FSharpReflection.fs @@ -103,6 +103,14 @@ type FSharpValueTests() = let discStructUnionCaseB = DiscStructUnionType.B(1) let discStructUnionCaseC = DiscStructUnionType.C(1.0, "stringparam") + let optionSome = Some(3) + let optionNone: int option = None + + let voptionSome = ValueSome("stringparam") + let voptionNone: string voption = ValueNone + + let list1 = [ 1; 2 ] + let list2: int list = [] let fsharpDelegate1 = new FSharpDelegate(fun (x:int) -> "delegate1") let fsharpDelegate2 = new FSharpDelegate(fun (x:int) -> "delegate2") @@ -738,6 +746,24 @@ type FSharpValueTests() = let (discUnionInfo, discvaluearray) = FSharpValue.GetUnionFields(discUnionRecCaseB, typeof>) let discUnionReader = FSharpValue.PreComputeUnionReader(discUnionInfo) Assert.AreEqual(discUnionReader(box(discUnionRecCaseB)) , [| box 1; box(Some(discUnionCaseB)) |]) + + // Option + let (optionCaseInfo, _) = FSharpValue.GetUnionFields(optionSome, typeof) + let optionReader = FSharpValue.PreComputeUnionReader(optionCaseInfo) + Assert.AreEqual(optionReader(box(optionSome)), [| box 3 |]) + + let (optionCaseInfo, _) = FSharpValue.GetUnionFields(optionNone, typeof) + let optionReader = FSharpValue.PreComputeUnionReader(optionCaseInfo) + Assert.AreEqual(optionReader(box(optionNone)), [| |]) + + // List + let (listCaseInfo, _) = FSharpValue.GetUnionFields(list1, typeof) + let listReader = FSharpValue.PreComputeUnionReader(listCaseInfo) + Assert.AreEqual(listReader(box(list1)), [| box 1; box [ 2 ] |]) + + let (listCaseInfo, _) = FSharpValue.GetUnionFields(list2, typeof) + let listReader = FSharpValue.PreComputeUnionReader(listCaseInfo) + Assert.AreEqual(listReader(box(list2)), [| |]) [] member __.PreComputeStructUnionReader() = @@ -751,6 +777,15 @@ type FSharpValueTests() = let (discUnionInfo, discvaluearray) = FSharpValue.GetUnionFields(discStructUnionCaseB, typeof>) let discUnionReader = FSharpValue.PreComputeUnionReader(discUnionInfo) Assert.AreEqual(discUnionReader(box(discStructUnionCaseB)) , [| box 1|]) + + // Value Option + let (voptionCaseInfo, _) = FSharpValue.GetUnionFields(voptionSome, typeof) + let voptionReader = FSharpValue.PreComputeUnionReader(voptionCaseInfo) + Assert.AreEqual(voptionReader(box(voptionSome)), [| box "stringparam" |]) + + let (voptionCaseInfo, _) = FSharpValue.GetUnionFields(voptionNone, typeof) + let voptionReader = FSharpValue.PreComputeUnionReader(voptionCaseInfo) + Assert.AreEqual(voptionReader(box(voptionNone)), [| |]) [] member __.PreComputeUnionTagMemberInfo() = @@ -790,6 +825,16 @@ type FSharpValueTests() = // DiscUnion let discUnionTagReader = FSharpValue.PreComputeUnionTagReader(typeof>) Assert.AreEqual(discUnionTagReader(box(discUnionCaseB)), 1) + + // Option + let optionTagReader = FSharpValue.PreComputeUnionTagReader(typeof) + Assert.AreEqual(optionTagReader(box(optionSome)), 1) + Assert.AreEqual(optionTagReader(box(optionNone)), 0) + + // Value Option + let voptionTagReader = FSharpValue.PreComputeUnionTagReader(typeof) + Assert.AreEqual(voptionTagReader(box(voptionSome)), 1) + Assert.AreEqual(voptionTagReader(box(voptionNone)), 0) // null value CheckThrowsArgumentException(fun () ->FSharpValue.PreComputeUnionTagReader(null)|> ignore)