diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index f9a1b5ef27..1471c580ba 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -16,6 +16,7 @@ + diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 6f21a1cb01..509a67bb4f 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -28,20 +28,42 @@ public enum VectorKind VariableVector } + /// + /// The column name. + /// public readonly string Name; + + /// + /// The type of the column: scalar, fixed vector or variable vector. + /// public readonly VectorKind Kind; - public readonly DataKind ItemKind; + + /// + /// The 'raw' type of column item: must be a primitive type or a structured type. + /// + public readonly ColumnType ItemType; + /// + /// The flag whether the column is actually a key. If yes, is representing + /// the underlying primitive type. + /// public readonly bool IsKey; + /// + /// The metadata kinds that are present for this column. + /// public readonly string[] MetadataKinds; - public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds = null) + public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, string[] metadataKinds = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValueOrNull(metadataKinds); + Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key"); + Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector"); + + Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key"); Name = name; Kind = vecKind; - ItemKind = itemKind; + ItemType = itemType; IsKey = isKey; MetadataKinds = metadataKinds ?? new string[0]; } @@ -51,7 +73,7 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st /// requirement. /// /// Namely, it returns true iff: - /// - The , , , fields match. + /// - The , , , fields match. /// - The of is a superset of our . /// public bool IsCompatibleWith(Column inputColumn) @@ -61,7 +83,7 @@ public bool IsCompatibleWith(Column inputColumn) return false; if (Kind != inputColumn.Kind) return false; - if (ItemKind != inputColumn.ItemKind) + if (!ItemType.Equals(inputColumn.ItemType)) return false; if (IsKey != inputColumn.IsKey) return false; @@ -72,7 +94,7 @@ public bool IsCompatibleWith(Column inputColumn) public string GetTypeString() { - string result = ItemKind.ToString(); + string result = ItemType.ToString(); if (IsKey) result = $"Key<{result}>"; if (Kind == VectorKind.Vector) @@ -110,13 +132,15 @@ public static SchemaShape Create(ISchema schema) else vecKind = Column.VectorKind.Scalar; - var kind = type.ItemType.RawKind; + ColumnType itemType = type.ItemType; + if (type.ItemType.IsKey) + itemType = PrimitiveType.FromKind(type.ItemType.RawKind); var isKey = type.ItemType.IsKey; var metadataNames = schema.GetMetadataTypes(iCol) .Select(kvp => kvp.Key) .ToArray(); - cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames)); + cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadataNames)); } } return new SchemaShape(cols.ToArray()); diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs new file mode 100644 index 0000000000..29c081ac35 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// The trivial implementation of that already has + /// the transformer and returns it on every call to . + /// + /// Concrete implementations still have to provide the schema propagation mechanism, since + /// there is no easy way to infer it from the transformer. + /// + public abstract class TrivialEstimator : IEstimator + where TTransformer : class, ITransformer + { + protected readonly IHost Host; + protected readonly TTransformer Transformer; + + protected TrivialEstimator(IHost host, TTransformer transformer) + { + Contracts.AssertValue(host); + + Host = host; + Host.CheckValue(transformer, nameof(transformer)); + Transformer = transformer; + } + + public TTransformer Fit(IDataView input) => Transformer; + + public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); + } +} diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 184c0226bb..3c47f84faa 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -70,7 +70,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var originalColumn = inputSchema.FindColumn(column.Source); if (originalColumn != null) { - var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemKind, originalColumn.IsKey, originalColumn.MetadataKinds); + var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.MetadataKinds); resultDic[column.Name] = col; } else diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs new file mode 100644 index 0000000000..b301eac793 --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -0,0 +1,177 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Model; + +namespace Microsoft.ML.Runtime.Data +{ + public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel + { + protected readonly IHost Host; + protected readonly (string input, string output)[] ColumnPairs; + + protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns) + { + Contracts.AssertValue(host); + host.CheckValue(columns, nameof(columns)); + + var newNames = new HashSet(); + foreach (var column in columns) + { + host.CheckNonEmpty(column.input, nameof(columns)); + host.CheckNonEmpty(column.output, nameof(columns)); + + if (!newNames.Add(column.output)) + throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times"); + } + + Host = host; + ColumnPairs = columns; + } + + protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) + { + Host = host; + // *** Binary format *** + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + + int n = ctx.Reader.ReadInt32(); + ColumnPairs = new (string input, string output)[n]; + for (int i = 0; i < n; i++) + { + string output = ctx.LoadNonEmptyString(); + string input = ctx.LoadNonEmptyString(); + ColumnPairs[i] = (input, output); + } + } + + public abstract void Save(ModelSaveContext ctx); + + protected void SaveColumns(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + // *** Binary format *** + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + + ctx.Writer.Write(ColumnPairs.Length); + for (int i = 0; i < ColumnPairs.Length; i++) + { + ctx.SaveNonEmptyString(ColumnPairs[i].output); + ctx.SaveNonEmptyString(ColumnPairs[i].input); + } + } + + private void CheckInput(ISchema inputSchema, int col, out int srcCol) + { + Contracts.AssertValue(inputSchema); + Contracts.Assert(0 <= col && col < ColumnPairs.Length); + + if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input); + CheckInputColumn(inputSchema, col, srcCol); + } + + protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + { + // By default, there are no extra checks. + } + + protected abstract IRowMapper MakeRowMapper(ISchema schema); + + public ISchema GetOutputSchema(ISchema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + // Check that all the input columns are present and correct. + for (int i = 0; i < ColumnPairs.Length; i++) + CheckInput(inputSchema, i, out int col); + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + + public IDataView Transform(IDataView input) => MakeDataTransform(input); + + protected RowToRowMapperTransform MakeDataTransform(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema)); + } + + protected abstract class MapperBase : IRowMapper + { + protected readonly IHost Host; + protected readonly Dictionary ColMapNewToOld; + protected readonly ISchema InputSchema; + private readonly OneToOneTransformerBase _parent; + + protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema) + { + Contracts.AssertValue(host); + Contracts.AssertValue(parent); + Contracts.AssertValue(inputSchema); + + Host = host; + _parent = parent; + + ColMapNewToOld = new Dictionary(); + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + _parent.CheckInput(inputSchema, i, out int srcCol); + ColMapNewToOld.Add(i, srcCol); + } + InputSchema = inputSchema; + } + public Func GetDependencies(Func activeOutput) + { + var active = new bool[InputSchema.ColumnCount]; + foreach (var pair in ColMapNewToOld) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; + } + + public abstract RowMapperColumnInfo[] GetOutputColumns(); + + public void Save(ModelSaveContext ctx) => _parent.Save(ctx); + + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + Contracts.Assert(input.Schema == InputSchema); + var result = new Delegate[_parent.ColumnPairs.Length]; + var disposers = new Action[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!activeOutput(i)) + continue; + int srcCol = ColMapNewToOld[i]; + result[i] = MakeGetter(input, i, out disposers[i]); + } + if (disposers.Any(x => x != null)) + { + disposer = () => + { + foreach (var act in disposers) + act(); + }; + } + else + disposer = null; + return result; + } + + protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer); + } + } +} diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs index 97c613485f..921309d7ef 100644 --- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs +++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs @@ -16,7 +16,7 @@ public static class ImageAnalytics public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageLoaderTransform", input); - var xf = new ImageLoaderTransform(h, input, input.Data); + var xf = ImageLoaderTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), @@ -29,7 +29,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input); - var xf = new ImageResizerTransform(h, input, input.Data); + var xf = ImageResizerTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), @@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, I public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment env, ImagePixelExtractorTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImagePixelExtractorTransform", input); - var xf = new ImagePixelExtractorTransform(h, input, input.Data); + var xf = ImagePixelExtractorTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), @@ -55,7 +55,7 @@ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, ImageGrayscaleTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageGrayscaleTransform", input); - var xf = new ImageGrayscaleTransform(h, input, input.Data); + var xf = ImageGrayscaleTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 7a267cf1b8..21076b2d86 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -2,22 +2,31 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Drawing; -using System.Drawing.Imaging; -using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.ImageAnalytics; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.ImageAnalytics; +using System; +using System.Collections.Generic; +using System.Drawing; +using System.Drawing.Imaging; +using System.Linq; +using System.Text; -[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform), ImageGrayscaleTransform.UserName, "ImageGrayscaleTransform", "ImageGrayscale")] -[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform), + ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadModel), + ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadRowMapper), ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics @@ -28,7 +37,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics /// Transform which takes one or many columns of type in IDataView and /// convert them to greyscale representation of the same image. /// - public sealed class ImageGrayscaleTransform : OneToOneTransformBase + public sealed class ImageGrayscaleTransform : OneToOneTransformerBase { public sealed class Column : OneToOneColumn { @@ -69,50 +78,57 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ImageGrayscale"; - // Public constructor corresponding to SignatureDataTransform. - public ImageGrayscaleTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type") + public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + + public ImageGrayscaleTransform(IHostEnvironment env, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); - Metadata.Seal(); } - private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type") + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertValue(ctx); - // *** Binary format *** - // - Host.AssertNonEmpty(Infos); - Metadata.Seal(); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + env.CheckValue(args.Column, nameof(args.Column)); + + return new ImageGrayscaleTransform(env, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + .MakeDataTransform(input); } - public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); + var host = env.Register(RegistrationName); + host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new ImageGrayscaleTransform(h, ctx, input)); + return new ImageGrayscaleTransform(host, ctx); } + private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { + } + + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** // - SaveBase(ctx); - } - - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return Infos[iinfo].TypeSrc; + base.SaveColumns(ctx); } private static readonly ColorMatrix _grayscaleColorMatrix = new ColorMatrix( @@ -125,47 +141,96 @@ protected override ColumnType GetColumnTypeCore(int iinfo) new float[] {0, 0, 0, 0, 1} }); - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + if (!(inputSchema.GetColumnType(srcCol) is ImageType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString()); + } - var src = default(Bitmap); - var getSrc = GetSrcGetter(input, iinfo); + private sealed class Mapper : MapperBase + { + private ImageGrayscaleTransform _parent; - disposer = - () => - { - if (src != null) - { - src.Dispose(); - src = null; - } - }; + public Mapper(ImageGrayscaleTransform parent, ISchema inputSchema) + :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + } - ValueGetter del = - (ref Bitmap dst) => - { - if (dst != null) - dst.Dispose(); - - getSrc(ref src); - if (src == null || src.Height <= 0 || src.Width <= 0) - return; - - dst = new Bitmap(src.Width, src.Height); - ImageAttributes attributes = new ImageAttributes(); - attributes.SetColorMatrix(_grayscaleColorMatrix); - var srcRectangle = new Rectangle(0, 0, src.Width, src.Height); - using (var g = Graphics.FromImage(dst)) + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent.ColumnPairs.Select((x, idx) => new RowMapperColumnInfo(x.output, InputSchema.GetColumnType(ColMapNewToOld[idx]), null)).ToArray(); + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); + + var src = default(Bitmap); + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + + disposer = + () => + { + if (src != null) + { + src.Dispose(); + src = null; + } + }; + + ValueGetter del = + (ref Bitmap dst) => { - g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes); - } - Host.Assert(dst.Width == src.Width && dst.Height == src.Height); - }; + if (dst != null) + dst.Dispose(); + + getSrc(ref src); + if (src == null || src.Height <= 0 || src.Width <= 0) + return; + + dst = new Bitmap(src.Width, src.Height); + ImageAttributes attributes = new ImageAttributes(); + attributes.SetColorMatrix(_grayscaleColorMatrix); + var srcRectangle = new Rectangle(0, 0, src.Width, src.Height); + using (var g = Graphics.FromImage(dst)) + { + g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes); + } + Contracts.Assert(dst.Width == src.Width && dst.Height == src.Height); + }; + + return del; + } + } + } + + public sealed class ImageGrayscaleEstimator : TrivialEstimator + { + public ImageGrayscaleEstimator(IHostEnvironment env, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageGrayscaleEstimator)), new ImageGrayscaleTransform(env, columns)) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in Transformer.Columns) + { + var col = inputSchema.FindColumn(colInfo.input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, new ImageType().ToString(), col.GetTypeString()); + + result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind, col.ItemType, col.IsKey, col.MetadataKinds); + } - return del; + return new SchemaShape(result.Values); } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 488c710743..20e1476feb 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -2,31 +2,37 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Drawing; -using System.IO; -using System.Text; -using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.ImageAnalytics; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using System; +using System.Collections.Generic; +using System.Drawing; +using System.IO; +using System.Linq; +using System.Text; -[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform), ImageLoaderTransform.UserName, "ImageLoaderTransform", "ImageLoader")] -[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform), ImageLoaderTransform.UserName, ImageLoaderTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(ImageLoaderTransform), null, typeof(SignatureLoadModel), "", ImageLoaderTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageLoaderTransform), null, typeof(SignatureLoadRowMapper), "", ImageLoaderTransform.LoaderSignature)] + namespace Microsoft.ML.Runtime.ImageAnalytics { - // REVIEW: Rewrite as LambdaTransform to simplify. /// /// Transform which takes one or many columns of type and loads them as /// - public sealed class ImageLoaderTransform : OneToOneTransformBase + public sealed class ImageLoaderTransform : OneToOneTransformerBase { public sealed class Column : OneToOneColumn { @@ -61,118 +67,177 @@ public sealed class Arguments : TransformInputBase internal const string UserName = "Image Loader Transform"; public const string LoaderSignature = "ImageLoaderTransform"; - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "IMGLOADT", - //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); - } + public readonly string ImageFolder; - private readonly ImageType _type; - private readonly string _imageFolder; + public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); - private const string RegistrationName = "ImageLoader"; + public ImageLoaderTransform(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderTransform)), columns) + { + ImageFolder = imageFolder; + } - // Public constructor corresponding to SignatureDataTransform. - public ImageLoaderTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestIsText) + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data) { - Host.AssertNonEmpty(Infos); - _imageFolder = args.ImageFolder; - Host.Assert(Infos.Length == Utils.Size(args.Column)); - _type = new ImageType(); - Metadata.Seal(); + return new ImageLoaderTransform(env, args.ImageFolder, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + .MakeDataTransform(data); } - private ImageLoaderTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestIsText) + public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + + ctx.CheckAtModel(GetVersionInfo()); + return new ImageLoaderTransform(env.Register(nameof(ImageLoaderTransform)), ctx); + } + private ImageLoaderTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { // *** Binary format *** // - _imageFolder = ctx.Reader.ReadString(); - _type = new ImageType(); - Metadata.Seal(); + // int: id of image folder + + ImageFolder = ctx.LoadStringOrNull(); } - public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new ImageLoaderTransform(h, ctx, input)); + if (!inputSchema.GetColumnType(srcCol).IsText) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextType.Instance.ToString(), inputSchema.GetColumnType(srcCol).ToString()); } public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** // - ctx.Writer.Write(_imageFolder); - SaveBase(ctx); + // int: id of image folder + + base.SaveColumns(ctx); + ctx.SaveStringOrNull(ImageFolder); } - protected override ColumnType GetColumnTypeCore(int iinfo) + private static VersionInfo GetVersionInfo() { - Host.Check(0 <= iinfo && iinfo < Infos.Length); - return _type; + return new VersionInfo( + modelSignature: "IMGLOADR", + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010002, + loaderSignature: LoaderSignature); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); + + private sealed class Mapper : MapperBase { - Host.AssertValue(ch, nameof(ch)); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - var getSrc = GetSrcGetter(input, iinfo); - DvText src = default; - ValueGetter del = - (ref Bitmap dst) => - { - if (dst != null) - { - dst.Dispose(); - dst = null; - } + private readonly ImageLoaderTransform _parent; + private readonly ImageType _imageType; - getSrc(ref src); + public Mapper(ImageLoaderTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _imageType = new ImageType(); + _parent = parent; + } - if (src.Length > 0) + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); + + disposer = null; + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + DvText src = default; + ValueGetter del = + (ref Bitmap dst) => { - // Catch exceptions and pass null through. Should also log failures... - try + if (dst != null) { - string path = src.ToString(); - if (!string.IsNullOrWhiteSpace(_imageFolder)) - path = Path.Combine(_imageFolder, path); - dst = new Bitmap(path); + dst.Dispose(); + dst = null; } - catch (Exception e) + + getSrc(ref src); + + if (src.Length > 0) { - // REVIEW: We catch everything since the documentation for new Bitmap(string) - // appears to be incorrect. When the file isn't found, it throws an ArgumentException, - // while the documentation says FileNotFoundException. Not sure what it will throw - // in other cases, like corrupted file, etc. - - // REVIEW : Log failures. - ch.Info(e.Message); - ch.Info(e.StackTrace); - dst = null; + // Catch exceptions and pass null through. Should also log failures... + try + { + string path = src.ToString(); + if (!string.IsNullOrWhiteSpace(_parent.ImageFolder)) + path = Path.Combine(_parent.ImageFolder, path); + dst = new Bitmap(path); + } + catch (Exception) + { + // REVIEW: We catch everything since the documentation for new Bitmap(string) + // appears to be incorrect. When the file isn't found, it throws an ArgumentException, + // while the documentation says FileNotFoundException. Not sure what it will throw + // in other cases, like corrupted file, etc. + + // REVIEW : Log failures. + dst = null; + } } - } - }; - return del; + }; + return del; + } + + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent.ColumnPairs.Select(x => new RowMapperColumnInfo(x.output, _imageType, null)).ToArray(); + } + } + + public sealed class ImageLoaderEstimator : TrivialEstimator + { + private readonly ImageType _imageType; + + public ImageLoaderEstimator(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) + : this(env, new ImageLoaderTransform(env, imageFolder, columns)) + { + } + + public ImageLoaderEstimator(IHostEnvironment env, ImageLoaderTransform transformer) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderEstimator)), transformer) + { + _imageType = new ImageType(); + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var (input, output) in Transformer.Columns) + { + var col = inputSchema.FindColumn(input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (!col.ItemType.IsText || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), col.GetTypeString()); + + result[output] = new SchemaShape.Column(output, SchemaShape.Column.VectorKind.Scalar, _imageType, false); + } + + return new SchemaShape(result.Values); } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index de0aa98124..0bd5bf7879 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -3,8 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Drawing; +using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -13,19 +16,24 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform), ImagePixelExtractorTransform.UserName, "ImagePixelExtractorTransform", "ImagePixelExtractor")] -[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform), + ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadModel), + ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadRowMapper), ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics { - // REVIEW: Rewrite as LambdaTransform to simplify. /// /// Transform which takes one or many columns of and convert them into vector representation. /// - public sealed class ImagePixelExtractorTransform : OneToOneTransformBase + public sealed class ImagePixelExtractorTransform : OneToOneTransformerBase { public class Column : OneToOneColumn { @@ -110,24 +118,28 @@ public class Arguments : TransformInputBase /// Which color channels are extracted. Note that these values are serialized so should not be modified. /// [Flags] - private enum ColorBits : byte + public enum ColorBits : byte { Alpha = 0x01, Red = 0x02, Green = 0x04, Blue = 0x08, + Rgb = Red | Green | Blue, All = Alpha | Red | Green | Blue } - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; + public readonly ColorBits Colors; public readonly byte Planes; public readonly bool Convert; - public readonly Single Offset; - public readonly Single Scale; + public readonly float Offset; + public readonly float Scale; public readonly bool Interleave; public bool Alpha { get { return (Colors & ColorBits.Alpha) != 0; } } @@ -135,8 +147,14 @@ private sealed class ColInfoEx public bool Green { get { return (Colors & ColorBits.Green) != 0; } } public bool Blue { get { return (Colors & ColorBits.Blue) != 0; } } - public ColInfoEx(Column item, Arguments args) + internal ColumnInfo(Column item, Arguments args) { + Contracts.CheckValue(item, nameof(item)); + Contracts.CheckValue(args, nameof(args)); + + Input = item.Source ?? item.Name; + Output = item.Name; + if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; } if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; } if (item.UseGreen ?? args.UseGreen) { Colors |= ColorBits.Green; Planes++; } @@ -160,10 +178,57 @@ public ColInfoEx(Column item, Arguments args) } } - public ColInfoEx(ModelLoadContext ctx) + public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false) + : this(input, output, colors, interleave, false, 1f, 0f) + { + } + + public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false, float scale = 1f, float offset = 0f) + : this(input, output, colors, interleave, true, scale, offset) + { + } + + private ColumnInfo(string input, string output, ColorBits colors, bool interleave, bool convert, float scale, float offset) + { + Contracts.CheckNonEmpty(input, nameof(input)); + Contracts.CheckNonEmpty(output, nameof(output)); + + Input = input; + Output = output; + Colors = colors; + + if ((Colors & ColorBits.Alpha) == ColorBits.Alpha) Planes++; + if ((Colors & ColorBits.Red) == ColorBits.Red) Planes++; + if ((Colors & ColorBits.Green) == ColorBits.Green) Planes++; + if ((Colors & ColorBits.Blue) == ColorBits.Blue) Planes++; + Contracts.CheckParam(Planes > 0, nameof(colors), "Need to use at least one color plane"); + + Interleave = interleave; + + Convert = convert; + if (!Convert) + { + Offset = 0; + Scale = 1; + } + else + { + Offset = offset; + Scale = scale; + Contracts.CheckParam(FloatUtils.IsFinite(Offset), nameof(offset)); + Contracts.CheckParam(FloatUtils.IsFiniteNonZero(Scale), nameof(scale)); + } + } + + internal ColumnInfo(string input, string output, ModelLoadContext ctx) { + Contracts.AssertNonEmpty(input); + Contracts.AssertNonEmpty(output); Contracts.AssertValue(ctx); + Input = input; + Output = output; + // *** Binary format *** // byte: colors // byte: convert @@ -193,7 +258,6 @@ public ColInfoEx(ModelLoadContext ctx) public void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); - #if DEBUG // This code is used in deserialization - assert that it matches what we computed above. int planes = (int)Colors; @@ -237,305 +301,368 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ImagePixelExtractor"; - private readonly ColInfoEx[] _exes; - private readonly VectorType[] _types; + private readonly ColumnInfo[] _columns; + + public IReadOnlyCollection Columns => _columns.AsReadOnly(); + + public ImagePixelExtractorTransform(IHostEnvironment env, string inputColumn, string outputColumn, + ColorBits colors = ColorBits.Rgb, bool interleave = false) + : this(env, new ColumnInfo(inputColumn, outputColumn, colors, interleave)) + { + } + + public ImagePixelExtractorTransform(IHostEnvironment env, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) + { + _columns = columns.ToArray(); + } - // Public constructor corresponding to SignatureDataTransform. - public ImagePixelExtractorTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, - t => t is ImageType ? null : "Expected Image type") + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + // SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); + + var columns = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < columns.Length; i++) { var item = args.Column[i]; - _exes[i] = new ColInfoEx(item, args); + columns[i] = new ColumnInfo(item, args); } - _types = ConstructTypes(true); + var transformer = new ImagePixelExtractorTransform(env, columns); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type") + public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); - // *** Binary format *** - // - // - // foreach added column - // ColInfoEx - Host.AssertNonEmpty(Infos); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) - _exes[i] = new ColInfoEx(ctx); - - _types = ConstructTypes(false); + return new ImagePixelExtractorTransform(host, ctx); } - public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // - return h.Apply("Loading Model", - ch => - { - // *** Binary format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - ch.CheckDecode(cbFloat == sizeof(Single)); - return new ImagePixelExtractorTransform(h, ctx, input); - }); + // for each added column + // ColumnInfo + + _columns = new ColumnInfo[ColumnPairs.Length]; + for (int i = 0; i < _columns.Length; i++) + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx); } + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) // - // foreach added column - // ColInfoEx - ctx.Writer.Write(sizeof(Single)); - SaveBase(ctx); - - Host.Assert(_exes.Length == Infos.Length); - for (int i = 0; i < _exes.Length; i++) - _exes[i].Save(ctx); - } - private VectorType[] ConstructTypes(bool user) - { - var types = new VectorType[Infos.Length]; - for (int i = 0; i < Infos.Length; i++) - { - var info = Infos[i]; - var ex = _exes[i]; - Host.Assert(ex.Planes > 0); + // for each added column + // ColumnInfo - var type = Source.Schema.GetColumnType(info.Source) as ImageType; - Host.Assert(type != null); - if (type.Height <= 0 || type.Width <= 0) - { - // REVIEW: Could support this case by making the destination column be variable sized. - // However, there's no mechanism to communicate the dimensions through with the pixel data. - string name = Source.Schema.GetColumnName(info.Source); - throw user ? - Host.ExceptUserArg(nameof(Arguments.Column), "Column '{0}' does not have known size", name) : - Host.Except("Column '{0}' does not have known size", name); - } - int height = type.Height; - int width = type.Width; - Host.Assert(height > 0); - Host.Assert(width > 0); - Host.Assert((long)height * width <= int.MaxValue / 4); - - if (ex.Interleave) - types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, height, width, ex.Planes); - else - types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, ex.Planes, height, width); - } - Metadata.Seal(); - return types; + base.SaveColumns(ctx); + + foreach (ColumnInfo info in _columns) + info.Save(ctx); } - protected override ColumnType GetColumnTypeCore(int iinfo) + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return _types[iinfo]; + var inputColName = _columns[col].Input; + var imageType = inputSchema.GetColumnType(srcCol) as ImageType; + if (imageType == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "image", inputSchema.GetColumnType(srcCol).ToString()); + if (imageType.Height <= 0 || imageType.Width <= 0) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "known-size image", "unknown-size image"); + if ((long)imageType.Height * imageType.Width > int.MaxValue / 4) + throw Host.Except("Image dimensions are too large"); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + private sealed class Mapper : MapperBase { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + private readonly ImagePixelExtractorTransform _parent; + private readonly VectorType[] _types; - if (_exes[iinfo].Convert) - return GetGetterCore(input, iinfo, out disposer); - return GetGetterCore(input, iinfo, out disposer); - } + public Mapper(ImagePixelExtractorTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _types = ConstructTypes(); + } - //REVIEW Rewrite it to where TValue : IConvertible - private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) - { - var type = _types[iinfo]; - Host.Assert(type.DimCount == 3); + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent._columns.Select((x, idx) => new RowMapperColumnInfo(x.Output, _types[idx], null)).ToArray(); - var ex = _exes[iinfo]; + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); - int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0); - int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1); - int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2); + if (_parent._columns[iinfo].Convert) + return GetGetterCore(input, iinfo, out disposer); + return GetGetterCore(input, iinfo, out disposer); + } - int size = type.ValueCount; - Host.Assert(size > 0); - Host.Assert(size == planes * height * width); - int cpix = height * width; + //REVIEW Rewrite it to where TValue : IConvertible + private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) + { + var type = _types[iinfo]; + Contracts.Assert(type.DimCount == 3); - var getSrc = GetSrcGetter(input, iinfo); - var src = default(Bitmap); + var ex = _parent._columns[iinfo]; - disposer = - () => - { - if (src != null) - { - src.Dispose(); - src = null; - } - }; + int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0); + int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1); + int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2); - return - (ref VBuffer dst) => - { - getSrc(ref src); - Contracts.AssertValueOrNull(src); + int size = type.ValueCount; + Contracts.Assert(size > 0); + Contracts.Assert(size == planes * height * width); + int cpix = height * width; + + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + var src = default(Bitmap); - if (src == null) + disposer = + () => { - dst = new VBuffer(size, 0, dst.Values, dst.Indices); - return; - } + if (src != null) + { + src.Dispose(); + src = null; + } + }; - Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb); - Host.Check(src.Height == height && src.Width == width); + return + (ref VBuffer dst) => + { + getSrc(ref src); + Contracts.AssertValueOrNull(src); - var values = dst.Values; - if (Utils.Size(values) < size) - values = new TValue[size]; + if (src == null) + { + dst = new VBuffer(size, 0, dst.Values, dst.Indices); + return; + } - Single offset = ex.Offset; - Single scale = ex.Scale; - Host.Assert(scale != 0); + Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb); + Host.Check(src.Height == height && src.Width == width); - var vf = values as Single[]; - var vb = values as byte[]; - Host.Assert(vf != null || vb != null); - bool needScale = offset != 0 || scale != 1; - Host.Assert(!needScale || vf != null); + var values = dst.Values; + if (Utils.Size(values) < size) + values = new TValue[size]; - bool a = ex.Alpha; - bool r = ex.Red; - bool g = ex.Green; - bool b = ex.Blue; + float offset = ex.Offset; + float scale = ex.Scale; + Contracts.Assert(scale != 0); - int h = height; - int w = width; + var vf = values as float[]; + var vb = values as byte[]; + Contracts.Assert(vf != null || vb != null); + bool needScale = offset != 0 || scale != 1; + Contracts.Assert(!needScale || vf != null); - if (ex.Interleave) - { - int idst = 0; - for (int y = 0; y < h; ++y) - for (int x = 0; x < w; x++) - { - var pb = src.GetPixel(y, x); - if (vb != null) + bool a = ex.Alpha; + bool r = ex.Red; + bool g = ex.Green; + bool b = ex.Blue; + + int h = height; + int w = width; + + if (ex.Interleave) + { + int idst = 0; + for (int y = 0; y < h; ++y) + for (int x = 0; x < w; x++) { - if (a) { vb[idst++] = (byte)0; } - if (r) { vb[idst++] = pb.R; } - if (g) { vb[idst++] = pb.G; } - if (b) { vb[idst++] = pb.B; } + var pb = src.GetPixel(y, x); + if (vb != null) + { + if (a) { vb[idst++] = (byte)0; } + if (r) { vb[idst++] = pb.R; } + if (g) { vb[idst++] = pb.G; } + if (b) { vb[idst++] = pb.B; } + } + else if (!needScale) + { + if (a) { vf[idst++] = 0.0f; } + if (r) { vf[idst++] = pb.R; } + if (g) { vf[idst++] = pb.G; } + if (b) { vf[idst++] = pb.B; } + } + else + { + if (a) { vf[idst++] = 0.0f; } + if (r) { vf[idst++] = (pb.R - offset) * scale; } + if (g) { vf[idst++] = (pb.B - offset) * scale; } + if (b) { vf[idst++] = (pb.G - offset) * scale; } + } } - else if (!needScale) + Contracts.Assert(idst == size); + } + else + { + int idstMin = 0; + if (ex.Alpha) + { + // The image only has rgb but we need to supply alpha as well, so fake it up, + // assuming that it is 0xFF. + if (vf != null) { - if (a) { vf[idst++] = 0.0f; } - if (r) { vf[idst++] = pb.R; } - if (g) { vf[idst++] = pb.G; } - if (b) { vf[idst++] = pb.B; } + Single v = (0xFF - offset) * scale; + for (int i = 0; i < cpix; i++) + vf[i] = v; } else { - if (a) { vf[idst++] = 0.0f; } - if (r) { vf[idst++] = (pb.R - offset) * scale; } - if (g) { vf[idst++] = (pb.B - offset) * scale; } - if (b) { vf[idst++] = (pb.G - offset) * scale; } + for (int i = 0; i < cpix; i++) + vb[i] = 0xFF; } - } - Host.Assert(idst == size); - } - else - { - int idstMin = 0; - if (ex.Alpha) - { - // The image only has rgb but we need to supply alpha as well, so fake it up, - // assuming that it is 0xFF. - if (vf != null) - { - Single v = (0xFF - offset) * scale; - for (int i = 0; i < cpix; i++) - vf[i] = v; - } - else - { - for (int i = 0; i < cpix; i++) - vb[i] = 0xFF; - } - idstMin = cpix; + idstMin = cpix; - // We've preprocessed alpha, avoid it in the - // scan operation below. - a = false; - } - - for (int y = 0; y < h; ++y) - { - int idstBase = idstMin + y * w; + // We've preprocessed alpha, avoid it in the + // scan operation below. + a = false; + } - // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB. - if (vb != null) + for (int y = 0; y < h; ++y) { - for (int x = 0; x < w; x++, idstBase++) + int idstBase = idstMin + y * w; + + // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB. + if (vb != null) { - var pb = src.GetPixel(x, y); - int idst = idstBase; - if (a) { vb[idst] = pb.A; idst += cpix; } - if (r) { vb[idst] = pb.R; idst += cpix; } - if (g) { vb[idst] = pb.G; idst += cpix; } - if (b) { vb[idst] = pb.B; idst += cpix; } + for (int x = 0; x < w; x++, idstBase++) + { + var pb = src.GetPixel(x, y); + int idst = idstBase; + if (a) { vb[idst] = pb.A; idst += cpix; } + if (r) { vb[idst] = pb.R; idst += cpix; } + if (g) { vb[idst] = pb.G; idst += cpix; } + if (b) { vb[idst] = pb.B; idst += cpix; } + } } - } - else if (!needScale) - { - for (int x = 0; x < w; x++, idstBase++) + else if (!needScale) { - var pb = src.GetPixel(x, y); - int idst = idstBase; - if (a) { vf[idst] = pb.A; idst += cpix; } - if (r) { vf[idst] = pb.R; idst += cpix; } - if (g) { vf[idst] = pb.G; idst += cpix; } - if (b) { vf[idst] = pb.B; idst += cpix; } + for (int x = 0; x < w; x++, idstBase++) + { + var pb = src.GetPixel(x, y); + int idst = idstBase; + if (a) { vf[idst] = pb.A; idst += cpix; } + if (r) { vf[idst] = pb.R; idst += cpix; } + if (g) { vf[idst] = pb.G; idst += cpix; } + if (b) { vf[idst] = pb.B; idst += cpix; } + } } - } - else - { - for (int x = 0; x < w; x++, idstBase++) + else { - var pb = src.GetPixel(x, y); - int idst = idstBase; - if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; } - if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; } - if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; } - if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; } + for (int x = 0; x < w; x++, idstBase++) + { + var pb = src.GetPixel(x, y); + int idst = idstBase; + if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; } + if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; } + if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; } + if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; } + } } } } - } - dst = new VBuffer(size, values, dst.Indices); - }; + dst = new VBuffer(size, values, dst.Indices); + }; + } + + private VectorType[] ConstructTypes() + { + var types = new VectorType[_parent._columns.Length]; + for (int i = 0; i < _parent._columns.Length; i++) + { + var column = _parent._columns[i]; + Contracts.Assert(column.Planes > 0); + + var type = InputSchema.GetColumnType(ColMapNewToOld[i]) as ImageType; + Contracts.Assert(type != null); + + int height = type.Height; + int width = type.Width; + Contracts.Assert(height > 0); + Contracts.Assert(width > 0); + Contracts.Assert((long)height * width <= int.MaxValue / 4); + + if (column.Interleave) + types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, height, width, column.Planes); + else + types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, column.Planes, height, width); + } + return types; + } + } + } + + public sealed class ImagePixelExtractorEstimator : TrivialEstimator + { + public ImagePixelExtractorEstimator(IHostEnvironment env, string inputColumn, string outputColumn, + ImagePixelExtractorTransform.ColorBits colors = ImagePixelExtractorTransform.ColorBits.Rgb, bool interleave = false) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, inputColumn, outputColumn, colors, interleave)) + { + } + + public ImagePixelExtractorEstimator(IHostEnvironment env, params ImagePixelExtractorTransform.ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, columns)) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in Transformer.Columns) + { + var col = inputSchema.FindColumn(colInfo.Input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString()); + + var itemType = colInfo.Convert ? NumberType.R4 : NumberType.U1; + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, itemType, false); + } + + return new SchemaShape(result.Values); } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index dd1abc9181..ac11c7fa8d 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -3,8 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Drawing; +using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -20,13 +23,19 @@ [assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform), ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(ImageResizerTransform), null, typeof(SignatureLoadModel), + ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageResizerTransform), null, typeof(SignatureLoadRowMapper), + ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] + namespace Microsoft.ML.Runtime.ImageAnalytics { // REVIEW: Rewrite as LambdaTransform to simplify. /// /// Transform which takes one or many columns of and resize them to provided height and width. /// - public sealed class ImageResizerTransform : OneToOneTransformBase + public sealed class ImageResizerTransform : OneToOneTransformerBase { public enum ResizingKind : byte { @@ -98,23 +107,30 @@ public class Arguments : TransformInputBase } /// - /// Extra information for each column (in addition to ColumnInfo). + /// Information for each column pair. /// - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; + public readonly int Width; public readonly int Height; public readonly ResizingKind Scale; public readonly Anchor Anchor; public readonly ColumnType Type; - public ColInfoEx(int width, int height, ResizingKind scale, Anchor anchor) + public ColumnInfo(string input, string output, int width, int height, ResizingKind scale, Anchor anchor) { + Contracts.CheckNonEmpty(input, nameof(input)); + Contracts.CheckNonEmpty(output, nameof(output)); Contracts.CheckUserArg(width > 0, nameof(Column.ImageWidth)); Contracts.CheckUserArg(height > 0, nameof(Column.ImageHeight)); Contracts.CheckUserArg(Enum.IsDefined(typeof(ResizingKind), scale), nameof(Column.Resizing)); Contracts.CheckUserArg(Enum.IsDefined(typeof(Anchor), anchor), nameof(Column.CropAnchor)); + Input = input; + Output = output; Width = width; Height = height; Scale = scale; @@ -133,53 +149,87 @@ private static VersionInfo GetVersionInfo() return new VersionInfo( modelSignature: "IMGSCALF", //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010002, + //verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap + verWrittenCur: 0x00010003, // No more sizeof(float) + verReadableCur: 0x00010003, + verWeCanReadBack: 0x00010003, loaderSignature: LoaderSignature); } private const string RegistrationName = "ImageScaler"; - // This is parallel to Infos. - private readonly ColInfoEx[] _exes; + private readonly ColumnInfo[] _columns; + + public IReadOnlyCollection Columns => _columns.AsReadOnly(); + + public ImageResizerTransform(IHostEnvironment env, string inputColumn, string outputColumn, + int imageWidth, int imageHeight, ResizingKind resizing = ResizingKind.IsoCrop, Anchor cropAnchor = Anchor.Center) + : this(env, new ColumnInfo(inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor)) + { + } + + public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) + { + _columns = columns.ToArray(); + } + + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } - // Public constructor corresponding to SignatureDataTransform. - public ImageResizerTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type") + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - _exes[i] = new ColInfoEx( + cols[i] = new ColumnInfo( + item.Source ?? item.Name, + item.Name, item.ImageWidth ?? args.ImageWidth, item.ImageHeight ?? args.ImageHeight, item.Resizing ?? args.Resizing, item.CropAnchor ?? args.CropAnchor); } - Metadata.Seal(); + + return new ImageResizerTransform(env, cols).MakeDataTransform(input); } - private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type") + public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new ImageResizerTransform(host, ctx); + } + + private ImageResizerTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { // *** Binary format *** - // // + // for each added column // int: width // int: height // byte: scaling kind - Host.AssertNonEmpty(Infos); + // byte: anchor - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + _columns = new ColumnInfo[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) { int width = ctx.Reader.ReadInt32(); Host.CheckDecode(width > 0); @@ -189,182 +239,224 @@ private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input) Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); var anchor = (Anchor)ctx.Reader.ReadByte(); Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); - _exes[i] = new ColInfoEx(width, height, scale, anchor); + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, width, height, scale, anchor); } - Metadata.Seal(); } - public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", - ch => - { - // *** Binary format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - ch.CheckDecode(cbFloat == sizeof(Single)); - return new ImageResizerTransform(h, ctx, input); - }); - } + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) // + // for each added column // int: width // int: height // byte: scaling kind - ctx.Writer.Write(sizeof(Single)); - SaveBase(ctx); + // byte: anchor + + base.SaveColumns(ctx); - Host.Assert(_exes.Length == Infos.Length); - for (int i = 0; i < _exes.Length; i++) + foreach (var col in _columns) { - var ex = _exes[i]; - ctx.Writer.Write(ex.Width); - ctx.Writer.Write(ex.Height); - Host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale); - ctx.Writer.Write((byte)ex.Scale); - Host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor); - ctx.Writer.Write((byte)ex.Anchor); + ctx.Writer.Write(col.Width); + ctx.Writer.Write(col.Height); + Contracts.Assert((ResizingKind)(byte)col.Scale == col.Scale); + ctx.Writer.Write((byte)col.Scale); + Contracts.Assert((Anchor)(byte)col.Anchor == col.Anchor); + ctx.Writer.Write((byte)col.Anchor); } } - protected override ColumnType GetColumnTypeCore(int iinfo) + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Host.Check(0 <= iinfo && iinfo < Infos.Length); - return _exes[iinfo].Type; + if (!(inputSchema.GetColumnType(srcCol) is ImageType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].Input, "image", inputSchema.GetColumnType(srcCol).ToString()); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + private sealed class Mapper : MapperBase { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - - var src = default(Bitmap); - var getSrc = GetSrcGetter(input, iinfo); - var ex = _exes[iinfo]; - - disposer = - () => - { - if (src != null) - { - src.Dispose(); - src = null; - } - }; - - ValueGetter del = - (ref Bitmap dst) => - { - if (dst != null) - dst.Dispose(); - - getSrc(ref src); - if (src == null || src.Height <= 0 || src.Width <= 0) - return; - if (src.Height == ex.Height && src.Width == ex.Width) - { - dst = src; - return; - } - - int sourceWidth = src.Width; - int sourceHeight = src.Height; - int sourceX = 0; - int sourceY = 0; - int destX = 0; - int destY = 0; - int destWidth = 0; - int destHeight = 0; - float aspect = 0; - float widthAspect = 0; - float heightAspect = 0; - - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; - - if (ex.Scale == ResizingKind.IsoPad) + private readonly ImageResizerTransform _parent; + + public Mapper(ImageResizerTransform parent, ISchema inputSchema) + :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + } + + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent._columns.Select(x => new RowMapperColumnInfo(x.Output, x.Type, null)).ToArray(); + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); + + var src = default(Bitmap); + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + var info = _parent._columns[iinfo]; + + disposer = + () => { - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; - if (heightAspect < widthAspect) + if (src != null) { - aspect = heightAspect; - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); + src.Dispose(); + src = null; } - else + }; + + ValueGetter del = + (ref Bitmap dst) => + { + if (dst != null) + dst.Dispose(); + + getSrc(ref src); + if (src == null || src.Height <= 0 || src.Width <= 0) + return; + if (src.Height == info.Height && src.Width == info.Width) { - aspect = widthAspect; - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + dst = src; + return; } - destWidth = (int)(sourceWidth * aspect); - destHeight = (int)(sourceHeight * aspect); - } - else - { - if (heightAspect < widthAspect) + int sourceWidth = src.Width; + int sourceHeight = src.Height; + int sourceX = 0; + int sourceY = 0; + int destX = 0; + int destY = 0; + int destWidth = 0; + int destHeight = 0; + float aspect = 0; + float widthAspect = 0; + float heightAspect = 0; + + widthAspect = (float)info.Width / sourceWidth; + heightAspect = (float)info.Height / sourceHeight; + + if (info.Scale == ResizingKind.IsoPad) { - aspect = widthAspect; - switch (ex.Anchor) + widthAspect = (float)info.Width / sourceWidth; + heightAspect = (float)info.Height / sourceHeight; + if (heightAspect < widthAspect) { - case Anchor.Top: - destY = 0; - break; - case Anchor.Bottom: - destY = (int)(ex.Height - (sourceHeight * aspect)); - break; - default: - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); - break; + aspect = heightAspect; + destX = (int)((info.Width - (sourceWidth * aspect)) / 2); } + else + { + aspect = widthAspect; + destY = (int)((info.Height - (sourceHeight * aspect)) / 2); + } + + destWidth = (int)(sourceWidth * aspect); + destHeight = (int)(sourceHeight * aspect); } else { - aspect = heightAspect; - switch (ex.Anchor) + if (heightAspect < widthAspect) + { + aspect = widthAspect; + switch (info.Anchor) + { + case Anchor.Top: + destY = 0; + break; + case Anchor.Bottom: + destY = (int)(info.Height - (sourceHeight * aspect)); + break; + default: + destY = (int)((info.Height - (sourceHeight * aspect)) / 2); + break; + } + } + else { - case Anchor.Left: - destX = 0; - break; - case Anchor.Right: - destX = (int)(ex.Width - (sourceWidth * aspect)); - break; - default: - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); - break; + aspect = heightAspect; + switch (info.Anchor) + { + case Anchor.Left: + destX = 0; + break; + case Anchor.Right: + destX = (int)(info.Width - (sourceWidth * aspect)); + break; + default: + destX = (int)((info.Width - (sourceWidth * aspect)) / 2); + break; + } } + + destWidth = (int)(sourceWidth * aspect); + destHeight = (int)(sourceHeight * aspect); + } + dst = new Bitmap(info.Width, info.Height); + var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); + var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); + using (var g = Graphics.FromImage(dst)) + { + g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); } + Contracts.Assert(dst.Width == info.Width && dst.Height == info.Height); + }; - destWidth = (int)(sourceWidth * aspect); - destHeight = (int)(sourceHeight * aspect); - } - dst = new Bitmap(ex.Width, ex.Height); - var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); - var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); - using (var g = Graphics.FromImage(dst)) - { - g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); - } - Host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); - }; + return del; + } + } + } + + public sealed class ImageResizerEstimator : TrivialEstimator + { + public ImageResizerEstimator(IHostEnvironment env, string inputColumn, string outputColumn, + int imageWidth, int imageHeight, ImageResizerTransform.ResizingKind resizing = ImageResizerTransform.ResizingKind.IsoCrop, ImageResizerTransform.Anchor cropAnchor = ImageResizerTransform.Anchor.Center) + : this(env, new ImageResizerTransform(env, inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor)) + { + } + + public ImageResizerEstimator(IHostEnvironment env, params ImageResizerTransform.ColumnInfo[] columns) + : this(env, new ImageResizerTransform(env, columns)) + { + } + + public ImageResizerEstimator(IHostEnvironment env, ImageResizerTransform transformer) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageResizerEstimator)), transformer) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in Transformer.Columns) + { + var col = inputSchema.FindColumn(colInfo.Input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString()); + + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Scalar, colInfo.Type, false); + } - return del; + return new SchemaShape(result.Values); } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 852ea09d9d..fd31302808 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -35,7 +35,7 @@ public override bool Equals(ColumnType other) return false; if (Height != tmp.Height) return false; - return Width != tmp.Width; + return Width == tmp.Width; } public override bool Equals(object other) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 2da2075728..52d2d3aef0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1407,8 +1407,8 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args, _positiveInstanceWeight = _args.PositiveInstanceWeight; OutputColumns = new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) }; } @@ -1426,7 +1426,8 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar) error(); - if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8 && labelCol.ItemKind != DataKind.BL) + + if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !labelCol.ItemType.IsBool) error(); } @@ -1434,17 +1435,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false); + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); } private static SchemaShape.Column MakeFeatureColumn(string featureColumn) { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false); + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias) diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index f775be92bd..c49593332b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -57,8 +57,8 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, _args = args; OutputColumns = new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true) }; } @@ -76,7 +76,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar) error(); - if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8) + if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8) error(); } @@ -84,17 +84,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true); + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); } private static SchemaShape.Column MakeFeatureColumn(string featureColumn) { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false); + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 163da10e7b..0b620959c3 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -61,7 +61,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur _args = args; OutputColumns = new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) }; } @@ -73,17 +73,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeFeatureColumn(string featureColumn) { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false); + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } protected override LinearRegressionPredictor CreatePredictor(VBuffer[] weights, Float[] bias) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 760cf9c910..d189e627f5 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -5,6 +5,8 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; @@ -17,6 +19,108 @@ namespace Microsoft.ML.Runtime.RunTests { public abstract partial class TestDataPipeBase : TestDataViewBase { + /// + /// 'Workout test' for an estimator. + /// Checks the following traits: + /// - the estimator is applicable to the validFitInput, and not applicable to validTransformInput and invalidInput; + /// - the fitted transformer is applicable to validFitInput and validTransformInput, and not applicable to invalidInput; + /// - fitted transformer can be saved and re-loaded into the transformer with the same behavior. + /// - schema propagation for fitted transformer conforms to schema propagation of estimator. + /// + protected void TestEstimatorCore(IEstimator estimator, + IDataView validFitInput, IDataView validTransformInput = null, IDataView invalidInput = null) + { + Contracts.AssertValue(estimator); + Contracts.AssertValue(validFitInput); + Contracts.AssertValueOrNull(validTransformInput); + Contracts.AssertValueOrNull(invalidInput); + Action mustFail = (Action action) => + { + try + { + action(); + Assert.False(true); + } + catch (ArgumentOutOfRangeException) { } + catch (InvalidOperationException) { } + }; + + // Schema propagation tests for estimator. + var outSchemaShape = estimator.GetOutputSchema(SchemaShape.Create(validFitInput.Schema)); + if (validTransformInput != null) + { + mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(validTransformInput.Schema))); + mustFail(() => estimator.Fit(validTransformInput)); + } + + if (invalidInput != null) + { + mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(invalidInput.Schema))); + mustFail(() => estimator.Fit(invalidInput)); + } + + var transformer = estimator.Fit(validFitInput); + // Save and reload. + string modelPath = GetOutputPath(TestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + transformer.SaveTo(Env, fs); + + ITransformer loadedTransformer; + using (var fs = File.OpenRead(modelPath)) + loadedTransformer = TransformerChain.LoadFrom(Env, fs); + DeleteOutputPath(modelPath); + + // Run on train data. + Action checkOnData = (IDataView data) => + { + var schema = transformer.GetOutputSchema(data.Schema); + + // Loaded transformer needs to have the same schema propagation. + CheckSameSchemas(schema, loadedTransformer.GetOutputSchema(data.Schema)); + + var scoredTrain = transformer.Transform(data); + var scoredTrain2 = loadedTransformer.Transform(data); + + // The schema of the transformed data must match the schema provided by schema propagation. + CheckSameSchemas(schema, scoredTrain.Schema); + + // The schema and data of scored dataset must be identical between loaded + // and original transformer. + // This in turn means that the schema of loaded transformer matches for + // Transform and GetOutputSchema calls. + CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); + CheckSameValues(scoredTrain, scoredTrain2); + }; + + checkOnData(validFitInput); + + if (validTransformInput != null) + checkOnData(validTransformInput); + + if (invalidInput != null) + { + mustFail(() => transformer.GetOutputSchema(invalidInput.Schema)); + mustFail(() => transformer.Transform(invalidInput)); + mustFail(() => loadedTransformer.GetOutputSchema(invalidInput.Schema)); + mustFail(() => loadedTransformer.Transform(invalidInput)); + } + + // Schema verification between estimator and transformer. + var scoredTrainSchemaShape = SchemaShape.Create(transformer.GetOutputSchema(validFitInput.Schema)); + CheckSameSchemaShape(outSchemaShape, scoredTrainSchemaShape); + } + + private void CheckSameSchemaShape(SchemaShape first, SchemaShape second) + { + Assert.True(first.Columns.Length == second.Columns.Length); + var sortedCols1 = first.Columns.OrderBy(x => x.Name); + var sortedCols2 = second.Columns.OrderBy(x => x.Name); + + Assert.True(sortedCols1.Zip(sortedCols2, + (x, y) => x.IsCompatibleWith(y) && y.IsCompatibleWith(x)) + .All(x => x)); + } + // REVIEW: incorporate the testing for re-apply logic here? /// /// Create PipeDataLoader from the given args, save it, re-load it, verify that the data of @@ -878,41 +982,44 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ { switch (type.RawKind) { - case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U1: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U2: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U4: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U8: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.R4: - return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - case DataKind.R8: - if (exactDoubles) - return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - else - return GetComparerOne(r1, r2, col, EqualWithEps); - case DataKind.Text: - return GetComparerOne(r1, r2, col, DvText.Identical); - case DataKind.Bool: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.TimeSpan: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.DT: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.DZ: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.UG: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.I1: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U1: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.I2: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U2: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.I4: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U4: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.I8: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U8: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.R4: + return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + case DataKind.R8: + if (exactDoubles) + return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + else + return GetComparerOne(r1, r2, col, EqualWithEps); + case DataKind.Text: + return GetComparerOne(r1, r2, col, DvText.Identical); + case DataKind.Bool: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.TimeSpan: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.DT: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.DZ: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.UG: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case (DataKind)0: + // We cannot compare custom types (including image). + return () => true; } } else @@ -921,41 +1028,41 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ Contracts.Assert(size >= 0); switch (type.ItemType.RawKind) { - case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U1: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U2: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U4: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U8: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.R4: - return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - case DataKind.R8: - if (exactDoubles) - return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - else - return GetComparerVec(r1, r2, col, size, EqualWithEps); - case DataKind.Text: - return GetComparerVec(r1, r2, col, size, DvText.Identical); - case DataKind.Bool: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.TimeSpan: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.DT: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.DZ: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.UG: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.I1: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U1: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.I2: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U2: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.I4: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U4: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.I8: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U8: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.R4: + return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + case DataKind.R8: + if (exactDoubles) + return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + else + return GetComparerVec(r1, r2, col, size, EqualWithEps); + case DataKind.Text: + return GetComparerVec(r1, r2, col, size, DvText.Identical); + case DataKind.Bool: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.TimeSpan: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.DT: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.DZ: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.UG: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); } } diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index a12032400a..39087b9606 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -5,20 +5,76 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using System.Drawing; using System.IO; +using System.Linq; using Xunit; using Xunit.Abstractions; namespace Microsoft.ML.Tests { - public class ImageTests : BaseTestClass + public class ImageTests : TestDataPipeBase { public ImageTests(ITestOutputHelper output) : base(output) { } + [Fact] + public void TestEstimatorChain() + { + using (var env = new TlcEnvironment()) + { + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var invalidData = env.CreateLoader("Text{col=ImagePath:R4:0}", new MultiFileSource(dataFile)); + + var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) + .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) + .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels")) + .Append(new ImageGrayscaleEstimator(env, ("ImageReal", "ImageGray"))); + + TestEstimatorCore(pipe, data, null, invalidData); + } + Done(); + } + + [Fact] + public void TestEstimatorSaveLoad() + { + using (var env = new TlcEnvironment()) + { + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + + var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) + .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) + .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels")) + .Append(new ImageGrayscaleEstimator(env, ("ImageReal", "ImageGray"))); + + pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema)); + var model = pipe.Fit(data); + + using (var file = env.CreateTempFile()) + { + using (var fs = file.CreateWriteStream()) + model.SaveTo(env, fs); + var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); + + var newCols = ((ImageLoaderTransform)model2.First()).Columns; + var oldCols = ((ImageLoaderTransform)model.First()).Columns; + Assert.True(newCols + .Zip(oldCols, (x, y) => x == y) + .All(x => x)); + } + } + Done(); + } + [Fact] public void TestSaveImages() { @@ -27,7 +83,7 @@ public void TestSaveImages() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { @@ -36,7 +92,7 @@ public void TestSaveImages() ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + IDataView cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =100, ImageWidth = 100, Resizing = ImageResizerTransform.ResizingKind.IsoPad} @@ -61,6 +117,7 @@ public void TestSaveImages() } } } + Done(); } [Fact] @@ -73,7 +130,7 @@ public void TestGreyscaleTransformImages() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { @@ -81,20 +138,29 @@ public void TestGreyscaleTransformImages() }, ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } }, images); - var grey = new ImageGrayscaleTransform(env, new ImageGrayscaleTransform.Arguments() + IDataView grey = ImageGrayscaleTransform.Create(env, new ImageGrayscaleTransform.Arguments() { Column = new ImageGrayscaleTransform.Column[1]{ new ImageGrayscaleTransform.Column() { Name= "ImageGrey", Source = "ImageCropped"} } }, cropped); + var fname = nameof(TestGreyscaleTransformImages) + "_model.zip"; + + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(grey)); + + grey = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); + grey.Schema.TryGetColumnIndex("ImageGrey", out int greyColumn); using (var cursor = grey.GetRowCursor((x) => true)) { @@ -114,6 +180,7 @@ public void TestGreyscaleTransformImages() } } } + Done(); } [Fact] @@ -126,7 +193,7 @@ public void TestBackAndForthConversion() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { @@ -134,27 +201,37 @@ public void TestBackAndForthConversion() }, ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } }, images); - var pixels = new ImagePixelExtractorTransform(env, new ImagePixelExtractorTransform.Arguments() + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() { Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true} } }, cropped); - var backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } }, pixels); + var fname = nameof(TestBackAndForthConversion) + "_model.zip"; + + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); + + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); using (var cursor = backToBitmaps.GetRowCursor((x) => true)) @@ -178,6 +255,7 @@ public void TestBackAndForthConversion() } } } + Done(); } } } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index ed450cd8c0..c04b35669c 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -185,7 +185,7 @@ public void TensorFlowTransformCifar() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { @@ -193,14 +193,14 @@ public void TensorFlowTransformCifar() }, ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } }, images); - var pixels = new ImagePixelExtractorTransform(env, new ImagePixelExtractorTransform.Arguments() + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() { Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true}