From 034c81c01cc5a746b51f24f51f6fae9afa078b0e Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 11 Apr 2019 10:15:43 -0700 Subject: [PATCH 01/24] Make type registration framework Register image type --- .../Data/SchemaDefinition.cs | 55 +++++++++-------- .../DataView/InternalSchemaDefinition.cs | 2 +- src/Microsoft.ML.Data/DataView/TypedCursor.cs | 4 ++ src/Microsoft.ML.Data/Utils/ApiUtils.cs | 3 +- src/Microsoft.ML.DataView/TypeManager.cs | 59 +++++++++++++++++++ src/Microsoft.ML.ImageAnalytics/ImageType.cs | 6 ++ test/Microsoft.ML.Tests/ImagesTests.cs | 47 +++++++++++++++ 7 files changed, 149 insertions(+), 27 deletions(-) create mode 100644 src/Microsoft.ML.DataView/TypeManager.cs diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index e08960ffaf..1147ee4106 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -392,37 +392,42 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType); - PrimitiveDataViewType itemType; - var keyAttr = memberInfo.GetCustomAttribute(); - if (keyAttr != null) - { - if (!KeyDataViewType.IsValidDataType(dataType)) - throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name); - if (keyAttr.KeyCount == null) - itemType = new KeyDataViewType(dataType, dataType.ToMaxInt()); - else - itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault()); - } - else - itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType); - // Get the column type. DataViewType columnType; - var vectorAttr = memberInfo.GetCustomAttribute(); - if (vectorAttr != null && !isVector) - throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name); - if (isVector) + if (TypeManager.GetDataViewTypeOrNull(dataType) == null) { - int[] dims = vectorAttr?.Dims; - if (dims != null && dims.Any(d => d < 0)) - throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative"); - if (Utils.Size(dims) == 0) - columnType = new VectorDataViewType(itemType, 0); + PrimitiveDataViewType itemType; + var keyAttr = memberInfo.GetCustomAttribute(); + if (keyAttr != null) + { + if (!KeyDataViewType.IsValidDataType(dataType)) + throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name); + if (keyAttr.KeyCount == null) + itemType = new KeyDataViewType(dataType, dataType.ToMaxInt()); + else + itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault()); + } + else + itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType); + + var vectorAttr = memberInfo.GetCustomAttribute(); + if (vectorAttr != null && !isVector) + throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name); + if (isVector) + { + int[] dims = vectorAttr?.Dims; + if (dims != null && dims.Any(d => d < 0)) + throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative"); + if (Utils.Size(dims) == 0) + columnType = new VectorDataViewType(itemType, 0); + else + columnType = new VectorDataViewType(itemType, dims); + } else - columnType = new VectorDataViewType(itemType, dims); + columnType = itemType; } else - columnType = itemType; + columnType = TypeManager.GetDataViewTypeOrNull(dataType); cols.Add(new Column(memberInfo.Name, columnType, name)); } diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index 14e97cf100..bcc06a415c 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -187,7 +187,7 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe if (itemType == typeof(string)) itemType = typeof(ReadOnlyMemory); - else if (!itemType.TryGetDataKind(out _)) + else if (!itemType.TryGetDataKind(out _) && TypeManager.GetDataViewTypeOrNull(itemType) == null) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } diff --git a/src/Microsoft.ML.Data/DataView/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs index b5ec10fd2a..1de4bba4c2 100644 --- a/src/Microsoft.ML.Data/DataView/TypedCursor.cs +++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs @@ -319,6 +319,10 @@ private Action GenerateSetter(DataViewRow input, int index, InternalSchema del = CreateDirectSetter; } + else if (TypeManager.GetRawTypeOrNull(colType) != null) + { + del = CreateDirectSetter; + } else { // REVIEW: Is this even possible? diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index 6704cbda1e..b3e47b999f 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -23,7 +23,8 @@ private static OpCode GetAssignmentOpCode(Type t) if (t == typeof(ReadOnlyMemory) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || - t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || t == typeof(DataViewRowId)) + t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || + t == typeof(DataViewRowId) || TypeManager.GetDataViewTypeOrNull(t) != null) { return OpCodes.Stobj; } diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs new file mode 100644 index 0000000000..248518b4a1 --- /dev/null +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using Microsoft.ML.Internal.DataView; + +namespace Microsoft.ML.Data +{ + public static class TypeManager + { + // Types have been used in ML.NET type systems. They can have multiple-to-one type mapping. + // For example, UInt32 and Key can be mapped to uint. We enforce one-to-one mapping for all + // user-registered types. + private static HashSet _notAllowedRawTypes; + private static ConcurrentDictionary _rawTypeToDataViewTypeMap; + private static ConcurrentDictionary _dataViewTypeToRawTypeMap; + + /// + /// Constructor to initialize type mappings. + /// + static TypeManager() + { + _notAllowedRawTypes = new HashSet() { + typeof(Boolean), typeof(SByte), typeof(Byte), + typeof(Int16), typeof(UInt16), typeof(Int32), typeof(UInt32), + typeof(Int64), typeof(UInt64), typeof(string), typeof(ReadOnlySpan) + }; + _rawTypeToDataViewTypeMap = new ConcurrentDictionary(); + _dataViewTypeToRawTypeMap = new ConcurrentDictionary(); + } + + public static DataViewType GetDataViewTypeOrNull(Type type) + { + if (_rawTypeToDataViewTypeMap.ContainsKey(type)) + return _rawTypeToDataViewTypeMap[type]; + else + return null; + } + + public static Type GetRawTypeOrNull(DataViewType type) + { + if (_dataViewTypeToRawTypeMap.ContainsKey(type)) + return _dataViewTypeToRawTypeMap[type]; + else + return null; + } + + public static void Register(Type rawType, DataViewType dataViewType) + { + if (_notAllowedRawTypes.Contains(rawType)) + throw Contracts.ExceptParam(nameof(rawType), $"Type {rawType} has been registered as ML.NET's default type. " + + $"so it can't not be registered again."); + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType)) + throw Contracts.ExceptParam(nameof(rawType), $"Repeated type registration. The raw type {rawType} " + + $"has been associated with {_rawTypeToDataViewTypeMap[rawType]}."); + _rawTypeToDataViewTypeMap[rawType] = dataViewType; + _dataViewTypeToRawTypeMap[dataViewType] = rawType; + } + } +} diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 3b50354177..fb66facf21 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -13,6 +13,12 @@ public sealed class ImageDataViewType : StructuredDataViewType { public readonly int Height; public readonly int Width; + + static ImageDataViewType() + { + TypeManager.Register(typeof(Bitmap), new ImageDataViewType()); + } + public ImageDataViewType(int height, int width) : base(typeof(Bitmap)) { diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index ef1abf93a9..afb43acf3e 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -184,6 +184,53 @@ public void TestGreyscaleTransformImages() Done(); } + [Fact] + public void TestGrayScaleInMemory() + { + var imagesDataFile = SamplesUtils.DatasetUtils.DownloadImages(); + + var data = ML.Data.CreateTextLoader(new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("ImagePath", DataKind.String, 0), + new TextLoader.Column("Name", DataKind.String, 1), + } + }).Load(imagesDataFile); + + var imagesFolder = Path.GetDirectoryName(imagesDataFile); + // Image loading and conversion pipeline. + var pipeline = ML.Transforms.LoadImages("ImageObject", imagesFolder, "ImagePath") + .Append(ML.Transforms.ConvertToGrayscale("Grayscale", "ImageObject")); + + var transformedData = pipeline.Fit(data).Transform(data); + + var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, true).ToList(); + + foreach (var datapoint in transformedDataPoints) + { + var image = datapoint.Grayscale; + Assert.NotNull(image); + for (int x = 0; x < image.Width; x++) + { + for (int y = 0; y < image.Height; y++) + { + var pixel = image.GetPixel(x, y); + // greyscale image has same values for R, G and B. + Assert.True(pixel.R == pixel.G && pixel.G == pixel.B); + } + } + } + } + + private class TransformedImageDataPoint + { + public string ImagePath { get; set; } + public string Name { get; set; } + public Bitmap ImageObject { get; set; } + public Bitmap Grayscale { get; set; } + } + [Fact] public void TestBackAndForthConversionWithAlphaInterleave() { From 8aface60208fc1b92486380e4839443c16d95477 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 16 May 2019 16:05:35 -0700 Subject: [PATCH 02/24] Add Bitmap to IDataView --- .../DataView/DataViewConstructionUtils.cs | 4 +++ test/Microsoft.ML.Tests/ImagesTests.cs | 31 +++++++++++++++++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 7fde918a09..549099d226 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -269,6 +269,10 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType); } } + else if (TypeManager.GetRawTypeOrNull(colType) != null) + { + del = CreateDirectGetterDelegate; + } else { // REVIEW: Is this even possible? diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index afb43acf3e..3dbb850cb3 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -203,17 +203,17 @@ public void TestGrayScaleInMemory() var pipeline = ML.Transforms.LoadImages("ImageObject", imagesFolder, "ImagePath") .Append(ML.Transforms.ConvertToGrayscale("Grayscale", "ImageObject")); + // Test path: image files -> IDataView -> Enumerable of Bitmaps. var transformedData = pipeline.Fit(data).Transform(data); - var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, true).ToList(); foreach (var datapoint in transformedDataPoints) { var image = datapoint.Grayscale; Assert.NotNull(image); - for (int x = 0; x < image.Width; x++) + for (int x = 0; x < image.Width; ++x) { - for (int y = 0; y < image.Height; y++) + for (int y = 0; y < image.Height; ++y) { var pixel = image.GetPixel(x, y); // greyscale image has same values for R, G and B. @@ -221,6 +221,31 @@ public void TestGrayScaleInMemory() } } } + + // Test path: Enumerable of Bitmaps -> IDataView -> Enumerable of Bitmaps. + var imagesInDataView = ML.Data.LoadFromEnumerable(transformedDataPoints); + var imagesObtainedFromDataView = ML.Data.CreateEnumerable(imagesInDataView, true).ToList(); + + for (int i = 0; i < transformedDataPoints.Count; ++i) + { + var expectedImage = transformedDataPoints[i].Grayscale; + var obtainedImage = imagesObtainedFromDataView[i].Grayscale; + + Assert.Equal(expectedImage.Width, obtainedImage.Width); + Assert.Equal(expectedImage.Height, obtainedImage.Height); + for (int x = 0; x < expectedImage.Width; ++x) + { + for (int y = 0; y < expectedImage.Height; ++y) + { + var expectedPixel = expectedImage.GetPixel(x, y); + var obtainedPixel = obtainedImage.GetPixel(x, y); + + Assert.Equal(expectedPixel.R, obtainedPixel.R); + Assert.Equal(expectedPixel.G, obtainedPixel.G); + Assert.Equal(expectedPixel.B, obtainedPixel.B); + } + } + } } private class TransformedImageDataPoint From e915cf86c8f42178f239cb912d13d9bb98156207 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 17 May 2019 16:56:47 -0700 Subject: [PATCH 03/24] Address some comments --- .../Data/SchemaDefinition.cs | 4 +- .../DataView/DataViewConstructionUtils.cs | 2 +- .../DataView/InternalSchemaDefinition.cs | 2 +- src/Microsoft.ML.Data/DataView/TypedCursor.cs | 2 +- src/Microsoft.ML.Data/Utils/ApiUtils.cs | 2 +- src/Microsoft.ML.DataView/DataViewType.cs | 2 +- src/Microsoft.ML.DataView/TypeManager.cs | 148 ++++++++++++++---- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 2 +- 8 files changed, 122 insertions(+), 42 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index 1147ee4106..d327315ae0 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -394,7 +394,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc // Get the column type. DataViewType columnType; - if (TypeManager.GetDataViewTypeOrNull(dataType) == null) + if (!DataViewTypeManager.Knows(dataType)) { PrimitiveDataViewType itemType; var keyAttr = memberInfo.GetCustomAttribute(); @@ -427,7 +427,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc columnType = itemType; } else - columnType = TypeManager.GetDataViewTypeOrNull(dataType); + columnType = DataViewTypeManager.GetDataViewType(dataType); cols.Add(new Column(memberInfo.Name, columnType, name)); } diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 549099d226..8e5e9ba578 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -269,7 +269,7 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType); } } - else if (TypeManager.GetRawTypeOrNull(colType) != null) + else if (DataViewTypeManager.Knows(colType)) { del = CreateDirectGetterDelegate; } diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index bcc06a415c..22347880b7 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -187,7 +187,7 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe if (itemType == typeof(string)) itemType = typeof(ReadOnlyMemory); - else if (!itemType.TryGetDataKind(out _) && TypeManager.GetDataViewTypeOrNull(itemType) == null) + else if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType)) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } diff --git a/src/Microsoft.ML.Data/DataView/TypedCursor.cs b/src/Microsoft.ML.Data/DataView/TypedCursor.cs index 1de4bba4c2..f7b79feb21 100644 --- a/src/Microsoft.ML.Data/DataView/TypedCursor.cs +++ b/src/Microsoft.ML.Data/DataView/TypedCursor.cs @@ -319,7 +319,7 @@ private Action GenerateSetter(DataViewRow input, int index, InternalSchema del = CreateDirectSetter; } - else if (TypeManager.GetRawTypeOrNull(colType) != null) + else if (DataViewTypeManager.Knows(colType)) { del = CreateDirectSetter; } diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index b3e47b999f..67f79b93ba 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -24,7 +24,7 @@ private static OpCode GetAssignmentOpCode(Type t) (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || - t == typeof(DataViewRowId) || TypeManager.GetDataViewTypeOrNull(t) != null) + t == typeof(DataViewRowId) || DataViewTypeManager.Knows(t)) { return OpCodes.Stobj; } diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index 153ba02261..643ccbec22 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -21,7 +21,7 @@ public abstract class DataViewType : IEquatable /// /// Constructor for extension types, which must be either or . /// - private protected DataViewType(Type rawType) + protected DataViewType(Type rawType) { RawType = rawType ?? throw new ArgumentNullException(nameof(rawType)); } diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs index 248518b4a1..2aa06633c3 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -1,59 +1,139 @@ -using System; -using System.Collections.Concurrent; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; +using System.Threading; using Microsoft.ML.Internal.DataView; namespace Microsoft.ML.Data { - public static class TypeManager + /// + /// A singleton class for managing the map between ML.NET and C# . + /// To support custom column type in , the column's underlying type (e.g., a C# class's type) + /// should be registered with a class derived from . + /// + public static class DataViewTypeManager { // Types have been used in ML.NET type systems. They can have multiple-to-one type mapping. - // For example, UInt32 and Key can be mapped to uint. We enforce one-to-one mapping for all + // For example, UInt32 and Key can be mapped to uint. This class enforces one-to-one mapping for all // user-registered types. - private static HashSet _notAllowedRawTypes; - private static ConcurrentDictionary _rawTypeToDataViewTypeMap; - private static ConcurrentDictionary _dataViewTypeToRawTypeMap; + private static HashSet _bannedRawTypes = new HashSet() + { + typeof(Boolean), typeof(SByte), typeof(Byte), + typeof(Int16), typeof(UInt16), typeof(Int32), typeof(UInt32), + typeof(Int64), typeof(UInt64), typeof(Single), typeof(Double), + typeof(string), typeof(ReadOnlySpan) + }; + + private static Dictionary _rawTypeToDataViewTypeMap = new Dictionary(); + private static Dictionary _dataViewTypeToRawTypeMap = new Dictionary(); + private static SpinLock _lock = new SpinLock(); /// - /// Constructor to initialize type mappings. + /// Returns the registered for . /// - static TypeManager() + public static DataViewType GetDataViewType(Type type) { - _notAllowedRawTypes = new HashSet() { - typeof(Boolean), typeof(SByte), typeof(Byte), - typeof(Int16), typeof(UInt16), typeof(Int32), typeof(UInt32), - typeof(Int64), typeof(UInt64), typeof(string), typeof(ReadOnlySpan) - }; - _rawTypeToDataViewTypeMap = new ConcurrentDictionary(); - _dataViewTypeToRawTypeMap = new ConcurrentDictionary(); + bool ownLock = false; + DataViewType dataViewType = null; + try + { + _lock.Enter(ref ownLock); + if (!_rawTypeToDataViewTypeMap.ContainsKey(type)) + throw Contracts.ExceptParam(nameof(type), $"The raw type {type} is not registered with a DataView type."); + dataViewType = _rawTypeToDataViewTypeMap[type]; + } + finally + { + if (ownLock) _lock.Exit(); + } + return dataViewType; } - public static DataViewType GetDataViewTypeOrNull(Type type) + /// + /// If has been registered with a , this function returns . + /// Otherwise, this function returns . + /// + public static bool Knows(Type type) { - if (_rawTypeToDataViewTypeMap.ContainsKey(type)) - return _rawTypeToDataViewTypeMap[type]; - else - return null; + bool ownLock = false; + bool answer = false; + try + { + _lock.Enter(ref ownLock); + if (_rawTypeToDataViewTypeMap.ContainsKey(type)) + answer = true; + } + finally + { + if (ownLock) _lock.Exit(); + } + return answer; } - public static Type GetRawTypeOrNull(DataViewType type) + /// + /// If has been registered with a , this function returns . + /// Otherwise, this function returns . + /// + public static bool Knows(DataViewType type) { - if (_dataViewTypeToRawTypeMap.ContainsKey(type)) - return _dataViewTypeToRawTypeMap[type]; - else - return null; + bool ownLock = false; + bool answer = false; + try + { + _lock.Enter(ref ownLock); + if (_dataViewTypeToRawTypeMap.ContainsKey(type)) + answer = true; + } + finally + { + if (ownLock) _lock.Exit(); + } + return answer; } + /// + /// This function tells that should be representation of data in in + /// ML.NET's type system. The registered must be a standard C# object's type. + /// + /// Native type in C#. + /// The corresponding type of in ML.NET's type system. public static void Register(Type rawType, DataViewType dataViewType) { - if (_notAllowedRawTypes.Contains(rawType)) - throw Contracts.ExceptParam(nameof(rawType), $"Type {rawType} has been registered as ML.NET's default type. " + - $"so it can't not be registered again."); - if (_rawTypeToDataViewTypeMap.ContainsKey(rawType)) - throw Contracts.ExceptParam(nameof(rawType), $"Repeated type registration. The raw type {rawType} " + - $"has been associated with {_rawTypeToDataViewTypeMap[rawType]}."); - _rawTypeToDataViewTypeMap[rawType] = dataViewType; - _dataViewTypeToRawTypeMap[dataViewType] = rawType; + bool ownLock = false; + + try + { + _lock.Enter(ref ownLock); + + if (_bannedRawTypes.Contains(rawType)) + throw Contracts.ExceptParam(nameof(rawType), $"Type {rawType} has been registered as ML.NET's default supported type, " + + $"so it can't not be registered again."); + + // Registering the same pair of (rawType, dataViewType) multiple times is ok. However, a raw type can be associated + // with only one DataView type. + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType) && _rawTypeToDataViewTypeMap[rawType] != dataViewType) + throw Contracts.ExceptParam(nameof(rawType), $"Repeated type register. The raw type {rawType} " + + $"has been associated with {_rawTypeToDataViewTypeMap[rawType]} so it cannot be associated with {dataViewType}."); + + // Registering the same pair of (rawType, dataViewType) multiple times is ok. However, a DataView type can be associated + // with only one raw type. + if (_dataViewTypeToRawTypeMap.ContainsKey(dataViewType) && _dataViewTypeToRawTypeMap[dataViewType] != rawType) + throw Contracts.ExceptParam(nameof(dataViewType), $"Repeated type register. The DataView type {dataViewType} " + + $"has been associated with {_dataViewTypeToRawTypeMap[dataViewType]} so it cannot be associated with {rawType}."); + + if (!_rawTypeToDataViewTypeMap.ContainsKey(rawType)) + _rawTypeToDataViewTypeMap.Add(rawType, dataViewType); + + if (!_dataViewTypeToRawTypeMap.ContainsKey(dataViewType)) + _dataViewTypeToRawTypeMap.Add(dataViewType, rawType); + } + finally + { + if (ownLock) _lock.Exit(); + } } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index fb66facf21..ec02a3da29 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -16,7 +16,7 @@ public sealed class ImageDataViewType : StructuredDataViewType static ImageDataViewType() { - TypeManager.Register(typeof(Bitmap), new ImageDataViewType()); + DataViewTypeManager.Register(typeof(Bitmap), new ImageDataViewType()); } public ImageDataViewType(int height, int width) From 01f69da3dee70d4ad723589e947427402cfce01e Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 17 May 2019 18:05:30 -0700 Subject: [PATCH 04/24] One more test Add missing file --- src/Microsoft.ML.DataView/TypeManager.cs | 4 +- .../UnitTests/TestCustomTypeRegister.cs | 180 ++++++++++++++++++ 2 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs index 2aa06633c3..23ed1dcc22 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -41,9 +41,9 @@ public static DataViewType GetDataViewType(Type type) try { _lock.Enter(ref ownLock); - if (!_rawTypeToDataViewTypeMap.ContainsKey(type)) + + if (!_rawTypeToDataViewTypeMap.TryGetValue(type, out dataViewType)) throw Contracts.ExceptParam(nameof(type), $"The raw type {type} is not registered with a DataView type."); - dataViewType = _rawTypeToDataViewTypeMap[type]; } finally { diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs new file mode 100644 index 0000000000..41b0bfe830 --- /dev/null +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -0,0 +1,180 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.RunTests +{ + public class TestCustomTypeRegister : TestDataViewBase + { + public TestCustomTypeRegister(ITestOutputHelper helper) + : base(helper) + { + } + + /// + /// A custom type which ML.NET doesn't know yet. Its value will be loaded as a DataView column in this test. + /// + private class Body + { + public int Age { get; set; } + public float Height { get; set; } + public float Weight { get; set; } + + /// + /// Type register should happen before the creation of the first . Otherwise, ML.NET might not recognize + /// that is typed to in ML.NET's internal type system. + /// + static Body() + { + DataViewTypeManager.Register(typeof(Body), DataViewBodyType.Instance); + } + + public Body() + { + Age = 0; + Height = 0; + Weight = 0; + } + + public Body(int age, float height, float weight) + { + Age = age; + Height = height; + Weight = weight; + } + } + + /// + /// A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. + /// + private class Hero + { + public string Name { get; set; } + public Body One { get; set; } + + public Hero() + { + Name = "Earth"; + One = new Body(10000000, 500000, 800000); + } + + public Hero(string name, int age, float height, float weight) + { + Name = name; + One = new Body(age, height, weight); + } + } + + /// + /// Type of in ML.NET. + /// + private class DataViewBodyType : DataViewType + { + private static volatile DataViewBodyType _instance; + + /// + /// The singleton instance of this type. + /// + public static DataViewBodyType Instance + { + get + { + return _instance ?? + Interlocked.CompareExchange(ref _instance, new DataViewBodyType(), null) ?? + _instance; + } + } + + private DataViewBodyType() : base(typeof(Body)) + { + } + + public override bool Equals(DataViewType other) + { + if (other == this) + return true; + return false; + } + } + + /// + /// Pass in as a column in and load back. + /// + [Fact] + public void RegisterCustomType() + { + var tribe = new List(){ new Hero("Earth", 10, 5.8f, 100.0f), new Hero("Mars", 20, 6.8f, 120.8f) }; + + var tribeDataView = ML.Data.LoadFromEnumerable(tribe); + var tribeEnumerable = ML.Data.CreateEnumerable(tribeDataView, false).ToList(); + + for (int i = 0; i < tribe.Count; ++i) + { + Assert.Equal(tribe[i].Name, tribeEnumerable[i].Name); + Assert.Equal(tribe[i].One.Age, tribeEnumerable[i].One.Age); + Assert.Equal(tribe[i].One.Height, tribeEnumerable[i].One.Height); + Assert.Equal(tribe[i].One.Weight, tribeEnumerable[i].One.Weight); + } + } + + private class SuperHero + { + public string SuperName { get; set; } + public Body SuperOne { get; set; } + + public SuperHero() + { + SuperName = "IronMan"; + SuperOne = new Body(); + } + } + + [CustomMappingFactoryAttribute("LambdaHero")] + private class MyLambda : CustomMappingFactory + { + public static void Grow(Hero input, SuperHero output) + { + output.SuperName = "Sr. " + input.Name; + output.SuperOne.Age = input.One.Age + 9999; + output.SuperOne.Height = input.One.Height * 10; + output.SuperOne.Weight = input.One.Weight * 10; + } + + public override Action GetMapping() + { + return Grow; + } + } + + [Fact] + public void ModifyCustomType() + { + var tribe = new List(){ new Hero("Earth", 10, 5.8f, 100.0f) }; + + var tribeDataView = ML.Data.LoadFromEnumerable(tribe); + + var heroEstimator = new CustomMappingEstimator(ML, MyLambda.Grow, "LambdaHero"); + + var tribeTransformed = heroEstimator.Fit(tribeDataView).Transform(tribeDataView); + + var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); + + for (int i = 0; i < tribe.Count; ++i) + { + Assert.Equal("Sr. " + tribe[i].Name, tribeEnumerable[i].SuperName); + Assert.Equal(tribe[i].One.Age + 9999, tribeEnumerable[i].SuperOne.Age); + Assert.Equal(tribe[i].One.Height * 10, tribeEnumerable[i].SuperOne.Height); + Assert.Equal(tribe[i].One.Weight * 10, tribeEnumerable[i].SuperOne.Weight); + } + } + } +} From 369839e72da22f67f8f77caee126e35c74b148ca Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 21 May 2019 17:10:17 -0700 Subject: [PATCH 05/24] Address more comments --- .../DataView/InternalSchemaDefinition.cs | 6 +- src/Microsoft.ML.DataView/DataViewType.cs | 2 +- src/Microsoft.ML.DataView/TypeManager.cs | 175 +++++++++++------- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 51 ++++- .../UnitTests/TestCustomTypeRegister.cs | 4 +- test/Microsoft.ML.Tests/ImagesTests.cs | 13 +- 6 files changed, 172 insertions(+), 79 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index 22347880b7..ca017b8721 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -185,9 +185,13 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe isVector = false; } + // The internal type of string is ReadOnlyMemory. That is, string will be stored as ReadOnlyMemory in IDataView. if (itemType == typeof(string)) itemType = typeof(ReadOnlyMemory); - else if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType)) + + // Check if the itemType extracted from rawType is supported by ML.NET's type system. + // It must be one of either ML.NET's pre-defined types or custom types registered by the user. + if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType)) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index 643ccbec22..153ba02261 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -21,7 +21,7 @@ public abstract class DataViewType : IEquatable /// /// Constructor for extension types, which must be either or . /// - protected DataViewType(Type rawType) + private protected DataViewType(Type rawType) { RawType = rawType ?? throw new ArgumentNullException(nameof(rawType)); } diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs index 23ed1dcc22..f5fd07fd47 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -4,7 +4,7 @@ using System; using System.Collections.Generic; -using System.Threading; +using System.Reflection; using Microsoft.ML.Internal.DataView; namespace Microsoft.ML.Data @@ -27,71 +27,107 @@ public static class DataViewTypeManager typeof(string), typeof(ReadOnlySpan) }; - private static Dictionary _rawTypeToDataViewTypeMap = new Dictionary(); - private static Dictionary _dataViewTypeToRawTypeMap = new Dictionary(); - private static SpinLock _lock = new SpinLock(); + /// + /// Mapping from ID to a . The ID is the ID of in ML.NET's type system. + /// + private static Dictionary _idToTypeMap = new Dictionary(); + + /// + /// Mapping from ID to a instance. The ID is the ID of instance in ML.NET's type system. + /// + private static Dictionary _idToDataViewTypeMap = new Dictionary(); + + /// + /// Mapping from hashing ID of a and its s to hashing ID of a . + /// + private static Dictionary _typeIdToDataViewTypeIdMap = new Dictionary(); + + /// + /// Mapping from hashing ID of a to hashing ID of a and its s. + /// + private static Dictionary _dataViewTypeIdToTypeIdMap = new Dictionary(); + + private static object _lock = new object(); /// - /// Returns the registered for . + /// This function computes a hashing ID from and attributes attached to it. + /// If a type is defined as a member in a , can be obtained by calling + /// . /// - public static DataViewType GetDataViewType(Type type) + /// + private static int ComputeHashCode(Type rawType, params Attribute[] rawTypeAttributes) { - bool ownLock = false; - DataViewType dataViewType = null; - try - { - _lock.Enter(ref ownLock); + var code = rawType.GetHashCode(); + for (int i = 0; i < rawTypeAttributes.Length; ++i) + code = Hashing.CombineHash(code, rawTypeAttributes[i].GetHashCode()); + return code; + } - if (!_rawTypeToDataViewTypeMap.TryGetValue(type, out dataViewType)) - throw Contracts.ExceptParam(nameof(type), $"The raw type {type} is not registered with a DataView type."); - } - finally + /// + /// This function hashes a and its own hashing code together. + /// + private static int ComputeHashCode(DataViewType dataViewType) => Hashing.CombineHash(dataViewType.GetType().GetHashCode(), dataViewType.GetHashCode()); + + /// + /// Returns the registered for and its . + /// + public static DataViewType GetDataViewType(Type rawType, params Attribute[] rawTypeAttributes) + { + // Overall flow: + // type (Type) + attrs ----> type ID ----------------> associated DataViewType's ID ----------------> DataViewType + // (hashing) (dictionary look-up) (dictionary look-up) + lock (_lock) { - if (ownLock) _lock.Exit(); + // Compute the ID of type with extra attributes. + var typeId = ComputeHashCode(rawType, rawTypeAttributes); + + // Get the DataViewType's ID which typeID is mapped into. + if (!_typeIdToDataViewTypeIdMap.TryGetValue(typeId, out int dataViewTypeId)) + throw Contracts.ExceptParam(nameof(rawType), $"The raw type {rawType} with attributes {rawTypeAttributes} is not registered with a DataView type."); + + // Retrieve the actual DataViewType identified by dataViewTypeId. + return _idToDataViewTypeMap[dataViewTypeId]; } - return dataViewType; } /// - /// If has been registered with a , this function returns . + /// If has been registered with a , this function returns . /// Otherwise, this function returns . /// - public static bool Knows(Type type) + public static bool Knows(Type rawType, params Attribute[] rawTypeAttributes) { - bool ownLock = false; - bool answer = false; - try + lock (_lock) { - _lock.Enter(ref ownLock); - if (_rawTypeToDataViewTypeMap.ContainsKey(type)) - answer = true; - } - finally - { - if (ownLock) _lock.Exit(); + // Compute the ID of type with extra attributes. + var typeId = ComputeHashCode(rawType, rawTypeAttributes); + + // Check if this ID has been associated with a DataViewType. + // Note that the dictionary below contains (typeId, type) pairs (key is typeId, and value is type). + if (_idToTypeMap.ContainsKey(typeId)) + return true; + else + return false; } - return answer; } /// - /// If has been registered with a , this function returns . + /// If has been registered with a , this function returns . /// Otherwise, this function returns . /// - public static bool Knows(DataViewType type) + public static bool Knows(DataViewType dataViewType) { - bool ownLock = false; - bool answer = false; - try + lock (_lock) { - _lock.Enter(ref ownLock); - if (_dataViewTypeToRawTypeMap.ContainsKey(type)) - answer = true; - } - finally - { - if (ownLock) _lock.Exit(); + // Compute the ID of the input DataViewType. + var dataViewTypeId = ComputeHashCode(dataViewType); + + // Check if this the ID has been associated with a DataViewType. + // Note that the dictionary below contains (dataViewTypeId, type) pairs (key is dataViewTypeId, and value is type). + if (_idToDataViewTypeMap.ContainsKey(dataViewTypeId)) + return true; + else + return false; } - return answer; } /// @@ -100,39 +136,48 @@ public static bool Knows(DataViewType type) /// /// Native type in C#. /// The corresponding type of in ML.NET's type system. - public static void Register(Type rawType, DataViewType dataViewType) + /// The s attached to . + public static void Register(DataViewType dataViewType, Type rawType, params Attribute[] rawTypeAttributes) { - bool ownLock = false; - - try + lock (_lock) { - _lock.Enter(ref ownLock); - if (_bannedRawTypes.Contains(rawType)) throw Contracts.ExceptParam(nameof(rawType), $"Type {rawType} has been registered as ML.NET's default supported type, " + $"so it can't not be registered again."); - // Registering the same pair of (rawType, dataViewType) multiple times is ok. However, a raw type can be associated - // with only one DataView type. - if (_rawTypeToDataViewTypeMap.ContainsKey(rawType) && _rawTypeToDataViewTypeMap[rawType] != dataViewType) + int rawTypeId = ComputeHashCode(rawType, rawTypeAttributes); + int dataViewTypeId = ComputeHashCode(dataViewType); + + if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && _typeIdToDataViewTypeIdMap[rawTypeId] == dataViewTypeId && + _dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && _dataViewTypeIdToTypeIdMap[dataViewTypeId] == rawTypeId) + // This type pair has been registered. Note that registering one data type pair multiple times is allowed. + return; + + if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && _typeIdToDataViewTypeIdMap[rawTypeId] != dataViewTypeId) + { + // There is a pair of (rawTypeId, anotherDataViewTypeId) in _typeIdToDataViewTypeId so we cannot register + // (rawTypeId, dataViewTypeId) again. The assumption here is that one rawTypeId can only be associated + // with one dataViewTypeId. + var associatedDataViewType = _idToDataViewTypeMap[_typeIdToDataViewTypeIdMap[rawTypeId]]; throw Contracts.ExceptParam(nameof(rawType), $"Repeated type register. The raw type {rawType} " + - $"has been associated with {_rawTypeToDataViewTypeMap[rawType]} so it cannot be associated with {dataViewType}."); + $"has been associated with {associatedDataViewType} so it cannot be associated with {dataViewType}."); + } - // Registering the same pair of (rawType, dataViewType) multiple times is ok. However, a DataView type can be associated - // with only one raw type. - if (_dataViewTypeToRawTypeMap.ContainsKey(dataViewType) && _dataViewTypeToRawTypeMap[dataViewType] != rawType) + if (_dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && _dataViewTypeIdToTypeIdMap[dataViewTypeId] != rawTypeId) + { + // There is a pair of (dataViewTypeId, anotherRawTypeId) in _dataViewTypeIdToTypeId so we cannot register + // (dataViewTypeId, rawTypeId) again. The assumption here is that one dataViewTypeId can only be associated + // with one rawTypeId. + var associatedRawType = _idToTypeMap[_dataViewTypeIdToTypeIdMap[dataViewTypeId]]; throw Contracts.ExceptParam(nameof(dataViewType), $"Repeated type register. The DataView type {dataViewType} " + - $"has been associated with {_dataViewTypeToRawTypeMap[dataViewType]} so it cannot be associated with {rawType}."); + $"has been associated with {associatedRawType} so it cannot be associated with {rawType}."); + } - if (!_rawTypeToDataViewTypeMap.ContainsKey(rawType)) - _rawTypeToDataViewTypeMap.Add(rawType, dataViewType); + _typeIdToDataViewTypeIdMap.Add(rawTypeId, dataViewTypeId); + _dataViewTypeIdToTypeIdMap.Add(dataViewTypeId, rawTypeId); - if (!_dataViewTypeToRawTypeMap.ContainsKey(dataViewType)) - _dataViewTypeToRawTypeMap.Add(dataViewType, rawType); - } - finally - { - if (ownLock) _lock.Exit(); + _idToDataViewTypeMap[dataViewTypeId] = dataViewType; + _idToTypeMap[rawTypeId] = rawType; } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index ec02a3da29..c15319b208 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Drawing; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; @@ -9,15 +10,51 @@ namespace Microsoft.ML.Transforms.Image { - public sealed class ImageDataViewType : StructuredDataViewType + /// + /// Allows a member to be marked as a , primarily allowing one to set + /// the dimensionality of the resulting array. + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] + public sealed class ImageTypeAttribute : Attribute { - public readonly int Height; - public readonly int Width; + /// + /// The height of the image type. + /// + internal int Height { get; } + + /// + /// The width of the image type. + /// + internal int Width { get; } + + /// + /// Create an image type without knowing its height and width. + /// + public ImageTypeAttribute() + { + } + + /// + /// Create an image type with known height and width. + /// + public ImageTypeAttribute(int height, int width) + { + Contracts.CheckParam(width > 0, nameof(width), "Should be positive number"); + Contracts.CheckParam(height > 0, nameof(height), "Should be positive number"); + Height = height; + Width = width; + } - static ImageDataViewType() + public override int GetHashCode() { - DataViewTypeManager.Register(typeof(Bitmap), new ImageDataViewType()); + return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); } + } + + public sealed class ImageDataViewType : StructuredDataViewType + { + public readonly int Height; + public readonly int Width; public ImageDataViewType(int height, int width) : base(typeof(Bitmap)) @@ -25,12 +62,16 @@ public ImageDataViewType(int height, int width) Contracts.CheckParam(height > 0, nameof(height), "Must be positive."); Contracts.CheckParam(width > 0, nameof(width), " Must be positive."); Contracts.CheckParam((long)height * width <= int.MaxValue / 4, nameof(height), nameof(height) + " * " + nameof(width) + " is too large."); + Height = height; Width = width; + + DataViewTypeManager.Register(this, typeof(Bitmap), new ImageTypeAttribute(height, width)); } public ImageDataViewType() : base(typeof(Bitmap)) { + DataViewTypeManager.Register(this, typeof(Bitmap)); } public override bool Equals(DataViewType other) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 41b0bfe830..2db7463da6 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -35,7 +35,7 @@ private class Body /// static Body() { - DataViewTypeManager.Register(typeof(Body), DataViewBodyType.Instance); + DataViewTypeManager.Register(DataViewBodyType.Instance, typeof(Body)); } public Body() @@ -77,7 +77,7 @@ public Hero(string name, int age, float height, float weight) /// /// Type of in ML.NET. /// - private class DataViewBodyType : DataViewType + private class DataViewBodyType : StructuredDataViewType { private static volatile DataViewBodyType _instance; diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 3dbb850cb3..e507705fb1 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -205,7 +205,7 @@ public void TestGrayScaleInMemory() // Test path: image files -> IDataView -> Enumerable of Bitmaps. var transformedData = pipeline.Fit(data).Transform(data); - var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, true).ToList(); + var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, false); foreach (var datapoint in transformedDataPoints) { @@ -224,12 +224,15 @@ public void TestGrayScaleInMemory() // Test path: Enumerable of Bitmaps -> IDataView -> Enumerable of Bitmaps. var imagesInDataView = ML.Data.LoadFromEnumerable(transformedDataPoints); - var imagesObtainedFromDataView = ML.Data.CreateEnumerable(imagesInDataView, true).ToList(); + var imagesObtainedFromDataView = ML.Data.CreateEnumerable(imagesInDataView, false); - for (int i = 0; i < transformedDataPoints.Count; ++i) + var expectedImages = new[] { transformedDataPoints.First().Grayscale, transformedDataPoints.Last().Grayscale } ; + var obtainedImages = new[] { imagesObtainedFromDataView.First().Grayscale, imagesObtainedFromDataView.Last().Grayscale }; + + for (int i = 0; i < expectedImages.Length; ++i) { - var expectedImage = transformedDataPoints[i].Grayscale; - var obtainedImage = imagesObtainedFromDataView[i].Grayscale; + var expectedImage = expectedImages[i]; + var obtainedImage = obtainedImages[i]; Assert.Equal(expectedImage.Width, obtainedImage.Width); Assert.Equal(expectedImage.Height, obtainedImage.Height); From 10ead85edb8eb4009ef557567c70a9e62cc36836 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 10:15:36 -0700 Subject: [PATCH 06/24] Register type with attributes --- .../Data/SchemaDefinition.cs | 4 +- .../DataView/InternalSchemaDefinition.cs | 9 +- src/Microsoft.ML.Data/Utils/ApiUtils.cs | 13 +- .../UnitTests/TestCustomTypeRegister.cs | 152 +++++++++++++++++- 4 files changed, 164 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index d327315ae0..ff6e0628d2 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -394,7 +394,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc // Get the column type. DataViewType columnType; - if (!DataViewTypeManager.Knows(dataType)) + if (!DataViewTypeManager.Knows(dataType, memberInfo.GetCustomAttributes().ToArray())) { PrimitiveDataViewType itemType; var keyAttr = memberInfo.GetCustomAttribute(); @@ -427,7 +427,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc columnType = itemType; } else - columnType = DataViewTypeManager.GetDataViewType(dataType); + columnType = DataViewTypeManager.GetDataViewType(dataType, memberInfo.GetCustomAttributes().ToArray()); cols.Add(new Column(memberInfo.Name, columnType, name)); } diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index ca017b8721..c662c4aab2 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -147,11 +147,11 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector switch (memberInfo) { case FieldInfo fieldInfo: - GetVectorAndItemType(fieldInfo.FieldType, fieldInfo.Name, out isVector, out itemType); + GetVectorAndItemType(fieldInfo.FieldType, fieldInfo.Name, out isVector, out itemType, fieldInfo.GetCustomAttributes().ToArray()); break; case PropertyInfo propertyInfo: - GetVectorAndItemType(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out itemType); + GetVectorAndItemType(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out itemType, propertyInfo.GetCustomAttributes().ToArray()); break; default: @@ -171,7 +171,8 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector /// /// The corresponding RawType of the type, or items of this type if vector. /// - public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType) + /// Attribute of . + public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType, params Attribute[] attributes) { // Determine whether this is a vector, and also determine the raw item type. isVector = true; @@ -191,7 +192,7 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe // Check if the itemType extracted from rawType is supported by ML.NET's type system. // It must be one of either ML.NET's pre-defined types or custom types registered by the user. - if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType)) + if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes)) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index 67f79b93ba..c4c73c60db 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Linq; using System.Reflection; using System.Reflection.Emit; using Microsoft.ML.Data; @@ -16,7 +17,7 @@ namespace Microsoft.ML internal static class ApiUtils { - private static OpCode GetAssignmentOpCode(Type t) + private static OpCode GetAssignmentOpCode(Type t, params Attribute[] attributes) { // REVIEW: This should be a Dictionary based solution. // DvTypes, strings, arrays, all nullable types, VBuffers and RowId. @@ -24,7 +25,7 @@ private static OpCode GetAssignmentOpCode(Type t) (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || - t == typeof(DataViewRowId) || DataViewTypeManager.Knows(t)) + t == typeof(DataViewRowId) || DataViewTypeManager.Knows(t, attributes)) { return OpCodes.Stobj; } @@ -57,7 +58,7 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType); + var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes().ToArray()); Func func = GeneratePeek; var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); @@ -66,7 +67,7 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes().ToArray()); Func funcProp = GeneratePeek; var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); @@ -133,7 +134,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType); + var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes().ToArray()); Func func = GeneratePoke; var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); @@ -142,7 +143,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes().ToArray()); Func funcProp = GeneratePoke; var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 2db7463da6..d971600b3f 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -112,7 +112,7 @@ public override bool Equals(DataViewType other) [Fact] public void RegisterCustomType() { - var tribe = new List(){ new Hero("Earth", 10, 5.8f, 100.0f), new Hero("Mars", 20, 6.8f, 120.8f) }; + var tribe = new List() { new Hero("Earth", 10, 5.8f, 100.0f), new Hero("Mars", 20, 6.8f, 120.8f) }; var tribeDataView = ML.Data.LoadFromEnumerable(tribe); var tribeEnumerable = ML.Data.CreateEnumerable(tribeDataView, false).ToList(); @@ -158,7 +158,7 @@ public override Action GetMapping() [Fact] public void ModifyCustomType() { - var tribe = new List(){ new Hero("Earth", 10, 5.8f, 100.0f) }; + var tribe = new List() { new Hero("Earth", 10, 5.8f, 100.0f) }; var tribeDataView = ML.Data.LoadFromEnumerable(tribe); @@ -176,5 +176,153 @@ public void ModifyCustomType() Assert.Equal(tribe[i].One.Weight * 10, tribeEnumerable[i].SuperOne.Weight); } } + + /// + /// A custom type which ML.NET doesn't know yet. Its value will be loaded as a DataView column in this test. + /// + private class AlienBody + { + public int Age { get; set; } + public float Height { get; set; } + public float Weight { get; set; } + public int HandCount { get; set; } + + /// + /// Type register should happen before the creation of the first . Otherwise, ML.NET might not recognize + /// that is typed to in ML.NET's internal type system. + /// + static AlienBody() + { + } + + public AlienBody() + { + Age = 0; + Height = 0; + Weight = 0; + HandCount = 0; + } + + public AlienBody(int age, float height, float weight, int handCount) + { + Age = age; + Height = height; + Weight = weight; + HandCount = handCount; + } + } + + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] + private sealed class AlienTypeAttributeAttribute : Attribute + { + public int Id { get; } + + /// + /// Create an image type with known height and width. + /// + public AlienTypeAttributeAttribute(int id) + { + Id = id; + } + } + + /// + /// A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. + /// + private class AlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(100)] + public AlienBody One { get; set; } + + [AlienTypeAttribute(200)] + public AlienBody Two { get; set; } + + public AlienHero() + { + Name = "Earth"; + One = new AlienBody(10000000, 500000, 800000, 100); + Two = new AlienBody(10, 9, 8, 7); + } + } + + /// + /// Type of in ML.NET. + /// + private class DataViewAlienBodyType : StructuredDataViewType + { + public int Id { get; } + + public DataViewAlienBodyType(int id) : base(typeof(AlienBody)) + { + Id = id; + } + + public override bool Equals(DataViewType other) + { + if (other is DataViewAlienBodyType) + return ((DataViewAlienBodyType)other).Id == Id; + else + return false; + } + } + + /// + /// The output type of processing . + /// + private class SuperAlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(007)] + public AlienBody Merged { get; set; } + + public SuperAlienHero() + { + Name = "Earth"; + Merged = new AlienBody(0, 0, 0, 0); + } + } + + [CustomMappingFactoryAttribute("LambdaAlienHero")] + private class AlienLambda : CustomMappingFactory + { + public static void MergeBody(AlienHero input, SuperAlienHero output) + { + output.Name = "Super " + input.Name; + output.Merged.Age = input.One.Age + input.Two.Age; + output.Merged.Height = input.One.Height + input.Two.Height; + output.Merged.Weight = input.One.Weight + input.Two.Weight; + } + + public override Action GetMapping() + { + return MergeBody; + } + } + + [Fact] + public void RegisterTypeWithAttribute() + { + var tribe = new List() { new AlienHero() }; + + DataViewTypeManager.Register(new DataViewAlienBodyType(100), typeof(AlienBody), new AlienTypeAttributeAttribute(100)); + DataViewTypeManager.Register(new DataViewAlienBodyType(200), typeof(AlienBody), new AlienTypeAttributeAttribute(200)); + DataViewTypeManager.Register(new DataViewAlienBodyType(007), typeof(AlienBody), new AlienTypeAttributeAttribute(007)); + + var tribeDataView = ML.Data.LoadFromEnumerable(tribe); + + var heroEstimator = new CustomMappingEstimator(ML, AlienLambda.MergeBody, "LambdaAlienHero"); + + var tribeTransformed = heroEstimator.Fit(tribeDataView).Transform(tribeDataView); + + var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); + + Assert.Equal(tribeEnumerable[0].Name, "Super " + tribe[0].Name); + Assert.Equal(tribeEnumerable[0].Merged.Age, tribe[0].One.Age + tribe[0].Two.Age); + Assert.Equal(tribeEnumerable[0].Merged.Height, tribe[0].One.Height + tribe[0].Two.Height); + Assert.Equal(tribeEnumerable[0].Merged.Weight, tribe[0].One.Weight + tribe[0].Two.Weight); + } } } From d8bff2c16e28114dde38c08a54992c8c2684b655 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 11:02:13 -0700 Subject: [PATCH 07/24] Remove a conceptually duplicate test --- .../UnitTests/TestCustomTypeRegister.cs | 166 +----------------- 1 file changed, 7 insertions(+), 159 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index d971600b3f..8d65996030 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -20,163 +20,6 @@ public TestCustomTypeRegister(ITestOutputHelper helper) { } - /// - /// A custom type which ML.NET doesn't know yet. Its value will be loaded as a DataView column in this test. - /// - private class Body - { - public int Age { get; set; } - public float Height { get; set; } - public float Weight { get; set; } - - /// - /// Type register should happen before the creation of the first . Otherwise, ML.NET might not recognize - /// that is typed to in ML.NET's internal type system. - /// - static Body() - { - DataViewTypeManager.Register(DataViewBodyType.Instance, typeof(Body)); - } - - public Body() - { - Age = 0; - Height = 0; - Weight = 0; - } - - public Body(int age, float height, float weight) - { - Age = age; - Height = height; - Weight = weight; - } - } - - /// - /// A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. - /// - private class Hero - { - public string Name { get; set; } - public Body One { get; set; } - - public Hero() - { - Name = "Earth"; - One = new Body(10000000, 500000, 800000); - } - - public Hero(string name, int age, float height, float weight) - { - Name = name; - One = new Body(age, height, weight); - } - } - - /// - /// Type of in ML.NET. - /// - private class DataViewBodyType : StructuredDataViewType - { - private static volatile DataViewBodyType _instance; - - /// - /// The singleton instance of this type. - /// - public static DataViewBodyType Instance - { - get - { - return _instance ?? - Interlocked.CompareExchange(ref _instance, new DataViewBodyType(), null) ?? - _instance; - } - } - - private DataViewBodyType() : base(typeof(Body)) - { - } - - public override bool Equals(DataViewType other) - { - if (other == this) - return true; - return false; - } - } - - /// - /// Pass in as a column in and load back. - /// - [Fact] - public void RegisterCustomType() - { - var tribe = new List() { new Hero("Earth", 10, 5.8f, 100.0f), new Hero("Mars", 20, 6.8f, 120.8f) }; - - var tribeDataView = ML.Data.LoadFromEnumerable(tribe); - var tribeEnumerable = ML.Data.CreateEnumerable(tribeDataView, false).ToList(); - - for (int i = 0; i < tribe.Count; ++i) - { - Assert.Equal(tribe[i].Name, tribeEnumerable[i].Name); - Assert.Equal(tribe[i].One.Age, tribeEnumerable[i].One.Age); - Assert.Equal(tribe[i].One.Height, tribeEnumerable[i].One.Height); - Assert.Equal(tribe[i].One.Weight, tribeEnumerable[i].One.Weight); - } - } - - private class SuperHero - { - public string SuperName { get; set; } - public Body SuperOne { get; set; } - - public SuperHero() - { - SuperName = "IronMan"; - SuperOne = new Body(); - } - } - - [CustomMappingFactoryAttribute("LambdaHero")] - private class MyLambda : CustomMappingFactory - { - public static void Grow(Hero input, SuperHero output) - { - output.SuperName = "Sr. " + input.Name; - output.SuperOne.Age = input.One.Age + 9999; - output.SuperOne.Height = input.One.Height * 10; - output.SuperOne.Weight = input.One.Weight * 10; - } - - public override Action GetMapping() - { - return Grow; - } - } - - [Fact] - public void ModifyCustomType() - { - var tribe = new List() { new Hero("Earth", 10, 5.8f, 100.0f) }; - - var tribeDataView = ML.Data.LoadFromEnumerable(tribe); - - var heroEstimator = new CustomMappingEstimator(ML, MyLambda.Grow, "LambdaHero"); - - var tribeTransformed = heroEstimator.Fit(tribeDataView).Transform(tribeDataView); - - var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); - - for (int i = 0; i < tribe.Count; ++i) - { - Assert.Equal("Sr. " + tribe[i].Name, tribeEnumerable[i].SuperName); - Assert.Equal(tribe[i].One.Age + 9999, tribeEnumerable[i].SuperOne.Age); - Assert.Equal(tribe[i].One.Height * 10, tribeEnumerable[i].SuperOne.Height); - Assert.Equal(tribe[i].One.Weight * 10, tribeEnumerable[i].SuperOne.Weight); - } - } - /// /// A custom type which ML.NET doesn't know yet. Its value will be loaded as a DataView column in this test. /// @@ -228,6 +71,7 @@ public AlienTypeAttributeAttribute(int id) /// /// A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. + /// It will be the input of . /// private class AlienHero { @@ -248,7 +92,7 @@ public AlienHero() } /// - /// Type of in ML.NET. + /// Type of in ML.NET's type system. /// private class DataViewAlienBodyType : StructuredDataViewType { @@ -269,7 +113,7 @@ public override bool Equals(DataViewType other) } /// - /// The output type of processing . + /// The output type of processing using . /// private class SuperAlienHero { @@ -285,6 +129,10 @@ public SuperAlienHero() } } + /// + /// A mapping from to . It is used to create a + /// in . + /// [CustomMappingFactoryAttribute("LambdaAlienHero")] private class AlienLambda : CustomMappingFactory { From 09210a2818ff848fa228ee7e477ded8b6278b128 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 11:23:52 -0700 Subject: [PATCH 08/24] Polish test --- src/Microsoft.ML.DataView/TypeManager.cs | 8 ++-- .../UnitTests/TestCustomTypeRegister.cs | 41 +++++++++---------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs index f5fd07fd47..8470f4f224 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -24,7 +24,9 @@ public static class DataViewTypeManager typeof(Boolean), typeof(SByte), typeof(Byte), typeof(Int16), typeof(UInt16), typeof(Int32), typeof(UInt32), typeof(Int64), typeof(UInt64), typeof(Single), typeof(Double), - typeof(string), typeof(ReadOnlySpan) + typeof(string), typeof(ReadOnlySpan), typeof(ReadOnlyMemory), + typeof(VBuffer<>), typeof(Nullable<>), typeof(DateTime), typeof(DateTimeOffset), + typeof(TimeSpan), typeof(DataViewRowId) }; /// @@ -176,8 +178,8 @@ public static void Register(DataViewType dataViewType, Type rawType, params Attr _typeIdToDataViewTypeIdMap.Add(rawTypeId, dataViewTypeId); _dataViewTypeIdToTypeIdMap.Add(dataViewTypeId, rawTypeId); - _idToDataViewTypeMap[dataViewTypeId] = dataViewType; - _idToTypeMap[rawTypeId] = rawType; + _idToDataViewTypeMap.Add(dataViewTypeId, dataViewType); + _idToTypeMap.Add(rawTypeId, rawType); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 8d65996030..8662656ba7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Threading; using Microsoft.ML.Data; using Microsoft.ML.Transforms; using Xunit; @@ -30,22 +29,6 @@ private class AlienBody public float Weight { get; set; } public int HandCount { get; set; } - /// - /// Type register should happen before the creation of the first . Otherwise, ML.NET might not recognize - /// that is typed to in ML.NET's internal type system. - /// - static AlienBody() - { - } - - public AlienBody() - { - Age = 0; - Height = 0; - Weight = 0; - HandCount = 0; - } - public AlienBody(int age, float height, float weight, int handCount) { Age = age; @@ -72,6 +55,9 @@ public AlienTypeAttributeAttribute(int id) /// /// A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. /// It will be the input of . + /// + /// and would be mapped to different types inside ML.NET type system because they + /// have different s. /// private class AlienHero { @@ -85,9 +71,18 @@ private class AlienHero public AlienHero() { - Name = "Earth"; - One = new AlienBody(10000000, 500000, 800000, 100); - Two = new AlienBody(10, 9, 8, 7); + Name = "Unknown"; + One = new AlienBody(0, 0, 0, 0); + Two = new AlienBody(0, 0, 0, 0); + } + + public AlienHero(string name, + int age, float height, float weight, int handCount, + int anotherAge, float anotherHeight, float anotherWeight, int anotherHandCount) + { + Name = "Unknown"; + One = new AlienBody(age, height, weight, handCount); + Two = new AlienBody(anotherAge, anotherHeight, anotherWeight, anotherHandCount); } } @@ -124,7 +119,7 @@ private class SuperAlienHero public SuperAlienHero() { - Name = "Earth"; + Name = "Unknown"; Merged = new AlienBody(0, 0, 0, 0); } } @@ -142,6 +137,7 @@ public static void MergeBody(AlienHero input, SuperAlienHero output) output.Merged.Age = input.One.Age + input.Two.Age; output.Merged.Height = input.One.Height + input.Two.Height; output.Merged.Weight = input.One.Weight + input.Two.Weight; + output.Merged.HandCount = input.One.HandCount + input.Two.HandCount; } public override Action GetMapping() @@ -153,7 +149,7 @@ public override Action GetMapping() [Fact] public void RegisterTypeWithAttribute() { - var tribe = new List() { new AlienHero() }; + var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; DataViewTypeManager.Register(new DataViewAlienBodyType(100), typeof(AlienBody), new AlienTypeAttributeAttribute(100)); DataViewTypeManager.Register(new DataViewAlienBodyType(200), typeof(AlienBody), new AlienTypeAttributeAttribute(200)); @@ -171,6 +167,7 @@ public void RegisterTypeWithAttribute() Assert.Equal(tribeEnumerable[0].Merged.Age, tribe[0].One.Age + tribe[0].Two.Age); Assert.Equal(tribeEnumerable[0].Merged.Height, tribe[0].One.Height + tribe[0].Two.Height); Assert.Equal(tribeEnumerable[0].Merged.Weight, tribe[0].One.Weight + tribe[0].Two.Weight); + Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); } } } From 3fff64c05094837a33a8e26159e0c24913272a78 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 13:27:53 -0700 Subject: [PATCH 09/24] Polish a test which uses in-memory Bitmap --- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 8 +- test/Microsoft.ML.Tests/ImagesTests.cs | 91 +++++++++----------- 2 files changed, 47 insertions(+), 52 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index c15319b208..0a9c4a0806 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -56,6 +56,11 @@ public sealed class ImageDataViewType : StructuredDataViewType public readonly int Height; public readonly int Width; + static ImageDataViewType() + { + DataViewTypeManager.Register(new ImageDataViewType(), typeof(Bitmap)); + } + public ImageDataViewType(int height, int width) : base(typeof(Bitmap)) { @@ -65,13 +70,10 @@ public ImageDataViewType(int height, int width) Height = height; Width = width; - - DataViewTypeManager.Register(this, typeof(Bitmap), new ImageTypeAttribute(height, width)); } public ImageDataViewType() : base(typeof(Bitmap)) { - DataViewTypeManager.Register(this, typeof(Bitmap)); } public override bool Equals(DataViewType other) diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index e507705fb1..9be5146292 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -187,76 +187,69 @@ public void TestGreyscaleTransformImages() [Fact] public void TestGrayScaleInMemory() { - var imagesDataFile = SamplesUtils.DatasetUtils.DownloadImages(); + // Create an image list. + var images = new List(){ new ImageDataPoint(10, 10, Color.Blue), new ImageDataPoint(10, 10, Color.Red) }; - var data = ML.Data.CreateTextLoader(new TextLoader.Options() - { - Columns = new[] - { - new TextLoader.Column("ImagePath", DataKind.String, 0), - new TextLoader.Column("Name", DataKind.String, 1), - } - }).Load(imagesDataFile); + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var data = ML.Data.LoadFromEnumerable(images); - var imagesFolder = Path.GetDirectoryName(imagesDataFile); - // Image loading and conversion pipeline. - var pipeline = ML.Transforms.LoadImages("ImageObject", imagesFolder, "ImagePath") - .Append(ML.Transforms.ConvertToGrayscale("Grayscale", "ImageObject")); + // Convert image to gray scale. + var pipeline = ML.Transforms.ConvertToGrayscale("GrayImage", "Image"); // Test path: image files -> IDataView -> Enumerable of Bitmaps. var transformedData = pipeline.Fit(data).Transform(data); - var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, false); - foreach (var datapoint in transformedDataPoints) + // Load images in DataView back to Enumerable. + var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, false); + + foreach (var dataPoint in transformedDataPoints) { - var image = datapoint.Grayscale; - Assert.NotNull(image); - for (int x = 0; x < image.Width; ++x) + var image = dataPoint.Image; + var grayImage = dataPoint.GrayImage; + + Assert.NotNull(grayImage); + + Assert.Equal(image.Width, grayImage.Width); + Assert.Equal(image.Height, grayImage.Height); + + for (int x = 0; x < grayImage.Width; ++x) { - for (int y = 0; y < image.Height; ++y) + for (int y = 0; y < grayImage.Height; ++y) { - var pixel = image.GetPixel(x, y); + var pixel = grayImage.GetPixel(x, y); // greyscale image has same values for R, G and B. Assert.True(pixel.R == pixel.G && pixel.G == pixel.B); } } } + } - // Test path: Enumerable of Bitmaps -> IDataView -> Enumerable of Bitmaps. - var imagesInDataView = ML.Data.LoadFromEnumerable(transformedDataPoints); - var imagesObtainedFromDataView = ML.Data.CreateEnumerable(imagesInDataView, false); + private class ImageDataPoint + { + [ImageType(10, 10)] + public Bitmap Image { get; set; } - var expectedImages = new[] { transformedDataPoints.First().Grayscale, transformedDataPoints.Last().Grayscale } ; - var obtainedImages = new[] { imagesObtainedFromDataView.First().Grayscale, imagesObtainedFromDataView.Last().Grayscale }; + [ImageType(10, 10)] + public Bitmap GrayImage { get; set; } - for (int i = 0; i < expectedImages.Length; ++i) + static ImageDataPoint() { - var expectedImage = expectedImages[i]; - var obtainedImage = obtainedImages[i]; - - Assert.Equal(expectedImage.Width, obtainedImage.Width); - Assert.Equal(expectedImage.Height, obtainedImage.Height); - for (int x = 0; x < expectedImage.Width; ++x) - { - for (int y = 0; y < expectedImage.Height; ++y) - { - var expectedPixel = expectedImage.GetPixel(x, y); - var obtainedPixel = obtainedImage.GetPixel(x, y); + DataViewTypeManager.Register(new ImageDataViewType(10, 10), typeof(Bitmap), new ImageTypeAttribute(10, 10)); + } - Assert.Equal(expectedPixel.R, obtainedPixel.R); - Assert.Equal(expectedPixel.G, obtainedPixel.G); - Assert.Equal(expectedPixel.B, obtainedPixel.B); - } - } + public ImageDataPoint() + { + Image = null; + GrayImage = null; } - } - private class TransformedImageDataPoint - { - public string ImagePath { get; set; } - public string Name { get; set; } - public Bitmap ImageObject { get; set; } - public Bitmap Grayscale { get; set; } + public ImageDataPoint(int width, int height, Color color) + { + Image = new Bitmap(width, height); + for (int i = 0; i < width; ++i) + for (int j = 0; j < height; ++j) + Image.SetPixel(i, j, color); + } } [Fact] From cce425ae6e3ec98a96a90ad060cd0b2ddfa89569 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 13:40:37 -0700 Subject: [PATCH 10/24] Test prediction engine using in-memory images --- .../UnitTests/TestCustomTypeRegister.cs | 14 +++++++++++- test/Microsoft.ML.Tests/ImagesTests.cs | 22 ++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 8662656ba7..2b7e8bab90 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -159,7 +159,9 @@ public void RegisterTypeWithAttribute() var heroEstimator = new CustomMappingEstimator(ML, AlienLambda.MergeBody, "LambdaAlienHero"); - var tribeTransformed = heroEstimator.Fit(tribeDataView).Transform(tribeDataView); + var model = heroEstimator.Fit(tribeDataView); + + var tribeTransformed = model.Transform(tribeDataView); var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); @@ -168,6 +170,16 @@ public void RegisterTypeWithAttribute() Assert.Equal(tribeEnumerable[0].Merged.Height, tribe[0].One.Height + tribe[0].Two.Height); Assert.Equal(tribeEnumerable[0].Merged.Weight, tribe[0].One.Weight + tribe[0].Two.Weight); Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); + + var engine = ML.Model.CreatePredictionEngine(model); + var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); + var superAlien = engine.Predict(alien); + + Assert.Equal(superAlien.Name, "Super " + alien.Name); + Assert.Equal(superAlien.Merged.Age, alien.One.Age + alien.Two.Age); + Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height); + Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight); + Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount); } } } diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 9be5146292..1b1729896d 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -196,8 +196,11 @@ public void TestGrayScaleInMemory() // Convert image to gray scale. var pipeline = ML.Transforms.ConvertToGrayscale("GrayImage", "Image"); + // Fit the model. + var model = pipeline.Fit(data); + // Test path: image files -> IDataView -> Enumerable of Bitmaps. - var transformedData = pipeline.Fit(data).Transform(data); + var transformedData = model.Transform(data); // Load images in DataView back to Enumerable. var transformedDataPoints = ML.Data.CreateEnumerable(transformedData, false); @@ -222,6 +225,23 @@ public void TestGrayScaleInMemory() } } } + + var engine = ML.Model.CreatePredictionEngine(model); + var singleImage = new ImageDataPoint(17, 36, Color.Pink); + var transformedSingleImage = engine.Predict(singleImage); + + Assert.Equal(singleImage.Image.Height, transformedSingleImage.GrayImage.Height); + Assert.Equal(singleImage.Image.Width, transformedSingleImage.GrayImage.Width); + + for (int x = 0; x < transformedSingleImage.GrayImage.Width; ++x) + { + for (int y = 0; y < transformedSingleImage.GrayImage.Height; ++y) + { + var pixel = transformedSingleImage.GrayImage.GetPixel(x, y); + // greyscale image has same values for R, G and B. + Assert.True(pixel.R == pixel.G && pixel.G == pixel.B); + } + } } private class ImageDataPoint From 2a4977cccb2998e0162870c162b4481c979dceaa Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 14:11:47 -0700 Subject: [PATCH 11/24] Add more tests --- .../UnitTests/TestCustomTypeRegister.cs | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 2b7e8bab90..421581e4df 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -105,6 +105,11 @@ public override bool Equals(DataViewType other) else return false; } + + public override int GetHashCode() + { + return Id.GetHashCode(); + } } /// @@ -151,10 +156,35 @@ public void RegisterTypeWithAttribute() { var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; + // Type manager doesn't know any of those custom types, so all calls to it should return false. + Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(100))); + Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(200))); + Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(007))); + Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(100))); + Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(200))); + Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(007))); + + // Register those custom types. DataViewTypeManager.Register(new DataViewAlienBodyType(100), typeof(AlienBody), new AlienTypeAttributeAttribute(100)); DataViewTypeManager.Register(new DataViewAlienBodyType(200), typeof(AlienBody), new AlienTypeAttributeAttribute(200)); DataViewTypeManager.Register(new DataViewAlienBodyType(007), typeof(AlienBody), new AlienTypeAttributeAttribute(007)); + // Type manager now knows those those custom types, so all calls to it should return true. + Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(100))); + Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(200))); + Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(007))); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(100))); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(200))); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(007))); + + // Check if the custom type (AlienBody with its attributes) is registered correctly with a DataView type (DataViewAlienBodyType). + Assert.Equal(new DataViewAlienBodyType(100), + DataViewTypeManager.GetDataViewType(typeof(AlienBody), new AlienTypeAttributeAttribute(100))); + Assert.Equal(new DataViewAlienBodyType(200), + DataViewTypeManager.GetDataViewType(typeof(AlienBody), new AlienTypeAttributeAttribute(200))); + Assert.Equal(new DataViewAlienBodyType(007), + DataViewTypeManager.GetDataViewType(typeof(AlienBody), new AlienTypeAttributeAttribute(007))); + var tribeDataView = ML.Data.LoadFromEnumerable(tribe); var heroEstimator = new CustomMappingEstimator(ML, AlienLambda.MergeBody, "LambdaAlienHero"); @@ -181,5 +211,56 @@ public void RegisterTypeWithAttribute() Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight); Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount); } + + [Fact] + public void TestTypeManager() + { + // Semantically identical DataViewTypes should produce the same hash code. + var a = new DataViewAlienBodyType(9527); + var aCode = a.GetHashCode(); + var b = new DataViewAlienBodyType(9527); + var bCode = b.GetHashCode(); + + Assert.Equal(aCode, bCode); + + // Semantically identical attributes should produce the same hash code. + var c = new AlienTypeAttributeAttribute(1228); + var cCode = c.GetHashCode(); + var d = new AlienTypeAttributeAttribute(1228); + var dCode = d.GetHashCode(); + + Assert.Equal(cCode, dCode); + + // Check register the same type pair is ok. + DataViewTypeManager.Register(a, typeof(AlienBody)); + DataViewTypeManager.Register(a, typeof(AlienBody)); + + bool isWrong = false; + try + { + // "a" has been registered with AlienBody without any attribute, so the user can't + // register "a" again with AlienBody with different attribute. + DataViewTypeManager.Register(a, typeof(AlienBody), c); + } + catch + { + isWrong = true; + } + Assert.True(isWrong); + + + bool isWrongAgain = false; + try + { + // AlienBody has been registered with "a," so user can't register it with + // "new DataViewAlienBodyType(5566)" again. + DataViewTypeManager.Register(new DataViewAlienBodyType(5566), typeof(AlienBody)); + } + catch + { + isWrongAgain = true; + } + Assert.True(isWrongAgain); + } } } From 2851aa56652a429539144274da2b1df1e9c3ee26 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 16:13:46 -0700 Subject: [PATCH 12/24] Address comments --- src/Microsoft.ML.DataView/TypeManager.cs | 15 ++++--- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 18 ++++++++- .../UnitTests/TestCustomTypeRegister.cs | 39 +++++++++++-------- test/Microsoft.ML.Tests/ImagesTests.cs | 2 +- 4 files changed, 48 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs index 8470f4f224..8a895f97aa 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -57,11 +57,14 @@ public static class DataViewTypeManager /// . /// /// - private static int ComputeHashCode(Type rawType, params Attribute[] rawTypeAttributes) + private static int ComputeHashCode(Type rawType, IEnumerable rawTypeAttributes) { + if (rawTypeAttributes == null) + return rawType.GetHashCode(); + var code = rawType.GetHashCode(); - for (int i = 0; i < rawTypeAttributes.Length; ++i) - code = Hashing.CombineHash(code, rawTypeAttributes[i].GetHashCode()); + foreach (var attr in rawTypeAttributes) + code = Hashing.CombineHash(code, attr.GetHashCode()); return code; } @@ -73,7 +76,7 @@ private static int ComputeHashCode(Type rawType, params Attribute[] rawTypeAttri /// /// Returns the registered for and its . /// - public static DataViewType GetDataViewType(Type rawType, params Attribute[] rawTypeAttributes) + public static DataViewType GetDataViewType(Type rawType, IEnumerable rawTypeAttributes) { // Overall flow: // type (Type) + attrs ----> type ID ----------------> associated DataViewType's ID ----------------> DataViewType @@ -96,7 +99,7 @@ public static DataViewType GetDataViewType(Type rawType, params Attribute[] rawT /// If has been registered with a , this function returns . /// Otherwise, this function returns . /// - public static bool Knows(Type rawType, params Attribute[] rawTypeAttributes) + public static bool Knows(Type rawType, IEnumerable rawTypeAttributes) { lock (_lock) { @@ -139,7 +142,7 @@ public static bool Knows(DataViewType dataViewType) /// Native type in C#. /// The corresponding type of in ML.NET's type system. /// The s attached to . - public static void Register(DataViewType dataViewType, Type rawType, params Attribute[] rawTypeAttributes) + public static void Register(DataViewType dataViewType, Type rawType, IEnumerable rawTypeAttributes = null) { lock (_lock) { diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 0a9c4a0806..709b2edc7f 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -11,8 +11,8 @@ namespace Microsoft.ML.Transforms.Image { /// - /// Allows a member to be marked as a , primarily allowing one to set - /// the dimensionality of the resulting array. + /// Allows a member to be marked as a , primarily allowing one to set + /// the shape of an image field. /// [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class ImageTypeAttribute : Attribute @@ -45,10 +45,24 @@ public ImageTypeAttribute(int height, int width) Width = width; } + /// + /// Images with the same width and height should produce the same hash code. + /// public override int GetHashCode() { return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); } + + /// + /// Images with the same width and height should equal. + /// + public override bool Equals(object other) + { + if (other is ImageDataViewType) + return Height == ((ImageDataViewType)other).Height && Width == ((ImageDataViewType)other).Width; + else + return false; + } } public sealed class ImageDataViewType : StructuredDataViewType diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 421581e4df..9c21962715 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -160,31 +160,32 @@ public void RegisterTypeWithAttribute() Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(100))); Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(200))); Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(007))); - Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(100))); - Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(200))); - Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(007))); + Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) })); + Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) })); + Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) })); // Register those custom types. - DataViewTypeManager.Register(new DataViewAlienBodyType(100), typeof(AlienBody), new AlienTypeAttributeAttribute(100)); - DataViewTypeManager.Register(new DataViewAlienBodyType(200), typeof(AlienBody), new AlienTypeAttributeAttribute(200)); - DataViewTypeManager.Register(new DataViewAlienBodyType(007), typeof(AlienBody), new AlienTypeAttributeAttribute(007)); + DataViewTypeManager.Register(new DataViewAlienBodyType(100), typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) }); + DataViewTypeManager.Register(new DataViewAlienBodyType(200), typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) }); + DataViewTypeManager.Register(new DataViewAlienBodyType(007), typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) }); - // Type manager now knows those those custom types, so all calls to it should return true. + // Type manager now knows those custom types, so all calls to it should return true. Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(100))); Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(200))); Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(007))); - Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(100))); - Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(200))); - Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new AlienTypeAttributeAttribute(007))); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) })); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) })); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) })); // Check if the custom type (AlienBody with its attributes) is registered correctly with a DataView type (DataViewAlienBodyType). Assert.Equal(new DataViewAlienBodyType(100), - DataViewTypeManager.GetDataViewType(typeof(AlienBody), new AlienTypeAttributeAttribute(100))); + DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) })); Assert.Equal(new DataViewAlienBodyType(200), - DataViewTypeManager.GetDataViewType(typeof(AlienBody), new AlienTypeAttributeAttribute(200))); + DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) })); Assert.Equal(new DataViewAlienBodyType(007), - DataViewTypeManager.GetDataViewType(typeof(AlienBody), new AlienTypeAttributeAttribute(007))); + DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) })); + // Build a ML.NET pipeline and make prediction. var tribeDataView = ML.Data.LoadFromEnumerable(tribe); var heroEstimator = new CustomMappingEstimator(ML, AlienLambda.MergeBody, "LambdaAlienHero"); @@ -195,16 +196,19 @@ public void RegisterTypeWithAttribute() var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); + // Make sure the pipeline output is correct. Assert.Equal(tribeEnumerable[0].Name, "Super " + tribe[0].Name); Assert.Equal(tribeEnumerable[0].Merged.Age, tribe[0].One.Age + tribe[0].Two.Age); Assert.Equal(tribeEnumerable[0].Merged.Height, tribe[0].One.Height + tribe[0].Two.Height); Assert.Equal(tribeEnumerable[0].Merged.Weight, tribe[0].One.Weight + tribe[0].Two.Weight); Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); + // Build prediction engine from the trained pipeline. var engine = ML.Model.CreatePredictionEngine(model); var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); var superAlien = engine.Predict(alien); + // Make sure the prediction engine produces expected result. Assert.Equal(superAlien.Name, "Super " + alien.Name); Assert.Equal(superAlien.Merged.Age, alien.One.Age + alien.Two.Age); Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height); @@ -231,16 +235,17 @@ public void TestTypeManager() Assert.Equal(cCode, dCode); - // Check register the same type pair is ok. + // Check registering the same type pair is OK. DataViewTypeManager.Register(a, typeof(AlienBody)); DataViewTypeManager.Register(a, typeof(AlienBody)); + // Make sure registering the same type twice throws. bool isWrong = false; try { // "a" has been registered with AlienBody without any attribute, so the user can't - // register "a" again with AlienBody with different attribute. - DataViewTypeManager.Register(a, typeof(AlienBody), c); + // register "a" again with AlienBody plus the attribute "c." + DataViewTypeManager.Register(a, typeof(AlienBody), new[] { c }); } catch { @@ -248,7 +253,7 @@ public void TestTypeManager() } Assert.True(isWrong); - + // Make sure registering the same type twice throws. bool isWrongAgain = false; try { diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 1b1729896d..8fd12b950c 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -254,7 +254,7 @@ private class ImageDataPoint static ImageDataPoint() { - DataViewTypeManager.Register(new ImageDataViewType(10, 10), typeof(Bitmap), new ImageTypeAttribute(10, 10)); + DataViewTypeManager.Register(new ImageDataViewType(10, 10), typeof(Bitmap), new[] { new ImageTypeAttribute(10, 10) }); } public ImageDataPoint() From 4b7ec09b73cafff00b086df87ac8237e2e7e27da Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 22 May 2019 17:14:58 -0700 Subject: [PATCH 13/24] Address comments --- src/Microsoft.ML.Data/Data/SchemaDefinition.cs | 4 ++-- src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index ff6e0628d2..31a35529bb 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -394,7 +394,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc // Get the column type. DataViewType columnType; - if (!DataViewTypeManager.Knows(dataType, memberInfo.GetCustomAttributes().ToArray())) + if (!DataViewTypeManager.Knows(dataType, memberInfo.GetCustomAttributes())) { PrimitiveDataViewType itemType; var keyAttr = memberInfo.GetCustomAttribute(); @@ -427,7 +427,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc columnType = itemType; } else - columnType = DataViewTypeManager.GetDataViewType(dataType, memberInfo.GetCustomAttributes().ToArray()); + columnType = DataViewTypeManager.GetDataViewType(dataType, memberInfo.GetCustomAttributes()); cols.Add(new Column(memberInfo.Name, columnType, name)); } diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index c662c4aab2..1aa2301033 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -147,11 +147,11 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector switch (memberInfo) { case FieldInfo fieldInfo: - GetVectorAndItemType(fieldInfo.FieldType, fieldInfo.Name, out isVector, out itemType, fieldInfo.GetCustomAttributes().ToArray()); + GetVectorAndItemType(fieldInfo.FieldType, fieldInfo.Name, out isVector, out itemType, fieldInfo.GetCustomAttributes()); break; case PropertyInfo propertyInfo: - GetVectorAndItemType(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out itemType, propertyInfo.GetCustomAttributes().ToArray()); + GetVectorAndItemType(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out itemType, propertyInfo.GetCustomAttributes()); break; default: @@ -172,7 +172,7 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector /// The corresponding RawType of the type, or items of this type if vector. /// /// Attribute of . - public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType, params Attribute[] attributes) + public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType, IEnumerable attributes=null) { // Determine whether this is a vector, and also determine the raw item type. isVector = true; From 264f60fa8b85a3dae31d7b69ff143022dbde4912 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 23 May 2019 09:11:17 -0700 Subject: [PATCH 14/24] Allow auto-register --- .../Data/SchemaDefinition.cs | 10 +++ src/Microsoft.ML.DataView/DataViewType.cs | 12 ++++ src/Microsoft.ML.DataView/TypeManager.cs | 4 +- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 8 ++- .../UnitTests/TestCustomTypeRegister.cs | 72 +++++++++---------- test/Microsoft.ML.Tests/ImagesTests.cs | 5 -- 6 files changed, 65 insertions(+), 46 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index 31a35529bb..1b13c3fcc0 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -382,6 +382,16 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc if (memberInfo.GetCustomAttribute() != null) continue; + var customTypeAttributes = memberInfo.GetCustomAttributes().Where(x => x is DataViewTypeAttribute); + if (customTypeAttributes.Count() > 1) + throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.", + memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute)); + else if (customTypeAttributes.Count() == 1) + { + var customTypeAttribute = (DataViewTypeAttribute)customTypeAttributes.First(); + customTypeAttribute.Register(); + } + var mappingNameAttr = memberInfo.GetCustomAttribute(); string name = mappingNameAttr?.Name ?? memberInfo.Name; // Disallow duplicate names, because the field enumeration order is not actually diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index 153ba02261..c0e527ea6a 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -461,4 +461,16 @@ public override bool Equals(DataViewType other) public override string ToString() => "TimeSpan"; } + + public abstract class DataViewTypeAttribute : Attribute + { + protected DataViewTypeAttribute() : base() + { + } + + /// + /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custome type plus its attributes. + /// + public abstract void Register(); + } } \ No newline at end of file diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs index 8a895f97aa..fbc1c9731c 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -76,7 +76,7 @@ private static int ComputeHashCode(Type rawType, IEnumerable rawTypeA /// /// Returns the registered for and its . /// - public static DataViewType GetDataViewType(Type rawType, IEnumerable rawTypeAttributes) + public static DataViewType GetDataViewType(Type rawType, IEnumerable rawTypeAttributes = null) { // Overall flow: // type (Type) + attrs ----> type ID ----------------> associated DataViewType's ID ----------------> DataViewType @@ -99,7 +99,7 @@ public static DataViewType GetDataViewType(Type rawType, IEnumerable /// If has been registered with a , this function returns . /// Otherwise, this function returns . /// - public static bool Knows(Type rawType, IEnumerable rawTypeAttributes) + public static bool Knows(Type rawType, IEnumerable rawTypeAttributes = null) { lock (_lock) { diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 709b2edc7f..ab2e2f9ca5 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.Transforms.Image /// the shape of an image field. /// [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] - public sealed class ImageTypeAttribute : Attribute + public sealed class ImageTypeAttribute : DataViewTypeAttribute { /// /// The height of the image type. @@ -63,6 +63,11 @@ public override bool Equals(object other) else return false; } + + public override void Register() + { + DataViewTypeManager.Register(new ImageDataViewType(Height, Width), typeof(Bitmap), new[] { this }); + } } public sealed class ImageDataViewType : StructuredDataViewType @@ -72,7 +77,6 @@ public sealed class ImageDataViewType : StructuredDataViewType static ImageDataViewType() { - DataViewTypeManager.Register(new ImageDataViewType(), typeof(Bitmap)); } public ImageDataViewType(int height, int width) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 9c21962715..a3498fc098 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -39,7 +39,7 @@ public AlienBody(int age, float height, float weight, int handCount) } [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] - private sealed class AlienTypeAttributeAttribute : Attribute + private sealed class AlienTypeAttributeAttribute : DataViewTypeAttribute { public int Id { get; } @@ -50,6 +50,14 @@ public AlienTypeAttributeAttribute(int id) { Id = id; } + + /// + /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custome type plus its attributes. + /// + public override void Register() + { + DataViewTypeManager.Register(new DataViewAlienBodyType(Id), typeof(AlienBody), new[] { this }); + } } /// @@ -154,46 +162,14 @@ public override Action GetMapping() [Fact] public void RegisterTypeWithAttribute() { + // Build in-memory data. var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; - // Type manager doesn't know any of those custom types, so all calls to it should return false. - Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(100))); - Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(200))); - Assert.False(DataViewTypeManager.Knows(new DataViewAlienBodyType(007))); - Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) })); - Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) })); - Assert.False(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) })); - - // Register those custom types. - DataViewTypeManager.Register(new DataViewAlienBodyType(100), typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) }); - DataViewTypeManager.Register(new DataViewAlienBodyType(200), typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) }); - DataViewTypeManager.Register(new DataViewAlienBodyType(007), typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) }); - - // Type manager now knows those custom types, so all calls to it should return true. - Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(100))); - Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(200))); - Assert.True(DataViewTypeManager.Knows(new DataViewAlienBodyType(007))); - Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) })); - Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) })); - Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) })); - - // Check if the custom type (AlienBody with its attributes) is registered correctly with a DataView type (DataViewAlienBodyType). - Assert.Equal(new DataViewAlienBodyType(100), - DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(100) })); - Assert.Equal(new DataViewAlienBodyType(200), - DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(200) })); - Assert.Equal(new DataViewAlienBodyType(007), - DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { new AlienTypeAttributeAttribute(007) })); - // Build a ML.NET pipeline and make prediction. var tribeDataView = ML.Data.LoadFromEnumerable(tribe); - var heroEstimator = new CustomMappingEstimator(ML, AlienLambda.MergeBody, "LambdaAlienHero"); - var model = heroEstimator.Fit(tribeDataView); - var tribeTransformed = model.Transform(tribeDataView); - var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); // Make sure the pipeline output is correct. @@ -236,8 +212,18 @@ public void TestTypeManager() Assert.Equal(cCode, dCode); // Check registering the same type pair is OK. + // Note that "a" and "b" should be identical. DataViewTypeManager.Register(a, typeof(AlienBody)); DataViewTypeManager.Register(a, typeof(AlienBody)); + DataViewTypeManager.Register(b, typeof(AlienBody)); + DataViewTypeManager.Register(b, typeof(AlienBody)); + + // Check if register of (a, typeof(AlienBody)) successes. + Assert.True(DataViewTypeManager.Knows(a)); + Assert.True(DataViewTypeManager.Knows(b)); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody))); + Assert.Equal(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody))); + Assert.Equal(b, DataViewTypeManager.GetDataViewType(typeof(AlienBody))); // Make sure registering the same type twice throws. bool isWrong = false; @@ -254,7 +240,7 @@ public void TestTypeManager() Assert.True(isWrong); // Make sure registering the same type twice throws. - bool isWrongAgain = false; + isWrong = false; try { // AlienBody has been registered with "a," so user can't register it with @@ -263,9 +249,21 @@ public void TestTypeManager() } catch { - isWrongAgain = true; + isWrong = true; } - Assert.True(isWrongAgain); + Assert.True(isWrong); + + // Register a type with attribute. + var e = new DataViewAlienBodyType(7788); + var f = new AlienTypeAttributeAttribute(8877); + DataViewTypeManager.Register(e, typeof(AlienBody), new[] { f }); + Assert.True(DataViewTypeManager.Knows(e)); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f })); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f })); + // "e" is associated with typeof(AlienBody) with "f," so the call below should return true. + Assert.Equal(e, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f })); + // "a" is associated with typeof(AlienBody) without any attribute, so the call below should return false. + Assert.NotEqual(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f })); } } } diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 8fd12b950c..77a913fd5c 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -252,11 +252,6 @@ private class ImageDataPoint [ImageType(10, 10)] public Bitmap GrayImage { get; set; } - static ImageDataPoint() - { - DataViewTypeManager.Register(new ImageDataViewType(10, 10), typeof(Bitmap), new[] { new ImageTypeAttribute(10, 10) }); - } - public ImageDataPoint() { Image = null; From 7370a478bf2451d11a995b257e860d367cc43b29 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 23 May 2019 13:00:51 -0700 Subject: [PATCH 15/24] Make key unique --- src/Microsoft.ML.DataView/DataViewType.cs | 29 +++- src/Microsoft.ML.DataView/TypeManager.cs | 127 +++++++++++++++--- src/Microsoft.ML.DataView/VectorType.cs | 7 +- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 7 +- .../UnitTests/TestCustomTypeRegister.cs | 2 +- 5 files changed, 134 insertions(+), 38 deletions(-) diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index c0e527ea6a..e491daf2dd 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -37,11 +37,25 @@ private protected DataViewType(Type rawType) /// public Type RawType { get; } - // IEquatable interface recommends also to override base class implementations of - // Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other) - // is effectively a referencial comparison, there is no need to override base class implementations - // of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison. + /// + /// Return if is equivalent to and otherwise. + /// + /// Another to be compared with . public abstract bool Equals(DataViewType other); + + /// + /// Produce the hashing code of . It's the implementation of . + /// + public abstract int GetDataViewTypeHashCode(); + + public override bool Equals(object obj) + { + if (obj is DataViewType) + return Equals((DataViewType)obj); + return false; + } + + public override int GetHashCode() => GetDataViewTypeHashCode(); } /// @@ -76,6 +90,11 @@ protected PrimitiveDataViewType(Type rawType) if (typeof(IDisposable).GetTypeInfo().IsAssignableFrom(RawType.GetTypeInfo())) throw new ArgumentException("A " + nameof(PrimitiveDataViewType) + " cannot have a disposable " + nameof(RawType), nameof(rawType)); } + + /// + /// All primitive s are singltons, so we only need one hash code for each of them. + /// + public override int GetDataViewTypeHashCode() => 0; } /// @@ -469,7 +488,7 @@ protected DataViewTypeAttribute() : base() } /// - /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custome type plus its attributes. + /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custom type plus its attributes. /// public abstract void Register(); } diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/TypeManager.cs index fbc1c9731c..078cdb8521 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/TypeManager.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Reflection; using Microsoft.ML.Internal.DataView; @@ -32,31 +33,25 @@ public static class DataViewTypeManager /// /// Mapping from ID to a . The ID is the ID of in ML.NET's type system. /// - private static Dictionary _idToTypeMap = new Dictionary(); + private static Dictionary _idToTypeMap = new Dictionary(); /// /// Mapping from ID to a instance. The ID is the ID of instance in ML.NET's type system. /// - private static Dictionary _idToDataViewTypeMap = new Dictionary(); + private static Dictionary _idToDataViewTypeMap = new Dictionary(); /// /// Mapping from hashing ID of a and its s to hashing ID of a . /// - private static Dictionary _typeIdToDataViewTypeIdMap = new Dictionary(); + private static Dictionary _typeIdToDataViewTypeIdMap = new Dictionary(); /// /// Mapping from hashing ID of a to hashing ID of a and its s. /// - private static Dictionary _dataViewTypeIdToTypeIdMap = new Dictionary(); + private static Dictionary _dataViewTypeIdToTypeIdMap = new Dictionary(); private static object _lock = new object(); - /// - /// This function computes a hashing ID from and attributes attached to it. - /// If a type is defined as a member in a , can be obtained by calling - /// . - /// - /// private static int ComputeHashCode(Type rawType, IEnumerable rawTypeAttributes) { if (rawTypeAttributes == null) @@ -84,12 +79,14 @@ public static DataViewType GetDataViewType(Type rawType, IEnumerable lock (_lock) { // Compute the ID of type with extra attributes. - var typeId = ComputeHashCode(rawType, rawTypeAttributes); + var typeId = new TypeWithAttributesId(rawType, rawTypeAttributes); // Get the DataViewType's ID which typeID is mapped into. - if (!_typeIdToDataViewTypeIdMap.TryGetValue(typeId, out int dataViewTypeId)) + if (!_typeIdToDataViewTypeIdMap.TryGetValue(typeId, out DataViewTypeId dataViewTypeId)) throw Contracts.ExceptParam(nameof(rawType), $"The raw type {rawType} with attributes {rawTypeAttributes} is not registered with a DataView type."); + var x = _idToDataViewTypeMap.Keys.First(); + var y = x.Equals(dataViewTypeId); // Retrieve the actual DataViewType identified by dataViewTypeId. return _idToDataViewTypeMap[dataViewTypeId]; } @@ -104,7 +101,7 @@ public static bool Knows(Type rawType, IEnumerable rawTypeAttributes lock (_lock) { // Compute the ID of type with extra attributes. - var typeId = ComputeHashCode(rawType, rawTypeAttributes); + var typeId = new TypeWithAttributesId(rawType, rawTypeAttributes); // Check if this ID has been associated with a DataViewType. // Note that the dictionary below contains (typeId, type) pairs (key is typeId, and value is type). @@ -124,7 +121,7 @@ public static bool Knows(DataViewType dataViewType) lock (_lock) { // Compute the ID of the input DataViewType. - var dataViewTypeId = ComputeHashCode(dataViewType); + var dataViewTypeId = new DataViewTypeId(dataViewType); // Check if this the ID has been associated with a DataViewType. // Note that the dictionary below contains (dataViewTypeId, type) pairs (key is dataViewTypeId, and value is type). @@ -150,15 +147,15 @@ public static void Register(DataViewType dataViewType, Type rawType, IEnumerable throw Contracts.ExceptParam(nameof(rawType), $"Type {rawType} has been registered as ML.NET's default supported type, " + $"so it can't not be registered again."); - int rawTypeId = ComputeHashCode(rawType, rawTypeAttributes); - int dataViewTypeId = ComputeHashCode(dataViewType); + var rawTypeId = new TypeWithAttributesId(rawType, rawTypeAttributes); + var dataViewTypeId = new DataViewTypeId(dataViewType); - if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && _typeIdToDataViewTypeIdMap[rawTypeId] == dataViewTypeId && - _dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && _dataViewTypeIdToTypeIdMap[dataViewTypeId] == rawTypeId) + if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && _typeIdToDataViewTypeIdMap[rawTypeId].Equals(dataViewTypeId) && + _dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && _dataViewTypeIdToTypeIdMap[dataViewTypeId].Equals(rawTypeId)) // This type pair has been registered. Note that registering one data type pair multiple times is allowed. return; - if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && _typeIdToDataViewTypeIdMap[rawTypeId] != dataViewTypeId) + if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && !_typeIdToDataViewTypeIdMap[rawTypeId].Equals(dataViewTypeId)) { // There is a pair of (rawTypeId, anotherDataViewTypeId) in _typeIdToDataViewTypeId so we cannot register // (rawTypeId, dataViewTypeId) again. The assumption here is that one rawTypeId can only be associated @@ -168,7 +165,7 @@ public static void Register(DataViewType dataViewType, Type rawType, IEnumerable $"has been associated with {associatedDataViewType} so it cannot be associated with {dataViewType}."); } - if (_dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && _dataViewTypeIdToTypeIdMap[dataViewTypeId] != rawTypeId) + if (_dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && !_dataViewTypeIdToTypeIdMap[dataViewTypeId].Equals(rawTypeId)) { // There is a pair of (dataViewTypeId, anotherRawTypeId) in _dataViewTypeIdToTypeId so we cannot register // (dataViewTypeId, rawTypeId) again. The assumption here is that one dataViewTypeId can only be associated @@ -185,5 +182,95 @@ public static void Register(DataViewType dataViewType, Type rawType, IEnumerable _idToTypeMap.Add(rawTypeId, rawType); } } + + /// + /// An instance of represents an unique key of its and . + /// + private class TypeWithAttributesId + { + public Type TargetType { get; } + public IEnumerable AssociatedAttributes { get; } + + public TypeWithAttributesId(Type rawType, IEnumerable attributes) + { + TargetType = rawType; + AssociatedAttributes = attributes; + } + + /// + /// This function computes a hashing ID from and attributes attached to it. + /// If a type is defined as a member in a , can be obtained by calling + /// . + /// + public override int GetHashCode() + { + if (AssociatedAttributes == null) + return TargetType.GetHashCode(); + + var code = TargetType.GetHashCode(); + foreach (var attr in AssociatedAttributes) + code = Hashing.CombineHash(code, attr.GetHashCode()); + return code; + } + + public override bool Equals(object obj) + { + if (obj is TypeWithAttributesId) + { + var other = (TypeWithAttributesId)obj; + var sameType = TargetType.Equals(other.TargetType); + + var sameAttributeConfig = true; + + if (AssociatedAttributes == null && other.AssociatedAttributes == null) + sameAttributeConfig = true; + else if (AssociatedAttributes == null && other.AssociatedAttributes != null) + sameAttributeConfig = false; + else if (AssociatedAttributes != null && other.AssociatedAttributes == null) + sameAttributeConfig = false; + else + { + var zipped = AssociatedAttributes.Zip(other.AssociatedAttributes, (attr, otherAttr) => (attr, otherAttr)); + foreach (var (attr, otherAttr) in zipped) + { + if (!attr.Equals(otherAttr)) + sameAttributeConfig = false; + } + } + + return sameType && sameAttributeConfig; + } + return false; + } + } + + /// + /// An instance of represents an unique key of its . + /// + private class DataViewTypeId + { + public DataViewType TargetType { get; } + + public DataViewTypeId(DataViewType type) + { + TargetType = type; + } + + public override bool Equals(object obj) + { + if (obj is DataViewTypeId) + { + var other = (DataViewTypeId)obj; + return TargetType.Equals(other.TargetType); + } + + return false; + } + + public override int GetHashCode() + { + return TargetType.GetHashCode(); + } + } } } diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index 8402b27661..1052cf9446 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -138,12 +138,7 @@ public override bool Equals(DataViewType other) return true; } - public override bool Equals(object other) - { - return other is DataViewType tmp && Equals(tmp); - } - - public override int GetHashCode() + public override int GetDataViewTypeHashCode() { int hash = Hashing.CombineHash(ItemType.GetHashCode(), Size); hash = Hashing.CombineHash(hash, Dimensions.Length); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index ab2e2f9ca5..56794d6f62 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -105,12 +105,7 @@ public override bool Equals(DataViewType other) return Width == tmp.Width; } - public override bool Equals(object other) - { - return other is DataViewType tmp && Equals(tmp); - } - - public override int GetHashCode() + public override int GetDataViewTypeHashCode() { return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index a3498fc098..6475807183 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -114,7 +114,7 @@ public override bool Equals(DataViewType other) return false; } - public override int GetHashCode() + public override int GetDataViewTypeHashCode() { return Id.GetHashCode(); } From 977c422d4512799398b1d0d52d63336fdcace3cb Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 23 May 2019 15:51:35 -0700 Subject: [PATCH 16/24] Address comments --- .../Data/SchemaDefinition.cs | 7 ++-- src/Microsoft.ML.Data/Utils/ApiUtils.cs | 3 +- src/Microsoft.ML.DataView/DataViewType.cs | 38 +++++++++++++++++-- ...{TypeManager.cs => DataViewTypeManager.cs} | 2 - src/Microsoft.ML.ImageAnalytics/ImageType.cs | 22 ++++------- .../UnitTests/TestCustomTypeRegister.cs | 23 +++++++---- 6 files changed, 64 insertions(+), 31 deletions(-) rename src/Microsoft.ML.DataView/{TypeManager.cs => DataViewTypeManager.cs} (99%) diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index 1b13c3fcc0..157d35d940 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -382,7 +382,8 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc if (memberInfo.GetCustomAttribute() != null) continue; - var customTypeAttributes = memberInfo.GetCustomAttributes().Where(x => x is DataViewTypeAttribute); + var customAttributes = memberInfo.GetCustomAttributes(); + var customTypeAttributes = customAttributes.Where(x => x is DataViewTypeAttribute); if (customTypeAttributes.Count() > 1) throw Contracts.ExceptParam(nameof(userType), "Member {0} cannot be marked with multiple attributes, {1}, derived from {2}.", memberInfo.Name, customTypeAttributes, typeof(DataViewTypeAttribute)); @@ -404,7 +405,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc // Get the column type. DataViewType columnType; - if (!DataViewTypeManager.Knows(dataType, memberInfo.GetCustomAttributes())) + if (!DataViewTypeManager.Knows(dataType, customAttributes)) { PrimitiveDataViewType itemType; var keyAttr = memberInfo.GetCustomAttribute(); @@ -437,7 +438,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc columnType = itemType; } else - columnType = DataViewTypeManager.GetDataViewType(dataType, memberInfo.GetCustomAttributes()); + columnType = DataViewTypeManager.GetDataViewType(dataType, customAttributes); cols.Add(new Column(memberInfo.Name, columnType, name)); } diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index c4c73c60db..5e53824019 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Reflection.Emit; @@ -17,7 +18,7 @@ namespace Microsoft.ML internal static class ApiUtils { - private static OpCode GetAssignmentOpCode(Type t, params Attribute[] attributes) + private static OpCode GetAssignmentOpCode(Type t, IEnumerable attributes) { // REVIEW: This should be a Dictionary based solution. // DvTypes, strings, arrays, all nullable types, VBuffers and RowId. diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index e491daf2dd..fd5c619fcb 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -481,15 +481,47 @@ public override bool Equals(DataViewType other) public override string ToString() => "TimeSpan"; } + /// + /// should be used to decorated class properties and fields, if that class' instances will be loaded as ML.NET . + /// The function will be called to register a for a with its s. + /// Whenever a value typed to the registered and its s, its would be the + /// associated . + /// public abstract class DataViewTypeAttribute : Attribute { - protected DataViewTypeAttribute() : base() + /// + /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custom type plus its attributes. + /// + public abstract void Register(); + + /// + /// Return if is equivalent to and otherwise. + /// + /// Another to be compared with . + public abstract bool Equals(DataViewTypeAttribute other); + + /// + /// Produce the hashing code of . It's the implementation of . + /// + public abstract int GetDataViewTypeAttributeHashCode(); + + /// + /// Return if is equivalent to and otherwise. + /// Derived classes should implement their comparison logics by overriding . + /// + /// An to be compared with . + public sealed override bool Equals(object obj) { + if (obj is DataViewTypeAttribute) + return Equals((DataViewTypeAttribute)obj); + return false; } /// - /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custom type plus its attributes. + /// Returns hash code of . + /// Derived classes override to implement their own hashing algorithm. Notice that equivalent attributes should + /// produce the same hash code. /// - public abstract void Register(); + public sealed override int GetHashCode() => GetDataViewTypeAttributeHashCode(); } } \ No newline at end of file diff --git a/src/Microsoft.ML.DataView/TypeManager.cs b/src/Microsoft.ML.DataView/DataViewTypeManager.cs similarity index 99% rename from src/Microsoft.ML.DataView/TypeManager.cs rename to src/Microsoft.ML.DataView/DataViewTypeManager.cs index 078cdb8521..75c004ebe7 100644 --- a/src/Microsoft.ML.DataView/TypeManager.cs +++ b/src/Microsoft.ML.DataView/DataViewTypeManager.cs @@ -85,8 +85,6 @@ public static DataViewType GetDataViewType(Type rawType, IEnumerable if (!_typeIdToDataViewTypeIdMap.TryGetValue(typeId, out DataViewTypeId dataViewTypeId)) throw Contracts.ExceptParam(nameof(rawType), $"The raw type {rawType} with attributes {rawTypeAttributes} is not registered with a DataView type."); - var x = _idToDataViewTypeMap.Keys.First(); - var y = x.Equals(dataViewTypeId); // Retrieve the actual DataViewType identified by dataViewTypeId. return _idToDataViewTypeMap[dataViewTypeId]; } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 56794d6f62..e34e934ad9 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -46,22 +46,18 @@ public ImageTypeAttribute(int height, int width) } /// - /// Images with the same width and height should produce the same hash code. + /// Images with the same width and height should equal. /// - public override int GetHashCode() + public override bool Equals(DataViewTypeAttribute other) { - return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); + if (other is ImageTypeAttribute) + return Height == ((ImageTypeAttribute)other).Height && Width == ((ImageTypeAttribute)other).Width; + return false; } - /// - /// Images with the same width and height should equal. - /// - public override bool Equals(object other) + public override int GetDataViewTypeAttributeHashCode() { - if (other is ImageDataViewType) - return Height == ((ImageDataViewType)other).Height && Width == ((ImageDataViewType)other).Width; - else - return false; + return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); } public override void Register() @@ -75,10 +71,6 @@ public sealed class ImageDataViewType : StructuredDataViewType public readonly int Height; public readonly int Width; - static ImageDataViewType() - { - } - public ImageDataViewType(int height, int width) : base(typeof(Bitmap)) { diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 6475807183..f5080a6c92 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -41,14 +41,14 @@ public AlienBody(int age, float height, float weight, int handCount) [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] private sealed class AlienTypeAttributeAttribute : DataViewTypeAttribute { - public int Id { get; } + public int RaceId { get; } /// /// Create an image type with known height and width. /// public AlienTypeAttributeAttribute(int id) { - Id = id; + RaceId = id; } /// @@ -56,8 +56,17 @@ public AlienTypeAttributeAttribute(int id) /// public override void Register() { - DataViewTypeManager.Register(new DataViewAlienBodyType(Id), typeof(AlienBody), new[] { this }); + DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), new[] { this }); } + + public override bool Equals(DataViewTypeAttribute other) + { + if (other is AlienTypeAttributeAttribute) + return RaceId == ((AlienTypeAttributeAttribute)other).RaceId; + return false; + } + + public override int GetDataViewTypeAttributeHashCode() => RaceId.GetHashCode(); } /// @@ -99,24 +108,24 @@ public AlienHero(string name, /// private class DataViewAlienBodyType : StructuredDataViewType { - public int Id { get; } + public int RaceId { get; } public DataViewAlienBodyType(int id) : base(typeof(AlienBody)) { - Id = id; + RaceId = id; } public override bool Equals(DataViewType other) { if (other is DataViewAlienBodyType) - return ((DataViewAlienBodyType)other).Id == Id; + return ((DataViewAlienBodyType)other).RaceId == RaceId; else return false; } public override int GetDataViewTypeHashCode() { - return Id.GetHashCode(); + return RaceId.GetHashCode(); } } From c020196345c707f3beededddefe2f3b2e257e23e Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 23 May 2019 16:20:40 -0700 Subject: [PATCH 17/24] Address some comments --- .../DataView/InternalSchemaDefinition.cs | 3 +- src/Microsoft.ML.Data/Utils/ApiUtils.cs | 8 +- src/Microsoft.ML.DataView/DataViewType.cs | 51 +-------- .../DataViewTypeManager.cs | 101 ++++++++---------- src/Microsoft.ML.DataView/VectorType.cs | 2 +- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 11 +- .../UnitTests/TestCustomTypeRegister.cs | 4 +- 7 files changed, 64 insertions(+), 116 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index 1aa2301033..dfc1d2e52e 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -189,10 +189,9 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe // The internal type of string is ReadOnlyMemory. That is, string will be stored as ReadOnlyMemory in IDataView. if (itemType == typeof(string)) itemType = typeof(ReadOnlyMemory); - // Check if the itemType extracted from rawType is supported by ML.NET's type system. // It must be one of either ML.NET's pre-defined types or custom types registered by the user. - if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes)) + else if (!itemType.TryGetDataKind(out _) && !DataViewTypeManager.Knows(itemType, attributes)) throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name); } diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index 5e53824019..117e1981cb 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -59,7 +59,7 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes().ToArray()); + var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes()); Func func = GeneratePeek; var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); @@ -68,7 +68,7 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes().ToArray()); + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes()); Func funcProp = GeneratePeek; var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); @@ -135,7 +135,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes().ToArray()); + var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes()); Func func = GeneratePoke; var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); @@ -144,7 +144,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes().ToArray()); + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes()); Func funcProp = GeneratePoke; var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index fd5c619fcb..d84fcfa03a 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -42,20 +42,6 @@ private protected DataViewType(Type rawType) /// /// Another to be compared with . public abstract bool Equals(DataViewType other); - - /// - /// Produce the hashing code of . It's the implementation of . - /// - public abstract int GetDataViewTypeHashCode(); - - public override bool Equals(object obj) - { - if (obj is DataViewType) - return Equals((DataViewType)obj); - return false; - } - - public override int GetHashCode() => GetDataViewTypeHashCode(); } /// @@ -90,11 +76,6 @@ protected PrimitiveDataViewType(Type rawType) if (typeof(IDisposable).GetTypeInfo().IsAssignableFrom(RawType.GetTypeInfo())) throw new ArgumentException("A " + nameof(PrimitiveDataViewType) + " cannot have a disposable " + nameof(RawType), nameof(rawType)); } - - /// - /// All primitive s are singltons, so we only need one hash code for each of them. - /// - public override int GetDataViewTypeHashCode() => 0; } /// @@ -484,10 +465,10 @@ public override bool Equals(DataViewType other) /// /// should be used to decorated class properties and fields, if that class' instances will be loaded as ML.NET . /// The function will be called to register a for a with its s. - /// Whenever a value typed to the registered and its s, its would be the - /// associated . + /// Whenever a value typed to the registered and its s, that value's type (i.e., a ) + /// in would be the associated . /// - public abstract class DataViewTypeAttribute : Attribute + public abstract class DataViewTypeAttribute : Attribute, IEquatable { /// /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custom type plus its attributes. @@ -497,31 +478,7 @@ public abstract class DataViewTypeAttribute : Attribute /// /// Return if is equivalent to and otherwise. /// - /// Another to be compared with . + /// Another to be compared with . public abstract bool Equals(DataViewTypeAttribute other); - - /// - /// Produce the hashing code of . It's the implementation of . - /// - public abstract int GetDataViewTypeAttributeHashCode(); - - /// - /// Return if is equivalent to and otherwise. - /// Derived classes should implement their comparison logics by overriding . - /// - /// An to be compared with . - public sealed override bool Equals(object obj) - { - if (obj is DataViewTypeAttribute) - return Equals((DataViewTypeAttribute)obj); - return false; - } - - /// - /// Returns hash code of . - /// Derived classes override to implement their own hashing algorithm. Notice that equivalent attributes should - /// produce the same hash code. - /// - public sealed override int GetHashCode() => GetDataViewTypeAttributeHashCode(); } } \ No newline at end of file diff --git a/src/Microsoft.ML.DataView/DataViewTypeManager.cs b/src/Microsoft.ML.DataView/DataViewTypeManager.cs index 75c004ebe7..74011f96d0 100644 --- a/src/Microsoft.ML.DataView/DataViewTypeManager.cs +++ b/src/Microsoft.ML.DataView/DataViewTypeManager.cs @@ -17,9 +17,11 @@ namespace Microsoft.ML.Data /// public static class DataViewTypeManager { - // Types have been used in ML.NET type systems. They can have multiple-to-one type mapping. - // For example, UInt32 and Key can be mapped to uint. This class enforces one-to-one mapping for all - // user-registered types. + /// + /// Types have been used in ML.NET type systems. They can have multiple-to-one type mapping. + /// For example, UInt32 and Key can be mapped to uint. This class enforces one-to-one mapping for all + /// user-registered types. + /// private static HashSet _bannedRawTypes = new HashSet() { typeof(Boolean), typeof(SByte), typeof(Byte), @@ -50,23 +52,10 @@ public static class DataViewTypeManager /// private static Dictionary _dataViewTypeIdToTypeIdMap = new Dictionary(); - private static object _lock = new object(); - - private static int ComputeHashCode(Type rawType, IEnumerable rawTypeAttributes) - { - if (rawTypeAttributes == null) - return rawType.GetHashCode(); - - var code = rawType.GetHashCode(); - foreach (var attr in rawTypeAttributes) - code = Hashing.CombineHash(code, attr.GetHashCode()); - return code; - } - /// - /// This function hashes a and its own hashing code together. + /// The lock that one should acquire if the state of will be accessed or modified. /// - private static int ComputeHashCode(DataViewType dataViewType) => Hashing.CombineHash(dataViewType.GetType().GetHashCode(), dataViewType.GetHashCode()); + private static object _lock = new object(); /// /// Returns the registered for and its . @@ -182,53 +171,39 @@ public static void Register(DataViewType dataViewType, Type rawType, IEnumerable } /// - /// An instance of represents an unique key of its and . + /// An instance of represents an unique key of its and . /// private class TypeWithAttributesId { - public Type TargetType { get; } - public IEnumerable AssociatedAttributes { get; } + private Type _targetType; + private IEnumerable _associatedAttributes; public TypeWithAttributesId(Type rawType, IEnumerable attributes) { - TargetType = rawType; - AssociatedAttributes = attributes; - } - - /// - /// This function computes a hashing ID from and attributes attached to it. - /// If a type is defined as a member in a , can be obtained by calling - /// . - /// - public override int GetHashCode() - { - if (AssociatedAttributes == null) - return TargetType.GetHashCode(); - - var code = TargetType.GetHashCode(); - foreach (var attr in AssociatedAttributes) - code = Hashing.CombineHash(code, attr.GetHashCode()); - return code; + _targetType = rawType; + _associatedAttributes = attributes; } public override bool Equals(object obj) { - if (obj is TypeWithAttributesId) + if (obj is TypeWithAttributesId other) { - var other = (TypeWithAttributesId)obj; - var sameType = TargetType.Equals(other.TargetType); - + // Flag of having the same type. + var sameType = _targetType.Equals(other._targetType); + // Flag of having the attribute configurations. var sameAttributeConfig = true; - if (AssociatedAttributes == null && other.AssociatedAttributes == null) + if (_associatedAttributes == null && other._associatedAttributes == null) sameAttributeConfig = true; - else if (AssociatedAttributes == null && other.AssociatedAttributes != null) + else if (_associatedAttributes == null && other._associatedAttributes != null) sameAttributeConfig = false; - else if (AssociatedAttributes != null && other.AssociatedAttributes == null) + else if (_associatedAttributes != null && other._associatedAttributes == null) + sameAttributeConfig = false; + else if (_associatedAttributes.Count() != other._associatedAttributes.Count()) sameAttributeConfig = false; else { - var zipped = AssociatedAttributes.Zip(other.AssociatedAttributes, (attr, otherAttr) => (attr, otherAttr)); + var zipped = _associatedAttributes.Zip(other._associatedAttributes, (attr, otherAttr) => (attr, otherAttr)); foreach (var (attr, otherAttr) in zipped) { if (!attr.Equals(otherAttr)) @@ -240,34 +215,48 @@ public override bool Equals(object obj) } return false; } + + /// + /// This function computes a hashing ID from and attributes attached to it. + /// If a type is defined as a member in a , can be obtained by calling + /// . + /// + public override int GetHashCode() + { + if (_associatedAttributes == null) + return _targetType.GetHashCode(); + + var code = _targetType.GetHashCode(); + foreach (var attr in _associatedAttributes) + code = Hashing.CombineHash(code, attr.GetHashCode()); + return code; + } + } /// - /// An instance of represents an unique key of its . + /// An instance of represents an unique key of its . /// private class DataViewTypeId { - public DataViewType TargetType { get; } + private DataViewType _targetType; public DataViewTypeId(DataViewType type) { - TargetType = type; + _targetType = type; } public override bool Equals(object obj) { - if (obj is DataViewTypeId) - { - var other = (DataViewTypeId)obj; - return TargetType.Equals(other.TargetType); - } + if (obj is DataViewTypeId other) + return _targetType.Equals(other._targetType); return false; } public override int GetHashCode() { - return TargetType.GetHashCode(); + return _targetType.GetHashCode(); } } } diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index 1052cf9446..ff2e707e63 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -138,7 +138,7 @@ public override bool Equals(DataViewType other) return true; } - public override int GetDataViewTypeHashCode() + public override int GetHashCode() { int hash = Hashing.CombineHash(ItemType.GetHashCode(), Size); hash = Hashing.CombineHash(hash, Dimensions.Length); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index e34e934ad9..96a46a23c5 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -50,12 +50,15 @@ public ImageTypeAttribute(int height, int width) /// public override bool Equals(DataViewTypeAttribute other) { - if (other is ImageTypeAttribute) - return Height == ((ImageTypeAttribute)other).Height && Width == ((ImageTypeAttribute)other).Width; + if (other is ImageTypeAttribute otherImage) + return Height == otherImage.Height && Width == otherImage.Width; return false; } - public override int GetDataViewTypeAttributeHashCode() + /// + /// Produce the same hash code for all images with the same height and the same width. + /// + public override int GetHashCode() { return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); } @@ -97,7 +100,7 @@ public override bool Equals(DataViewType other) return Width == tmp.Width; } - public override int GetDataViewTypeHashCode() + public override int GetHashCode() { return Hashing.CombineHash(Height.GetHashCode(), Width.GetHashCode()); } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index f5080a6c92..b9a005a271 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -66,7 +66,7 @@ public override bool Equals(DataViewTypeAttribute other) return false; } - public override int GetDataViewTypeAttributeHashCode() => RaceId.GetHashCode(); + public override int GetHashCode() => RaceId.GetHashCode(); } /// @@ -123,7 +123,7 @@ public override bool Equals(DataViewType other) return false; } - public override int GetDataViewTypeHashCode() + public override int GetHashCode() { return RaceId.GetHashCode(); } From 5ab21f39fffe7530445d32686c62ee595778d2a7 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 23 May 2019 17:12:33 -0700 Subject: [PATCH 18/24] Remove redundant dictionaries Polish --- .../DataView/DataViewConstructionUtils.cs | 2 +- .../DataView/InternalSchemaDefinition.cs | 14 +++--- src/Microsoft.ML.DataView/DataViewType.cs | 5 +++ .../DataViewTypeManager.cs | 45 +++++++------------ src/Microsoft.ML.DataView/VectorType.cs | 5 +++ src/Microsoft.ML.ImageAnalytics/ImageType.cs | 1 - 6 files changed, 34 insertions(+), 38 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 8e5e9ba578..6df350c20c 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -847,7 +847,7 @@ public AnnotationInfo(string kind, T value, DataViewType annotationType = null) Contracts.Assert(value != null); bool isVector; Type itemType; - InternalSchemaDefinition.GetVectorAndItemType(typeof(T), "annotation value", out isVector, out itemType); + InternalSchemaDefinition.GetVectorAndItemType("annotation value", typeof(T), null, out isVector, out itemType); if (annotationType == null) { diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index dfc1d2e52e..7e6981140f 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -119,7 +119,7 @@ public void AssertRep() Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void)); // Checks that the return type of the generator is compatible with ColumnType. - GetVectorAndItemType(ComputedReturnType, "return type", out bool isVector, out Type itemType); + GetVectorAndItemType("return type", ComputedReturnType, null, out bool isVector, out Type itemType); Contracts.Assert(isVector == ColumnType is VectorDataViewType); Contracts.Assert(itemType == ColumnType.GetItemType().RawType); } @@ -147,11 +147,11 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector switch (memberInfo) { case FieldInfo fieldInfo: - GetVectorAndItemType(fieldInfo.FieldType, fieldInfo.Name, out isVector, out itemType, fieldInfo.GetCustomAttributes()); + GetVectorAndItemType(fieldInfo.Name, fieldInfo.FieldType, fieldInfo.GetCustomAttributes(), out isVector, out itemType); break; case PropertyInfo propertyInfo: - GetVectorAndItemType(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out itemType, propertyInfo.GetCustomAttributes()); + GetVectorAndItemType(propertyInfo.Name, propertyInfo.PropertyType, propertyInfo.GetCustomAttributes(), out isVector, out itemType); break; default: @@ -165,14 +165,14 @@ public static void GetVectorAndItemType(MemberInfo memberInfo, out bool isVector /// and also the associated data type for this type. If a valid data type could not /// be determined, this will throw. /// - /// The type of the variable to inspect. /// The name of the variable to inspect. + /// The type of the variable to inspect. + /// Attribute of . It can be if attributes don't exist. /// Whether this appears to be a vector type. /// /// The corresponding RawType of the type, or items of this type if vector. /// - /// Attribute of . - public static void GetVectorAndItemType(Type rawType, string name, out bool isVector, out Type itemType, IEnumerable attributes=null) + public static void GetVectorAndItemType(string name, Type rawType, IEnumerable attributes, out bool isVector, out Type itemType) { // Determine whether this is a vector, and also determine the raw item type. isVector = true; @@ -246,7 +246,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us var parameterType = col.ReturnType; if (parameterType == null) throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No return parameter found in computed column."); - GetVectorAndItemType(parameterType, "returnType", out isVector, out dataItemType); + GetVectorAndItemType("returnType", parameterType, null, out isVector, out dataItemType); } // Infer the column name. var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName; diff --git a/src/Microsoft.ML.DataView/DataViewType.cs b/src/Microsoft.ML.DataView/DataViewType.cs index d84fcfa03a..9a12cfe981 100644 --- a/src/Microsoft.ML.DataView/DataViewType.cs +++ b/src/Microsoft.ML.DataView/DataViewType.cs @@ -37,6 +37,10 @@ private protected DataViewType(Type rawType) /// public Type RawType { get; } + // IEquatable interface recommends also to override base class implementations of + // Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other) + // is effectively a referencial comparison, there is no need to override base class implementations + // of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison. /// /// Return if is equivalent to and otherwise. /// @@ -468,6 +472,7 @@ public override bool Equals(DataViewType other) /// Whenever a value typed to the registered and its s, that value's type (i.e., a ) /// in would be the associated . /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public abstract class DataViewTypeAttribute : Attribute, IEquatable { /// diff --git a/src/Microsoft.ML.DataView/DataViewTypeManager.cs b/src/Microsoft.ML.DataView/DataViewTypeManager.cs index 74011f96d0..a177d70963 100644 --- a/src/Microsoft.ML.DataView/DataViewTypeManager.cs +++ b/src/Microsoft.ML.DataView/DataViewTypeManager.cs @@ -32,16 +32,6 @@ public static class DataViewTypeManager typeof(TimeSpan), typeof(DataViewRowId) }; - /// - /// Mapping from ID to a . The ID is the ID of in ML.NET's type system. - /// - private static Dictionary _idToTypeMap = new Dictionary(); - - /// - /// Mapping from ID to a instance. The ID is the ID of instance in ML.NET's type system. - /// - private static Dictionary _idToDataViewTypeMap = new Dictionary(); - /// /// Mapping from hashing ID of a and its s to hashing ID of a . /// @@ -75,7 +65,7 @@ public static DataViewType GetDataViewType(Type rawType, IEnumerable throw Contracts.ExceptParam(nameof(rawType), $"The raw type {rawType} with attributes {rawTypeAttributes} is not registered with a DataView type."); // Retrieve the actual DataViewType identified by dataViewTypeId. - return _idToDataViewTypeMap[dataViewTypeId]; + return dataViewTypeId.TargetType; } } @@ -92,7 +82,7 @@ public static bool Knows(Type rawType, IEnumerable rawTypeAttributes // Check if this ID has been associated with a DataViewType. // Note that the dictionary below contains (typeId, type) pairs (key is typeId, and value is type). - if (_idToTypeMap.ContainsKey(typeId)) + if (_typeIdToDataViewTypeIdMap.ContainsKey(typeId)) return true; else return false; @@ -112,7 +102,7 @@ public static bool Knows(DataViewType dataViewType) // Check if this the ID has been associated with a DataViewType. // Note that the dictionary below contains (dataViewTypeId, type) pairs (key is dataViewTypeId, and value is type). - if (_idToDataViewTypeMap.ContainsKey(dataViewTypeId)) + if (_dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId)) return true; else return false; @@ -147,7 +137,7 @@ public static void Register(DataViewType dataViewType, Type rawType, IEnumerable // There is a pair of (rawTypeId, anotherDataViewTypeId) in _typeIdToDataViewTypeId so we cannot register // (rawTypeId, dataViewTypeId) again. The assumption here is that one rawTypeId can only be associated // with one dataViewTypeId. - var associatedDataViewType = _idToDataViewTypeMap[_typeIdToDataViewTypeIdMap[rawTypeId]]; + var associatedDataViewType = _typeIdToDataViewTypeIdMap[rawTypeId].TargetType; throw Contracts.ExceptParam(nameof(rawType), $"Repeated type register. The raw type {rawType} " + $"has been associated with {associatedDataViewType} so it cannot be associated with {dataViewType}."); } @@ -157,30 +147,27 @@ public static void Register(DataViewType dataViewType, Type rawType, IEnumerable // There is a pair of (dataViewTypeId, anotherRawTypeId) in _dataViewTypeIdToTypeId so we cannot register // (dataViewTypeId, rawTypeId) again. The assumption here is that one dataViewTypeId can only be associated // with one rawTypeId. - var associatedRawType = _idToTypeMap[_dataViewTypeIdToTypeIdMap[dataViewTypeId]]; + var associatedRawType = _dataViewTypeIdToTypeIdMap[dataViewTypeId].TargetType; throw Contracts.ExceptParam(nameof(dataViewType), $"Repeated type register. The DataView type {dataViewType} " + $"has been associated with {associatedRawType} so it cannot be associated with {rawType}."); } _typeIdToDataViewTypeIdMap.Add(rawTypeId, dataViewTypeId); _dataViewTypeIdToTypeIdMap.Add(dataViewTypeId, rawTypeId); - - _idToDataViewTypeMap.Add(dataViewTypeId, dataViewType); - _idToTypeMap.Add(rawTypeId, rawType); } } /// - /// An instance of represents an unique key of its and . + /// An instance of represents an unique key of its and . /// private class TypeWithAttributesId { - private Type _targetType; + public Type TargetType { get; } private IEnumerable _associatedAttributes; public TypeWithAttributesId(Type rawType, IEnumerable attributes) { - _targetType = rawType; + TargetType = rawType; _associatedAttributes = attributes; } @@ -189,7 +176,7 @@ public override bool Equals(object obj) if (obj is TypeWithAttributesId other) { // Flag of having the same type. - var sameType = _targetType.Equals(other._targetType); + var sameType = TargetType.Equals(other.TargetType); // Flag of having the attribute configurations. var sameAttributeConfig = true; @@ -224,9 +211,9 @@ public override bool Equals(object obj) public override int GetHashCode() { if (_associatedAttributes == null) - return _targetType.GetHashCode(); + return TargetType.GetHashCode(); - var code = _targetType.GetHashCode(); + var code = TargetType.GetHashCode(); foreach (var attr in _associatedAttributes) code = Hashing.CombineHash(code, attr.GetHashCode()); return code; @@ -235,28 +222,28 @@ public override int GetHashCode() } /// - /// An instance of represents an unique key of its . + /// An instance of represents an unique key of its . /// private class DataViewTypeId { - private DataViewType _targetType; + public DataViewType TargetType { get; } public DataViewTypeId(DataViewType type) { - _targetType = type; + TargetType = type; } public override bool Equals(object obj) { if (obj is DataViewTypeId other) - return _targetType.Equals(other._targetType); + return TargetType.Equals(other.TargetType); return false; } public override int GetHashCode() { - return _targetType.GetHashCode(); + return TargetType.GetHashCode(); } } } diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index ff2e707e63..8402b27661 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -138,6 +138,11 @@ public override bool Equals(DataViewType other) return true; } + public override bool Equals(object other) + { + return other is DataViewType tmp && Equals(tmp); + } + public override int GetHashCode() { int hash = Hashing.CombineHash(ItemType.GetHashCode(), Size); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 96a46a23c5..6a30084f81 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -14,7 +14,6 @@ namespace Microsoft.ML.Transforms.Image /// Allows a member to be marked as a , primarily allowing one to set /// the shape of an image field. /// - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class ImageTypeAttribute : DataViewTypeAttribute { /// From 55b94895c6d1a2c25fbbec971a49f4cfa6dfa97c Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 23 May 2019 18:14:40 -0700 Subject: [PATCH 19/24] Polish a bit --- src/Microsoft.ML.DataView/DataViewTypeManager.cs | 5 +---- .../UnitTests/TestCustomTypeRegister.cs | 10 ++++------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.DataView/DataViewTypeManager.cs b/src/Microsoft.ML.DataView/DataViewTypeManager.cs index a177d70963..1d0ab1cd6b 100644 --- a/src/Microsoft.ML.DataView/DataViewTypeManager.cs +++ b/src/Microsoft.ML.DataView/DataViewTypeManager.cs @@ -19,7 +19,7 @@ public static class DataViewTypeManager { /// /// Types have been used in ML.NET type systems. They can have multiple-to-one type mapping. - /// For example, UInt32 and Key can be mapped to uint. This class enforces one-to-one mapping for all + /// For example, UInt32 and Key can be mapped to . This class enforces one-to-one mapping for all /// user-registered types. /// private static HashSet _bannedRawTypes = new HashSet() @@ -52,9 +52,6 @@ public static class DataViewTypeManager /// public static DataViewType GetDataViewType(Type rawType, IEnumerable rawTypeAttributes = null) { - // Overall flow: - // type (Type) + attrs ----> type ID ----------------> associated DataViewType's ID ----------------> DataViewType - // (hashing) (dictionary look-up) (dictionary look-up) lock (_lock) { // Compute the ID of type with extra attributes. diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index b9a005a271..4077810d63 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -38,7 +38,6 @@ public AlienBody(int age, float height, float weight, int handCount) } } - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] private sealed class AlienTypeAttributeAttribute : DataViewTypeAttribute { public int RaceId { get; } @@ -52,7 +51,7 @@ public AlienTypeAttributeAttribute(int id) } /// - /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custome type plus its attributes. + /// A function implicitly invoked by ML.NET when processing a custom type. It binds a DataViewType to a custom type plus its attributes. /// public override void Register() { @@ -117,10 +116,9 @@ public DataViewAlienBodyType(int id) : base(typeof(AlienBody)) public override bool Equals(DataViewType other) { - if (other is DataViewAlienBodyType) - return ((DataViewAlienBodyType)other).RaceId == RaceId; - else - return false; + if (other is DataViewAlienBodyType otherAlien) + return otherAlien.RaceId == RaceId; + return false; } public override int GetHashCode() From b4e5a618d775df17228c5d73c86820429a062408 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 24 May 2019 09:24:59 -0700 Subject: [PATCH 20/24] Address one comment and polish code --- .../DataViewTypeManager.cs | 131 ++++++++---------- .../UnitTests/TestCustomTypeRegister.cs | 13 +- 2 files changed, 64 insertions(+), 80 deletions(-) diff --git a/src/Microsoft.ML.DataView/DataViewTypeManager.cs b/src/Microsoft.ML.DataView/DataViewTypeManager.cs index 1d0ab1cd6b..d39d034814 100644 --- a/src/Microsoft.ML.DataView/DataViewTypeManager.cs +++ b/src/Microsoft.ML.DataView/DataViewTypeManager.cs @@ -35,12 +35,12 @@ public static class DataViewTypeManager /// /// Mapping from hashing ID of a and its s to hashing ID of a . /// - private static Dictionary _typeIdToDataViewTypeIdMap = new Dictionary(); + private static Dictionary _rawTypeToDataViewTypeMap = new Dictionary(); /// /// Mapping from hashing ID of a to hashing ID of a and its s. /// - private static Dictionary _dataViewTypeIdToTypeIdMap = new Dictionary(); + private static Dictionary _dataViewTypeToRawTypeMap = new Dictionary(); /// /// The lock that one should acquire if the state of will be accessed or modified. @@ -48,38 +48,38 @@ public static class DataViewTypeManager private static object _lock = new object(); /// - /// Returns the registered for and its . + /// Returns the registered for and its . /// - public static DataViewType GetDataViewType(Type rawType, IEnumerable rawTypeAttributes = null) + public static DataViewType GetDataViewType(Type type, IEnumerable typeAttributes = null) { lock (_lock) { // Compute the ID of type with extra attributes. - var typeId = new TypeWithAttributesId(rawType, rawTypeAttributes); + var rawType = new TypeWithAttributes(type, typeAttributes); // Get the DataViewType's ID which typeID is mapped into. - if (!_typeIdToDataViewTypeIdMap.TryGetValue(typeId, out DataViewTypeId dataViewTypeId)) - throw Contracts.ExceptParam(nameof(rawType), $"The raw type {rawType} with attributes {rawTypeAttributes} is not registered with a DataView type."); + if (!_rawTypeToDataViewTypeMap.TryGetValue(rawType, out DataViewType dataViewType)) + throw Contracts.ExceptParam(nameof(type), $"The raw type {type} with attributes {typeAttributes} is not registered with a DataView type."); - // Retrieve the actual DataViewType identified by dataViewTypeId. - return dataViewTypeId.TargetType; + // Retrieve the actual DataViewType identified by dataViewType. + return dataViewType; } } /// - /// If has been registered with a , this function returns . + /// If has been registered with a , this function returns . /// Otherwise, this function returns . /// - public static bool Knows(Type rawType, IEnumerable rawTypeAttributes = null) + public static bool Knows(Type type, IEnumerable typeAttributes = null) { lock (_lock) { // Compute the ID of type with extra attributes. - var typeId = new TypeWithAttributesId(rawType, rawTypeAttributes); + var rawType = new TypeWithAttributes(type, typeAttributes); // Check if this ID has been associated with a DataViewType. - // Note that the dictionary below contains (typeId, type) pairs (key is typeId, and value is type). - if (_typeIdToDataViewTypeIdMap.ContainsKey(typeId)) + // Note that the dictionary below contains (rawType, dataViewType) pairs (key type is TypeWithAttributes, and value type is DataViewType). + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType)) return true; else return false; @@ -94,12 +94,9 @@ public static bool Knows(DataViewType dataViewType) { lock (_lock) { - // Compute the ID of the input DataViewType. - var dataViewTypeId = new DataViewTypeId(dataViewType); - // Check if this the ID has been associated with a DataViewType. - // Note that the dictionary below contains (dataViewTypeId, type) pairs (key is dataViewTypeId, and value is type). - if (_dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId)) + // Note that the dictionary below contains (dataViewType, rawType) pairs (key type is DataViewType, and value type is TypeWithAttributes). + if (_dataViewTypeToRawTypeMap.ContainsKey(dataViewType)) return true; else return false; @@ -107,70 +104,78 @@ public static bool Knows(DataViewType dataViewType) } /// - /// This function tells that should be representation of data in in - /// ML.NET's type system. The registered must be a standard C# object's type. + /// This function tells that should be representation of data in in + /// ML.NET's type system. The registered must be a standard C# object's type. /// - /// Native type in C#. - /// The corresponding type of in ML.NET's type system. - /// The s attached to . - public static void Register(DataViewType dataViewType, Type rawType, IEnumerable rawTypeAttributes = null) + /// Native type in C#. + /// The corresponding type of in ML.NET's type system. + /// The s attached to . + public static void Register(DataViewType dataViewType, Type type, IEnumerable typeAttributes = null) { lock (_lock) { - if (_bannedRawTypes.Contains(rawType)) - throw Contracts.ExceptParam(nameof(rawType), $"Type {rawType} has been registered as ML.NET's default supported type, " + + if (_bannedRawTypes.Contains(type)) + throw Contracts.ExceptParam(nameof(type), $"Type {type} has been registered as ML.NET's default supported type, " + $"so it can't not be registered again."); - var rawTypeId = new TypeWithAttributesId(rawType, rawTypeAttributes); - var dataViewTypeId = new DataViewTypeId(dataViewType); + var rawType = new TypeWithAttributes(type, typeAttributes); - if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && _typeIdToDataViewTypeIdMap[rawTypeId].Equals(dataViewTypeId) && - _dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && _dataViewTypeIdToTypeIdMap[dataViewTypeId].Equals(rawTypeId)) + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType) && _rawTypeToDataViewTypeMap[rawType].Equals(dataViewType) && + _dataViewTypeToRawTypeMap.ContainsKey(dataViewType) && _dataViewTypeToRawTypeMap[dataViewType].Equals(rawType)) // This type pair has been registered. Note that registering one data type pair multiple times is allowed. return; - if (_typeIdToDataViewTypeIdMap.ContainsKey(rawTypeId) && !_typeIdToDataViewTypeIdMap[rawTypeId].Equals(dataViewTypeId)) + if (_rawTypeToDataViewTypeMap.ContainsKey(rawType) && !_rawTypeToDataViewTypeMap[rawType].Equals(dataViewType)) { - // There is a pair of (rawTypeId, anotherDataViewTypeId) in _typeIdToDataViewTypeId so we cannot register - // (rawTypeId, dataViewTypeId) again. The assumption here is that one rawTypeId can only be associated - // with one dataViewTypeId. - var associatedDataViewType = _typeIdToDataViewTypeIdMap[rawTypeId].TargetType; - throw Contracts.ExceptParam(nameof(rawType), $"Repeated type register. The raw type {rawType} " + + // There is a pair of (rawType, anotherDataViewType) in _typeToDataViewType so we cannot register + // (rawType, dataViewType) again. The assumption here is that one rawType can only be associated + // with one dataViewType. + var associatedDataViewType = _rawTypeToDataViewTypeMap[rawType]; + throw Contracts.ExceptParam(nameof(type), $"Repeated type register. The raw type {type} " + $"has been associated with {associatedDataViewType} so it cannot be associated with {dataViewType}."); } - if (_dataViewTypeIdToTypeIdMap.ContainsKey(dataViewTypeId) && !_dataViewTypeIdToTypeIdMap[dataViewTypeId].Equals(rawTypeId)) + if (_dataViewTypeToRawTypeMap.ContainsKey(dataViewType) && !_dataViewTypeToRawTypeMap[dataViewType].Equals(rawType)) { - // There is a pair of (dataViewTypeId, anotherRawTypeId) in _dataViewTypeIdToTypeId so we cannot register - // (dataViewTypeId, rawTypeId) again. The assumption here is that one dataViewTypeId can only be associated - // with one rawTypeId. - var associatedRawType = _dataViewTypeIdToTypeIdMap[dataViewTypeId].TargetType; + // There is a pair of (dataViewType, anotherRawType) in _dataViewTypeToType so we cannot register + // (dataViewType, rawType) again. The assumption here is that one dataViewType can only be associated + // with one rawType. + var associatedRawType = _dataViewTypeToRawTypeMap[dataViewType].TargetType; throw Contracts.ExceptParam(nameof(dataViewType), $"Repeated type register. The DataView type {dataViewType} " + - $"has been associated with {associatedRawType} so it cannot be associated with {rawType}."); + $"has been associated with {associatedRawType} so it cannot be associated with {type}."); } - _typeIdToDataViewTypeIdMap.Add(rawTypeId, dataViewTypeId); - _dataViewTypeIdToTypeIdMap.Add(dataViewTypeId, rawTypeId); + _rawTypeToDataViewTypeMap.Add(rawType, dataViewType); + _dataViewTypeToRawTypeMap.Add(dataViewType, rawType); } } /// - /// An instance of represents an unique key of its and . + /// An instance of represents an unique key of its and . /// - private class TypeWithAttributesId + private class TypeWithAttributes { + /// + /// The underlying type. + /// public Type TargetType { get; } + + /// + /// The underlying type's attributes. Together with , uniquely defines + /// a key when using as the key type in . Note that the + /// uniqueness is determined by and below. + /// private IEnumerable _associatedAttributes; - public TypeWithAttributesId(Type rawType, IEnumerable attributes) + public TypeWithAttributes(Type type, IEnumerable attributes) { - TargetType = rawType; + TargetType = type; _associatedAttributes = attributes; } public override bool Equals(object obj) { - if (obj is TypeWithAttributesId other) + if (obj is TypeWithAttributes other) { // Flag of having the same type. var sameType = TargetType.Equals(other.TargetType); @@ -217,31 +222,5 @@ public override int GetHashCode() } } - - /// - /// An instance of represents an unique key of its . - /// - private class DataViewTypeId - { - public DataViewType TargetType { get; } - - public DataViewTypeId(DataViewType type) - { - TargetType = type; - } - - public override bool Equals(object obj) - { - if (obj is DataViewTypeId other) - return TargetType.Equals(other.TargetType); - - return false; - } - - public override int GetHashCode() - { - return TargetType.GetHashCode(); - } - } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 4077810d63..6cdc48bc01 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -38,16 +38,19 @@ public AlienBody(int age, float height, float weight, int handCount) } } + /// + /// applied to class members. + /// private sealed class AlienTypeAttributeAttribute : DataViewTypeAttribute { public int RaceId { get; } /// - /// Create an image type with known height and width. + /// Create an from to a . /// - public AlienTypeAttributeAttribute(int id) + public AlienTypeAttributeAttribute(int raceId) { - RaceId = id; + RaceId = raceId; } /// @@ -73,7 +76,8 @@ public override bool Equals(DataViewTypeAttribute other) /// It will be the input of . /// /// and would be mapped to different types inside ML.NET type system because they - /// have different s. + /// have different s. For example, the column type of would + /// be . /// private class AlienHero { @@ -104,6 +108,7 @@ public AlienHero(string name, /// /// Type of in ML.NET's type system. + /// It usually shows up as among . /// private class DataViewAlienBodyType : StructuredDataViewType { From ef611515b5595ba8d3e61b14f4d69b1f6d28b574 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 28 May 2019 11:18:38 -0700 Subject: [PATCH 21/24] Update src/Microsoft.ML.DataView/DataViewTypeManager.cs --- src/Microsoft.ML.DataView/DataViewTypeManager.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.DataView/DataViewTypeManager.cs b/src/Microsoft.ML.DataView/DataViewTypeManager.cs index d39d034814..1855d90b7c 100644 --- a/src/Microsoft.ML.DataView/DataViewTypeManager.cs +++ b/src/Microsoft.ML.DataView/DataViewTypeManager.cs @@ -33,7 +33,7 @@ public static class DataViewTypeManager }; /// - /// Mapping from hashing ID of a and its s to hashing ID of a . + /// Mapping from a plus its s to a . /// private static Dictionary _rawTypeToDataViewTypeMap = new Dictionary(); From 18bf87a93de5d5c73f9e3e7d13aa55dda4eedd5a Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 28 May 2019 11:19:36 -0700 Subject: [PATCH 22/24] Update src/Microsoft.ML.DataView/DataViewTypeManager.cs --- src/Microsoft.ML.DataView/DataViewTypeManager.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.DataView/DataViewTypeManager.cs b/src/Microsoft.ML.DataView/DataViewTypeManager.cs index 1855d90b7c..0ed1ab8fb9 100644 --- a/src/Microsoft.ML.DataView/DataViewTypeManager.cs +++ b/src/Microsoft.ML.DataView/DataViewTypeManager.cs @@ -38,7 +38,7 @@ public static class DataViewTypeManager private static Dictionary _rawTypeToDataViewTypeMap = new Dictionary(); /// - /// Mapping from hashing ID of a to hashing ID of a and its s. + /// Mapping from a to a plus its s. /// private static Dictionary _dataViewTypeToRawTypeMap = new Dictionary(); From fb0440ad48e2756f5f6ec573c7150c5f2e3a29f8 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 28 May 2019 14:23:41 -0700 Subject: [PATCH 23/24] Add samples --- .../CustomMappingWithInMemoryCustomType.cs | 179 ++++++++++++++++++ .../ConvertToGrayScaleInMemory.cs | 85 +++++++++ .../ExtensionsCatalog.cs | 1 + .../CustomMappingCatalog.cs | 1 + .../UnitTests/TestCustomTypeRegister.cs | 8 +- 5 files changed, 270 insertions(+), 4 deletions(-) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScaleInMemory.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs new file mode 100644 index 0000000000..688c5d1fe5 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs @@ -0,0 +1,179 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; + +namespace Samples.Dynamic +{ + class CustomMappingWithInMemoryCustomType + { + static public void Example() + { + var mlContext = new MLContext(); + // Build in-memory data. + var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; + + // Build a ML.NET pipeline and make prediction. + var tribeDataView = mlContext.Data.LoadFromEnumerable(tribe); + var pipeline = mlContext.Transforms.CustomMapping(AlienFusionProcess.GetMapping(), contractName: null); + var model = pipeline.Fit(tribeDataView); + var tribeTransformed = model.Transform(tribeDataView); + + // Print out prediction produced by the model. + var firstAlien = mlContext.Data.CreateEnumerable(tribeTransformed, false).First(); + Console.WriteLine($"We got a super alien with name {firstAlien.Name}, age {firstAlien.Merged.Age}, " + + $"height {firstAlien.Merged.Height}, weight {firstAlien.Merged.Weight}, and {firstAlien.Merged.HandCount} hands."); + + // Expected output: + // We got a super alien with name Super Unknown, age 4002, height 6000, weight 8000, and 10000 hands. + + // Create a prediction engine and print out its prediction. + var engine = mlContext.Model.CreatePredictionEngine(model); + var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); + var superAlien = engine.Predict(alien); + Console.Write($"We got a super alien with name {superAlien.Name}, age {superAlien.Merged.Age}, " + + $"height {superAlien.Merged.Height}, weight {superAlien.Merged.Weight}, and {superAlien.Merged.HandCount} hands."); + + // Expected output: + // We got a super alien with name Super Unknown, age 6, height 8, weight 10, and 12 hands. + } + + // A custom type which ML.NET doesn't know yet. Its value will be loaded as a DataView column in this test. + private class AlienBody + { + public int Age { get; set; } + public float Height { get; set; } + public float Weight { get; set; } + public int HandCount { get; set; } + + public AlienBody(int age, float height, float weight, int handCount) + { + Age = age; + Height = height; + Weight = weight; + HandCount = handCount; + } + } + + // DataViewTypeAttribute applied to class AlienBody members. + private sealed class AlienTypeAttributeAttribute : DataViewTypeAttribute + { + public int RaceId { get; } + + // Create an DataViewTypeAttribute> from raceId to a AlienBody. + public AlienTypeAttributeAttribute(int raceId) + { + RaceId = raceId; + } + + // A function implicitly invoked by ML.NET when processing a custom type. + // It binds a DataViewType to a custom type plus its attributes. + public override void Register() + { + DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), new[] { this }); + } + + public override bool Equals(DataViewTypeAttribute other) + { + if (other is AlienTypeAttributeAttribute) + return RaceId == ((AlienTypeAttributeAttribute)other).RaceId; + return false; + } + + public override int GetHashCode() => RaceId.GetHashCode(); + } + + // A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. + // It will be the input of AlienFusionProcess.MergeBody(AlienHero, SuperAlienHero). + // + // The members One> and Two" would be mapped to different types inside ML.NET type system because they + // have different AlienTypeAttributeAttribute's. For example, the column type of One would be DataViewAlienBodyType + // with RaceId=100. + // + private class AlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(100)] + public AlienBody One { get; set; } + + [AlienTypeAttribute(200)] + public AlienBody Two { get; set; } + + public AlienHero() + { + Name = "Unknown"; + One = new AlienBody(0, 0, 0, 0); + Two = new AlienBody(0, 0, 0, 0); + } + + public AlienHero(string name, + int age, float height, float weight, int handCount, + int anotherAge, float anotherHeight, float anotherWeight, int anotherHandCount) + { + Name = "Unknown"; + One = new AlienBody(age, height, weight, handCount); + Two = new AlienBody(anotherAge, anotherHeight, anotherWeight, anotherHandCount); + } + } + + // Type of AlienBody in ML.NET's type system. + // It usually shows up as DataViewSchema.Column.Type among IDataView.Schema. + private class DataViewAlienBodyType : StructuredDataViewType + { + public int RaceId { get; } + + public DataViewAlienBodyType(int id) : base(typeof(AlienBody)) + { + RaceId = id; + } + + public override bool Equals(DataViewType other) + { + if (other is DataViewAlienBodyType otherAlien) + return otherAlien.RaceId == RaceId; + return false; + } + + public override int GetHashCode() + { + return RaceId.GetHashCode(); + } + } + + // The output type of processing AlienHero using AlienFusionProcess.MergeBody(AlienHero, SuperAlienHero). + private class SuperAlienHero + { + public string Name { get; set; } + + [AlienTypeAttribute(007)] + public AlienBody Merged { get; set; } + + public SuperAlienHero() + { + Name = "Unknown"; + Merged = new AlienBody(0, 0, 0, 0); + } + } + + // The implementation of custom mapping is MergeBody. It accepts AlienHero and produces SuperAlienHero. + private class AlienFusionProcess + { + public static void MergeBody(AlienHero input, SuperAlienHero output) + { + output.Name = "Super " + input.Name; + output.Merged.Age = input.One.Age + input.Two.Age; + output.Merged.Height = input.One.Height + input.Two.Height; + output.Merged.Weight = input.One.Weight + input.Two.Weight; + output.Merged.HandCount = input.One.HandCount + input.Two.HandCount; + } + + public static Action GetMapping() + { + return MergeBody; + } + } + + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScaleInMemory.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScaleInMemory.cs new file mode 100644 index 0000000000..883dfa5dc1 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScaleInMemory.cs @@ -0,0 +1,85 @@ +using System; +using System.Drawing; +using Microsoft.ML; +using Microsoft.ML.Transforms.Image; + +namespace Samples.Dynamic +{ + class ConvertToGrayScaleInMemory + { + static public void Example() + { + var mlContext = new MLContext(); + // Create an image list. + var images = new[] { new ImageDataPoint(2, 3, Color.Blue), new ImageDataPoint(2, 3, Color.Red) }; + + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var data = mlContext.Data.LoadFromEnumerable(images); + + // Convert image to gray scale. + var pipeline = mlContext.Transforms.ConvertToGrayscale("GrayImage", "Image"); + + // Fit the model. + var model = pipeline.Fit(data); + + // Test path: image files -> IDataView -> Enumerable of Bitmaps. + var transformedData = model.Transform(data); + + // Load images in DataView back to Enumerable. + var transformedDataPoints = mlContext.Data.CreateEnumerable(transformedData, false); + + // Print out input and output pixels. + foreach (var dataPoint in transformedDataPoints) + { + var image = dataPoint.Image; + var grayImage = dataPoint.GrayImage; + for (int x = 0; x < grayImage.Width; ++x) + { + for (int y = 0; y < grayImage.Height; ++y) + { + var pixel = image.GetPixel(x, y); + var grayPixel = grayImage.GetPixel(x, y); + Console.WriteLine($"The original pixel is {pixel} and its pixel in gray is {grayPixel}"); + } + } + } + + // Expected output: + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 0, G = 0, B = 255] and its pixel in gray is Color[A = 255, R = 28, G = 28, B = 28] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + // The original pixel is Color[A = 255, R = 255, G = 0, B = 0] and its pixel in gray is Color[A = 255, R = 77, G = 77, B = 77] + } + + private class ImageDataPoint + { + [ImageType(3, 4)] + public Bitmap Image { get; set; } + + [ImageType(3, 4)] + public Bitmap GrayImage { get; set; } + + public ImageDataPoint() + { + Image = null; + GrayImage = null; + } + + public ImageDataPoint(int width, int height, Color color) + { + Image = new Bitmap(width, height); + for (int i = 0; i < width; ++i) + for (int j = 0; j < height; ++j) + Image.SetPixel(i, j, color); + } + } + } +} diff --git a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs index 2ffd5ba97c..78cbc74874 100644 --- a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs @@ -26,6 +26,7 @@ public static class ImageEstimatorsCatalog /// /// /// public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null) diff --git a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs index 9dd54c37f8..819a518188 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs @@ -38,6 +38,7 @@ public static class CustomMappingCatalog /// /// public static CustomMappingEstimator CustomMapping(this TransformsCatalog catalog, Action mapAction, string contractName, diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index 6cdc48bc01..b51d8952d5 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -73,7 +73,7 @@ public override bool Equals(DataViewTypeAttribute other) /// /// A custom class with a type which ML.NET doesn't know yet. Its value will be loaded as a DataView row in this test. - /// It will be the input of . + /// It will be the input of . /// /// and would be mapped to different types inside ML.NET type system because they /// have different s. For example, the column type of would @@ -133,7 +133,7 @@ public override int GetHashCode() } /// - /// The output type of processing using . + /// The output type of processing using . /// private class SuperAlienHero { @@ -154,7 +154,7 @@ public SuperAlienHero() /// in . /// [CustomMappingFactoryAttribute("LambdaAlienHero")] - private class AlienLambda : CustomMappingFactory + private class AlienFusionProcess : CustomMappingFactory { public static void MergeBody(AlienHero input, SuperAlienHero output) { @@ -179,7 +179,7 @@ public void RegisterTypeWithAttribute() // Build a ML.NET pipeline and make prediction. var tribeDataView = ML.Data.LoadFromEnumerable(tribe); - var heroEstimator = new CustomMappingEstimator(ML, AlienLambda.MergeBody, "LambdaAlienHero"); + var heroEstimator = new CustomMappingEstimator(ML, AlienFusionProcess.MergeBody, "LambdaAlienHero"); var model = heroEstimator.Fit(tribeDataView); var tribeTransformed = model.Transform(tribeDataView); var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); From 164908fc22101605cd337ac64b1aabfa73d2b641 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 28 May 2019 14:58:36 -0700 Subject: [PATCH 24/24] Internalize some methods --- .../Data}/DataViewTypeManager.cs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) rename src/{Microsoft.ML.DataView => Microsoft.ML.Data/Data}/DataViewTypeManager.cs (97%) diff --git a/src/Microsoft.ML.DataView/DataViewTypeManager.cs b/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs similarity index 97% rename from src/Microsoft.ML.DataView/DataViewTypeManager.cs rename to src/Microsoft.ML.Data/Data/DataViewTypeManager.cs index 0ed1ab8fb9..fda18633ff 100644 --- a/src/Microsoft.ML.DataView/DataViewTypeManager.cs +++ b/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs @@ -6,7 +6,8 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.ML.Internal.DataView; +using Microsoft.ML.Internal.CpuMath.Core; +using Microsoft.ML.Internal.Utilities; namespace Microsoft.ML.Data { @@ -50,7 +51,7 @@ public static class DataViewTypeManager /// /// Returns the registered for and its . /// - public static DataViewType GetDataViewType(Type type, IEnumerable typeAttributes = null) + internal static DataViewType GetDataViewType(Type type, IEnumerable typeAttributes = null) { lock (_lock) { @@ -70,7 +71,7 @@ public static DataViewType GetDataViewType(Type type, IEnumerable typ /// If has been registered with a , this function returns . /// Otherwise, this function returns . /// - public static bool Knows(Type type, IEnumerable typeAttributes = null) + internal static bool Knows(Type type, IEnumerable typeAttributes = null) { lock (_lock) { @@ -90,7 +91,7 @@ public static bool Knows(Type type, IEnumerable typeAttributes = null /// If has been registered with a , this function returns . /// Otherwise, this function returns . /// - public static bool Knows(DataViewType dataViewType) + internal static bool Knows(DataViewType dataViewType) { lock (_lock) {