diff --git a/src/Microsoft.ML.Core/Utilities/Hashing.cs b/src/Microsoft.ML.Core/Utilities/Hashing.cs index 4438293317..f551c6ccac 100644 --- a/src/Microsoft.ML.Core/Utilities/Hashing.cs +++ b/src/Microsoft.ML.Core/Utilities/Hashing.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers.Binary; using System.Runtime.CompilerServices; using System.Text; using Microsoft.ML.Runtime; @@ -174,6 +175,87 @@ public static uint MurmurHash(uint hash, ReadOnlySpan span, bool toUpper = return hash; } + public static uint MurmurHashV2(uint hash, ReadOnlySpan span, bool toUpper = false) + { + // Byte length (in pseudo UTF-8 form). + int len = 0; + + // Current bits, value and count. + ulong cur = 0; + int bits = 0; + for (int ich = 0; ich < span.Length; ich++) + { + Contracts.Assert((bits & 0x7) == 0); + Contracts.Assert((uint)bits <= 24); + Contracts.Assert(cur <= 0x00FFFFFF); + + uint ch = toUpper ? char.ToUpperInvariant(span[ich]) : span[ich]; + if (ch <= 0x007F) + { + cur |= ch << bits; + bits += 8; + } + else if (ch <= 0x07FF) + { + cur |= (ulong)((ch & 0x003F) | ((ch << 2) & 0x1F00) | 0xC080) << bits; + cur = (cur & 0xFF) << 8 | cur >> 8; + bits += 16; + } + else if (ch <= 0xFFFF) + { + cur |= (ulong)((ch & 0x003F) | ((ch << 2) & 0x3F00) | ((ch << 4) & 0x0F0000) | 0xE08080) << bits; + cur = (cur & 0xFF) << 16 | ((cur >> 8) & 0xFF) << 8 | cur >> 16; + bits += 24; + } + else + { + Contracts.Assert(ch <= 0x10FFFF); + cur |= (ulong)((ch & 0x003F) | ((ch << 2) & 0x3F00) | ((ch << 4) & 0x3F0000) | ((ch << 6) & 0x07000000) | 0xF0808080) << bits; + cur = (cur & 0xFF) << 24 | ((cur >> 8) & 0xFF) << 16 | ((cur >> 16) & 0xFF) << 8 | cur >> 24; + bits += 32; + } + + if (bits >= 32) + { + hash = MurmurRound(hash, (uint)cur); + cur = cur >> 32; + bits -= 32; + len += 4; + } + } + Contracts.Assert((bits & 0x7) == 0); + Contracts.Assert((uint)bits <= 24); + Contracts.Assert(cur <= 0x00FFFFFF); + + if (bits > 0) + { + len += bits / 8; + } + + // tail processing + uint k1 = 0; + switch (len & 3) + { + case 3: + k1 ^= (uint)(((cur >> 16) & 0xFF) << 16); + goto case 2; + case 2: + k1 ^= (uint)((cur >> 8) & 0xFF) << 8; + goto case 1; + case 1: + k1 ^= (uint)(cur & 0xFF); + k1 *= 0xCC9E2D51; k1 = Rotate(k1, 15); + k1 *= 0x1B873593; + hash ^= k1; + break; + } + + // Final mixing ritual for the hash. + hash = MixHashV2(hash, len); + + return hash; + } + /// /// Implements the murmur hash 3 algorithm, using a mock UTF-8 encoding. /// The UTF-8 conversion ignores the possibilities of unicode planes other than the 0th. @@ -284,6 +366,18 @@ public static uint MixHash(uint hash) return hash; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint MixHashV2(uint hash, int len) + { + hash ^= (uint)len; + hash ^= hash >> 16; + hash *= 0x85ebca6b; + hash ^= hash >> 13; + hash *= 0xc2b2ae35; + hash ^= hash >> 16; + return hash; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static uint Rotate(uint x, int r) { diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 6f8251a51d..9ac10aacfc 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -11,6 +11,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -122,8 +123,9 @@ private static VersionInfo GetVersionInfo() return new VersionInfo( modelSignature: "HASHTRNS", // verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Invert hash key values, hash fix - verReadableCur: 0x00010002, + //verWrittenCur: 0x00010002, // Invert hash key values, hash fix + verWrittenCur: 0x00010003, + verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(HashingTransformer).Assembly.FullName); @@ -245,9 +247,15 @@ private Delegate GetGetterCore(DataViewRow input, int iinfo, out Action disposer disposer = null; input.Schema.TryGetColumnIndex(_columns[iinfo].InputColumnName, out int srcCol); var srcType = input.Schema[srcCol].Type; - if (!(srcType is VectorDataViewType vectorType)) - return ComposeGetterOne(input, iinfo, srcCol, srcType); - return ComposeGetterVec(input, iinfo, srcCol, vectorType); + if (GetVersionInfo().VerWrittenCur == 0x00010002) + { + if (!(srcType is VectorDataViewType vectorType)) + return ComposeGetterOne(input, iinfo, srcCol, srcType); + return ComposeGetterVec(input, iinfo, srcCol, vectorType); + } + if (!(srcType is VectorDataViewType vectorType2)) + return ComposeGetterOneV2(input, iinfo, srcCol, srcType); + return ComposeGetterVecV2(input, iinfo, srcCol, vectorType2); } private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); @@ -378,6 +386,58 @@ private ValueGetter ComposeGetterOne(DataViewRow input, int iinfo, int src return MakeScalarHashGetter(input, srcCol, seed, mask); } + private ValueGetter ComposeGetterOneV2(DataViewRow input, int iinfo, int srcCol, DataViewType srcType) + { + Host.Assert(HashingEstimator.IsColumnTypeValid(srcType)); + + var mask = (1U << _columns[iinfo].NumberOfBits) - 1; + uint seed = _columns[iinfo].Seed; + // In case of single valued input column, hash in 0 for the slot index. + if (_columns[iinfo].UseOrderedHashing) + seed = Hashing.MurmurRound(seed, 0); + + if (srcType is KeyDataViewType) + { + if (srcType.RawType == typeof(uint)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ulong)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ushort)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + + Host.Assert(srcType.RawType == typeof(byte)); + return MakeScalarHashGetter(input, srcCol, seed, mask); + } + + if (srcType.RawType == typeof(ReadOnlyMemory)) + return MakeScalarHashGetter, HashTextV2>(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(float)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(double)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(sbyte)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(short)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(int)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(long)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(byte)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ushort)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(uint)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(ulong)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + else if (srcType.RawType == typeof(DataViewRowId)) + return MakeScalarHashGetter(input, srcCol, seed, mask); + + Host.Assert(srcType.RawType == typeof(bool)); + return MakeScalarHashGetter(input, srcCol, seed, mask); + } + private ValueGetter> ComposeGetterVec(DataViewRow input, int iinfo, int srcCol, VectorDataViewType srcType) { Host.Assert(HashingEstimator.IsColumnTypeValid(srcType.ItemType)); @@ -425,6 +485,53 @@ private ValueGetter> ComposeGetterVec(DataViewRow input, int iinfo return ComposeGetterVecCore, HashText>(input, iinfo, srcCol, srcType); } + private ValueGetter> ComposeGetterVecV2(DataViewRow input, int iinfo, int srcCol, VectorDataViewType srcType) + { + Host.Assert(HashingEstimator.IsColumnTypeValid(srcType.ItemType)); + + Type rawType = srcType.ItemType.RawType; + if (srcType.ItemType is KeyDataViewType) + { + if (rawType == typeof(byte)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(ushort)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(uint)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + + Host.Assert(rawType == typeof(ulong)); + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + } + + if (rawType == typeof(byte)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(ushort)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(uint)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(ulong)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(DataViewRowId)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(sbyte)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(short)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(int)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(long)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(float)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(double)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + else if (rawType == typeof(bool)) + return ComposeGetterVecCore(input, iinfo, srcCol, srcType); + + Host.Assert(srcType.ItemType == TextDataViewType.Instance); + return ComposeGetterVecCore, HashTextV2>(input, iinfo, srcCol, srcType); + } + private ValueGetter> ComposeGetterVecCore(DataViewRow input, int iinfo, int srcCol, VectorDataViewType srcType) where THash : struct, IHasher { @@ -472,6 +579,13 @@ public uint HashCore(uint seed, uint mask, in float value) => float.IsNaN(value) ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, FloatUtils.GetBits(value == 0 ? 0 : value))) & mask) + 1; } + private readonly struct HashFloatV2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in float value) + => float.IsNaN(value) ? 0 : (Hashing.MixHashV2(Hashing.MurmurRound(seed, FloatUtils.GetBits(value == 0 ? 0 : value)), 4) & mask); + } + private readonly struct HashDouble : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -490,6 +604,23 @@ public uint HashCore(uint seed, uint mask, in double value) } } + private readonly struct HashDoubleV2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in double value) + { + if (double.IsNaN(value)) + return 0; + + ulong v = FloatUtils.GetBits(value == 0 ? 0 : value); + var hash = Hashing.MurmurRound(seed, Utils.GetLo(v)); + var hi = Utils.GetHi(v); + if (hi != 0) + hash = Hashing.MurmurRound(hash, hi); + return (Hashing.MixHashV2(hash, 4) & mask); + } + } + private readonly struct HashText : IHasher> { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -497,6 +628,13 @@ public uint HashCore(uint seed, uint mask, in ReadOnlyMemory value) => value.IsEmpty ? 0 : (Hashing.MurmurHash(seed, value.Span.Trim(' ')) & mask) + 1; } + private readonly struct HashTextV2 : IHasher> + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in ReadOnlyMemory value) + => value.IsEmpty ? 0 : (Hashing.MurmurHashV2(seed, value.Span)); + } + private readonly struct HashKey1 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -504,6 +642,13 @@ public uint HashCore(uint seed, uint mask, in byte value) => value == 0 ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, value)) & mask) + 1; } + private readonly struct HashKey1V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in byte value) + => value == 0 ? 0 : (Hashing.MixHashV2(Hashing.MurmurRound(seed, value), 4) & mask) + 1; + } + private readonly struct HashKey2 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -511,6 +656,13 @@ public uint HashCore(uint seed, uint mask, in ushort value) => value == 0 ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, value)) & mask) + 1; } + private readonly struct HashKey2V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in ushort value) + => value == 0 ? 0 : (Hashing.MixHashV2(Hashing.MurmurRound(seed, value), 4) & mask) + 1; + } + private readonly struct HashKey4 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -518,6 +670,13 @@ public uint HashCore(uint seed, uint mask, in uint value) => value == 0 ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, value)) & mask) + 1; } + private readonly struct HashKey4V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in uint value) + => value == 0 ? 0 : (Hashing.MixHashV2(Hashing.MurmurRound(seed, value), 4) & mask) + 1; + } + private readonly struct HashKey8 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -533,6 +692,21 @@ public uint HashCore(uint seed, uint mask, in ulong value) } } + private readonly struct HashKey8V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in ulong value) + { + if (value == 0) + return 0; + var hash = Hashing.MurmurRound(seed, Utils.GetLo(value)); + var hi = Utils.GetHi(value); + if (hi != 0) + hash = Hashing.MurmurRound(hash, hi); + return (Hashing.MixHashV2(hash, 4) & mask) + 1; + } + } + private readonly struct HashU1 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -540,6 +714,13 @@ public uint HashCore(uint seed, uint mask, in byte value) => (Hashing.MixHash(Hashing.MurmurRound(seed, value)) & mask) + 1; } + private readonly struct HashU1V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in byte value) + => (Hashing.MixHashV2(Hashing.MurmurRound(seed, value), 4) & mask); + } + private readonly struct HashU2 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -547,6 +728,13 @@ public uint HashCore(uint seed, uint mask, in ushort value) => (Hashing.MixHash(Hashing.MurmurRound(seed, value)) & mask) + 1; } + private readonly struct HashU2V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in ushort value) + => (Hashing.MixHashV2(Hashing.MurmurRound(seed, value), 4) & mask); + } + private readonly struct HashU4 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -554,6 +742,13 @@ public uint HashCore(uint seed, uint mask, in uint value) => (Hashing.MixHash(Hashing.MurmurRound(seed, value)) & mask) + 1; } + private readonly struct HashU4V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in uint value) + => (Hashing.MixHashV2(Hashing.MurmurRound(seed, value), 4)) & mask; + } + private readonly struct HashU8 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -567,6 +762,19 @@ public uint HashCore(uint seed, uint mask, in ulong value) } } + private readonly struct HashU8V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in ulong value) + { + var hash = Hashing.MurmurRound(seed, Utils.GetLo(value)); + var hi = Utils.GetHi(value); + if (hi != 0) + hash = Hashing.MurmurRound(hash, hi); + return (Hashing.MixHashV2(hash, 4) & mask) + 1; + } + } + private readonly struct HashU16 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -594,6 +802,13 @@ public uint HashCore(uint seed, uint mask, in bool value) => (Hashing.MixHash(Hashing.MurmurRound(seed, value ? 1u : 0u)) & mask) + 1; } + private readonly struct HashBoolV2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in bool value) + => (Hashing.MixHashV2(Hashing.MurmurRound(seed, value ? 1u : 0u), 4) & mask); + } + private readonly struct HashI1 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -601,6 +816,13 @@ public uint HashCore(uint seed, uint mask, in sbyte value) => (Hashing.MixHash(Hashing.MurmurRound(seed, (uint)value)) & mask) + 1; } + private readonly struct HashI1V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in sbyte value) + => (Hashing.MixHashV2(Hashing.MurmurRound(seed, (uint)value), 4) & mask); + } + private readonly struct HashI2 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -608,6 +830,13 @@ public uint HashCore(uint seed, uint mask, in short value) => (Hashing.MixHash(Hashing.MurmurRound(seed, (uint)value)) & mask) + 1; } + private readonly struct HashI2V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in short value) + => (Hashing.MixHashV2(Hashing.MurmurRound(seed, (uint)value), 4) & mask); + } + private readonly struct HashI4 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -615,6 +844,13 @@ public uint HashCore(uint seed, uint mask, in int value) => (Hashing.MixHash(Hashing.MurmurRound(seed, (uint)value)) & mask) + 1; } + private readonly struct HashI4V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in int value) + => (Hashing.MixHashV2(Hashing.MurmurRound(seed, (uint)value), 4) & mask); + } + private readonly struct HashI8 : IHasher { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -628,6 +864,19 @@ public uint HashCore(uint seed, uint mask, in long value) } } + private readonly struct HashI8V2 : IHasher + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint HashCore(uint seed, uint mask, in long value) + { + var hash = Hashing.MurmurRound(seed, Utils.GetLo((ulong)value)); + var hi = Utils.GetHi((ulong)value); + if (hi != 0) + hash = Hashing.MurmurRound(hash, hi); + return (Hashing.MixHashV2(hash, 4) & mask); + } + } + private static ValueGetter MakeScalarHashGetter(DataViewRow input, int srcCol, uint seed, uint mask) where THash : struct, IHasher { @@ -780,7 +1029,7 @@ private static ValueGetter> MakeVectorOrderedHashGetter( }; } - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private sealed class ColInfo { @@ -834,6 +1083,66 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder) } protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable) + { + string opType; + OnnxNode murmurNode; + + opType = "MurmurHash3"; + if (_types[iinfo].RawType == typeof(KeyDataViewType)) + { + string murmurOutput = ctx.AddIntermediateVariable(_types[iinfo], "MurmurOutput", true); + murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft"); + + opType = "Cast"; + string castOutput = ctx.AddIntermediateVariable(_types[iinfo], "CastOutput", true); + var castNode = ctx.CreateNode(opType, murmurOutput, castOutput, ctx.GetNodeName(opType), ""); + var t = NumberDataViewType.Int64.RawType; + castNode.AddAttribute("to", t); + + opType = "Add"; + string addOutput = ctx.AddIntermediateVariable(_types[iinfo], "AddOutput", true); + string one = ctx.AddInitializer(1); + var addNode = ctx.CreateNode(opType, new[] { castOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), ""); + + opType = "Cast"; + var castNodeFinal = ctx.CreateNode(opType, addOutput, dstVariable, ctx.GetNodeName(opType), ""); + var tFinal = NumberDataViewType.UInt32.RawType; + castNodeFinal.AddAttribute("to", tFinal); + } + else + { + murmurNode = ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), "com.microsoft"); + } + + murmurNode.AddAttribute("positive", 1); + var seed = _parent._columns[iinfo].Seed; + murmurNode.AddAttribute("seed", seed); + return true; + } + + void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + for (int iinfo = 0; iinfo < _parent._columns.Length; ++iinfo) + { + var colName = _parent._columns[iinfo].Name; + string inputColumnName = InputSchema[colName].Name; + if (!ctx.ContainsColumn(inputColumnName)) + { + ctx.RemoveColumn(inputColumnName, false); + continue; + } + + if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(_types[iinfo], inputColumnName))) + { + ctx.RemoveColumn(inputColumnName, true); + } + } + } + + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; } private abstract class InvertHashHelper diff --git a/test/Microsoft.ML.Benchmarks/Text/MulticlassHashClassification.cs b/test/Microsoft.ML.Benchmarks/Text/MulticlassHashClassification.cs new file mode 100644 index 0000000000..044eaa8cf6 --- /dev/null +++ b/test/Microsoft.ML.Benchmarks/Text/MulticlassHashClassification.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using BenchmarkDotNet.Attributes; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.LightGbm; +using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; +using Microsoft.ML.TestFrameworkCommon; + +namespace Microsoft.ML.Benchmarks +{ + [Config(typeof(TrainConfig))] + public class MulticlassHashClassificationTrain + { + private string _dataPath_Wiki; + + [GlobalSetup] + public void SetupTrainingSpeedTests() + { + _dataPath_Wiki = BaseTestClass.GetDataPath(TestDatasets.WikiDetox.trainFilename); + + if (!File.Exists(_dataPath_Wiki)) + throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _dataPath_Wiki)); + } + + [Benchmark] + public void CV_Multiclass_WikiDetox_BigramsAndTrichar_LightGBMMulticlass() + { + string cmd = @"CV k=5 data=" + _dataPath_Wiki + + " loader=TextLoader{quote=- sparse=- col=Label:R4:0 col=rev_id:TX:1 col=comment:TX:2 col=logged_in:BL:4 col=ns:TX:5 col=sample:TX:6 col=split:TX:7 col=year:R4:3 header=+}" + + " xf=Convert{col=logged_in type=R4}" + + " xf=CategoricalTransform{col=ns}" + + " xf=TextTransform{col=FeaturesText:comment wordExtractor=NGramExtractorTransform{ngram=2}}" + + " xf=Concat{col=Features:FeaturesText,logged_in,ns}" + + " tr=LightGBMMulticlass{iter=10}"; + + var environment = EnvironmentFactory.CreateClassificationEnvironment(); + cmd.ExecuteMamlCommand(environment); + } + } +} diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 12688e1392..a891dc3dc3 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1087,6 +1087,98 @@ public void NgramOnnxConnversionTest( Done(); } + private class HashData + { + public ReadOnlyMemory Education { get; set; } + } + + [Fact] + public void MurmurHashStringTest() + { + var mlContext = new MLContext(); + + var samples = new[] + { + new HashData {Education = "alibaba".AsMemory()}, + new HashData {Education = "baba".AsMemory()}, + new HashData {Education = "U+123".AsMemory()}, + new HashData {Education = "djldaoiejffjauhglehdlgh".AsMemory()}, + new HashData {Education = "~".AsMemory()}, + }; + + IDataView data = mlContext.Data.LoadFromEnumerable(samples); + + var hashEstimator = new HashingEstimator(Env, "Education"); + var model = hashEstimator.Fit(data); + var hashTransformedData = model.Transform(data); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data); + + var onnxFileName = "MurmurHashV2.onnx"; + var onnxTextName = "MurmurHashV2.txt"; + var onnxModelPath = GetOutputPath(onnxFileName); + var onnxTextPath = GetOutputPath(onnxTextName); + + SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(data); + var onnxResult = onnxTransformer.Transform(data); + CompareSelectedScalarColumns("Education", outputNames[0], hashTransformedData, onnxResult); + } + Done(); + } + + private class HashNumData + { + public uint Education { get; set; } + } + + [Fact] + public void MurmurHashUIntTest() + { + var mlContext = new MLContext(); + + var samples = new[] + { + new HashNumData {Education = 12}, + new HashNumData {Education = 456}, + new HashNumData {Education = 2}, + new HashNumData {Education = 34556789}, + new HashNumData {Education = 7896}, + }; + + IDataView data = mlContext.Data.LoadFromEnumerable(samples); + + var hashEstimator = new HashingEstimator(Env, "Education"); + var model = hashEstimator.Fit(data); + var hashTransformedData = model.Transform(data); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data); + + var onnxFileName = "MurmurHashV2.onnx"; + var onnxTextName = "MurmurHashV2.txt"; + var onnxModelPath = GetOutputPath(onnxFileName); + var onnxTextPath = GetOutputPath(onnxTextName); + + SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(data); + var onnxResult = onnxTransformer.Transform(data); + CompareSelectedScalarColumns("Education", outputNames[0], hashTransformedData, onnxResult); + } + Done(); + } + [Fact] public void OptionalColumnOnnxTest() { @@ -1284,6 +1376,7 @@ private void CompareResults(string leftColumnName, string rightColumnName, IData CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); } + private void CompareSelectedVectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) { var leftColumn = left.Schema[leftColumnName];