From ad16cc6d0c2f8f913ad7ca5d0407d5c9f5c0fd2c Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Sun, 24 May 2020 13:42:48 -0700 Subject: [PATCH 1/3] Fixed onnx export for key types other than uint --- src/Microsoft.ML.Data/Transforms/Hashing.cs | 2 +- src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs | 2 +- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 6bd61f5d82..981ae7b302 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1388,7 +1388,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri // Since these numeric types are not supported by Onnxruntime, we cast them to UInt32. if (srcType == NumberDataViewType.UInt16 || srcType == NumberDataViewType.Int16 || srcType == NumberDataViewType.SByte || srcType == NumberDataViewType.Byte || - srcType == BooleanDataViewType.Instance) + srcType == BooleanDataViewType.Instance || srcType is KeyDataViewType) { castOutput = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, "CastOutput", true); castNode = ctx.CreateNode("Cast", srcVariable, castOutput, ctx.GetNodeName(opType), ""); diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index 530e3193f3..dc7cceabf6 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -395,7 +395,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : // then throw an exception. // This is done except in the case where the ONNX model input node expects a UInt32 but the input column is actually KeyDataViewType // This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426 - if (!(type.GetItemType() is KeyDataViewType && inputNodeInfo.DataViewType.GetItemType().RawType == typeof(UInt32))) + if (!(type.GetItemType() is KeyDataViewType && type.GetItemType().RawType == inputNodeInfo.DataViewType.GetItemType().RawType)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 6b5d032dcb..8d90cd81e7 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1202,7 +1202,7 @@ private class HashData [Theory] [CombinatorialData] public void MurmurHashKeyTest( - [CombinatorialValues(/*DataKind.Byte, DataKind.UInt16, */DataKind.UInt32/*, DataKind.UInt64*/)]DataKind keyType) + [CombinatorialValues(DataKind.Byte, DataKind.UInt16, DataKind.UInt32, DataKind.UInt64)]DataKind keyType) { var dataFile = DeleteOutputPath("KeysToOnnx.txt"); File.WriteAllLines(dataFile, From 4bf1a5df45c12087b5fb1da83cbb9f2f65719a1a Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Mon, 25 May 2020 22:51:46 -0700 Subject: [PATCH 2/3] Fixed casting logic for hashing to be based on raw type in order to support key types correctly --- src/Microsoft.ML.Data/Transforms/Hashing.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 981ae7b302..5bf272b50d 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1359,7 +1359,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri OnnxNode murmurNode; OnnxNode isZeroNode; - var srcType = _srcTypes[iinfo].GetItemType(); + var srcType = _srcTypes[iinfo].GetItemType().RawType; if (_parent._columns[iinfo].Combine) return false; @@ -1386,9 +1386,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri } // Since these numeric types are not supported by Onnxruntime, we cast them to UInt32. - if (srcType == NumberDataViewType.UInt16 || srcType == NumberDataViewType.Int16 || - srcType == NumberDataViewType.SByte || srcType == NumberDataViewType.Byte || - srcType == BooleanDataViewType.Instance || srcType is KeyDataViewType) + if (srcType == typeof(ushort) || srcType == typeof(short) || + srcType == typeof(sbyte) || srcType == typeof(byte) || + srcType == typeof(bool)) { castOutput = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, "CastOutput", true); castNode = ctx.CreateNode("Cast", srcVariable, castOutput, ctx.GetNodeName(opType), ""); From 417b2d02fd9001577b523a428232848809354744 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Mon, 25 May 2020 23:15:53 -0700 Subject: [PATCH 3/3] Addressed code review comments --- src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index dc7cceabf6..9c5912f52e 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -389,13 +389,16 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : if (vectorType != null && vectorType.Size == 0) throw Host.Except($"Variable length input columns not supported"); - if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType()) + var itemType = type.GetItemType(); + var nodeItemType = inputNodeInfo.DataViewType.GetItemType(); + if (itemType != nodeItemType) { // If the ONNX model input node expects a type that mismatches with the type of the input IDataView column that is provided // then throw an exception. // This is done except in the case where the ONNX model input node expects a UInt32 but the input column is actually KeyDataViewType // This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426 - if (!(type.GetItemType() is KeyDataViewType && type.GetItemType().RawType == inputNodeInfo.DataViewType.GetItemType().RawType)) + var isKeyType = itemType is KeyDataViewType; + if (!isKeyType || itemType.RawType != nodeItemType.RawType) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); }