diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 6bd61f5d82..5bf272b50d 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1359,7 +1359,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri OnnxNode murmurNode; OnnxNode isZeroNode; - var srcType = _srcTypes[iinfo].GetItemType(); + var srcType = _srcTypes[iinfo].GetItemType().RawType; if (_parent._columns[iinfo].Combine) return false; @@ -1386,9 +1386,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri } // Since these numeric types are not supported by Onnxruntime, we cast them to UInt32. - if (srcType == NumberDataViewType.UInt16 || srcType == NumberDataViewType.Int16 || - srcType == NumberDataViewType.SByte || srcType == NumberDataViewType.Byte || - srcType == BooleanDataViewType.Instance) + if (srcType == typeof(ushort) || srcType == typeof(short) || + srcType == typeof(sbyte) || srcType == typeof(byte) || + srcType == typeof(bool)) { castOutput = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, "CastOutput", true); castNode = ctx.CreateNode("Cast", srcVariable, castOutput, ctx.GetNodeName(opType), ""); diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index 530e3193f3..9c5912f52e 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -389,13 +389,16 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : if (vectorType != null && vectorType.Size == 0) throw Host.Except($"Variable length input columns not supported"); - if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType()) + var itemType = type.GetItemType(); + var nodeItemType = inputNodeInfo.DataViewType.GetItemType(); + if (itemType != nodeItemType) { // If the ONNX model input node expects a type that mismatches with the type of the input IDataView column that is provided // then throw an exception. // This is done except in the case where the ONNX model input node expects a UInt32 but the input column is actually KeyDataViewType // This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426 - if (!(type.GetItemType() is KeyDataViewType && inputNodeInfo.DataViewType.GetItemType().RawType == typeof(UInt32))) + var isKeyType = itemType is KeyDataViewType; + if (!isKeyType || itemType.RawType != nodeItemType.RawType) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 6b5d032dcb..8d90cd81e7 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1202,7 +1202,7 @@ private class HashData [Theory] [CombinatorialData] public void MurmurHashKeyTest( - [CombinatorialValues(/*DataKind.Byte, DataKind.UInt16, */DataKind.UInt32/*, DataKind.UInt64*/)]DataKind keyType) + [CombinatorialValues(DataKind.Byte, DataKind.UInt16, DataKind.UInt32, DataKind.UInt64)]DataKind keyType) { var dataFile = DeleteOutputPath("KeysToOnnx.txt"); File.WriteAllLines(dataFile,