From 70d761196e29744b7ac7b748ed3c959959c9d32c Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 08:53:07 -0700 Subject: [PATCH 01/17] Converted ImageLoaderTransform to be an Estimator/Transformer pair --- src/Microsoft.ML.Core/Data/IEstimator.cs | 16 +- .../EntryPoints/ImageAnalytics.cs | 2 +- .../ImageLoaderTransform.cs | 354 ++++++++++++++---- .../Standard/LinearClassificationTrainer.cs | 13 +- .../Standard/SdcaMultiClass.cs | 12 +- .../Standard/SdcaRegression.cs | 8 +- test/Microsoft.ML.Tests/ImagesTests.cs | 41 +- 7 files changed, 334 insertions(+), 112 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 6f21a1cb01..da1cd683b0 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -30,18 +30,18 @@ public enum VectorKind public readonly string Name; public readonly VectorKind Kind; - public readonly DataKind ItemKind; + public readonly ColumnType ItemType; public readonly bool IsKey; public readonly string[] MetadataKinds; - public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds = null) + public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, string[] metadataKinds = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValueOrNull(metadataKinds); Name = name; Kind = vecKind; - ItemKind = itemKind; + ItemType = itemType; IsKey = isKey; MetadataKinds = metadataKinds ?? new string[0]; } @@ -51,7 +51,7 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st /// requirement. /// /// Namely, it returns true iff: - /// - The , , , fields match. + /// - The , , , fields match. /// - The of is a superset of our . /// public bool IsCompatibleWith(Column inputColumn) @@ -61,7 +61,7 @@ public bool IsCompatibleWith(Column inputColumn) return false; if (Kind != inputColumn.Kind) return false; - if (ItemKind != inputColumn.ItemKind) + if (!ItemType.Equals(inputColumn.ItemType)) return false; if (IsKey != inputColumn.IsKey) return false; @@ -72,7 +72,7 @@ public bool IsCompatibleWith(Column inputColumn) public string GetTypeString() { - string result = ItemKind.ToString(); + string result = ItemType.ToString(); if (IsKey) result = $"Key<{result}>"; if (Kind == VectorKind.Vector) @@ -110,13 +110,13 @@ public static SchemaShape Create(ISchema schema) else vecKind = Column.VectorKind.Scalar; - var kind = type.ItemType.RawKind; + var itemKind = type.ItemType.RawKind; var isKey = type.ItemType.IsKey; var metadataNames = schema.GetMetadataTypes(iCol) .Select(kvp => kvp.Key) .ToArray(); - cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames)); + cols.Add(new Column(schema.GetColumnName(iCol), vecKind, PrimitiveType.FromKind(itemKind), isKey, metadataNames)); } } return new SchemaShape(cols.ToArray()); diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs index 97c613485f..e38ff0daa0 100644 --- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs +++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs @@ -16,7 +16,7 @@ public static class ImageAnalytics public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageLoaderTransform", input); - var xf = new ImageLoaderTransform(h, input, input.Data); + var xf = ImageLoaderTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 488c710743..d571a981b3 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -13,20 +13,81 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Core.Data; +using System.Collections.Generic; +using System.Linq; -[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform), ImageLoaderTransform.UserName, "ImageLoaderTransform", "ImageLoader")] -[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform), ImageLoaderTransform.UserName, ImageLoaderTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(ImageLoaderTransform), null, typeof(SignatureLoadModel), "", ImageLoaderTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageLoaderTransform.Mapper), null, typeof(SignatureLoadRowMapper), "", ImageLoaderTransform.LoaderSignature)] + namespace Microsoft.ML.Runtime.ImageAnalytics { - // REVIEW: Rewrite as LambdaTransform to simplify. + public abstract class TrivialEstimator : IEstimator + where TTransformer : class, ITransformer + { + protected readonly IHost Host; + protected readonly TTransformer Transformer; + + protected TrivialEstimator(IHost host, TTransformer transformer) + { + Contracts.AssertValue(host); + + Host = host; + Host.CheckValue(transformer, nameof(transformer)); + Transformer = transformer; + } + + public TTransformer Fit(IDataView input) => Transformer; + + public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); + } + + public sealed class ImageLoaderEstimator : TrivialEstimator + { + private readonly ImageType _imageType; + + public ImageLoaderEstimator(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) + : this(env, new ImageLoaderTransform(env, imageFolder, columns)) + { + } + + public ImageLoaderEstimator(IHostEnvironment env, ImageLoaderTransform transformer) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderEstimator)), transformer) + { + _imageType = new ImageType(); + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var (input, output) in Transformer.Columns) + { + var col = inputSchema.FindColumn(input); + + if (input == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (!col.ItemType.IsText || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), col.GetTypeString()); + + result[output] = new SchemaShape.Column(output, SchemaShape.Column.VectorKind.Scalar, _imageType, false); + } + + return new SchemaShape(result.Values); + } + } + /// /// Transform which takes one or many columns of type and loads them as /// - public sealed class ImageLoaderTransform : OneToOneTransformBase + public sealed class ImageLoaderTransform : ITransformer, ICanSaveModel { public sealed class Column : OneToOneColumn { @@ -61,118 +122,245 @@ public sealed class Arguments : TransformInputBase internal const string UserName = "Image Loader Transform"; public const string LoaderSignature = "ImageLoaderTransform"; - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "IMGLOADT", - //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); - } - - private readonly ImageType _type; private readonly string _imageFolder; + private readonly (string input, string output)[] _columns; + private readonly IHost _host; - private const string RegistrationName = "ImageLoader"; + public IReadOnlyCollection<(string input, string output)> Columns => _columns.AsReadOnly(); - // Public constructor corresponding to SignatureDataTransform. - public ImageLoaderTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestIsText) + public ImageLoaderTransform(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) { - Host.AssertNonEmpty(Infos); - _imageFolder = args.ImageFolder; - Host.Assert(Infos.Length == Utils.Size(args.Column)); - _type = new ImageType(); - Metadata.Seal(); + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(ImageLoaderTransform)); + _host.CheckValueOrNull(imageFolder); + _host.CheckValue(columns, nameof(columns)); + + _imageFolder = imageFolder; + + var newNames = new HashSet(); + foreach (var column in columns) + { + _host.CheckNonEmpty(column.input, nameof(columns)); + _host.CheckNonEmpty(column.output, nameof(columns)); + + if (!newNames.Add(column.output)) + throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times"); + } + _columns = columns; } - private ImageLoaderTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestIsText) + public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + + ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** - // - _imageFolder = ctx.Reader.ReadString(); - _type = new ImageType(); - Metadata.Seal(); + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + // int: id of image folder + + int n = ctx.Reader.ReadInt32(); + var columns = new (string input, string output)[n]; + for (int i = 0; i < n; i++) + { + string output = ctx.LoadNonEmptyString(); + string input = ctx.LoadNonEmptyString(); + columns[i] = (input, output); + } + + string imageFolder = ctx.LoadStringOrNull(); + + return new ImageLoaderTransform(env, imageFolder, columns); } - public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + public ISchema GetOutputSchema(ISchema inputSchema) { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new ImageLoaderTransform(h, ctx, input)); + _host.CheckValue(inputSchema, nameof(inputSchema)); + + // Check that all the input columns are present and are scalar texts. + foreach (var (input, output) in _columns) + CheckInput(_host, inputSchema, input, out int col); + + return Transform(new EmptyDataView(_host, inputSchema)).Schema; } - public override void Save(ModelSaveContext ctx) + private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) + { + Contracts.AssertValueOrNull(ctx); + Contracts.AssertValue(inputSchema); + Contracts.AssertNonEmpty(input); + + if (!inputSchema.TryGetColumnIndex(input, out srcCol)) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (!inputSchema.GetColumnType(srcCol).IsText) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), inputSchema.GetColumnType(srcCol).ToString()); + } + + public IDataView Transform(IDataView input) => CreateDataTransform(input); + + public void Save(ModelSaveContext ctx) => SaveContents(ctx, _imageFolder, _columns); + + private static void SaveContents(ModelSaveContext ctx, string imageFolder, (string input, string output)[] columns) { - Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // - ctx.Writer.Write(_imageFolder); - SaveBase(ctx); + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + // int: id of image folder + + ctx.Writer.Write(columns.Length); + foreach (var (input, output) in columns) + { + ctx.SaveNonEmptyString(output); + ctx.SaveNonEmptyString(input); + } + ctx.SaveStringOrNull(imageFolder); + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "IMGLOADR", + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010002, + loaderSignature: LoaderSignature); + } + + public static IDataTransform Create(IHostEnvironment env, ImageLoaderTransform.Arguments args, IDataView data) + { + return new ImageLoaderTransform(env, args.ImageFolder, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + .CreateDataTransform(data); } - protected override ColumnType GetColumnTypeCore(int iinfo) + private IDataTransform CreateDataTransform(IDataView input) { - Host.Check(0 <= iinfo && iinfo < Infos.Length); - return _type; + _host.CheckValue(input, nameof(input)); + + var mapper = new Mapper(_host, _imageFolder, _columns, input.Schema); + return new RowToRowMapperTransform(_host, input, mapper); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + internal sealed class Mapper : IRowMapper { - Host.AssertValue(ch, nameof(ch)); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - var getSrc = GetSrcGetter(input, iinfo); - DvText src = default; - ValueGetter del = - (ref Bitmap dst) => + private readonly IHost _host; + private readonly string _imageFolder; + private readonly (string input, string output)[] _columns; + private readonly Dictionary _colMapNewToOld; + private readonly ISchema _inputSchema; + private readonly ImageType _imageType; + + public Mapper(IHostEnvironment env, string imageFolder, (string input, string output)[] columns, ISchema schema) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(Mapper)); + _host.CheckValueOrNull(imageFolder); + _host.CheckValue(columns, nameof(columns)); + _host.CheckValue(schema, nameof(schema)); + + _colMapNewToOld = new Dictionary(); + for (int i = 0; i < columns.Length; i++) { - if (dst != null) - { - dst.Dispose(); - dst = null; - } + CheckInput(_host, schema, columns[i].input, out int srcCol); + _colMapNewToOld.Add(i, srcCol); + } + + _imageFolder = imageFolder; + _columns = columns; + _inputSchema = schema; + _imageType = new ImageType(); + } + + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(schema, nameof(schema)); + + var xf = ImageLoaderTransform.Create(env, ctx); + return new Mapper(env, xf._imageFolder, xf._columns, schema); + } + + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + _host.Assert(input.Schema == _inputSchema); + var result = new Delegate[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + if (!activeOutput(i)) + continue; + int srcCol = _colMapNewToOld[i]; + result[i] = MakeGetter(input, i); + } + disposer = null; + return result; + } - getSrc(ref src); + private Delegate MakeGetter(IRow input, int iinfo) + { + _host.AssertValue(input); + _host.Assert(0 <= iinfo && iinfo < _columns.Length); - if (src.Length > 0) + var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + DvText src = default; + ValueGetter del = + (ref Bitmap dst) => { - // Catch exceptions and pass null through. Should also log failures... - try + if (dst != null) { - string path = src.ToString(); - if (!string.IsNullOrWhiteSpace(_imageFolder)) - path = Path.Combine(_imageFolder, path); - dst = new Bitmap(path); + dst.Dispose(); + dst = null; } - catch (Exception e) + + getSrc(ref src); + + if (src.Length > 0) { - // REVIEW: We catch everything since the documentation for new Bitmap(string) - // appears to be incorrect. When the file isn't found, it throws an ArgumentException, - // while the documentation says FileNotFoundException. Not sure what it will throw - // in other cases, like corrupted file, etc. - - // REVIEW : Log failures. - ch.Info(e.Message); - ch.Info(e.StackTrace); - dst = null; + // Catch exceptions and pass null through. Should also log failures... + try + { + string path = src.ToString(); + if (!string.IsNullOrWhiteSpace(_imageFolder)) + path = Path.Combine(_imageFolder, path); + dst = new Bitmap(path); + } + catch (Exception) + { + // REVIEW: We catch everything since the documentation for new Bitmap(string) + // appears to be incorrect. When the file isn't found, it throws an ArgumentException, + // while the documentation says FileNotFoundException. Not sure what it will throw + // in other cases, like corrupted file, etc. + + // REVIEW : Log failures. + dst = null; + } } - } - }; - return del; + }; + return del; + } + + public Func GetDependencies(Func activeOutput) + { + var active = new bool[_inputSchema.ColumnCount]; + foreach (var pair in _colMapNewToOld) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; + } + + public RowMapperColumnInfo[] GetOutputColumns() + => _columns.Select(x => new RowMapperColumnInfo(x.output, _imageType, null)).ToArray(); + + public void Save(ModelSaveContext ctx) => SaveContents(ctx, _imageFolder, _columns); } } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 2da2075728..52d2d3aef0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1407,8 +1407,8 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args, _positiveInstanceWeight = _args.PositiveInstanceWeight; OutputColumns = new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) }; } @@ -1426,7 +1426,8 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar) error(); - if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8 && labelCol.ItemKind != DataKind.BL) + + if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !labelCol.ItemType.IsBool) error(); } @@ -1434,17 +1435,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false); + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); } private static SchemaShape.Column MakeFeatureColumn(string featureColumn) { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false); + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias) diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index f775be92bd..c49593332b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -57,8 +57,8 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, _args = args; OutputColumns = new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true) }; } @@ -76,7 +76,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar) error(); - if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8) + if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8) error(); } @@ -84,17 +84,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true); + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); } private static SchemaShape.Column MakeFeatureColumn(string featureColumn) { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false); + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 163da10e7b..0b620959c3 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -61,7 +61,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur _args = args; OutputColumns = new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) }; } @@ -73,17 +73,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn) { if (weightColumn == null) return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false); + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } private static SchemaShape.Column MakeFeatureColumn(string featureColumn) { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false); + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } protected override LinearRegressionPredictor CreatePredictor(VBuffer[] weights, Float[] bias) diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index a12032400a..e2e13a6e72 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -5,9 +5,11 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Runtime.Model; using Microsoft.ML.TestFramework; using System.Drawing; using System.IO; +using System.Linq; using Xunit; using Xunit.Abstractions; @@ -19,6 +21,30 @@ public ImageTests(ITestOutputHelper output) : base(output) { } + [Fact] + public void TestEstimatorSaveLoad() + { + using (var env = new TlcEnvironment()) + { + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + + var loader = new ImageLoaderTransform(env, imageFolder, ("ImagePath", "ImageReal")); + using (var file = env.CreateTempFile()) + { + using (var fs = file.CreateWriteStream()) + loader.SaveTo(env, fs); + var loader2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); + var newCols = ((ImageLoaderTransform)loader2.LastTransformer).Columns; + var oldCols = loader.Columns; + Assert.True(newCols + .Zip(oldCols, (x, y) => x == y) + .All(x => x)); + } + } + } + [Fact] public void TestSaveImages() { @@ -27,7 +53,7 @@ public void TestSaveImages() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { @@ -36,13 +62,20 @@ public void TestSaveImages() ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + IDataView cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =100, ImageWidth = 100, Resizing = ImageResizerTransform.ResizingKind.IsoPad} } }, images); + var fh = env.CreateOutputFile("model.zip"); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(cropped)); + + cropped = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath("model.zip"); + cropped.Schema.TryGetColumnIndex("ImagePath", out int pathColumn); cropped.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); using (var cursor = cropped.GetRowCursor((x) => true)) @@ -73,7 +106,7 @@ public void TestGreyscaleTransformImages() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { @@ -126,7 +159,7 @@ public void TestBackAndForthConversion() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { From 1eeda489026bf7daa6fe5a15b8586d1f2add8d37 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 10:07:24 -0700 Subject: [PATCH 02/17] Temp commit --- .../ImageResizerTransform.cs | 135 ++++++++++++------ 1 file changed, 92 insertions(+), 43 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index dd1abc9181..a0a108d655 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -3,8 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Drawing; +using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -26,7 +29,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics /// /// Transform which takes one or many columns of and resize them to provided height and width. /// - public sealed class ImageResizerTransform : OneToOneTransformBase + public sealed class ImageResizerTransform : ITransformer, ICanSaveModel { public enum ResizingKind : byte { @@ -98,23 +101,30 @@ public class Arguments : TransformInputBase } /// - /// Extra information for each column (in addition to ColumnInfo). + /// Information for each column pair. /// - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; + public readonly int Width; public readonly int Height; public readonly ResizingKind Scale; public readonly Anchor Anchor; public readonly ColumnType Type; - public ColInfoEx(int width, int height, ResizingKind scale, Anchor anchor) + public ColumnInfo(string input, string output, int width, int height, ResizingKind scale, Anchor anchor) { + Contracts.CheckNonEmpty(input, nameof(input)); + Contracts.CheckNonEmpty(output, nameof(output)); Contracts.CheckUserArg(width > 0, nameof(Column.ImageWidth)); Contracts.CheckUserArg(height > 0, nameof(Column.ImageHeight)); Contracts.CheckUserArg(Enum.IsDefined(typeof(ResizingKind), scale), nameof(Column.Resizing)); Contracts.CheckUserArg(Enum.IsDefined(typeof(Anchor), anchor), nameof(Column.CropAnchor)); + Input = input; + Output = output; Width = width; Height = height; Scale = scale; @@ -141,57 +151,98 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ImageScaler"; - // This is parallel to Infos. - private readonly ColInfoEx[] _exes; + private readonly IHost _host; + private readonly ColumnInfo[] _columns; + + public IReadOnlyCollection Columns => _columns.AsReadOnly(); + + public ImageResizerTransform(IHostEnvironment env, string inputColumn, string outputColumn, + int imageWidth, int imageHeight, ResizingKind resizing = ResizingKind.IsoCrop, Anchor cropAnchor = Anchor.Center) + : this(env, new ColumnInfo(inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor)) + { + } + + public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + _host.CheckValue(columns, nameof(columns)); + + _columns = columns.ToArray(); + } // Public constructor corresponding to SignatureDataTransform. - public ImageResizerTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type") + public IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - _exes[i] = new ColInfoEx( + cols[i] = new ColumnInfo( + item.Source ?? item.Name, + item.Name, item.ImageWidth ?? args.ImageWidth, item.ImageHeight ?? args.ImageHeight, item.Resizing ?? args.Resizing, item.CropAnchor ?? args.CropAnchor); } - Metadata.Seal(); + + var transformer = new ImageResizerTransform(env, cols); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type") + public ImageResizerTransform(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** - // - // + // int: sizeof(float) + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + // for each added column // int: width // int: height // byte: scaling kind - Host.AssertNonEmpty(Infos); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + int cbFloat = ctx.Reader.ReadInt32(); + ch.CheckDecode(cbFloat == sizeof(Single)); + + int n = ctx.Reader.ReadInt32(); + + var names = new (string input, string output)[n]; + for (int i = 0; i < n; i++) + { + var output = ctx.LoadNonEmptyString(); + var input = ctx.LoadNonEmptyString(); + names[i] = (input, output); + } + + _columns = new ColumnInfo[n]; + for (int i = 0; i < n; i++) { int width = ctx.Reader.ReadInt32(); - Host.CheckDecode(width > 0); + _host.CheckDecode(width > 0); int height = ctx.Reader.ReadInt32(); - Host.CheckDecode(height > 0); + _host.CheckDecode(height > 0); var scale = (ResizingKind)ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); + _host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); var anchor = (Anchor)ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); - _exes[i] = new ColInfoEx(width, height, scale, anchor); + _host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); + _columns[i] = new ColumnInfo(names[i].input, names[i].output, width, height, scale, anchor); } - Metadata.Seal(); } public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) @@ -207,15 +258,13 @@ public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContex // *** Binary format *** // int: sizeof(Float) // - int cbFloat = ctx.Reader.ReadInt32(); - ch.CheckDecode(cbFloat == sizeof(Single)); return new ImageResizerTransform(h, ctx, input); }); } public override void Save(ModelSaveContext ctx) { - Host.CheckValue(ctx, nameof(ctx)); + _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); @@ -229,34 +278,34 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.Write(sizeof(Single)); SaveBase(ctx); - Host.Assert(_exes.Length == Infos.Length); - for (int i = 0; i < _exes.Length; i++) + _host.Assert(_columns.Length == Infos.Length); + for (int i = 0; i < _columns.Length; i++) { - var ex = _exes[i]; + var ex = _columns[i]; ctx.Writer.Write(ex.Width); ctx.Writer.Write(ex.Height); - Host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale); + _host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale); ctx.Writer.Write((byte)ex.Scale); - Host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor); + _host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor); ctx.Writer.Write((byte)ex.Anchor); } } protected override ColumnType GetColumnTypeCore(int iinfo) { - Host.Check(0 <= iinfo && iinfo < Infos.Length); - return _exes[iinfo].Type; + _host.Check(0 <= iinfo && iinfo < Infos.Length); + return _columns[iinfo].Type; } protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + _host.AssertValueOrNull(ch); + _host.AssertValue(input); + _host.Assert(0 <= iinfo && iinfo < Infos.Length); var src = default(Bitmap); var getSrc = GetSrcGetter(input, iinfo); - var ex = _exes[iinfo]; + var ex = _columns[iinfo]; disposer = () => @@ -361,7 +410,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou { g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); } - Host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); + _host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); }; return del; From 5e329100b8b7f9fafe4b8bcc9c47e15aaf88308d Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 11:00:21 -0700 Subject: [PATCH 03/17] Another temp checkin --- .../ImageResizerTransform.cs | 68 ++++++++++--------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index a0a108d655..b038de826c 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -218,7 +218,7 @@ public ImageResizerTransform(IHostEnvironment env, ModelLoadContext ctx) // byte: scaling kind int cbFloat = ctx.Reader.ReadInt32(); - ch.CheckDecode(cbFloat == sizeof(Single)); + _host.CheckDecode(cbFloat == sizeof(Single)); int n = ctx.Reader.ReadInt32(); @@ -245,56 +245,52 @@ public ImageResizerTransform(IHostEnvironment env, ModelLoadContext ctx) } } - public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", - ch => - { - // *** Binary format *** - // int: sizeof(Float) - // - return new ImageResizerTransform(h, ctx, input); - }); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + var transformer = new ImageResizerTransform(env, ctx); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - public override void Save(ModelSaveContext ctx) + public void Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) - // + // int: sizeof(float) + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + // for each added column // int: width // int: height // byte: scaling kind + ctx.Writer.Write(sizeof(Single)); - SaveBase(ctx); - _host.Assert(_columns.Length == Infos.Length); + ctx.Writer.Write(_columns.Length); for (int i = 0; i < _columns.Length; i++) { - var ex = _columns[i]; - ctx.Writer.Write(ex.Width); - ctx.Writer.Write(ex.Height); - _host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale); - ctx.Writer.Write((byte)ex.Scale); - _host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor); - ctx.Writer.Write((byte)ex.Anchor); + ctx.SaveNonEmptyString(_columns[i].Output); + ctx.SaveNonEmptyString(_columns[i].Input); } - } - protected override ColumnType GetColumnTypeCore(int iinfo) - { - _host.Check(0 <= iinfo && iinfo < Infos.Length); - return _columns[iinfo].Type; + foreach (var col in _columns) + { + ctx.Writer.Write(col.Width); + ctx.Writer.Write(col.Height); + _host.Assert((ResizingKind)(byte)col.Scale == col.Scale); + ctx.Writer.Write((byte)col.Scale); + _host.Assert((Anchor)(byte)col.Anchor == col.Anchor); + ctx.Writer.Write((byte)col.Anchor); + } } protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) @@ -415,5 +411,15 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou return del; } + + public ISchema GetOutputSchema(ISchema inputSchema) + { + throw new NotImplementedException(); + } + + public IDataView Transform(IDataView input) + { + throw new NotImplementedException(); + } } } From f0469e13c0729d03e0935e43e8125532515fc5d2 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 11:59:17 -0700 Subject: [PATCH 04/17] WIP 3 --- .../EntryPoints/ImageAnalytics.cs | 2 +- .../ImageResizerTransform.cs | 307 +++++++++++------- 2 files changed, 194 insertions(+), 115 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs index e38ff0daa0..a582b121d2 100644 --- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs +++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs @@ -29,7 +29,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input); - var xf = new ImageResizerTransform(h, input, input.Data); + var xf = ImageResizerTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index b038de826c..f99300ca85 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -172,7 +172,7 @@ public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns) } // Public constructor corresponding to SignatureDataTransform. - public IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); @@ -222,7 +222,7 @@ public ImageResizerTransform(IHostEnvironment env, ModelLoadContext ctx) int n = ctx.Reader.ReadInt32(); - var names = new (string input, string output)[n]; + var names = new(string input, string output)[n]; for (int i = 0; i < n; i++) { var output = ctx.LoadNonEmptyString(); @@ -255,9 +255,11 @@ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - public void Save(ModelSaveContext ctx) + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + + private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, ColumnInfo[] columns) { - _host.CheckValue(ctx, nameof(ctx)); + env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); @@ -275,151 +277,228 @@ public void Save(ModelSaveContext ctx) ctx.Writer.Write(sizeof(Single)); - ctx.Writer.Write(_columns.Length); - for (int i = 0; i < _columns.Length; i++) + ctx.Writer.Write(columns.Length); + for (int i = 0; i < columns.Length; i++) { - ctx.SaveNonEmptyString(_columns[i].Output); - ctx.SaveNonEmptyString(_columns[i].Input); + ctx.SaveNonEmptyString(columns[i].Output); + ctx.SaveNonEmptyString(columns[i].Input); } - foreach (var col in _columns) + foreach (var col in columns) { ctx.Writer.Write(col.Width); ctx.Writer.Write(col.Height); - _host.Assert((ResizingKind)(byte)col.Scale == col.Scale); + env.Assert((ResizingKind)(byte)col.Scale == col.Scale); ctx.Writer.Write((byte)col.Scale); - _host.Assert((Anchor)(byte)col.Anchor == col.Anchor); + env.Assert((Anchor)(byte)col.Anchor == col.Anchor); ctx.Writer.Write((byte)col.Anchor); } } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + public ISchema GetOutputSchema(ISchema inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + + return Transform(new EmptyDataView(_host, inputSchema)).Schema; + } + + public IDataView Transform(IDataView input) + { + var mapper = MakeRowMapper(input.Schema); + return new RowToRowMapperTransform(_host, input, mapper); + } + + private IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(_host, _columns, schema); + + private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) + { + Contracts.AssertValueOrNull(ctx); + Contracts.AssertValue(inputSchema); + Contracts.AssertNonEmpty(input); + + if (!inputSchema.TryGetColumnIndex(input, out srcCol)) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (inputSchema.GetColumnType(srcCol) is ImageType) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "image", inputSchema.GetColumnType(srcCol).ToString()); + } + + internal sealed class Mapper : IRowMapper { - _host.AssertValueOrNull(ch); - _host.AssertValue(input); - _host.Assert(0 <= iinfo && iinfo < Infos.Length); + private readonly IHost _host; + private readonly ColumnInfo[] _columns; + private readonly ISchema _inputSchema; + private readonly Dictionary _colMapNewToOld; - var src = default(Bitmap); - var getSrc = GetSrcGetter(input, iinfo); - var ex = _columns[iinfo]; + public Mapper(IHostEnvironment env, ColumnInfo[] columns, ISchema inputSchema) + { + Contracts.AssertValue(env); + _host = env.Register(nameof(Mapper)); + _host.AssertValue(columns); + _host.AssertValue(inputSchema); - disposer = - () => + _colMapNewToOld = new Dictionary(); + for (int i = 0; i < columns.Length; i++) { - if (src != null) - { - src.Dispose(); - src = null; - } - }; + CheckInput(_host, inputSchema, columns[i].Input, out int srcCol); + _colMapNewToOld.Add(i, srcCol); + } + _columns = columns; + _inputSchema = inputSchema; + } - ValueGetter del = - (ref Bitmap dst) => + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + _host.Assert(input.Schema == _inputSchema); + var result = new Delegate[_columns.Length]; + var disposers = new Action[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + if (!activeOutput(i)) + continue; + int srcCol = _colMapNewToOld[i]; + result[i] = MakeGetter(input, i, out disposers[i]); + } + disposer = () => { - if (dst != null) - dst.Dispose(); + foreach (var act in disposers) + act(); + }; + return result; + } - getSrc(ref src); - if (src == null || src.Height <= 0 || src.Width <= 0) - return; - if (src.Height == ex.Height && src.Width == ex.Width) - { - dst = src; - return; - } - - int sourceWidth = src.Width; - int sourceHeight = src.Height; - int sourceX = 0; - int sourceY = 0; - int destX = 0; - int destY = 0; - int destWidth = 0; - int destHeight = 0; - float aspect = 0; - float widthAspect = 0; - float heightAspect = 0; - - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; - - if (ex.Scale == ResizingKind.IsoPad) + public Func GetDependencies(Func activeOutput) + { + var active = new bool[_inputSchema.ColumnCount]; + foreach (var pair in _colMapNewToOld) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; + } + + public RowMapperColumnInfo[] GetOutputColumns() + => _columns.Select(x => new RowMapperColumnInfo(x.Output, x.Type, null)).ToArray(); + + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + + private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + _host.AssertValue(input); + _host.Assert(0 <= iinfo && iinfo < _columns.Length); + + var src = default(Bitmap); + var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + var ex = _columns[iinfo]; + + disposer = + () => { - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; - if (heightAspect < widthAspect) + if (src != null) { - aspect = heightAspect; - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); + src.Dispose(); + src = null; } - else + }; + + ValueGetter del = + (ref Bitmap dst) => + { + if (dst != null) + dst.Dispose(); + + getSrc(ref src); + if (src == null || src.Height <= 0 || src.Width <= 0) + return; + if (src.Height == ex.Height && src.Width == ex.Width) { - aspect = widthAspect; - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + dst = src; + return; } - destWidth = (int)(sourceWidth * aspect); - destHeight = (int)(sourceHeight * aspect); - } - else - { - if (heightAspect < widthAspect) + int sourceWidth = src.Width; + int sourceHeight = src.Height; + int sourceX = 0; + int sourceY = 0; + int destX = 0; + int destY = 0; + int destWidth = 0; + int destHeight = 0; + float aspect = 0; + float widthAspect = 0; + float heightAspect = 0; + + widthAspect = (float)ex.Width / sourceWidth; + heightAspect = (float)ex.Height / sourceHeight; + + if (ex.Scale == ResizingKind.IsoPad) { - aspect = widthAspect; - switch (ex.Anchor) + widthAspect = (float)ex.Width / sourceWidth; + heightAspect = (float)ex.Height / sourceHeight; + if (heightAspect < widthAspect) { - case Anchor.Top: - destY = 0; - break; - case Anchor.Bottom: - destY = (int)(ex.Height - (sourceHeight * aspect)); - break; - default: - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); - break; + aspect = heightAspect; + destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); } + else + { + aspect = widthAspect; + destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + } + + destWidth = (int)(sourceWidth * aspect); + destHeight = (int)(sourceHeight * aspect); } else { - aspect = heightAspect; - switch (ex.Anchor) + if (heightAspect < widthAspect) { - case Anchor.Left: - destX = 0; - break; - case Anchor.Right: - destX = (int)(ex.Width - (sourceWidth * aspect)); - break; - default: - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); - break; + aspect = widthAspect; + switch (ex.Anchor) + { + case Anchor.Top: + destY = 0; + break; + case Anchor.Bottom: + destY = (int)(ex.Height - (sourceHeight * aspect)); + break; + default: + destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + break; + } + } + else + { + aspect = heightAspect; + switch (ex.Anchor) + { + case Anchor.Left: + destX = 0; + break; + case Anchor.Right: + destX = (int)(ex.Width - (sourceWidth * aspect)); + break; + default: + destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); + break; + } } - } - - destWidth = (int)(sourceWidth * aspect); - destHeight = (int)(sourceHeight * aspect); - } - dst = new Bitmap(ex.Width, ex.Height); - var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); - var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); - using (var g = Graphics.FromImage(dst)) - { - g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); - } - _host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); - }; - return del; - } + destWidth = (int)(sourceWidth * aspect); + destHeight = (int)(sourceHeight * aspect); + } + dst = new Bitmap(ex.Width, ex.Height); + var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); + var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); + using (var g = Graphics.FromImage(dst)) + { + g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); + } + _host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); + }; - public ISchema GetOutputSchema(ISchema inputSchema) - { - throw new NotImplementedException(); + return del; + } } - public IDataView Transform(IDataView input) - { - throw new NotImplementedException(); - } } } From d2073eb58d6b1810422ecb1f02ee97ad76049e6c Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 10:07:24 -0700 Subject: [PATCH 05/17] Added image resizers --- .../EntryPoints/ImageAnalytics.cs | 2 +- .../ImageLoaderTransform.cs | 70 +-- .../ImageResizerTransform.cs | 498 ++++++++++++------ test/Microsoft.ML.Tests/ImagesTests.cs | 21 +- 4 files changed, 393 insertions(+), 198 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs index e38ff0daa0..a582b121d2 100644 --- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs +++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs @@ -29,7 +29,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input); - var xf = new ImageResizerTransform(h, input, input.Data); + var xf = ImageResizerTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index d571a981b3..8bd7f2d600 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -49,41 +49,6 @@ protected TrivialEstimator(IHost host, TTransformer transformer) public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); } - public sealed class ImageLoaderEstimator : TrivialEstimator - { - private readonly ImageType _imageType; - - public ImageLoaderEstimator(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) - : this(env, new ImageLoaderTransform(env, imageFolder, columns)) - { - } - - public ImageLoaderEstimator(IHostEnvironment env, ImageLoaderTransform transformer) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderEstimator)), transformer) - { - _imageType = new ImageType(); - } - - public override SchemaShape GetOutputSchema(SchemaShape inputSchema) - { - Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); - foreach (var (input, output) in Transformer.Columns) - { - var col = inputSchema.FindColumn(input); - - if (input == null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); - if (!col.ItemType.IsText || col.Kind != SchemaShape.Column.VectorKind.Scalar) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), col.GetTypeString()); - - result[output] = new SchemaShape.Column(output, SchemaShape.Column.VectorKind.Scalar, _imageType, false); - } - - return new SchemaShape(result.Values); - } - } - /// /// Transform which takes one or many columns of type and loads them as /// @@ -363,4 +328,39 @@ public RowMapperColumnInfo[] GetOutputColumns() public void Save(ModelSaveContext ctx) => SaveContents(ctx, _imageFolder, _columns); } } + + public sealed class ImageLoaderEstimator : TrivialEstimator + { + private readonly ImageType _imageType; + + public ImageLoaderEstimator(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) + : this(env, new ImageLoaderTransform(env, imageFolder, columns)) + { + } + + public ImageLoaderEstimator(IHostEnvironment env, ImageLoaderTransform transformer) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderEstimator)), transformer) + { + _imageType = new ImageType(); + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var (input, output) in Transformer.Columns) + { + var col = inputSchema.FindColumn(input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (!col.ItemType.IsText || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), col.GetTypeString()); + + result[output] = new SchemaShape.Column(output, SchemaShape.Column.VectorKind.Scalar, _imageType, false); + } + + return new SchemaShape(result.Values); + } + } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index dd1abc9181..bc662eea2c 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -3,8 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Drawing; +using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -20,13 +23,19 @@ [assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform), ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(ImageResizerTransform), null, typeof(SignatureLoadModel), + ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageResizerTransform.Mapper), null, typeof(SignatureLoadRowMapper), + ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] + namespace Microsoft.ML.Runtime.ImageAnalytics { // REVIEW: Rewrite as LambdaTransform to simplify. /// /// Transform which takes one or many columns of and resize them to provided height and width. /// - public sealed class ImageResizerTransform : OneToOneTransformBase + public sealed class ImageResizerTransform : ITransformer, ICanSaveModel { public enum ResizingKind : byte { @@ -98,23 +107,30 @@ public class Arguments : TransformInputBase } /// - /// Extra information for each column (in addition to ColumnInfo). + /// Information for each column pair. /// - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; + public readonly int Width; public readonly int Height; public readonly ResizingKind Scale; public readonly Anchor Anchor; public readonly ColumnType Type; - public ColInfoEx(int width, int height, ResizingKind scale, Anchor anchor) + public ColumnInfo(string input, string output, int width, int height, ResizingKind scale, Anchor anchor) { + Contracts.CheckNonEmpty(input, nameof(input)); + Contracts.CheckNonEmpty(output, nameof(output)); Contracts.CheckUserArg(width > 0, nameof(Column.ImageWidth)); Contracts.CheckUserArg(height > 0, nameof(Column.ImageHeight)); Contracts.CheckUserArg(Enum.IsDefined(typeof(ResizingKind), scale), nameof(Column.Resizing)); Contracts.CheckUserArg(Enum.IsDefined(typeof(Anchor), anchor), nameof(Column.CropAnchor)); + Input = input; + Output = output; Width = width; Height = height; Scale = scale; @@ -141,230 +157,404 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ImageScaler"; - // This is parallel to Infos. - private readonly ColInfoEx[] _exes; + private readonly IHost _host; + private readonly ColumnInfo[] _columns; + + public IReadOnlyCollection Columns => _columns.AsReadOnly(); + + public ImageResizerTransform(IHostEnvironment env, string inputColumn, string outputColumn, + int imageWidth, int imageHeight, ResizingKind resizing = ResizingKind.IsoCrop, Anchor cropAnchor = Anchor.Center) + : this(env, new ColumnInfo(inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor)) + { + } + + public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + _host.CheckValue(columns, nameof(columns)); + + _columns = columns.ToArray(); + } // Public constructor corresponding to SignatureDataTransform. - public ImageResizerTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type") + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - _exes[i] = new ColInfoEx( + cols[i] = new ColumnInfo( + item.Source ?? item.Name, + item.Name, item.ImageWidth ?? args.ImageWidth, item.ImageHeight ?? args.ImageHeight, item.Resizing ?? args.Resizing, item.CropAnchor ?? args.CropAnchor); } - Metadata.Seal(); + + var transformer = new ImageResizerTransform(env, cols); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type") + public ImageResizerTransform(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** - // - // + // int: sizeof(float) + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + // for each added column // int: width // int: height // byte: scaling kind - Host.AssertNonEmpty(Infos); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + int cbFloat = ctx.Reader.ReadInt32(); + _host.CheckDecode(cbFloat == sizeof(Single)); + + int n = ctx.Reader.ReadInt32(); + + var names = new (string input, string output)[n]; + for (int i = 0; i < n; i++) + { + var output = ctx.LoadNonEmptyString(); + var input = ctx.LoadNonEmptyString(); + names[i] = (input, output); + } + + _columns = new ColumnInfo[n]; + for (int i = 0; i < n; i++) { int width = ctx.Reader.ReadInt32(); - Host.CheckDecode(width > 0); + _host.CheckDecode(width > 0); int height = ctx.Reader.ReadInt32(); - Host.CheckDecode(height > 0); + _host.CheckDecode(height > 0); var scale = (ResizingKind)ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); + _host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); var anchor = (Anchor)ctx.Reader.ReadByte(); - Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); - _exes[i] = new ColInfoEx(width, height, scale, anchor); + _host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); + _columns[i] = new ColumnInfo(names[i].input, names[i].output, width, height, scale, anchor); } - Metadata.Seal(); } - public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", - ch => - { - // *** Binary format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - ch.CheckDecode(cbFloat == sizeof(Single)); - return new ImageResizerTransform(h, ctx, input); - }); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + var transformer = new ImageResizerTransform(env, ctx); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - public override void Save(ModelSaveContext ctx) + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + + private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, ColumnInfo[] columns) { - Host.CheckValue(ctx, nameof(ctx)); + env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) - // + // int: sizeof(float) + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + // for each added column // int: width // int: height // byte: scaling kind - ctx.Writer.Write(sizeof(Single)); - SaveBase(ctx); - Host.Assert(_exes.Length == Infos.Length); - for (int i = 0; i < _exes.Length; i++) + ctx.Writer.Write(sizeof(float)); + + ctx.Writer.Write(columns.Length); + for (int i = 0; i < columns.Length; i++) { - var ex = _exes[i]; - ctx.Writer.Write(ex.Width); - ctx.Writer.Write(ex.Height); - Host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale); - ctx.Writer.Write((byte)ex.Scale); - Host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor); - ctx.Writer.Write((byte)ex.Anchor); + ctx.SaveNonEmptyString(columns[i].Output); + ctx.SaveNonEmptyString(columns[i].Input); } + + foreach (var col in columns) + { + ctx.Writer.Write(col.Width); + ctx.Writer.Write(col.Height); + env.Assert((ResizingKind)(byte)col.Scale == col.Scale); + ctx.Writer.Write((byte)col.Scale); + env.Assert((Anchor)(byte)col.Anchor == col.Anchor); + ctx.Writer.Write((byte)col.Anchor); + } + } + + public ISchema GetOutputSchema(ISchema inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + + // Check that all the input columns are present and are images. + foreach (var column in _columns) + CheckInput(_host, inputSchema, column.Input, out int col); + + return Transform(new EmptyDataView(_host, inputSchema)).Schema; } - protected override ColumnType GetColumnTypeCore(int iinfo) + public IDataView Transform(IDataView input) { - Host.Check(0 <= iinfo && iinfo < Infos.Length); - return _exes[iinfo].Type; + var mapper = MakeRowMapper(input.Schema); + return new RowToRowMapperTransform(_host, input, mapper); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + private IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(_host, _columns, schema); + + private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + Contracts.AssertValueOrNull(ctx); + Contracts.AssertValue(inputSchema); + Contracts.AssertNonEmpty(input); + + if (!inputSchema.TryGetColumnIndex(input, out srcCol)) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (!(inputSchema.GetColumnType(srcCol) is ImageType)) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "image", inputSchema.GetColumnType(srcCol).ToString()); + } - var src = default(Bitmap); - var getSrc = GetSrcGetter(input, iinfo); - var ex = _exes[iinfo]; + internal sealed class Mapper : IRowMapper + { + private readonly IHost _host; + private readonly ColumnInfo[] _columns; + private readonly ISchema _inputSchema; + private readonly Dictionary _colMapNewToOld; + + public Mapper(IHostEnvironment env, ColumnInfo[] columns, ISchema inputSchema) + { + Contracts.AssertValue(env); + _host = env.Register(nameof(Mapper)); + _host.AssertValue(columns); + _host.AssertValue(inputSchema); - disposer = - () => + _colMapNewToOld = new Dictionary(); + for (int i = 0; i < columns.Length; i++) { - if (src != null) - { - src.Dispose(); - src = null; - } - }; + CheckInput(_host, inputSchema, columns[i].Input, out int srcCol); + _colMapNewToOld.Add(i, srcCol); + } + _columns = columns; + _inputSchema = inputSchema; + } - ValueGetter del = - (ref Bitmap dst) => + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + _host.Assert(input.Schema == _inputSchema); + var result = new Delegate[_columns.Length]; + var disposers = new Action[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + if (!activeOutput(i)) + continue; + int srcCol = _colMapNewToOld[i]; + result[i] = MakeGetter(input, i, out disposers[i]); + } + disposer = () => { - if (dst != null) - dst.Dispose(); + foreach (var act in disposers) + act(); + }; + return result; + } - getSrc(ref src); - if (src == null || src.Height <= 0 || src.Width <= 0) - return; - if (src.Height == ex.Height && src.Width == ex.Width) - { - dst = src; - return; - } - - int sourceWidth = src.Width; - int sourceHeight = src.Height; - int sourceX = 0; - int sourceY = 0; - int destX = 0; - int destY = 0; - int destWidth = 0; - int destHeight = 0; - float aspect = 0; - float widthAspect = 0; - float heightAspect = 0; - - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; - - if (ex.Scale == ResizingKind.IsoPad) + public Func GetDependencies(Func activeOutput) + { + var active = new bool[_inputSchema.ColumnCount]; + foreach (var pair in _colMapNewToOld) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; + } + + public RowMapperColumnInfo[] GetOutputColumns() + => _columns.Select(x => new RowMapperColumnInfo(x.Output, x.Type, null)).ToArray(); + + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(inputSchema, nameof(inputSchema)); + var transformer = new ImageResizerTransform(env, ctx); + return transformer.MakeRowMapper(inputSchema); + } + + private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + _host.AssertValue(input); + _host.Assert(0 <= iinfo && iinfo < _columns.Length); + + var src = default(Bitmap); + var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + var ex = _columns[iinfo]; + + disposer = + () => { - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; - if (heightAspect < widthAspect) + if (src != null) { - aspect = heightAspect; - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); + src.Dispose(); + src = null; } - else + }; + + ValueGetter del = + (ref Bitmap dst) => + { + if (dst != null) + dst.Dispose(); + + getSrc(ref src); + if (src == null || src.Height <= 0 || src.Width <= 0) + return; + if (src.Height == ex.Height && src.Width == ex.Width) { - aspect = widthAspect; - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + dst = src; + return; } - destWidth = (int)(sourceWidth * aspect); - destHeight = (int)(sourceHeight * aspect); - } - else - { - if (heightAspect < widthAspect) + int sourceWidth = src.Width; + int sourceHeight = src.Height; + int sourceX = 0; + int sourceY = 0; + int destX = 0; + int destY = 0; + int destWidth = 0; + int destHeight = 0; + float aspect = 0; + float widthAspect = 0; + float heightAspect = 0; + + widthAspect = (float)ex.Width / sourceWidth; + heightAspect = (float)ex.Height / sourceHeight; + + if (ex.Scale == ResizingKind.IsoPad) { - aspect = widthAspect; - switch (ex.Anchor) + widthAspect = (float)ex.Width / sourceWidth; + heightAspect = (float)ex.Height / sourceHeight; + if (heightAspect < widthAspect) { - case Anchor.Top: - destY = 0; - break; - case Anchor.Bottom: - destY = (int)(ex.Height - (sourceHeight * aspect)); - break; - default: - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); - break; + aspect = heightAspect; + destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); } + else + { + aspect = widthAspect; + destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + } + + destWidth = (int)(sourceWidth * aspect); + destHeight = (int)(sourceHeight * aspect); } else { - aspect = heightAspect; - switch (ex.Anchor) + if (heightAspect < widthAspect) { - case Anchor.Left: - destX = 0; - break; - case Anchor.Right: - destX = (int)(ex.Width - (sourceWidth * aspect)); - break; - default: - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); - break; + aspect = widthAspect; + switch (ex.Anchor) + { + case Anchor.Top: + destY = 0; + break; + case Anchor.Bottom: + destY = (int)(ex.Height - (sourceHeight * aspect)); + break; + default: + destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + break; + } } + else + { + aspect = heightAspect; + switch (ex.Anchor) + { + case Anchor.Left: + destX = 0; + break; + case Anchor.Right: + destX = (int)(ex.Width - (sourceWidth * aspect)); + break; + default: + destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); + break; + } + } + + destWidth = (int)(sourceWidth * aspect); + destHeight = (int)(sourceHeight * aspect); } + dst = new Bitmap(ex.Width, ex.Height); + var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); + var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); + using (var g = Graphics.FromImage(dst)) + { + g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); + } + _host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); + }; - destWidth = (int)(sourceWidth * aspect); - destHeight = (int)(sourceHeight * aspect); - } - dst = new Bitmap(ex.Width, ex.Height); - var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); - var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); - using (var g = Graphics.FromImage(dst)) - { - g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); - } - Host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); - }; + return del; + } + } + } + + public sealed class ImageResizerEstimator : TrivialEstimator + { + public ImageResizerEstimator(IHostEnvironment env, string inputColumn, string outputColumn, + int imageWidth, int imageHeight, ImageResizerTransform.ResizingKind resizing = ImageResizerTransform.ResizingKind.IsoCrop, ImageResizerTransform.Anchor cropAnchor = ImageResizerTransform.Anchor.Center) + : this(env, new ImageResizerTransform(env, inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor)) + { + } + + public ImageResizerEstimator(IHostEnvironment env, params ImageResizerTransform.ColumnInfo[] columns) + : this(env, new ImageResizerTransform(env, columns)) + { + } + + public ImageResizerEstimator(IHostEnvironment env, ImageResizerTransform transformer) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageResizerEstimator)), transformer) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in Transformer.Columns) + { + var col = inputSchema.FindColumn(colInfo.Input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString()); + + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Scalar, colInfo.Type, false); + } - return del; + return new SchemaShape(result.Values); } } } diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index e2e13a6e72..d8795e10e7 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -30,14 +30,19 @@ public void TestEstimatorSaveLoad() var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var loader = new ImageLoaderTransform(env, imageFolder, ("ImagePath", "ImageReal")); + var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) + .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)); + + var model = pipe.Fit(data); + using (var file = env.CreateTempFile()) { using (var fs = file.CreateWriteStream()) - loader.SaveTo(env, fs); - var loader2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); - var newCols = ((ImageLoaderTransform)loader2.LastTransformer).Columns; - var oldCols = loader.Columns; + model.SaveTo(env, fs); + var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); + + var newCols = ((ImageLoaderTransform)model2.First()).Columns; + var oldCols = ((ImageLoaderTransform)model.First()).Columns; Assert.True(newCols .Zip(oldCols, (x, y) => x == y) .All(x => x)); @@ -62,7 +67,7 @@ public void TestSaveImages() ImageFolder = imageFolder }, data); - IDataView cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + IDataView cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =100, ImageWidth = 100, Resizing = ImageResizerTransform.ResizingKind.IsoPad} @@ -114,7 +119,7 @@ public void TestGreyscaleTransformImages() }, ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} @@ -167,7 +172,7 @@ public void TestBackAndForthConversion() }, ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} From f05ec52dcb8f92f377843a360ed42764ad3c071b Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 15:36:13 -0700 Subject: [PATCH 06/17] Pixel extractor --- .../EntryPoints/ImageAnalytics.cs | 2 +- .../ImagePixelExtractorTransform.cs | 683 ++++++++++++------ .../ImageResizerTransform.cs | 8 +- test/Microsoft.ML.Tests/ImagesTests.cs | 24 +- 4 files changed, 478 insertions(+), 239 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs index a582b121d2..c21822be21 100644 --- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs +++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs @@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, I public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment env, ImagePixelExtractorTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImagePixelExtractorTransform", input); - var xf = new ImagePixelExtractorTransform(h, input, input.Data); + var xf = ImagePixelExtractorTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index de0aa98124..af24cbf560 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -3,8 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Drawing; +using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -13,10 +16,16 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform), ImagePixelExtractorTransform.UserName, "ImagePixelExtractorTransform", "ImagePixelExtractor")] -[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform), + ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadModel), + ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImagePixelExtractorTransform.Mapper), null, typeof(SignatureLoadRowMapper), ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics @@ -25,7 +34,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics /// /// Transform which takes one or many columns of and convert them into vector representation. /// - public sealed class ImagePixelExtractorTransform : OneToOneTransformBase + public sealed class ImagePixelExtractorTransform : ITransformer, ICanSaveModel { public class Column : OneToOneColumn { @@ -110,24 +119,28 @@ public class Arguments : TransformInputBase /// Which color channels are extracted. Note that these values are serialized so should not be modified. /// [Flags] - private enum ColorBits : byte + public enum ColorBits : byte { Alpha = 0x01, Red = 0x02, Green = 0x04, Blue = 0x08, + Rgb = Red | Green | Blue, All = Alpha | Red | Green | Blue } - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; + public readonly ColorBits Colors; public readonly byte Planes; public readonly bool Convert; - public readonly Single Offset; - public readonly Single Scale; + public readonly float Offset; + public readonly float Scale; public readonly bool Interleave; public bool Alpha { get { return (Colors & ColorBits.Alpha) != 0; } } @@ -135,8 +148,14 @@ private sealed class ColInfoEx public bool Green { get { return (Colors & ColorBits.Green) != 0; } } public bool Blue { get { return (Colors & ColorBits.Blue) != 0; } } - public ColInfoEx(Column item, Arguments args) + internal ColumnInfo(Column item, Arguments args) { + Contracts.CheckValue(item, nameof(item)); + Contracts.CheckValue(args, nameof(args)); + + Input = item.Source ?? item.Name; + Output = item.Name; + if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; } if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; } if (item.UseGreen ?? args.UseGreen) { Colors |= ColorBits.Green; Planes++; } @@ -160,10 +179,57 @@ public ColInfoEx(Column item, Arguments args) } } - public ColInfoEx(ModelLoadContext ctx) + public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false) + : this(input, output, colors, interleave, false, 1f, 0f) + { + } + + public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false, float scale = 1f, float offset = 0f) + : this(input, output, colors, interleave, true, scale, offset) + { + } + + private ColumnInfo(string input, string output, ColorBits colors, bool interleave, bool convert, float scale, float offset) + { + Contracts.CheckNonEmpty(input, nameof(input)); + Contracts.CheckNonEmpty(output, nameof(output)); + + Input = input; + Output = output; + Colors = colors; + + if ((Colors & ColorBits.Alpha) == ColorBits.Alpha) Planes++; + if ((Colors & ColorBits.Red) == ColorBits.Red) Planes++; + if ((Colors & ColorBits.Green) == ColorBits.Green) Planes++; + if ((Colors & ColorBits.Blue) == ColorBits.Blue) Planes++; + Contracts.CheckParam(Planes > 0, nameof(colors), "Need to use at least one color plane"); + + Interleave = interleave; + + Convert = convert; + if (!Convert) + { + Offset = 0; + Scale = 1; + } + else + { + Offset = offset; + Scale = scale; + Contracts.CheckParam(FloatUtils.IsFinite(Offset), nameof(offset)); + Contracts.CheckParam(FloatUtils.IsFiniteNonZero(Scale), nameof(scale)); + } + } + + internal ColumnInfo(string input, string output, ModelLoadContext ctx) { + Contracts.AssertNonEmpty(input); + Contracts.AssertNonEmpty(output); Contracts.AssertValue(ctx); + Input = input; + Output = output; + // *** Binary format *** // byte: colors // byte: convert @@ -237,305 +303,468 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ImagePixelExtractor"; - private readonly ColInfoEx[] _exes; - private readonly VectorType[] _types; + private readonly IHost _host; + private readonly ColumnInfo[] _columns; + + public IReadOnlyCollection Columns => _columns.AsReadOnly(); - // Public constructor corresponding to SignatureDataTransform. - public ImagePixelExtractorTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, - t => t is ImageType ? null : "Expected Image type") + public ImagePixelExtractorTransform(IHostEnvironment env, string inputColumn, string outputColumn, + ColorBits colors = ColorBits.Rgb, bool interleave = false) + : this(env, new ColumnInfo(inputColumn, outputColumn, colors, interleave)) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + } + + public ImagePixelExtractorTransform(IHostEnvironment env, params ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + _host.CheckValue(columns, nameof(columns)); + + _columns = columns.ToArray(); + } + + // SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) + var columns = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < columns.Length; i++) { var item = args.Column[i]; - _exes[i] = new ColInfoEx(item, args); + columns[i] = new ColumnInfo(item, args); } - _types = ConstructTypes(true); + var transformer = new ImagePixelExtractorTransform(env, columns); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type") + public ImagePixelExtractorTransform(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** - // - // + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + // foreach added column - // ColInfoEx - Host.AssertNonEmpty(Infos); - _exes = new ColInfoEx[Infos.Length]; - for (int i = 0; i < _exes.Length; i++) - _exes[i] = new ColInfoEx(ctx); + // ColInfo + + int n = ctx.Reader.ReadInt32(); + + var names = new (string input, string output)[n]; + for (int i = 0; i < n; i++) + { + var output = ctx.LoadNonEmptyString(); + var input = ctx.LoadNonEmptyString(); + names[i] = (input, output); + } - _types = ConstructTypes(false); + _columns = new ColumnInfo[n]; + for (int i = 0; i < _columns.Length; i++) + _columns[i] = new ColumnInfo(names[i].input, names[i].output, ctx); } - public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); - return h.Apply("Loading Model", - ch => - { - // *** Binary format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - ch.CheckDecode(cbFloat == sizeof(Single)); - return new ImagePixelExtractorTransform(h, ctx, input); - }); + var transformer = new ImagePixelExtractorTransform(env, ctx); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - public override void Save(ModelSaveContext ctx) + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + + private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, ColumnInfo[] columns) { - Host.CheckValue(ctx, nameof(ctx)); + Contracts.AssertValue(env); + env.CheckValue(ctx, nameof(ctx)); + Contracts.AssertValue(columns); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) - // - // foreach added column - // ColInfoEx - ctx.Writer.Write(sizeof(Single)); - SaveBase(ctx); + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name - Host.Assert(_exes.Length == Infos.Length); - for (int i = 0; i < _exes.Length; i++) - _exes[i].Save(ctx); - } + // foreach added column + // ColInfo - private VectorType[] ConstructTypes(bool user) - { - var types = new VectorType[Infos.Length]; - for (int i = 0; i < Infos.Length; i++) + ctx.Writer.Write(columns.Length); + for (int i = 0; i < columns.Length; i++) { - var info = Infos[i]; - var ex = _exes[i]; - Host.Assert(ex.Planes > 0); - - var type = Source.Schema.GetColumnType(info.Source) as ImageType; - Host.Assert(type != null); - if (type.Height <= 0 || type.Width <= 0) - { - // REVIEW: Could support this case by making the destination column be variable sized. - // However, there's no mechanism to communicate the dimensions through with the pixel data. - string name = Source.Schema.GetColumnName(info.Source); - throw user ? - Host.ExceptUserArg(nameof(Arguments.Column), "Column '{0}' does not have known size", name) : - Host.Except("Column '{0}' does not have known size", name); - } - int height = type.Height; - int width = type.Width; - Host.Assert(height > 0); - Host.Assert(width > 0); - Host.Assert((long)height * width <= int.MaxValue / 4); - - if (ex.Interleave) - types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, height, width, ex.Planes); - else - types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, ex.Planes, height, width); + ctx.SaveNonEmptyString(columns[i].Output); + ctx.SaveNonEmptyString(columns[i].Input); } - Metadata.Seal(); - return types; + + for (int i = 0; i < columns.Length; i++) + columns[i].Save(ctx); } - protected override ColumnType GetColumnTypeCore(int iinfo) + private IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(_host, _columns, schema); + + private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return _types[iinfo]; + Contracts.AssertValueOrNull(ctx); + Contracts.AssertValue(inputSchema); + Contracts.AssertNonEmpty(input); + + if (!inputSchema.TryGetColumnIndex(input, out srcCol)) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + var imageType = inputSchema.GetColumnType(srcCol) as ImageType; + if (imageType == null) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "image", inputSchema.GetColumnType(srcCol).ToString()); + if (imageType.Height <= 0 || imageType.Width <= 0) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "known-size image", "unknown-size image"); + if ((long)imageType.Height * imageType.Width > int.MaxValue / 4) + throw ctx.Except("Image dimensions are too large"); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + public ISchema GetOutputSchema(ISchema inputSchema) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + _host.CheckValue(inputSchema, nameof(inputSchema)); + + // Check that all the input columns are present and are images of known size. + foreach (var column in _columns) + CheckInput(_host, inputSchema, column.Input, out int col); - if (_exes[iinfo].Convert) - return GetGetterCore(input, iinfo, out disposer); - return GetGetterCore(input, iinfo, out disposer); + return Transform(new EmptyDataView(_host, inputSchema)).Schema; } - //REVIEW Rewrite it to where TValue : IConvertible - private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) + public IDataView Transform(IDataView input) { - var type = _types[iinfo]; - Host.Assert(type.DimCount == 3); + _host.CheckValue(input, nameof(input)); - var ex = _exes[iinfo]; + var mapper = MakeRowMapper(input.Schema); + return new RowToRowMapperTransform(_host, input, mapper); + } + + internal sealed class Mapper : IRowMapper + { + private readonly IHost _host; + private readonly ColumnInfo[] _columns; + private readonly VectorType[] _types; + private readonly ISchema _inputSchema; + private readonly Dictionary _colMapNewToOld; - int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0); - int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1); - int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2); + public Mapper(IHostEnvironment env, ColumnInfo[] columns, ISchema inputSchema) + { + Contracts.AssertValue(env); + _host = env.Register(nameof(Mapper)); + _host.AssertValue(columns); + _host.AssertValue(inputSchema); - int size = type.ValueCount; - Host.Assert(size > 0); - Host.Assert(size == planes * height * width); - int cpix = height * width; + _colMapNewToOld = new Dictionary(); + for (int i = 0; i < columns.Length; i++) + { + CheckInput(_host, inputSchema, columns[i].Input, out int srcCol); + _colMapNewToOld.Add(i, srcCol); + } - var getSrc = GetSrcGetter(input, iinfo); - var src = default(Bitmap); + _columns = columns; + _inputSchema = inputSchema; + _types = ConstructTypes(); + } - disposer = - () => + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + _host.Assert(input.Schema == _inputSchema); + var result = new Delegate[_columns.Length]; + var disposers = new Action[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) { - if (src != null) - { - src.Dispose(); - src = null; - } + if (!activeOutput(i)) + continue; + int srcCol = _colMapNewToOld[i]; + result[i] = MakeGetter(input, i, out disposers[i]); + } + disposer = () => + { + foreach (var act in disposers) + act(); }; + return result; + } - return - (ref VBuffer dst) => - { - getSrc(ref src); - Contracts.AssertValueOrNull(src); + public Func GetDependencies(Func activeOutput) + { + var active = new bool[_inputSchema.ColumnCount]; + foreach (var pair in _colMapNewToOld) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; + } - if (src == null) - { - dst = new VBuffer(size, 0, dst.Values, dst.Indices); - return; - } + public RowMapperColumnInfo[] GetOutputColumns() + => _columns.Select((x, idx) => new RowMapperColumnInfo(x.Output, _types[idx], null)).ToArray(); + + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); - Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb); - Host.Check(src.Height == height && src.Width == width); + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(inputSchema, nameof(inputSchema)); + var transformer = new ImagePixelExtractorTransform(env, ctx); + return transformer.MakeRowMapper(inputSchema); + } + + private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + _host.AssertValue(input); + _host.Assert(0 <= iinfo && iinfo < _columns.Length); + + if (_columns[iinfo].Convert) + return GetGetterCore(input, iinfo, out disposer); + return GetGetterCore(input, iinfo, out disposer); + } - var values = dst.Values; - if (Utils.Size(values) < size) - values = new TValue[size]; + //REVIEW Rewrite it to where TValue : IConvertible + private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) + { + var type = _types[iinfo]; + _host.Assert(type.DimCount == 3); - Single offset = ex.Offset; - Single scale = ex.Scale; - Host.Assert(scale != 0); + var ex = _columns[iinfo]; - var vf = values as Single[]; - var vb = values as byte[]; - Host.Assert(vf != null || vb != null); - bool needScale = offset != 0 || scale != 1; - Host.Assert(!needScale || vf != null); + int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0); + int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1); + int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2); - bool a = ex.Alpha; - bool r = ex.Red; - bool g = ex.Green; - bool b = ex.Blue; + int size = type.ValueCount; + _host.Assert(size > 0); + _host.Assert(size == planes * height * width); + int cpix = height * width; - int h = height; - int w = width; + var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + var src = default(Bitmap); - if (ex.Interleave) + disposer = + () => { - int idst = 0; - for (int y = 0; y < h; ++y) - for (int x = 0; x < w; x++) - { - var pb = src.GetPixel(y, x); - if (vb != null) + if (src != null) + { + src.Dispose(); + src = null; + } + }; + + return + (ref VBuffer dst) => + { + getSrc(ref src); + Contracts.AssertValueOrNull(src); + + if (src == null) + { + dst = new VBuffer(size, 0, dst.Values, dst.Indices); + return; + } + + _host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb); + _host.Check(src.Height == height && src.Width == width); + + var values = dst.Values; + if (Utils.Size(values) < size) + values = new TValue[size]; + + Single offset = ex.Offset; + Single scale = ex.Scale; + _host.Assert(scale != 0); + + var vf = values as Single[]; + var vb = values as byte[]; + _host.Assert(vf != null || vb != null); + bool needScale = offset != 0 || scale != 1; + _host.Assert(!needScale || vf != null); + + bool a = ex.Alpha; + bool r = ex.Red; + bool g = ex.Green; + bool b = ex.Blue; + + int h = height; + int w = width; + + if (ex.Interleave) + { + int idst = 0; + for (int y = 0; y < h; ++y) + for (int x = 0; x < w; x++) { - if (a) { vb[idst++] = (byte)0; } - if (r) { vb[idst++] = pb.R; } - if (g) { vb[idst++] = pb.G; } - if (b) { vb[idst++] = pb.B; } + var pb = src.GetPixel(y, x); + if (vb != null) + { + if (a) { vb[idst++] = (byte)0; } + if (r) { vb[idst++] = pb.R; } + if (g) { vb[idst++] = pb.G; } + if (b) { vb[idst++] = pb.B; } + } + else if (!needScale) + { + if (a) { vf[idst++] = 0.0f; } + if (r) { vf[idst++] = pb.R; } + if (g) { vf[idst++] = pb.G; } + if (b) { vf[idst++] = pb.B; } + } + else + { + if (a) { vf[idst++] = 0.0f; } + if (r) { vf[idst++] = (pb.R - offset) * scale; } + if (g) { vf[idst++] = (pb.B - offset) * scale; } + if (b) { vf[idst++] = (pb.G - offset) * scale; } + } } - else if (!needScale) + _host.Assert(idst == size); + } + else + { + int idstMin = 0; + if (ex.Alpha) + { + // The image only has rgb but we need to supply alpha as well, so fake it up, + // assuming that it is 0xFF. + if (vf != null) { - if (a) { vf[idst++] = 0.0f; } - if (r) { vf[idst++] = pb.R; } - if (g) { vf[idst++] = pb.G; } - if (b) { vf[idst++] = pb.B; } + Single v = (0xFF - offset) * scale; + for (int i = 0; i < cpix; i++) + vf[i] = v; } else { - if (a) { vf[idst++] = 0.0f; } - if (r) { vf[idst++] = (pb.R - offset) * scale; } - if (g) { vf[idst++] = (pb.B - offset) * scale; } - if (b) { vf[idst++] = (pb.G - offset) * scale; } + for (int i = 0; i < cpix; i++) + vb[i] = 0xFF; } - } - Host.Assert(idst == size); - } - else - { - int idstMin = 0; - if (ex.Alpha) - { - // The image only has rgb but we need to supply alpha as well, so fake it up, - // assuming that it is 0xFF. - if (vf != null) - { - Single v = (0xFF - offset) * scale; - for (int i = 0; i < cpix; i++) - vf[i] = v; - } - else - { - for (int i = 0; i < cpix; i++) - vb[i] = 0xFF; - } - idstMin = cpix; + idstMin = cpix; // We've preprocessed alpha, avoid it in the // scan operation below. a = false; - } + } - for (int y = 0; y < h; ++y) - { - int idstBase = idstMin + y * w; + for (int y = 0; y < h; ++y) + { + int idstBase = idstMin + y * w; // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB. if (vb != null) - { - for (int x = 0; x < w; x++, idstBase++) { - var pb = src.GetPixel(x, y); - int idst = idstBase; - if (a) { vb[idst] = pb.A; idst += cpix; } - if (r) { vb[idst] = pb.R; idst += cpix; } - if (g) { vb[idst] = pb.G; idst += cpix; } - if (b) { vb[idst] = pb.B; idst += cpix; } + for (int x = 0; x < w; x++, idstBase++) + { + var pb = src.GetPixel(x, y); + int idst = idstBase; + if (a) { vb[idst] = pb.A; idst += cpix; } + if (r) { vb[idst] = pb.R; idst += cpix; } + if (g) { vb[idst] = pb.G; idst += cpix; } + if (b) { vb[idst] = pb.B; idst += cpix; } + } } - } - else if (!needScale) - { - for (int x = 0; x < w; x++, idstBase++) + else if (!needScale) { - var pb = src.GetPixel(x, y); - int idst = idstBase; - if (a) { vf[idst] = pb.A; idst += cpix; } - if (r) { vf[idst] = pb.R; idst += cpix; } - if (g) { vf[idst] = pb.G; idst += cpix; } - if (b) { vf[idst] = pb.B; idst += cpix; } + for (int x = 0; x < w; x++, idstBase++) + { + var pb = src.GetPixel(x, y); + int idst = idstBase; + if (a) { vf[idst] = pb.A; idst += cpix; } + if (r) { vf[idst] = pb.R; idst += cpix; } + if (g) { vf[idst] = pb.G; idst += cpix; } + if (b) { vf[idst] = pb.B; idst += cpix; } + } } - } - else - { - for (int x = 0; x < w; x++, idstBase++) + else { - var pb = src.GetPixel(x, y); - int idst = idstBase; - if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; } - if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; } - if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; } - if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; } + for (int x = 0; x < w; x++, idstBase++) + { + var pb = src.GetPixel(x, y); + int idst = idstBase; + if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; } + if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; } + if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; } + if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; } + } } } } - } - dst = new VBuffer(size, values, dst.Indices); - }; + dst = new VBuffer(size, values, dst.Indices); + }; + } + + private VectorType[] ConstructTypes() + { + var types = new VectorType[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + var column = _columns[i]; + Contracts.Assert(column.Planes > 0); + + var type = _inputSchema.GetColumnType(_colMapNewToOld[i]) as ImageType; + Contracts.Assert(type != null); + + int height = type.Height; + int width = type.Width; + Contracts.Assert(height > 0); + Contracts.Assert(width > 0); + Contracts.Assert((long)height * width <= int.MaxValue / 4); + + if (column.Interleave) + types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, height, width, column.Planes); + else + types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, column.Planes, height, width); + } + return types; + } + } + } + + public sealed class ImagePixelExtractorEstimator: TrivialEstimator + { + public ImagePixelExtractorEstimator(IHostEnvironment env, string inputColumn, string outputColumn, + ImagePixelExtractorTransform.ColorBits colors = ImagePixelExtractorTransform.ColorBits.Rgb, bool interleave = false) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, inputColumn, outputColumn, colors, interleave)) + { + } + + public ImagePixelExtractorEstimator(IHostEnvironment env, params ImagePixelExtractorTransform.ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, columns)) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in Transformer.Columns) + { + var col = inputSchema.FindColumn(colInfo.Input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString()); + + var itemType = colInfo.Convert ? NumberType.R4 : NumberType.U1; + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, itemType, false); + } + + return new SchemaShape(result.Values); } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index bc662eea2c..0cea5199de 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -177,7 +177,7 @@ public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns) _columns = columns.ToArray(); } - // Public constructor corresponding to SignatureDataTransform. + // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -251,6 +251,7 @@ public ImageResizerTransform(IHostEnvironment env, ModelLoadContext ctx) } } + // Factory method for SignatureLoadDataTransform. public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -265,7 +266,10 @@ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, ColumnInfo[] columns) { + Contracts.AssertValue(env); env.CheckValue(ctx, nameof(ctx)); + Contracts.AssertValue(columns); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); @@ -314,6 +318,8 @@ public ISchema GetOutputSchema(ISchema inputSchema) public IDataView Transform(IDataView input) { + _host.CheckValue(input, nameof(input)); + var mapper = MakeRowMapper(input.Schema); return new RowToRowMapperTransform(_host, input, mapper); } diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index d8795e10e7..3f2b7626b2 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -31,7 +31,8 @@ public void TestEstimatorSaveLoad() var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) - .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)); + .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) + .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels")); var model = pipe.Fit(data); @@ -74,13 +75,6 @@ public void TestSaveImages() } }, images); - var fh = env.CreateOutputFile("model.zip"); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(cropped)); - - cropped = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath("model.zip"); - cropped.Schema.TryGetColumnIndex("ImagePath", out int pathColumn); cropped.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); using (var cursor = cropped.GetRowCursor((x) => true)) @@ -179,20 +173,30 @@ public void TestBackAndForthConversion() } }, images); - var pixels = new ImagePixelExtractorTransform(env, new ImagePixelExtractorTransform.Arguments() + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() { Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true} } }, cropped); - var backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } }, pixels); + var fname = nameof(TestBackAndForthConversion) + "_model.zip"; + + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); + + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); using (var cursor = backToBitmaps.GetRowCursor((x) => true)) From a03994b764f249d0aa37193d1801be36f495f3e4 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 15:48:17 -0700 Subject: [PATCH 07/17] Minor fix --- .../ImagePixelExtractorTransform.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index af24cbf560..3c275fe5dc 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -356,8 +356,8 @@ public ImagePixelExtractorTransform(IHostEnvironment env, ModelLoadContext ctx) // int: id of output column name // int: id of input column name - // foreach added column - // ColInfo + // for each added column + // ColumnInfo int n = ctx.Reader.ReadInt32(); @@ -402,8 +402,8 @@ private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, Col // int: id of output column name // int: id of input column name - // foreach added column - // ColInfo + // for each added column + // ColumnInfo ctx.Writer.Write(columns.Length); for (int i = 0; i < columns.Length; i++) From 55b33afe64bd77d5a31f9ca7f6ddd60ca2deb638 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 15:59:09 -0700 Subject: [PATCH 08/17] Fixed build --- src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 184c0226bb..3c47f84faa 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -70,7 +70,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var originalColumn = inputSchema.FindColumn(column.Source); if (originalColumn != null) { - var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemKind, originalColumn.IsKey, originalColumn.MetadataKinds); + var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.MetadataKinds); resultDic[column.Name] = col; } else From ac46be9f4c154ab46d4bc343056cb5dfd15c57a0 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 17:54:58 -0700 Subject: [PATCH 09/17] PR comments --- .../Microsoft.ML.Console.csproj | 1 + src/Microsoft.ML.Core/Data/IEstimator.cs | 22 ++++++++++++ .../DataLoadSave/TrivialEstimator.cs | 35 +++++++++++++++++++ .../ImageLoaderTransform.cs | 20 ----------- 4 files changed, 58 insertions(+), 20 deletions(-) create mode 100644 src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index f9a1b5ef27..1471c580ba 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -16,6 +16,7 @@ + diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index da1cd683b0..a27ea0e5d9 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -28,16 +28,38 @@ public enum VectorKind VariableVector } + /// + /// The column name. + /// public readonly string Name; + + /// + /// The type of the column: scalar, fixed vector or variable vector. + /// public readonly VectorKind Kind; + + /// + /// The 'raw' type of column item: must be a primitive type or a structured type. + /// public readonly ColumnType ItemType; + /// + /// The flag whether the column is actually a key. If yes, is representing + /// the underlying primitive type. + /// public readonly bool IsKey; + /// + /// The metadata kinds that are present for this column. + /// public readonly string[] MetadataKinds; public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, string[] metadataKinds = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValueOrNull(metadataKinds); + Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key"); + Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector"); + + Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key"); Name = name; Kind = vecKind; diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs new file mode 100644 index 0000000000..29c081ac35 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// The trivial implementation of that already has + /// the transformer and returns it on every call to . + /// + /// Concrete implementations still have to provide the schema propagation mechanism, since + /// there is no easy way to infer it from the transformer. + /// + public abstract class TrivialEstimator : IEstimator + where TTransformer : class, ITransformer + { + protected readonly IHost Host; + protected readonly TTransformer Transformer; + + protected TrivialEstimator(IHost host, TTransformer transformer) + { + Contracts.AssertValue(host); + + Host = host; + Host.CheckValue(transformer, nameof(transformer)); + Transformer = transformer; + } + + public TTransformer Fit(IDataView input) => Transformer; + + public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); + } +} diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 8bd7f2d600..aa19da7ff4 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -29,26 +29,6 @@ namespace Microsoft.ML.Runtime.ImageAnalytics { - public abstract class TrivialEstimator : IEstimator - where TTransformer : class, ITransformer - { - protected readonly IHost Host; - protected readonly TTransformer Transformer; - - protected TrivialEstimator(IHost host, TTransformer transformer) - { - Contracts.AssertValue(host); - - Host = host; - Host.CheckValue(transformer, nameof(transformer)); - Transformer = transformer; - } - - public TTransformer Fit(IDataView input) => Transformer; - - public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); - } - /// /// Transform which takes one or many columns of type and loads them as /// From 11aa1fc8a84e79ee0bc38a0ae4eb3445beabf278 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 20:10:50 -0700 Subject: [PATCH 10/17] Added grayscale transform --- .../EntryPoints/ImageAnalytics.cs | 2 +- .../ImageGrayscaleTransform.cs | 310 ++++++++++++++---- .../ImageLoaderTransform.cs | 22 +- test/Microsoft.ML.Tests/ImagesTests.cs | 15 +- 4 files changed, 271 insertions(+), 78 deletions(-) diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs index c21822be21..921309d7ef 100644 --- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs +++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs @@ -55,7 +55,7 @@ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, ImageGrayscaleTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageGrayscaleTransform", input); - var xf = new ImageGrayscaleTransform(h, input, input.Data); + var xf = ImageGrayscaleTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 7a267cf1b8..103888184d 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -2,22 +2,30 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Drawing; -using System.Drawing.Imaging; -using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Runtime.Model; +using System; +using System.Collections.Generic; +using System.Drawing; +using System.Drawing.Imaging; +using System.Linq; +using System.Text; -[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform), ImageGrayscaleTransform.UserName, "ImageGrayscaleTransform", "ImageGrayscale")] -[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform), + ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadModel), + ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageGrayscaleTransform.Mapper), null, typeof(SignatureLoadRowMapper), ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics @@ -28,7 +36,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics /// Transform which takes one or many columns of type in IDataView and /// convert them to greyscale representation of the same image. /// - public sealed class ImageGrayscaleTransform : OneToOneTransformBase + public sealed class ImageGrayscaleTransform : ITransformer, ICanSaveModel { public sealed class Column : OneToOneColumn { @@ -68,51 +76,87 @@ private static VersionInfo GetVersionInfo() } private const string RegistrationName = "ImageGrayscale"; + private readonly IHost _host; + private readonly (string input, string output)[] _columns; - // Public constructor corresponding to SignatureDataTransform. - public ImageGrayscaleTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type") + public (string input, string output)[] Columns => _columns; + + public ImageGrayscaleTransform(IHostEnvironment env, params (string input, string output)[] columns) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); - Metadata.Seal(); + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + _host.CheckValue(columns, nameof(columns)); + + _columns = columns.ToArray(); + } + + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + var transformer = new ImageGrayscaleTransform(env, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type") + public ImageGrayscaleTransform(IHostEnvironment env, ModelLoadContext ctx) { - Host.AssertValue(ctx); + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(RegistrationName); + _host.CheckValue(ctx, nameof(ctx)); + // *** Binary format *** - // - Host.AssertNonEmpty(Infos); - Metadata.Seal(); + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + + int n = ctx.Reader.ReadInt32(); + _columns = new (string input, string output)[n]; + for (int i = 0; i < n; i++) + { + string output = ctx.LoadNonEmptyString(); + string input = ctx.LoadNonEmptyString(); + _columns[i] = (input, output); + } } - public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new ImageGrayscaleTransform(h, ctx, input)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + var transformer = new ImageGrayscaleTransform(env, ctx); + return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - public override void Save(ModelSaveContext ctx) + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + + private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, (string input, string output)[] columns) { - Host.CheckValue(ctx, nameof(ctx)); + Contracts.AssertValue(env); + env.CheckValue(ctx, nameof(ctx)); + Contracts.AssertValue(columns); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // - SaveBase(ctx); - } + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return Infos[iinfo].TypeSrc; + ctx.Writer.Write(columns.Length); + for (int i = 0; i < columns.Length; i++) + { + ctx.SaveNonEmptyString(columns[i].output); + ctx.SaveNonEmptyString(columns[i].input); + } } private static readonly ColorMatrix _grayscaleColorMatrix = new ColorMatrix( @@ -125,47 +169,175 @@ protected override ColumnType GetColumnTypeCore(int iinfo) new float[] {0, 0, 0, 0, 1} }); - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + private IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(_host, _columns, schema); + + private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) + { + Contracts.AssertValueOrNull(ctx); + Contracts.AssertValue(inputSchema); + Contracts.AssertNonEmpty(input); + + if (!inputSchema.TryGetColumnIndex(input, out srcCol)) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (!(inputSchema.GetColumnType(srcCol) is ImageType)) + throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "image", inputSchema.GetColumnType(srcCol).ToString()); + } + + public ISchema GetOutputSchema(ISchema inputSchema) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + _host.CheckValue(inputSchema, nameof(inputSchema)); - var src = default(Bitmap); - var getSrc = GetSrcGetter(input, iinfo); + // Check that all the input columns are present and are images. + foreach (var column in _columns) + CheckInput(_host, inputSchema, column.input, out int col); - disposer = - () => + return Transform(new EmptyDataView(_host, inputSchema)).Schema; + } + + public IDataView Transform(IDataView input) + { + _host.CheckValue(input, nameof(input)); + + var mapper = MakeRowMapper(input.Schema); + return new RowToRowMapperTransform(_host, input, mapper); + } + + internal sealed class Mapper : IRowMapper + { + private readonly IHost _host; + private readonly (string input, string output)[] _columns; + private readonly ISchema _inputSchema; + private readonly Dictionary _colMapNewToOld; + + public Mapper(IHostEnvironment env, (string input, string output)[] columns, ISchema inputSchema) + { + Contracts.AssertValue(env); + _host = env.Register(nameof(Mapper)); + _host.AssertValue(columns); + _host.AssertValue(inputSchema); + + _colMapNewToOld = new Dictionary(); + for (int i = 0; i < columns.Length; i++) { - if (src != null) - { - src.Dispose(); - src = null; - } - }; + CheckInput(_host, inputSchema, columns[i].input, out int srcCol); + _colMapNewToOld.Add(i, srcCol); + } + _columns = columns; + _inputSchema = inputSchema; + } + + public Func GetDependencies(Func activeOutput) + { + var active = new bool[_inputSchema.ColumnCount]; + foreach (var pair in _colMapNewToOld) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; + } - ValueGetter del = - (ref Bitmap dst) => + public RowMapperColumnInfo[] GetOutputColumns() + => _columns.Select((x, idx) => new RowMapperColumnInfo(x.output, _inputSchema.GetColumnType(_colMapNewToOld[idx]), null)).ToArray(); + + public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(inputSchema, nameof(inputSchema)); + var transformer = new ImageGrayscaleTransform(env, ctx); + return transformer.MakeRowMapper(inputSchema); + } + + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + _host.Assert(input.Schema == _inputSchema); + var result = new Delegate[_columns.Length]; + var disposers = new Action[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) { - if (dst != null) - dst.Dispose(); - - getSrc(ref src); - if (src == null || src.Height <= 0 || src.Width <= 0) - return; - - dst = new Bitmap(src.Width, src.Height); - ImageAttributes attributes = new ImageAttributes(); - attributes.SetColorMatrix(_grayscaleColorMatrix); - var srcRectangle = new Rectangle(0, 0, src.Width, src.Height); - using (var g = Graphics.FromImage(dst)) - { - g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes); - } - Host.Assert(dst.Width == src.Width && dst.Height == src.Height); + if (!activeOutput(i)) + continue; + int srcCol = _colMapNewToOld[i]; + result[i] = MakeGetter(input, i, out disposers[i]); + } + disposer = () => + { + foreach (var act in disposers) + act(); }; + return result; + } + + private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + _host.AssertValue(input); + _host.Assert(0 <= iinfo && iinfo < _columns.Length); + + var src = default(Bitmap); + var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + + disposer = + () => + { + if (src != null) + { + src.Dispose(); + src = null; + } + }; + + ValueGetter del = + (ref Bitmap dst) => + { + if (dst != null) + dst.Dispose(); + + getSrc(ref src); + if (src == null || src.Height <= 0 || src.Width <= 0) + return; + + dst = new Bitmap(src.Width, src.Height); + ImageAttributes attributes = new ImageAttributes(); + attributes.SetColorMatrix(_grayscaleColorMatrix); + var srcRectangle = new Rectangle(0, 0, src.Width, src.Height); + using (var g = Graphics.FromImage(dst)) + { + g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes); + } + Contracts.Assert(dst.Width == src.Width && dst.Height == src.Height); + }; + + return del; + } + } + } + + public sealed class ImageGrayscaleEstimator: TrivialEstimator + { + public ImageGrayscaleEstimator(IHostEnvironment env, params (string input, string output)[] columns) + :base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageGrayscaleEstimator)), new ImageGrayscaleTransform(env, columns)) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in Transformer.Columns) + { + var col = inputSchema.FindColumn(colInfo.input); + + if (col == null) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, new ImageType().ToString(), col.GetTypeString()); + + result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind, col.ItemType, col.IsKey, col.MetadataKinds); + } - return del; + return new SchemaShape(result.Values); } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index aa19da7ff4..81006fd204 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -2,20 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Drawing; -using System.IO; -using System.Text; -using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.ImageAnalytics; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Core.Data; +using System; using System.Collections.Generic; +using System.Drawing; +using System.IO; using System.Linq; +using System.Text; [assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform), ImageLoaderTransform.UserName, "ImageLoaderTransform", "ImageLoader")] @@ -122,6 +122,16 @@ public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext return new ImageLoaderTransform(env, imageFolder, columns); } + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + var transformer = Create(env, ctx); + return transformer.CreateDataTransform(input); + } + public ISchema GetOutputSchema(ISchema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 3f2b7626b2..fe64f75bc4 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -32,8 +32,10 @@ public void TestEstimatorSaveLoad() var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) - .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels")); + .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels")) + .Append(new ImageGrayscaleEstimator(env, ("ImageReal", "ImageGray"))); + pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema)); var model = pipe.Fit(data); using (var file = env.CreateTempFile()) @@ -120,13 +122,22 @@ public void TestGreyscaleTransformImages() } }, images); - var grey = new ImageGrayscaleTransform(env, new ImageGrayscaleTransform.Arguments() + IDataView grey = ImageGrayscaleTransform.Create(env, new ImageGrayscaleTransform.Arguments() { Column = new ImageGrayscaleTransform.Column[1]{ new ImageGrayscaleTransform.Column() { Name= "ImageGrey", Source = "ImageCropped"} } }, cropped); + var fname = nameof(TestGreyscaleTransformImages) + "_model.zip"; + + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(grey)); + + grey = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); + grey.Schema.TryGetColumnIndex("ImageGrey", out int greyColumn); using (var cursor = grey.GetRowCursor((x) => true)) { From 281e731462685c37a53d6e3fadc03f5b4ed9a300 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 27 Aug 2018 21:01:32 -0700 Subject: [PATCH 11/17] wip one to one --- .../Transforms/OneToOneTransformerBase.cs | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs new file mode 100644 index 0000000000..fbd3438273 --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -0,0 +1,150 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Model; + +namespace Microsoft.ML.Runtime.Data +{ + public abstract class OneToOneTransformerBase: ITransformer + { + protected readonly IHost Host; + protected readonly (string input, string output)[] ColumnPairs; + + protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns) + { + Contracts.AssertValue(host); + Contracts.AssertValue(columns); + + Host = host; + ColumnPairs = columns; + } + + protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) + { + Host = host; + // *** Binary format *** + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + + int n = ctx.Reader.ReadInt32(); + ColumnPairs = new (string input, string output)[n]; + for (int i = 0; i < n; i++) + { + string output = ctx.LoadNonEmptyString(); + string input = ctx.LoadNonEmptyString(); + ColumnPairs[i] = (input, output); + } + } + + protected void Save(ModelSaveContext ctx) => SaveContents(Host, ctx, ColumnPairs); + + private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, (string input, string output)[] columns) + { + Contracts.AssertValue(env); + env.CheckValue(ctx, nameof(ctx)); + Contracts.AssertValue(columns); + + // *** Binary format *** + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + + ctx.Writer.Write(columns.Length); + for (int i = 0; i < columns.Length; i++) + { + ctx.SaveNonEmptyString(columns[i].output); + ctx.SaveNonEmptyString(columns[i].input); + } + } + + private void CheckInput(ISchema inputSchema, int col, out int srcCol) + { + Contracts.AssertValue(inputSchema); + Contracts.Assert(0 <= col && col < ColumnPairs.Length); + + if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input); + CheckInputColumn(inputSchema, col, srcCol); + } + + protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + { + // By default, no extra checks. + } + + protected abstract MapperBase MakeRowMapper(ISchema schema); + + protected abstract class MapperBase: IRowMapper + { + protected readonly IHost Host; + protected readonly Dictionary ColMapNewToOld; + protected readonly ISchema InputSchema; + private readonly OneToOneTransformerBase _parent; + + protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema) + { + Contracts.AssertValue(host); + Contracts.AssertValue(parent); + Contracts.AssertValue(inputSchema); + + Host = host; + _parent = parent; + + ColMapNewToOld = new Dictionary(); + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + _parent.CheckInput(inputSchema, i, out int srcCol); + ColMapNewToOld.Add(i, srcCol); + } + InputSchema = inputSchema; + } + public Func GetDependencies(Func activeOutput) + { + var active = new bool[_inputSchema.ColumnCount]; + foreach (var pair in _colMapNewToOld) + if (activeOutput(pair.Key)) + active[pair.Value] = true; + return col => active[col]; + } + + public abstract RowMapperColumnInfo[] GetOutputColumns(); + + public void Save(ModelSaveContext ctx) => _parent.Save(ctx); + + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + Contracts.Assert(input.Schema == InputSchema); + var result = new Delegate[_parent.ColumnPairs.Length]; + var disposers = new Action[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!activeOutput(i)) + continue; + int srcCol = ColMapNewToOld[i]; + result[i] = MakeGetter(input, i, out disposers[i]); + } + if (disposers.Any(x => x != null)) + { + disposer = () => + { + foreach (var act in disposers) + act(); + }; + } + else + disposer = null; + return result; + } + + protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer); + } + } +} From baa5c34422bd36f3df00d1b6c28f52aa8449d8a7 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Tue, 28 Aug 2018 08:22:28 -0700 Subject: [PATCH 12/17] Finished the mockup of one-to-one transformer base class. --- .../Transforms/OneToOneTransformerBase.cs | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index fbd3438273..88871cf13c 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -43,13 +43,9 @@ protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) } } - protected void Save(ModelSaveContext ctx) => SaveContents(Host, ctx, ColumnPairs); - - private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, (string input, string output)[] columns) + protected void SaveColumns(ModelSaveContext ctx) { - Contracts.AssertValue(env); - env.CheckValue(ctx, nameof(ctx)); - Contracts.AssertValue(columns); + Host.CheckValue(ctx, nameof(ctx)); // *** Binary format *** // int: number of added columns @@ -57,11 +53,11 @@ private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, (st // int: id of output column name // int: id of input column name - ctx.Writer.Write(columns.Length); - for (int i = 0; i < columns.Length; i++) + ctx.Writer.Write(ColumnPairs.Length); + for (int i = 0; i < ColumnPairs.Length; i++) { - ctx.SaveNonEmptyString(columns[i].output); - ctx.SaveNonEmptyString(columns[i].input); + ctx.SaveNonEmptyString(ColumnPairs[i].output); + ctx.SaveNonEmptyString(ColumnPairs[i].input); } } @@ -77,11 +73,30 @@ private void CheckInput(ISchema inputSchema, int col, out int srcCol) protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - // By default, no extra checks. + // By default, there are no extra checks. } protected abstract MapperBase MakeRowMapper(ISchema schema); + public ISchema GetOutputSchema(ISchema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + // Check that all the input columns are present and correct. + for (int i = 0; i < ColumnPairs.Length; i++) + CheckInput(inputSchema, i, out int col); + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + + public IDataView Transform(IDataView input) => MakeTransform(input); + + protected RowToRowMapperTransform MakeTransform(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema)); + } + protected abstract class MapperBase: IRowMapper { protected readonly IHost Host; @@ -108,8 +123,8 @@ protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSc } public Func GetDependencies(Func activeOutput) { - var active = new bool[_inputSchema.ColumnCount]; - foreach (var pair in _colMapNewToOld) + var active = new bool[InputSchema.ColumnCount]; + foreach (var pair in ColMapNewToOld) if (activeOutput(pair.Key)) active[pair.Value] = true; return col => active[col]; @@ -117,7 +132,7 @@ public Func GetDependencies(Func activeOutput) public abstract RowMapperColumnInfo[] GetOutputColumns(); - public void Save(ModelSaveContext ctx) => _parent.Save(ctx); + public void Save(ModelSaveContext ctx) => _parent.SaveColumns(ctx); public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) { From 245c3449734860619a440dae8dcbb67fe93cda98 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Tue, 28 Aug 2018 09:27:41 -0700 Subject: [PATCH 13/17] Converted to inherit from base class --- .../Transforms/OneToOneTransformerBase.cs | 24 +- .../ImageGrayscaleTransform.cs | 191 +++---------- .../ImageLoaderTransform.cs | 205 ++++---------- .../ImagePixelExtractorTransform.cs | 254 ++++++----------- .../ImageResizerTransform.cs | 256 ++++++------------ 5 files changed, 258 insertions(+), 672 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index 88871cf13c..0fa15bebca 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -10,7 +10,7 @@ namespace Microsoft.ML.Runtime.Data { - public abstract class OneToOneTransformerBase: ITransformer + public abstract class OneToOneTransformerBase: ITransformer, ICanSaveModel { protected readonly IHost Host; protected readonly (string input, string output)[] ColumnPairs; @@ -18,7 +18,17 @@ public abstract class OneToOneTransformerBase: ITransformer protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns) { Contracts.AssertValue(host); - Contracts.AssertValue(columns); + host.CheckValue(columns, nameof(columns)); + + var newNames = new HashSet(); + foreach (var column in columns) + { + host.CheckNonEmpty(column.input, nameof(columns)); + host.CheckNonEmpty(column.output, nameof(columns)); + + if (!newNames.Add(column.output)) + throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times"); + } Host = host; ColumnPairs = columns; @@ -43,6 +53,8 @@ protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) } } + public abstract void Save(ModelSaveContext ctx); + protected void SaveColumns(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -76,7 +88,7 @@ protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol // By default, there are no extra checks. } - protected abstract MapperBase MakeRowMapper(ISchema schema); + protected abstract IRowMapper MakeRowMapper(ISchema schema); public ISchema GetOutputSchema(ISchema inputSchema) { @@ -89,9 +101,9 @@ public ISchema GetOutputSchema(ISchema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public IDataView Transform(IDataView input) => MakeTransform(input); + public IDataView Transform(IDataView input) => MakeDataTransform(input); - protected RowToRowMapperTransform MakeTransform(IDataView input) + protected RowToRowMapperTransform MakeDataTransform(IDataView input) { Host.CheckValue(input, nameof(input)); return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema)); @@ -132,7 +144,7 @@ public Func GetDependencies(Func activeOutput) public abstract RowMapperColumnInfo[] GetOutputColumns(); - public void Save(ModelSaveContext ctx) => _parent.SaveColumns(ctx); + public void Save(ModelSaveContext ctx) => _parent.Save(ctx); public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) { diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 103888184d..21076b2d86 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using System; using System.Collections.Generic; @@ -25,7 +26,7 @@ [assembly: LoadableClass(typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadModel), ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageGrayscaleTransform.Mapper), null, typeof(SignatureLoadRowMapper), +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadRowMapper), ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics @@ -36,7 +37,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics /// Transform which takes one or many columns of type in IDataView and /// convert them to greyscale representation of the same image. /// - public sealed class ImageGrayscaleTransform : ITransformer, ICanSaveModel + public sealed class ImageGrayscaleTransform : OneToOneTransformerBase { public sealed class Column : OneToOneColumn { @@ -76,18 +77,12 @@ private static VersionInfo GetVersionInfo() } private const string RegistrationName = "ImageGrayscale"; - private readonly IHost _host; - private readonly (string input, string output)[] _columns; - public (string input, string output)[] Columns => _columns; + public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); public ImageGrayscaleTransform(IHostEnvironment env, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(RegistrationName); - _host.CheckValue(columns, nameof(columns)); - - _columns = columns.ToArray(); } // Factory method for SignatureDataTransform. @@ -96,67 +91,44 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); env.CheckValue(input, nameof(input)); + env.CheckValue(args.Column, nameof(args.Column)); - var transformer = new ImageGrayscaleTransform(env, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()); - return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); + return new ImageGrayscaleTransform(env, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + .MakeDataTransform(input); } - public ImageGrayscaleTransform(IHostEnvironment env, ModelLoadContext ctx) + public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); - _host = env.Register(RegistrationName); - _host.CheckValue(ctx, nameof(ctx)); + var host = env.Register(RegistrationName); + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + return new ImageGrayscaleTransform(host, ctx); + } - // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name - - int n = ctx.Reader.ReadInt32(); - _columns = new (string input, string output)[n]; - for (int i = 0; i < n; i++) - { - string output = ctx.LoadNonEmptyString(); - string input = ctx.LoadNonEmptyString(); - _columns[i] = (input, output); - } + private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { } // Factory method for SignatureLoadDataTransform. public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(input, nameof(input)); + => Create(env, ctx).MakeDataTransform(input); - var transformer = new ImageGrayscaleTransform(env, ctx); - return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); - } - - public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, (string input, string output)[] columns) + public override void Save(ModelSaveContext ctx) { - Contracts.AssertValue(env); - env.CheckValue(ctx, nameof(ctx)); - Contracts.AssertValue(columns); + Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name - - ctx.Writer.Write(columns.Length); - for (int i = 0; i < columns.Length; i++) - { - ctx.SaveNonEmptyString(columns[i].output); - ctx.SaveNonEmptyString(columns[i].input); - } + // + base.SaveColumns(ctx); } private static readonly ColorMatrix _grayscaleColorMatrix = new ColorMatrix( @@ -169,114 +141,35 @@ private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, (st new float[] {0, 0, 0, 0, 1} }); - private IRowMapper MakeRowMapper(ISchema schema) - => new Mapper(_host, _columns, schema); + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); - private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Contracts.AssertValueOrNull(ctx); - Contracts.AssertValue(inputSchema); - Contracts.AssertNonEmpty(input); - - if (!inputSchema.TryGetColumnIndex(input, out srcCol)) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); if (!(inputSchema.GetColumnType(srcCol) is ImageType)) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "image", inputSchema.GetColumnType(srcCol).ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString()); } - public ISchema GetOutputSchema(ISchema inputSchema) + private sealed class Mapper : MapperBase { - _host.CheckValue(inputSchema, nameof(inputSchema)); - - // Check that all the input columns are present and are images. - foreach (var column in _columns) - CheckInput(_host, inputSchema, column.input, out int col); + private ImageGrayscaleTransform _parent; - return Transform(new EmptyDataView(_host, inputSchema)).Schema; - } - - public IDataView Transform(IDataView input) - { - _host.CheckValue(input, nameof(input)); - - var mapper = MakeRowMapper(input.Schema); - return new RowToRowMapperTransform(_host, input, mapper); - } - - internal sealed class Mapper : IRowMapper - { - private readonly IHost _host; - private readonly (string input, string output)[] _columns; - private readonly ISchema _inputSchema; - private readonly Dictionary _colMapNewToOld; - - public Mapper(IHostEnvironment env, (string input, string output)[] columns, ISchema inputSchema) - { - Contracts.AssertValue(env); - _host = env.Register(nameof(Mapper)); - _host.AssertValue(columns); - _host.AssertValue(inputSchema); - - _colMapNewToOld = new Dictionary(); - for (int i = 0; i < columns.Length; i++) - { - CheckInput(_host, inputSchema, columns[i].input, out int srcCol); - _colMapNewToOld.Add(i, srcCol); - } - _columns = columns; - _inputSchema = inputSchema; - } - - public Func GetDependencies(Func activeOutput) + public Mapper(ImageGrayscaleTransform parent, ISchema inputSchema) + :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - var active = new bool[_inputSchema.ColumnCount]; - foreach (var pair in _colMapNewToOld) - if (activeOutput(pair.Key)) - active[pair.Value] = true; - return col => active[col]; + _parent = parent; } - public RowMapperColumnInfo[] GetOutputColumns() - => _columns.Select((x, idx) => new RowMapperColumnInfo(x.output, _inputSchema.GetColumnType(_colMapNewToOld[idx]), null)).ToArray(); - - public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); - - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(inputSchema, nameof(inputSchema)); - var transformer = new ImageGrayscaleTransform(env, ctx); - return transformer.MakeRowMapper(inputSchema); - } - - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) - { - _host.Assert(input.Schema == _inputSchema); - var result = new Delegate[_columns.Length]; - var disposers = new Action[_columns.Length]; - for (int i = 0; i < _columns.Length; i++) - { - if (!activeOutput(i)) - continue; - int srcCol = _colMapNewToOld[i]; - result[i] = MakeGetter(input, i, out disposers[i]); - } - disposer = () => - { - foreach (var act in disposers) - act(); - }; - return result; - } + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent.ColumnPairs.Select((x, idx) => new RowMapperColumnInfo(x.output, InputSchema.GetColumnType(ColMapNewToOld[idx]), null)).ToArray(); - private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { - _host.AssertValue(input); - _host.Assert(0 <= iinfo && iinfo < _columns.Length); + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); var src = default(Bitmap); - var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); disposer = () => @@ -314,10 +207,10 @@ private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) } } - public sealed class ImageGrayscaleEstimator: TrivialEstimator + public sealed class ImageGrayscaleEstimator : TrivialEstimator { public ImageGrayscaleEstimator(IHostEnvironment env, params (string input, string output)[] columns) - :base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageGrayscaleEstimator)), new ImageGrayscaleTransform(env, columns)) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageGrayscaleEstimator)), new ImageGrayscaleTransform(env, columns)) { } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 81006fd204..20e1476feb 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -25,14 +25,14 @@ [assembly: LoadableClass(typeof(ImageLoaderTransform), null, typeof(SignatureLoadModel), "", ImageLoaderTransform.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageLoaderTransform.Mapper), null, typeof(SignatureLoadRowMapper), "", ImageLoaderTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageLoaderTransform), null, typeof(SignatureLoadRowMapper), "", ImageLoaderTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics { /// /// Transform which takes one or many columns of type and loads them as /// - public sealed class ImageLoaderTransform : ITransformer, ICanSaveModel + public sealed class ImageLoaderTransform : OneToOneTransformerBase { public sealed class Column : OneToOneColumn { @@ -67,31 +67,20 @@ public sealed class Arguments : TransformInputBase internal const string UserName = "Image Loader Transform"; public const string LoaderSignature = "ImageLoaderTransform"; - private readonly string _imageFolder; - private readonly (string input, string output)[] _columns; - private readonly IHost _host; + public readonly string ImageFolder; - public IReadOnlyCollection<(string input, string output)> Columns => _columns.AsReadOnly(); + public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); public ImageLoaderTransform(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderTransform)), columns) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(ImageLoaderTransform)); - _host.CheckValueOrNull(imageFolder); - _host.CheckValue(columns, nameof(columns)); - - _imageFolder = imageFolder; - - var newNames = new HashSet(); - foreach (var column in columns) - { - _host.CheckNonEmpty(column.input, nameof(columns)); - _host.CheckNonEmpty(column.output, nameof(columns)); + ImageFolder = imageFolder; + } - if (!newNames.Add(column.output)) - throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times"); - } - _columns = columns; + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data) + { + return new ImageLoaderTransform(env, args.ImageFolder, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + .MakeDataTransform(data); } public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx) @@ -100,84 +89,46 @@ public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); + return new ImageLoaderTransform(env.Register(nameof(ImageLoaderTransform)), ctx); + } + private ImageLoaderTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name + // // int: id of image folder - int n = ctx.Reader.ReadInt32(); - var columns = new (string input, string output)[n]; - for (int i = 0; i < n; i++) - { - string output = ctx.LoadNonEmptyString(); - string input = ctx.LoadNonEmptyString(); - columns[i] = (input, output); - } - - string imageFolder = ctx.LoadStringOrNull(); - - return new ImageLoaderTransform(env, imageFolder, columns); + ImageFolder = ctx.LoadStringOrNull(); } + // Factory method for SignatureLoadDataTransform. public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(input, nameof(input)); - - var transformer = Create(env, ctx); - return transformer.CreateDataTransform(input); - } - - public ISchema GetOutputSchema(ISchema inputSchema) - { - _host.CheckValue(inputSchema, nameof(inputSchema)); + => Create(env, ctx).MakeDataTransform(input); - // Check that all the input columns are present and are scalar texts. - foreach (var (input, output) in _columns) - CheckInput(_host, inputSchema, input, out int col); + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - return Transform(new EmptyDataView(_host, inputSchema)).Schema; - } - - private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Contracts.AssertValueOrNull(ctx); - Contracts.AssertValue(inputSchema); - Contracts.AssertNonEmpty(input); - - if (!inputSchema.TryGetColumnIndex(input, out srcCol)) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); if (!inputSchema.GetColumnType(srcCol).IsText) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), inputSchema.GetColumnType(srcCol).ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextType.Instance.ToString(), inputSchema.GetColumnType(srcCol).ToString()); } - public IDataView Transform(IDataView input) => CreateDataTransform(input); - - public void Save(ModelSaveContext ctx) => SaveContents(ctx, _imageFolder, _columns); - - private static void SaveContents(ModelSaveContext ctx, string imageFolder, (string input, string output)[] columns) + public override void Save(ModelSaveContext ctx) { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name + // // int: id of image folder - ctx.Writer.Write(columns.Length); - foreach (var (input, output) in columns) - { - ctx.SaveNonEmptyString(output); - ctx.SaveNonEmptyString(input); - } - ctx.SaveStringOrNull(imageFolder); + base.SaveColumns(ctx); + ctx.SaveStringOrNull(ImageFolder); } private static VersionInfo GetVersionInfo() @@ -191,81 +142,28 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - public static IDataTransform Create(IHostEnvironment env, ImageLoaderTransform.Arguments args, IDataView data) - { - return new ImageLoaderTransform(env, args.ImageFolder, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) - .CreateDataTransform(data); - } - - private IDataTransform CreateDataTransform(IDataView input) - { - _host.CheckValue(input, nameof(input)); - - var mapper = new Mapper(_host, _imageFolder, _columns, input.Schema); - return new RowToRowMapperTransform(_host, input, mapper); - } + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); - internal sealed class Mapper : IRowMapper + private sealed class Mapper : MapperBase { - private readonly IHost _host; - private readonly string _imageFolder; - private readonly (string input, string output)[] _columns; - private readonly Dictionary _colMapNewToOld; - private readonly ISchema _inputSchema; + private readonly ImageLoaderTransform _parent; private readonly ImageType _imageType; - public Mapper(IHostEnvironment env, string imageFolder, (string input, string output)[] columns, ISchema schema) + public Mapper(ImageLoaderTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(Mapper)); - _host.CheckValueOrNull(imageFolder); - _host.CheckValue(columns, nameof(columns)); - _host.CheckValue(schema, nameof(schema)); - - _colMapNewToOld = new Dictionary(); - for (int i = 0; i < columns.Length; i++) - { - CheckInput(_host, schema, columns[i].input, out int srcCol); - _colMapNewToOld.Add(i, srcCol); - } - - _imageFolder = imageFolder; - _columns = columns; - _inputSchema = schema; _imageType = new ImageType(); + _parent = parent; } - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(schema, nameof(schema)); + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); - var xf = ImageLoaderTransform.Create(env, ctx); - return new Mapper(env, xf._imageFolder, xf._columns, schema); - } - - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) - { - _host.Assert(input.Schema == _inputSchema); - var result = new Delegate[_columns.Length]; - for (int i = 0; i < _columns.Length; i++) - { - if (!activeOutput(i)) - continue; - int srcCol = _colMapNewToOld[i]; - result[i] = MakeGetter(input, i); - } disposer = null; - return result; - } - - private Delegate MakeGetter(IRow input, int iinfo) - { - _host.AssertValue(input); - _host.Assert(0 <= iinfo && iinfo < _columns.Length); - - var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); DvText src = default; ValueGetter del = (ref Bitmap dst) => @@ -284,8 +182,8 @@ private Delegate MakeGetter(IRow input, int iinfo) try { string path = src.ToString(); - if (!string.IsNullOrWhiteSpace(_imageFolder)) - path = Path.Combine(_imageFolder, path); + if (!string.IsNullOrWhiteSpace(_parent.ImageFolder)) + path = Path.Combine(_parent.ImageFolder, path); dst = new Bitmap(path); } catch (Exception) @@ -303,19 +201,8 @@ private Delegate MakeGetter(IRow input, int iinfo) return del; } - public Func GetDependencies(Func activeOutput) - { - var active = new bool[_inputSchema.ColumnCount]; - foreach (var pair in _colMapNewToOld) - if (activeOutput(pair.Key)) - active[pair.Value] = true; - return col => active[col]; - } - - public RowMapperColumnInfo[] GetOutputColumns() - => _columns.Select(x => new RowMapperColumnInfo(x.output, _imageType, null)).ToArray(); - - public void Save(ModelSaveContext ctx) => SaveContents(ctx, _imageFolder, _columns); + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent.ColumnPairs.Select(x => new RowMapperColumnInfo(x.output, _imageType, null)).ToArray(); } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index 3c275fe5dc..eb3048d799 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -25,16 +25,15 @@ [assembly: LoadableClass(typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadModel), ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(ImagePixelExtractorTransform.Mapper), null, typeof(SignatureLoadRowMapper), +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadRowMapper), ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics { - // REVIEW: Rewrite as LambdaTransform to simplify. /// /// Transform which takes one or many columns of and convert them into vector representation. /// - public sealed class ImagePixelExtractorTransform : ITransformer, ICanSaveModel + public sealed class ImagePixelExtractorTransform : OneToOneTransformerBase, ICanSaveModel { public class Column : OneToOneColumn { @@ -259,7 +258,6 @@ internal ColumnInfo(string input, string output, ModelLoadContext ctx) public void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); - #if DEBUG // This code is used in deserialization - assert that it matches what we computed above. int planes = (int)Colors; @@ -303,7 +301,6 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ImagePixelExtractor"; - private readonly IHost _host; private readonly ColumnInfo[] _columns; public IReadOnlyCollection Columns => _columns.AsReadOnly(); @@ -315,14 +312,17 @@ public ImagePixelExtractorTransform(IHostEnvironment env, string inputColumn, st } public ImagePixelExtractorTransform(IHostEnvironment env, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(RegistrationName); - _host.CheckValue(columns, nameof(columns)); - _columns = columns.ToArray(); } + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } + // SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { @@ -343,195 +343,93 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); } - public ImagePixelExtractorTransform(IHostEnvironment env, ModelLoadContext ctx) + public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); - _host = env.Register(RegistrationName); - _host.CheckValue(ctx, nameof(ctx)); + var host = env.Register(RegistrationName); + host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); + return new ImagePixelExtractorTransform(host, ctx); + } + + private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name + // // for each added column // ColumnInfo - int n = ctx.Reader.ReadInt32(); - - var names = new (string input, string output)[n]; - for (int i = 0; i < n; i++) - { - var output = ctx.LoadNonEmptyString(); - var input = ctx.LoadNonEmptyString(); - names[i] = (input, output); - } - - _columns = new ColumnInfo[n]; + _columns = new ColumnInfo[ColumnPairs.Length]; for (int i = 0; i < _columns.Length; i++) - _columns[i] = new ColumnInfo(names[i].input, names[i].output, ctx); + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx); } // Factory method for SignatureLoadDataTransform. public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(input, nameof(input)); - - var transformer = new ImagePixelExtractorTransform(env, ctx); - return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); - } + => Create(env, ctx).MakeDataTransform(input); - public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, ColumnInfo[] columns) + public override void Save(ModelSaveContext ctx) { - Contracts.AssertValue(env); - env.CheckValue(ctx, nameof(ctx)); - Contracts.AssertValue(columns); + Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name + // // for each added column // ColumnInfo - ctx.Writer.Write(columns.Length); - for (int i = 0; i < columns.Length; i++) - { - ctx.SaveNonEmptyString(columns[i].Output); - ctx.SaveNonEmptyString(columns[i].Input); - } + base.SaveColumns(ctx); - for (int i = 0; i < columns.Length; i++) - columns[i].Save(ctx); + foreach (ColumnInfo info in _columns) + info.Save(ctx); } - private IRowMapper MakeRowMapper(ISchema schema) - => new Mapper(_host, _columns, schema); + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); - private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Contracts.AssertValueOrNull(ctx); - Contracts.AssertValue(inputSchema); - Contracts.AssertNonEmpty(input); - - if (!inputSchema.TryGetColumnIndex(input, out srcCol)) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + var inputColName = _columns[col].Input; var imageType = inputSchema.GetColumnType(srcCol) as ImageType; if (imageType == null) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "image", inputSchema.GetColumnType(srcCol).ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "image", inputSchema.GetColumnType(srcCol).ToString()); if (imageType.Height <= 0 || imageType.Width <= 0) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "known-size image", "unknown-size image"); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "known-size image", "unknown-size image"); if ((long)imageType.Height * imageType.Width > int.MaxValue / 4) - throw ctx.Except("Image dimensions are too large"); + throw Host.Except("Image dimensions are too large"); } - public ISchema GetOutputSchema(ISchema inputSchema) + private sealed class Mapper : MapperBase { - _host.CheckValue(inputSchema, nameof(inputSchema)); - - // Check that all the input columns are present and are images of known size. - foreach (var column in _columns) - CheckInput(_host, inputSchema, column.Input, out int col); - - return Transform(new EmptyDataView(_host, inputSchema)).Schema; - } - - public IDataView Transform(IDataView input) - { - _host.CheckValue(input, nameof(input)); - - var mapper = MakeRowMapper(input.Schema); - return new RowToRowMapperTransform(_host, input, mapper); - } - - internal sealed class Mapper : IRowMapper - { - private readonly IHost _host; - private readonly ColumnInfo[] _columns; + private readonly ImagePixelExtractorTransform _parent; private readonly VectorType[] _types; - private readonly ISchema _inputSchema; - private readonly Dictionary _colMapNewToOld; - public Mapper(IHostEnvironment env, ColumnInfo[] columns, ISchema inputSchema) + public Mapper(ImagePixelExtractorTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - Contracts.AssertValue(env); - _host = env.Register(nameof(Mapper)); - _host.AssertValue(columns); - _host.AssertValue(inputSchema); - - _colMapNewToOld = new Dictionary(); - for (int i = 0; i < columns.Length; i++) - { - CheckInput(_host, inputSchema, columns[i].Input, out int srcCol); - _colMapNewToOld.Add(i, srcCol); - } - - _columns = columns; - _inputSchema = inputSchema; + _parent = parent; _types = ConstructTypes(); } - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) - { - _host.Assert(input.Schema == _inputSchema); - var result = new Delegate[_columns.Length]; - var disposers = new Action[_columns.Length]; - for (int i = 0; i < _columns.Length; i++) - { - if (!activeOutput(i)) - continue; - int srcCol = _colMapNewToOld[i]; - result[i] = MakeGetter(input, i, out disposers[i]); - } - disposer = () => - { - foreach (var act in disposers) - act(); - }; - return result; - } - - public Func GetDependencies(Func activeOutput) - { - var active = new bool[_inputSchema.ColumnCount]; - foreach (var pair in _colMapNewToOld) - if (activeOutput(pair.Key)) - active[pair.Value] = true; - return col => active[col]; - } - - public RowMapperColumnInfo[] GetOutputColumns() - => _columns.Select((x, idx) => new RowMapperColumnInfo(x.Output, _types[idx], null)).ToArray(); - - public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); - - // Factory method for SignatureLoadRowMapper. - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(inputSchema, nameof(inputSchema)); - var transformer = new ImagePixelExtractorTransform(env, ctx); - return transformer.MakeRowMapper(inputSchema); - } + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent._columns.Select((x, idx) => new RowMapperColumnInfo(x.Output, _types[idx], null)).ToArray(); - private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { - _host.AssertValue(input); - _host.Assert(0 <= iinfo && iinfo < _columns.Length); + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); - if (_columns[iinfo].Convert) + if (_parent._columns[iinfo].Convert) return GetGetterCore(input, iinfo, out disposer); return GetGetterCore(input, iinfo, out disposer); } @@ -540,20 +438,20 @@ private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) { var type = _types[iinfo]; - _host.Assert(type.DimCount == 3); + Contracts.Assert(type.DimCount == 3); - var ex = _columns[iinfo]; + var ex = _parent._columns[iinfo]; int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0); int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1); int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2); int size = type.ValueCount; - _host.Assert(size > 0); - _host.Assert(size == planes * height * width); + Contracts.Assert(size > 0); + Contracts.Assert(size == planes * height * width); int cpix = height * width; - var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); var src = default(Bitmap); disposer = @@ -578,22 +476,22 @@ private ValueGetter> GetGetterCore(IRow input, int iinfo return; } - _host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb); - _host.Check(src.Height == height && src.Width == width); + Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb); + Host.Check(src.Height == height && src.Width == width); var values = dst.Values; if (Utils.Size(values) < size) values = new TValue[size]; - Single offset = ex.Offset; - Single scale = ex.Scale; - _host.Assert(scale != 0); + float offset = ex.Offset; + float scale = ex.Scale; + Contracts.Assert(scale != 0); - var vf = values as Single[]; + var vf = values as float[]; var vb = values as byte[]; - _host.Assert(vf != null || vb != null); + Contracts.Assert(vf != null || vb != null); bool needScale = offset != 0 || scale != 1; - _host.Assert(!needScale || vf != null); + Contracts.Assert(!needScale || vf != null); bool a = ex.Alpha; bool r = ex.Red; @@ -632,16 +530,16 @@ private ValueGetter> GetGetterCore(IRow input, int iinfo if (b) { vf[idst++] = (pb.G - offset) * scale; } } } - _host.Assert(idst == size); + Contracts.Assert(idst == size); } else { int idstMin = 0; if (ex.Alpha) { - // The image only has rgb but we need to supply alpha as well, so fake it up, - // assuming that it is 0xFF. - if (vf != null) + // The image only has rgb but we need to supply alpha as well, so fake it up, + // assuming that it is 0xFF. + if (vf != null) { Single v = (0xFF - offset) * scale; for (int i = 0; i < cpix; i++) @@ -654,17 +552,17 @@ private ValueGetter> GetGetterCore(IRow input, int iinfo } idstMin = cpix; - // We've preprocessed alpha, avoid it in the - // scan operation below. - a = false; + // We've preprocessed alpha, avoid it in the + // scan operation below. + a = false; } for (int y = 0; y < h; ++y) { int idstBase = idstMin + y * w; - // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB. - if (vb != null) + // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB. + if (vb != null) { for (int x = 0; x < w; x++, idstBase++) { @@ -709,13 +607,13 @@ private ValueGetter> GetGetterCore(IRow input, int iinfo private VectorType[] ConstructTypes() { - var types = new VectorType[_columns.Length]; - for (int i = 0; i < _columns.Length; i++) + var types = new VectorType[_parent._columns.Length]; + for (int i = 0; i < _parent._columns.Length; i++) { - var column = _columns[i]; + var column = _parent._columns[i]; Contracts.Assert(column.Planes > 0); - var type = _inputSchema.GetColumnType(_colMapNewToOld[i]) as ImageType; + var type = InputSchema.GetColumnType(ColMapNewToOld[i]) as ImageType; Contracts.Assert(type != null); int height = type.Height; @@ -734,7 +632,7 @@ private VectorType[] ConstructTypes() } } - public sealed class ImagePixelExtractorEstimator: TrivialEstimator + public sealed class ImagePixelExtractorEstimator : TrivialEstimator { public ImagePixelExtractorEstimator(IHostEnvironment env, string inputColumn, string outputColumn, ImagePixelExtractorTransform.ColorBits colors = ImagePixelExtractorTransform.ColorBits.Rgb, bool interleave = false) diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index 0cea5199de..ac11c7fa8d 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -26,7 +26,7 @@ [assembly: LoadableClass(typeof(ImageResizerTransform), null, typeof(SignatureLoadModel), ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageResizerTransform.Mapper), null, typeof(SignatureLoadRowMapper), +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageResizerTransform), null, typeof(SignatureLoadRowMapper), ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)] namespace Microsoft.ML.Runtime.ImageAnalytics @@ -35,7 +35,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics /// /// Transform which takes one or many columns of and resize them to provided height and width. /// - public sealed class ImageResizerTransform : ITransformer, ICanSaveModel + public sealed class ImageResizerTransform : OneToOneTransformerBase { public enum ResizingKind : byte { @@ -149,15 +149,15 @@ private static VersionInfo GetVersionInfo() return new VersionInfo( modelSignature: "IMGSCALF", //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010002, + //verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap + verWrittenCur: 0x00010003, // No more sizeof(float) + verReadableCur: 0x00010003, + verWeCanReadBack: 0x00010003, loaderSignature: LoaderSignature); } private const string RegistrationName = "ImageScaler"; - private readonly IHost _host; private readonly ColumnInfo[] _columns; public IReadOnlyCollection Columns => _columns.AsReadOnly(); @@ -169,14 +169,17 @@ public ImageResizerTransform(IHostEnvironment env, string inputColumn, string ou } public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(RegistrationName); - _host.CheckValue(columns, nameof(columns)); - _columns = columns.ToArray(); } + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } + // Factory method for SignatureDataTransform. public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { @@ -199,221 +202,114 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV item.CropAnchor ?? args.CropAnchor); } - var transformer = new ImageResizerTransform(env, cols); - return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); + return new ImageResizerTransform(env, cols).MakeDataTransform(input); } - public ImageResizerTransform(IHostEnvironment env, ModelLoadContext ctx) + public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); - _host = env.Register(RegistrationName); + var host = env.Register(RegistrationName); - _host.CheckValue(ctx, nameof(ctx)); + host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); + return new ImageResizerTransform(host, ctx); + } + + private ImageResizerTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { // *** Binary format *** - // int: sizeof(float) - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name + // // for each added column // int: width // int: height // byte: scaling kind + // byte: anchor - int cbFloat = ctx.Reader.ReadInt32(); - _host.CheckDecode(cbFloat == sizeof(Single)); - - int n = ctx.Reader.ReadInt32(); - - var names = new (string input, string output)[n]; - for (int i = 0; i < n; i++) - { - var output = ctx.LoadNonEmptyString(); - var input = ctx.LoadNonEmptyString(); - names[i] = (input, output); - } - - _columns = new ColumnInfo[n]; - for (int i = 0; i < n; i++) + _columns = new ColumnInfo[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) { int width = ctx.Reader.ReadInt32(); - _host.CheckDecode(width > 0); + Host.CheckDecode(width > 0); int height = ctx.Reader.ReadInt32(); - _host.CheckDecode(height > 0); + Host.CheckDecode(height > 0); var scale = (ResizingKind)ctx.Reader.ReadByte(); - _host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); + Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); var anchor = (Anchor)ctx.Reader.ReadByte(); - _host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); - _columns[i] = new ColumnInfo(names[i].input, names[i].output, width, height, scale, anchor); + Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, width, height, scale, anchor); } } // Factory method for SignatureLoadDataTransform. public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(input, nameof(input)); + => Create(env, ctx).MakeDataTransform(input); - var transformer = new ImageResizerTransform(env, ctx); - return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema)); - } - - public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); - private static void SaveContents(IHostEnvironment env, ModelSaveContext ctx, ColumnInfo[] columns) + public override void Save(ModelSaveContext ctx) { - Contracts.AssertValue(env); - env.CheckValue(ctx, nameof(ctx)); - Contracts.AssertValue(columns); + Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(float) - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name + // // for each added column // int: width // int: height // byte: scaling kind + // byte: anchor - ctx.Writer.Write(sizeof(float)); + base.SaveColumns(ctx); - ctx.Writer.Write(columns.Length); - for (int i = 0; i < columns.Length; i++) - { - ctx.SaveNonEmptyString(columns[i].Output); - ctx.SaveNonEmptyString(columns[i].Input); - } - - foreach (var col in columns) + foreach (var col in _columns) { ctx.Writer.Write(col.Width); ctx.Writer.Write(col.Height); - env.Assert((ResizingKind)(byte)col.Scale == col.Scale); + Contracts.Assert((ResizingKind)(byte)col.Scale == col.Scale); ctx.Writer.Write((byte)col.Scale); - env.Assert((Anchor)(byte)col.Anchor == col.Anchor); + Contracts.Assert((Anchor)(byte)col.Anchor == col.Anchor); ctx.Writer.Write((byte)col.Anchor); } } - public ISchema GetOutputSchema(ISchema inputSchema) - { - _host.CheckValue(inputSchema, nameof(inputSchema)); - - // Check that all the input columns are present and are images. - foreach (var column in _columns) - CheckInput(_host, inputSchema, column.Input, out int col); + protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); - return Transform(new EmptyDataView(_host, inputSchema)).Schema; - } - - public IDataView Transform(IDataView input) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - _host.CheckValue(input, nameof(input)); - - var mapper = MakeRowMapper(input.Schema); - return new RowToRowMapperTransform(_host, input, mapper); - } - - private IRowMapper MakeRowMapper(ISchema schema) - => new Mapper(_host, _columns, schema); - - private static void CheckInput(IExceptionContext ctx, ISchema inputSchema, string input, out int srcCol) - { - Contracts.AssertValueOrNull(ctx); - Contracts.AssertValue(inputSchema); - Contracts.AssertNonEmpty(input); - - if (!inputSchema.TryGetColumnIndex(input, out srcCol)) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input); if (!(inputSchema.GetColumnType(srcCol) is ImageType)) - throw ctx.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "image", inputSchema.GetColumnType(srcCol).ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].Input, "image", inputSchema.GetColumnType(srcCol).ToString()); } - internal sealed class Mapper : IRowMapper + private sealed class Mapper : MapperBase { - private readonly IHost _host; - private readonly ColumnInfo[] _columns; - private readonly ISchema _inputSchema; - private readonly Dictionary _colMapNewToOld; + private readonly ImageResizerTransform _parent; - public Mapper(IHostEnvironment env, ColumnInfo[] columns, ISchema inputSchema) + public Mapper(ImageResizerTransform parent, ISchema inputSchema) + :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - Contracts.AssertValue(env); - _host = env.Register(nameof(Mapper)); - _host.AssertValue(columns); - _host.AssertValue(inputSchema); - - _colMapNewToOld = new Dictionary(); - for (int i = 0; i < columns.Length; i++) - { - CheckInput(_host, inputSchema, columns[i].Input, out int srcCol); - _colMapNewToOld.Add(i, srcCol); - } - _columns = columns; - _inputSchema = inputSchema; + _parent = parent; } - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) - { - _host.Assert(input.Schema == _inputSchema); - var result = new Delegate[_columns.Length]; - var disposers = new Action[_columns.Length]; - for (int i = 0; i < _columns.Length; i++) - { - if (!activeOutput(i)) - continue; - int srcCol = _colMapNewToOld[i]; - result[i] = MakeGetter(input, i, out disposers[i]); - } - disposer = () => - { - foreach (var act in disposers) - act(); - }; - return result; - } - - public Func GetDependencies(Func activeOutput) - { - var active = new bool[_inputSchema.ColumnCount]; - foreach (var pair in _colMapNewToOld) - if (activeOutput(pair.Key)) - active[pair.Value] = true; - return col => active[col]; - } - - public RowMapperColumnInfo[] GetOutputColumns() - => _columns.Select(x => new RowMapperColumnInfo(x.Output, x.Type, null)).ToArray(); - - public void Save(ModelSaveContext ctx) => SaveContents(_host, ctx, _columns); - - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - env.CheckValue(inputSchema, nameof(inputSchema)); - var transformer = new ImageResizerTransform(env, ctx); - return transformer.MakeRowMapper(inputSchema); - } + public override RowMapperColumnInfo[] GetOutputColumns() + => _parent._columns.Select(x => new RowMapperColumnInfo(x.Output, x.Type, null)).ToArray(); - private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { - _host.AssertValue(input); - _host.Assert(0 <= iinfo && iinfo < _columns.Length); + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); var src = default(Bitmap); - var getSrc = input.GetGetter(_colMapNewToOld[iinfo]); - var ex = _columns[iinfo]; + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + var info = _parent._columns[iinfo]; disposer = () => @@ -434,7 +330,7 @@ private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) getSrc(ref src); if (src == null || src.Height <= 0 || src.Width <= 0) return; - if (src.Height == ex.Height && src.Width == ex.Width) + if (src.Height == info.Height && src.Width == info.Width) { dst = src; return; @@ -452,22 +348,22 @@ private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) float widthAspect = 0; float heightAspect = 0; - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; + widthAspect = (float)info.Width / sourceWidth; + heightAspect = (float)info.Height / sourceHeight; - if (ex.Scale == ResizingKind.IsoPad) + if (info.Scale == ResizingKind.IsoPad) { - widthAspect = (float)ex.Width / sourceWidth; - heightAspect = (float)ex.Height / sourceHeight; + widthAspect = (float)info.Width / sourceWidth; + heightAspect = (float)info.Height / sourceHeight; if (heightAspect < widthAspect) { aspect = heightAspect; - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); + destX = (int)((info.Width - (sourceWidth * aspect)) / 2); } else { aspect = widthAspect; - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + destY = (int)((info.Height - (sourceHeight * aspect)) / 2); } destWidth = (int)(sourceWidth * aspect); @@ -478,32 +374,32 @@ private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) if (heightAspect < widthAspect) { aspect = widthAspect; - switch (ex.Anchor) + switch (info.Anchor) { case Anchor.Top: destY = 0; break; case Anchor.Bottom: - destY = (int)(ex.Height - (sourceHeight * aspect)); + destY = (int)(info.Height - (sourceHeight * aspect)); break; default: - destY = (int)((ex.Height - (sourceHeight * aspect)) / 2); + destY = (int)((info.Height - (sourceHeight * aspect)) / 2); break; } } else { aspect = heightAspect; - switch (ex.Anchor) + switch (info.Anchor) { case Anchor.Left: destX = 0; break; case Anchor.Right: - destX = (int)(ex.Width - (sourceWidth * aspect)); + destX = (int)(info.Width - (sourceWidth * aspect)); break; default: - destX = (int)((ex.Width - (sourceWidth * aspect)) / 2); + destX = (int)((info.Width - (sourceWidth * aspect)) / 2); break; } } @@ -511,14 +407,14 @@ private Delegate MakeGetter(IRow input, int iinfo, out Action disposer) destWidth = (int)(sourceWidth * aspect); destHeight = (int)(sourceHeight * aspect); } - dst = new Bitmap(ex.Width, ex.Height); + dst = new Bitmap(info.Width, info.Height); var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight); var destRectangle = new Rectangle(destX, destY, destWidth, destHeight); using (var g = Graphics.FromImage(dst)) { g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel); } - _host.Assert(dst.Width == ex.Width && dst.Height == ex.Height); + Contracts.Assert(dst.Width == info.Width && dst.Height == info.Height); }; return del; From e7dc34822c9801c8fd4e9d0571fd99809e63ec6e Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Tue, 28 Aug 2018 11:07:25 -0700 Subject: [PATCH 14/17] Workout test --- src/Microsoft.ML.Core/Data/IEstimator.cs | 6 +- src/Microsoft.ML.ImageAnalytics/ImageType.cs | 2 +- .../DataPipe/TestDataPipeBase.cs | 242 +++++++++++++----- test/Microsoft.ML.Tests/ImagesTests.cs | 27 +- 4 files changed, 203 insertions(+), 74 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index a27ea0e5d9..509a67bb4f 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -132,13 +132,15 @@ public static SchemaShape Create(ISchema schema) else vecKind = Column.VectorKind.Scalar; - var itemKind = type.ItemType.RawKind; + ColumnType itemType = type.ItemType; + if (type.ItemType.IsKey) + itemType = PrimitiveType.FromKind(type.ItemType.RawKind); var isKey = type.ItemType.IsKey; var metadataNames = schema.GetMetadataTypes(iCol) .Select(kvp => kvp.Key) .ToArray(); - cols.Add(new Column(schema.GetColumnName(iCol), vecKind, PrimitiveType.FromKind(itemKind), isKey, metadataNames)); + cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadataNames)); } } return new SchemaShape(cols.ToArray()); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 852ea09d9d..fd31302808 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -35,7 +35,7 @@ public override bool Equals(ColumnType other) return false; if (Height != tmp.Height) return false; - return Width != tmp.Width; + return Width == tmp.Width; } public override bool Equals(object other) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 760cf9c910..b1ede303d8 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -5,6 +5,8 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; @@ -17,6 +19,103 @@ namespace Microsoft.ML.Runtime.RunTests { public abstract partial class TestDataPipeBase : TestDataViewBase { + /// + /// 'Workout test' for an estimator. + /// Checks the following traits: + /// - the estimator is applicable to the validFitInput, and not applicable to validTransformInput and invalidInput; + /// - the fitted transformer is applicable to validFitInput and validTransformInput, and not applicable to invalidInput; + /// - fitted transformer can be saved and re-loaded into the transformer with the same behavior. + /// - schema propagation for fitted transformer conforms to schema propagation of estimator. + /// + protected void TestEstimatorCore(IEstimator estimator, + IDataView validFitInput, IDataView validTransformInput = null, IDataView invalidInput = null) + { + Contracts.AssertValue(estimator); + Contracts.AssertValue(validFitInput); + Contracts.AssertValueOrNull(validTransformInput); + Contracts.AssertValueOrNull(invalidInput); + Action mustFail = (Action action) => + { + try + { + action(); + Assert.False(true); + } + catch (ArgumentOutOfRangeException) + { + } + catch (InvalidOperationException) + { + } + }; + + // Schema propagation tests for estimator. + var outSchemaShape = estimator.GetOutputSchema(SchemaShape.Create(validFitInput.Schema)); + if (validTransformInput != null) + { + mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(validTransformInput.Schema))); + mustFail(() => estimator.Fit(validTransformInput)); + } + + if (invalidInput != null) + { + mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(invalidInput.Schema))); + mustFail(() => estimator.Fit(invalidInput)); + } + + var transformer = estimator.Fit(validFitInput); + // Save and reload. + string modelPath = GetOutputPath(TestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + transformer.SaveTo(Env, fs); + + ITransformer loadedTransformer; + using (var fs = File.OpenRead(modelPath)) + loadedTransformer = TransformerChain.LoadFrom(Env, fs); + DeleteOutputPath(modelPath); + + // Run on train data. + Action checkOnData = (IDataView data) => + { + var schema = transformer.GetOutputSchema(data.Schema); + CheckSameSchemas(schema, loadedTransformer.GetOutputSchema(data.Schema)); + var scoredTrain = transformer.Transform(data); + var scoredTrain2 = loadedTransformer.Transform(data); + CheckSameSchemas(schema, scoredTrain.Schema); + + CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); + CheckSameValues(scoredTrain, scoredTrain2); + }; + + checkOnData(validFitInput); + + if (validTransformInput != null) + checkOnData(validTransformInput); + + if (invalidInput != null) + { + mustFail(() => transformer.GetOutputSchema(invalidInput.Schema)); + mustFail(() => transformer.Transform(invalidInput)); + mustFail(() => loadedTransformer.GetOutputSchema(invalidInput.Schema)); + mustFail(() => loadedTransformer.Transform(invalidInput)); + } + + // Schema verification between estimator and transformer. + var scoredTrainSchemaShape = SchemaShape.Create(transformer.GetOutputSchema(validFitInput.Schema)); + CheckSameSchemaShape(outSchemaShape, scoredTrainSchemaShape); + } + + private void CheckSameSchemaShape(SchemaShape first, SchemaShape second) + { + Assert.True(first.Columns.Length == second.Columns.Length); + var sortedCols1 = first.Columns.OrderBy(x => x.Name); + var sortedCols2 = second.Columns.OrderBy(x => x.Name); + + Assert.True(sortedCols1.Zip(sortedCols2, + (x, y) => x.IsCompatibleWith(y) && y.IsCompatibleWith(x)) + .All(x => x)); + } + // REVIEW: incorporate the testing for re-apply logic here? /// /// Create PipeDataLoader from the given args, save it, re-load it, verify that the data of @@ -878,41 +977,44 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ { switch (type.RawKind) { - case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U1: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U2: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U4: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); - case DataKind.U8: - return GetComparerOne(r1, r2, col, (x, y) => x == y); - case DataKind.R4: - return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - case DataKind.R8: - if (exactDoubles) - return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - else - return GetComparerOne(r1, r2, col, EqualWithEps); - case DataKind.Text: - return GetComparerOne(r1, r2, col, DvText.Identical); - case DataKind.Bool: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.TimeSpan: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.DT: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.DZ: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); - case DataKind.UG: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.I1: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U1: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.I2: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U2: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.I4: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U4: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.I8: + return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + case DataKind.U8: + return GetComparerOne(r1, r2, col, (x, y) => x == y); + case DataKind.R4: + return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + case DataKind.R8: + if (exactDoubles) + return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + else + return GetComparerOne(r1, r2, col, EqualWithEps); + case DataKind.Text: + return GetComparerOne(r1, r2, col, DvText.Identical); + case DataKind.Bool: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.TimeSpan: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.DT: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.DZ: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case DataKind.UG: + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + case (DataKind)0: + // We cannot compare custom types (including image). + return () => true; } } else @@ -921,41 +1023,41 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ Contracts.Assert(size >= 0); switch (type.ItemType.RawKind) { - case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U1: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U2: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U4: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); - case DataKind.U8: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); - case DataKind.R4: - return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - case DataKind.R8: - if (exactDoubles) - return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); - else - return GetComparerVec(r1, r2, col, size, EqualWithEps); - case DataKind.Text: - return GetComparerVec(r1, r2, col, size, DvText.Identical); - case DataKind.Bool: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.TimeSpan: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.DT: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.DZ: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); - case DataKind.UG: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.I1: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U1: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.I2: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U2: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.I4: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U4: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.I8: + return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + case DataKind.U8: + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + case DataKind.R4: + return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + case DataKind.R8: + if (exactDoubles) + return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + else + return GetComparerVec(r1, r2, col, size, EqualWithEps); + case DataKind.Text: + return GetComparerVec(r1, r2, col, size, DvText.Identical); + case DataKind.Bool: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.TimeSpan: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.DT: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.DZ: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + case DataKind.UG: + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); } } diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index fe64f75bc4..39087b9606 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.ImageAnalytics; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using System.Drawing; using System.IO; @@ -15,12 +16,32 @@ namespace Microsoft.ML.Tests { - public class ImageTests : BaseTestClass + public class ImageTests : TestDataPipeBase { public ImageTests(ITestOutputHelper output) : base(output) { } + [Fact] + public void TestEstimatorChain() + { + using (var env = new TlcEnvironment()) + { + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var invalidData = env.CreateLoader("Text{col=ImagePath:R4:0}", new MultiFileSource(dataFile)); + + var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) + .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) + .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels")) + .Append(new ImageGrayscaleEstimator(env, ("ImageReal", "ImageGray"))); + + TestEstimatorCore(pipe, data, null, invalidData); + } + Done(); + } + [Fact] public void TestEstimatorSaveLoad() { @@ -51,6 +72,7 @@ public void TestEstimatorSaveLoad() .All(x => x)); } } + Done(); } [Fact] @@ -95,6 +117,7 @@ public void TestSaveImages() } } } + Done(); } [Fact] @@ -157,6 +180,7 @@ public void TestGreyscaleTransformImages() } } } + Done(); } [Fact] @@ -231,6 +255,7 @@ public void TestBackAndForthConversion() } } } + Done(); } } } From d8c0648083bc48d8acbfdb809ffa4a526dc48988 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Wed, 29 Aug 2018 08:20:33 -0700 Subject: [PATCH 15/17] Minor changes --- .../Transforms/OneToOneTransformerBase.cs | 6 +++--- .../DataPipe/TestDataPipeBase.cs | 17 +++++++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index 0fa15bebca..b301eac793 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -10,7 +10,7 @@ namespace Microsoft.ML.Runtime.Data { - public abstract class OneToOneTransformerBase: ITransformer, ICanSaveModel + public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel { protected readonly IHost Host; protected readonly (string input, string output)[] ColumnPairs; @@ -109,7 +109,7 @@ protected RowToRowMapperTransform MakeDataTransform(IDataView input) return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema)); } - protected abstract class MapperBase: IRowMapper + protected abstract class MapperBase : IRowMapper { protected readonly IHost Host; protected readonly Dictionary ColMapNewToOld; @@ -172,6 +172,6 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac } protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer); - } } + } } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index b1ede303d8..d189e627f5 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -41,12 +41,8 @@ protected void TestEstimatorCore(IEstimator estimator, action(); Assert.False(true); } - catch (ArgumentOutOfRangeException) - { - } - catch (InvalidOperationException) - { - } + catch (ArgumentOutOfRangeException) { } + catch (InvalidOperationException) { } }; // Schema propagation tests for estimator. @@ -78,11 +74,20 @@ protected void TestEstimatorCore(IEstimator estimator, Action checkOnData = (IDataView data) => { var schema = transformer.GetOutputSchema(data.Schema); + + // Loaded transformer needs to have the same schema propagation. CheckSameSchemas(schema, loadedTransformer.GetOutputSchema(data.Schema)); + var scoredTrain = transformer.Transform(data); var scoredTrain2 = loadedTransformer.Transform(data); + + // The schema of the transformed data must match the schema provided by schema propagation. CheckSameSchemas(schema, scoredTrain.Schema); + // The schema and data of scored dataset must be identical between loaded + // and original transformer. + // This in turn means that the schema of loaded transformer matches for + // Transform and GetOutputSchema calls. CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); CheckSameValues(scoredTrain, scoredTrain2); }; From 3e845fd13556d90a6a28e765060836d7dd45f0fd Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Thu, 30 Aug 2018 09:28:26 -0700 Subject: [PATCH 16/17] PR comments --- src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index eb3048d799..0bd5bf7879 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -33,7 +33,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics /// /// Transform which takes one or many columns of and convert them into vector representation. /// - public sealed class ImagePixelExtractorTransform : OneToOneTransformerBase, ICanSaveModel + public sealed class ImagePixelExtractorTransform : OneToOneTransformerBase { public class Column : OneToOneColumn { From e660eebbb6ccdc7b5442f0deee28024b37e92992 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Thu, 30 Aug 2018 12:26:48 -0700 Subject: [PATCH 17/17] Fixed build --- .../ScenariosWithDirectInstantiation/TensorflowTests.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index ed450cd8c0..c04b35669c 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -185,7 +185,7 @@ public void TensorFlowTransformCifar() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] { @@ -193,14 +193,14 @@ public void TensorFlowTransformCifar() }, ImageFolder = imageFolder }, data); - var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() { Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } }, images); - var pixels = new ImagePixelExtractorTransform(env, new ImagePixelExtractorTransform.Arguments() + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() { Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true}