diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
index f9a1b5ef27..1471c580ba 100644
--- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
+++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
@@ -16,6 +16,7 @@
+
diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs
index 6f21a1cb01..509a67bb4f 100644
--- a/src/Microsoft.ML.Core/Data/IEstimator.cs
+++ b/src/Microsoft.ML.Core/Data/IEstimator.cs
@@ -28,20 +28,42 @@ public enum VectorKind
VariableVector
}
+ ///
+ /// The column name.
+ ///
public readonly string Name;
+
+ ///
+ /// The type of the column: scalar, fixed vector or variable vector.
+ ///
public readonly VectorKind Kind;
- public readonly DataKind ItemKind;
+
+ ///
+ /// The 'raw' type of column item: must be a primitive type or a structured type.
+ ///
+ public readonly ColumnType ItemType;
+ ///
+ /// The flag whether the column is actually a key. If yes, is representing
+ /// the underlying primitive type.
+ ///
public readonly bool IsKey;
+ ///
+ /// The metadata kinds that are present for this column.
+ ///
public readonly string[] MetadataKinds;
- public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds = null)
+ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, string[] metadataKinds = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(metadataKinds);
+ Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
+ Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");
+
+ Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
Name = name;
Kind = vecKind;
- ItemKind = itemKind;
+ ItemType = itemType;
IsKey = isKey;
MetadataKinds = metadataKinds ?? new string[0];
}
@@ -51,7 +73,7 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st
/// requirement.
///
/// Namely, it returns true iff:
- /// - The , , , fields match.
+ /// - The , , , fields match.
/// - The of is a superset of our .
///
public bool IsCompatibleWith(Column inputColumn)
@@ -61,7 +83,7 @@ public bool IsCompatibleWith(Column inputColumn)
return false;
if (Kind != inputColumn.Kind)
return false;
- if (ItemKind != inputColumn.ItemKind)
+ if (!ItemType.Equals(inputColumn.ItemType))
return false;
if (IsKey != inputColumn.IsKey)
return false;
@@ -72,7 +94,7 @@ public bool IsCompatibleWith(Column inputColumn)
public string GetTypeString()
{
- string result = ItemKind.ToString();
+ string result = ItemType.ToString();
if (IsKey)
result = $"Key<{result}>";
if (Kind == VectorKind.Vector)
@@ -110,13 +132,15 @@ public static SchemaShape Create(ISchema schema)
else
vecKind = Column.VectorKind.Scalar;
- var kind = type.ItemType.RawKind;
+ ColumnType itemType = type.ItemType;
+ if (type.ItemType.IsKey)
+ itemType = PrimitiveType.FromKind(type.ItemType.RawKind);
var isKey = type.ItemType.IsKey;
var metadataNames = schema.GetMetadataTypes(iCol)
.Select(kvp => kvp.Key)
.ToArray();
- cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
+ cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadataNames));
}
}
return new SchemaShape(cols.ToArray());
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
new file mode 100644
index 0000000000..29c081ac35
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
@@ -0,0 +1,35 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// The trivial implementation of that already has
+ /// the transformer and returns it on every call to .
+ ///
+ /// Concrete implementations still have to provide the schema propagation mechanism, since
+ /// there is no easy way to infer it from the transformer.
+ ///
+ public abstract class TrivialEstimator : IEstimator
+ where TTransformer : class, ITransformer
+ {
+ protected readonly IHost Host;
+ protected readonly TTransformer Transformer;
+
+ protected TrivialEstimator(IHost host, TTransformer transformer)
+ {
+ Contracts.AssertValue(host);
+
+ Host = host;
+ Host.CheckValue(transformer, nameof(transformer));
+ Transformer = transformer;
+ }
+
+ public TTransformer Fit(IDataView input) => Transformer;
+
+ public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
+ }
+}
diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
index 184c0226bb..3c47f84faa 100644
--- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
@@ -70,7 +70,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
var originalColumn = inputSchema.FindColumn(column.Source);
if (originalColumn != null)
{
- var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemKind, originalColumn.IsKey, originalColumn.MetadataKinds);
+ var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.MetadataKinds);
resultDic[column.Name] = col;
}
else
diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
new file mode 100644
index 0000000000..b301eac793
--- /dev/null
+++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
@@ -0,0 +1,177 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime.Model;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel
+ {
+ protected readonly IHost Host;
+ protected readonly (string input, string output)[] ColumnPairs;
+
+ protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns)
+ {
+ Contracts.AssertValue(host);
+ host.CheckValue(columns, nameof(columns));
+
+ var newNames = new HashSet();
+ foreach (var column in columns)
+ {
+ host.CheckNonEmpty(column.input, nameof(columns));
+ host.CheckNonEmpty(column.output, nameof(columns));
+
+ if (!newNames.Add(column.output))
+ throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times");
+ }
+
+ Host = host;
+ ColumnPairs = columns;
+ }
+
+ protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx)
+ {
+ Host = host;
+ // *** Binary format ***
+ // int: number of added columns
+ // for each added column
+ // int: id of output column name
+ // int: id of input column name
+
+ int n = ctx.Reader.ReadInt32();
+ ColumnPairs = new (string input, string output)[n];
+ for (int i = 0; i < n; i++)
+ {
+ string output = ctx.LoadNonEmptyString();
+ string input = ctx.LoadNonEmptyString();
+ ColumnPairs[i] = (input, output);
+ }
+ }
+
+ public abstract void Save(ModelSaveContext ctx);
+
+ protected void SaveColumns(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+
+ // *** Binary format ***
+ // int: number of added columns
+ // for each added column
+ // int: id of output column name
+ // int: id of input column name
+
+ ctx.Writer.Write(ColumnPairs.Length);
+ for (int i = 0; i < ColumnPairs.Length; i++)
+ {
+ ctx.SaveNonEmptyString(ColumnPairs[i].output);
+ ctx.SaveNonEmptyString(ColumnPairs[i].input);
+ }
+ }
+
+ private void CheckInput(ISchema inputSchema, int col, out int srcCol)
+ {
+ Contracts.AssertValue(inputSchema);
+ Contracts.Assert(0 <= col && col < ColumnPairs.Length);
+
+ if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input);
+ CheckInputColumn(inputSchema, col, srcCol);
+ }
+
+ protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
+ {
+ // By default, there are no extra checks.
+ }
+
+ protected abstract IRowMapper MakeRowMapper(ISchema schema);
+
+ public ISchema GetOutputSchema(ISchema inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+
+ // Check that all the input columns are present and correct.
+ for (int i = 0; i < ColumnPairs.Length; i++)
+ CheckInput(inputSchema, i, out int col);
+
+ return Transform(new EmptyDataView(Host, inputSchema)).Schema;
+ }
+
+ public IDataView Transform(IDataView input) => MakeDataTransform(input);
+
+ protected RowToRowMapperTransform MakeDataTransform(IDataView input)
+ {
+ Host.CheckValue(input, nameof(input));
+ return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema));
+ }
+
+ protected abstract class MapperBase : IRowMapper
+ {
+ protected readonly IHost Host;
+ protected readonly Dictionary ColMapNewToOld;
+ protected readonly ISchema InputSchema;
+ private readonly OneToOneTransformerBase _parent;
+
+ protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema)
+ {
+ Contracts.AssertValue(host);
+ Contracts.AssertValue(parent);
+ Contracts.AssertValue(inputSchema);
+
+ Host = host;
+ _parent = parent;
+
+ ColMapNewToOld = new Dictionary();
+ for (int i = 0; i < _parent.ColumnPairs.Length; i++)
+ {
+ _parent.CheckInput(inputSchema, i, out int srcCol);
+ ColMapNewToOld.Add(i, srcCol);
+ }
+ InputSchema = inputSchema;
+ }
+ public Func GetDependencies(Func activeOutput)
+ {
+ var active = new bool[InputSchema.ColumnCount];
+ foreach (var pair in ColMapNewToOld)
+ if (activeOutput(pair.Key))
+ active[pair.Value] = true;
+ return col => active[col];
+ }
+
+ public abstract RowMapperColumnInfo[] GetOutputColumns();
+
+ public void Save(ModelSaveContext ctx) => _parent.Save(ctx);
+
+ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer)
+ {
+ Contracts.Assert(input.Schema == InputSchema);
+ var result = new Delegate[_parent.ColumnPairs.Length];
+ var disposers = new Action[_parent.ColumnPairs.Length];
+ for (int i = 0; i < _parent.ColumnPairs.Length; i++)
+ {
+ if (!activeOutput(i))
+ continue;
+ int srcCol = ColMapNewToOld[i];
+ result[i] = MakeGetter(input, i, out disposers[i]);
+ }
+ if (disposers.Any(x => x != null))
+ {
+ disposer = () =>
+ {
+ foreach (var act in disposers)
+ act();
+ };
+ }
+ else
+ disposer = null;
+ return result;
+ }
+
+ protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
index 97c613485f..921309d7ef 100644
--- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
+++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
@@ -16,7 +16,7 @@ public static class ImageAnalytics
public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageLoaderTransform", input);
- var xf = new ImageLoaderTransform(h, input, input.Data);
+ var xf = ImageLoaderTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
@@ -29,7 +29,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im
public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input);
- var xf = new ImageResizerTransform(h, input, input.Data);
+ var xf = ImageResizerTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
@@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, I
public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment env, ImagePixelExtractorTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImagePixelExtractorTransform", input);
- var xf = new ImagePixelExtractorTransform(h, input, input.Data);
+ var xf = ImagePixelExtractorTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
@@ -55,7 +55,7 @@ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment
public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, ImageGrayscaleTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageGrayscaleTransform", input);
- var xf = new ImageGrayscaleTransform(h, input, input.Data);
+ var xf = ImageGrayscaleTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
index 7a267cf1b8..21076b2d86 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
@@ -2,22 +2,31 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Drawing;
-using System.Drawing.Imaging;
-using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
-using Microsoft.ML.Runtime.ImageAnalytics;
+using System;
+using System.Collections.Generic;
+using System.Drawing;
+using System.Drawing.Imaging;
+using System.Linq;
+using System.Text;
-[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform),
ImageGrayscaleTransform.UserName, "ImageGrayscaleTransform", "ImageGrayscale")]
-[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform),
+ ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadModel),
+ ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadRowMapper),
ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)]
namespace Microsoft.ML.Runtime.ImageAnalytics
@@ -28,7 +37,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics
/// Transform which takes one or many columns of type in IDataView and
/// convert them to greyscale representation of the same image.
///
- public sealed class ImageGrayscaleTransform : OneToOneTransformBase
+ public sealed class ImageGrayscaleTransform : OneToOneTransformerBase
{
public sealed class Column : OneToOneColumn
{
@@ -69,50 +78,57 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "ImageGrayscale";
- // Public constructor corresponding to SignatureDataTransform.
- public ImageGrayscaleTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type")
+ public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly();
+
+ public ImageGrayscaleTransform(IHostEnvironment env, params (string input, string output)[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
{
- Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Column));
- Metadata.Seal();
}
- private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ // Factory method for SignatureDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
- Host.AssertValue(ctx);
- // *** Binary format ***
- //
- Host.AssertNonEmpty(Infos);
- Metadata.Seal();
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+ env.CheckValue(args.Column, nameof(args.Column));
+
+ return new ImageGrayscaleTransform(env, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray())
+ .MakeDataTransform(input);
}
- public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
+ var host = env.Register(RegistrationName);
+ host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return h.Apply("Loading Model", ch => new ImageGrayscaleTransform(h, ctx, input));
+ return new ImageGrayscaleTransform(host, ctx);
}
+ private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
+ {
+ }
+
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
//
- SaveBase(ctx);
- }
-
- protected override ColumnType GetColumnTypeCore(int iinfo)
- {
- Host.Assert(0 <= iinfo & iinfo < Infos.Length);
- return Infos[iinfo].TypeSrc;
+ base.SaveColumns(ctx);
}
private static readonly ColorMatrix _grayscaleColorMatrix = new ColorMatrix(
@@ -125,47 +141,96 @@ protected override ColumnType GetColumnTypeCore(int iinfo)
new float[] {0, 0, 0, 0, 1}
});
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+ if (!(inputSchema.GetColumnType(srcCol) is ImageType))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString());
+ }
- var src = default(Bitmap);
- var getSrc = GetSrcGetter(input, iinfo);
+ private sealed class Mapper : MapperBase
+ {
+ private ImageGrayscaleTransform _parent;
- disposer =
- () =>
- {
- if (src != null)
- {
- src.Dispose();
- src = null;
- }
- };
+ public Mapper(ImageGrayscaleTransform parent, ISchema inputSchema)
+ :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _parent = parent;
+ }
- ValueGetter del =
- (ref Bitmap dst) =>
- {
- if (dst != null)
- dst.Dispose();
-
- getSrc(ref src);
- if (src == null || src.Height <= 0 || src.Width <= 0)
- return;
-
- dst = new Bitmap(src.Width, src.Height);
- ImageAttributes attributes = new ImageAttributes();
- attributes.SetColorMatrix(_grayscaleColorMatrix);
- var srcRectangle = new Rectangle(0, 0, src.Width, src.Height);
- using (var g = Graphics.FromImage(dst))
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent.ColumnPairs.Select((x, idx) => new RowMapperColumnInfo(x.output, InputSchema.GetColumnType(ColMapNewToOld[idx]), null)).ToArray();
+
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
+
+ var src = default(Bitmap);
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+
+ disposer =
+ () =>
+ {
+ if (src != null)
+ {
+ src.Dispose();
+ src = null;
+ }
+ };
+
+ ValueGetter del =
+ (ref Bitmap dst) =>
{
- g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes);
- }
- Host.Assert(dst.Width == src.Width && dst.Height == src.Height);
- };
+ if (dst != null)
+ dst.Dispose();
+
+ getSrc(ref src);
+ if (src == null || src.Height <= 0 || src.Width <= 0)
+ return;
+
+ dst = new Bitmap(src.Width, src.Height);
+ ImageAttributes attributes = new ImageAttributes();
+ attributes.SetColorMatrix(_grayscaleColorMatrix);
+ var srcRectangle = new Rectangle(0, 0, src.Width, src.Height);
+ using (var g = Graphics.FromImage(dst))
+ {
+ g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes);
+ }
+ Contracts.Assert(dst.Width == src.Width && dst.Height == src.Height);
+ };
+
+ return del;
+ }
+ }
+ }
+
+ public sealed class ImageGrayscaleEstimator : TrivialEstimator
+ {
+ public ImageGrayscaleEstimator(IHostEnvironment env, params (string input, string output)[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageGrayscaleEstimator)), new ImageGrayscaleTransform(env, columns))
+ {
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(colInfo.input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input);
+ if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, new ImageType().ToString(), col.GetTypeString());
+
+ result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind, col.ItemType, col.IsKey, col.MetadataKinds);
+ }
- return del;
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
index 488c710743..20e1476feb 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
@@ -2,31 +2,37 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Drawing;
-using System.IO;
-using System.Text;
-using Microsoft.ML.Runtime.ImageAnalytics;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
+using System;
+using System.Collections.Generic;
+using System.Drawing;
+using System.IO;
+using System.Linq;
+using System.Text;
-[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform),
ImageLoaderTransform.UserName, "ImageLoaderTransform", "ImageLoader")]
-[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform),
ImageLoaderTransform.UserName, ImageLoaderTransform.LoaderSignature)]
+[assembly: LoadableClass(typeof(ImageLoaderTransform), null, typeof(SignatureLoadModel), "", ImageLoaderTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageLoaderTransform), null, typeof(SignatureLoadRowMapper), "", ImageLoaderTransform.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.ImageAnalytics
{
- // REVIEW: Rewrite as LambdaTransform to simplify.
///
/// Transform which takes one or many columns of type and loads them as
///
- public sealed class ImageLoaderTransform : OneToOneTransformBase
+ public sealed class ImageLoaderTransform : OneToOneTransformerBase
{
public sealed class Column : OneToOneColumn
{
@@ -61,118 +67,177 @@ public sealed class Arguments : TransformInputBase
internal const string UserName = "Image Loader Transform";
public const string LoaderSignature = "ImageLoaderTransform";
- private static VersionInfo GetVersionInfo()
- {
- return new VersionInfo(
- modelSignature: "IMGLOADT",
- //verWrittenCur: 0x00010001, // Initial
- verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
- verReadableCur: 0x00010002,
- verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
- }
+ public readonly string ImageFolder;
- private readonly ImageType _type;
- private readonly string _imageFolder;
+ public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly();
- private const string RegistrationName = "ImageLoader";
+ public ImageLoaderTransform(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderTransform)), columns)
+ {
+ ImageFolder = imageFolder;
+ }
- // Public constructor corresponding to SignatureDataTransform.
- public ImageLoaderTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestIsText)
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data)
{
- Host.AssertNonEmpty(Infos);
- _imageFolder = args.ImageFolder;
- Host.Assert(Infos.Length == Utils.Size(args.Column));
- _type = new ImageType();
- Metadata.Seal();
+ return new ImageLoaderTransform(env, args.ImageFolder, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray())
+ .MakeDataTransform(data);
}
- private ImageLoaderTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, TestIsText)
+ public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
- Host.AssertValue(ctx);
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+
+ ctx.CheckAtModel(GetVersionInfo());
+ return new ImageLoaderTransform(env.Register(nameof(ImageLoaderTransform)), ctx);
+ }
+ private ImageLoaderTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
+ {
// *** Binary format ***
//
- _imageFolder = ctx.Reader.ReadString();
- _type = new ImageType();
- Metadata.Seal();
+ // int: id of image folder
+
+ ImageFolder = ctx.LoadStringOrNull();
}
- public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
- return h.Apply("Loading Model", ch => new ImageLoaderTransform(h, ctx, input));
+ if (!inputSchema.GetColumnType(srcCol).IsText)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextType.Instance.ToString(), inputSchema.GetColumnType(srcCol).ToString());
}
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
//
- ctx.Writer.Write(_imageFolder);
- SaveBase(ctx);
+ // int: id of image folder
+
+ base.SaveColumns(ctx);
+ ctx.SaveStringOrNull(ImageFolder);
}
- protected override ColumnType GetColumnTypeCore(int iinfo)
+ private static VersionInfo GetVersionInfo()
{
- Host.Check(0 <= iinfo && iinfo < Infos.Length);
- return _type;
+ return new VersionInfo(
+ modelSignature: "IMGLOADR",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verReadableCur: 0x00010002,
+ verWeCanReadBack: 0x00010002,
+ loaderSignature: LoaderSignature);
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ private sealed class Mapper : MapperBase
{
- Host.AssertValue(ch, nameof(ch));
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
- disposer = null;
-
- var getSrc = GetSrcGetter(input, iinfo);
- DvText src = default;
- ValueGetter del =
- (ref Bitmap dst) =>
- {
- if (dst != null)
- {
- dst.Dispose();
- dst = null;
- }
+ private readonly ImageLoaderTransform _parent;
+ private readonly ImageType _imageType;
- getSrc(ref src);
+ public Mapper(ImageLoaderTransform parent, ISchema inputSchema)
+ : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _imageType = new ImageType();
+ _parent = parent;
+ }
- if (src.Length > 0)
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
+
+ disposer = null;
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+ DvText src = default;
+ ValueGetter del =
+ (ref Bitmap dst) =>
{
- // Catch exceptions and pass null through. Should also log failures...
- try
+ if (dst != null)
{
- string path = src.ToString();
- if (!string.IsNullOrWhiteSpace(_imageFolder))
- path = Path.Combine(_imageFolder, path);
- dst = new Bitmap(path);
+ dst.Dispose();
+ dst = null;
}
- catch (Exception e)
+
+ getSrc(ref src);
+
+ if (src.Length > 0)
{
- // REVIEW: We catch everything since the documentation for new Bitmap(string)
- // appears to be incorrect. When the file isn't found, it throws an ArgumentException,
- // while the documentation says FileNotFoundException. Not sure what it will throw
- // in other cases, like corrupted file, etc.
-
- // REVIEW : Log failures.
- ch.Info(e.Message);
- ch.Info(e.StackTrace);
- dst = null;
+ // Catch exceptions and pass null through. Should also log failures...
+ try
+ {
+ string path = src.ToString();
+ if (!string.IsNullOrWhiteSpace(_parent.ImageFolder))
+ path = Path.Combine(_parent.ImageFolder, path);
+ dst = new Bitmap(path);
+ }
+ catch (Exception)
+ {
+ // REVIEW: We catch everything since the documentation for new Bitmap(string)
+ // appears to be incorrect. When the file isn't found, it throws an ArgumentException,
+ // while the documentation says FileNotFoundException. Not sure what it will throw
+ // in other cases, like corrupted file, etc.
+
+ // REVIEW : Log failures.
+ dst = null;
+ }
}
- }
- };
- return del;
+ };
+ return del;
+ }
+
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent.ColumnPairs.Select(x => new RowMapperColumnInfo(x.output, _imageType, null)).ToArray();
+ }
+ }
+
+ public sealed class ImageLoaderEstimator : TrivialEstimator
+ {
+ private readonly ImageType _imageType;
+
+ public ImageLoaderEstimator(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns)
+ : this(env, new ImageLoaderTransform(env, imageFolder, columns))
+ {
+ }
+
+ public ImageLoaderEstimator(IHostEnvironment env, ImageLoaderTransform transformer)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderEstimator)), transformer)
+ {
+ _imageType = new ImageType();
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var (input, output) in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
+ if (!col.ItemType.IsText || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), col.GetTypeString());
+
+ result[output] = new SchemaShape.Column(output, SchemaShape.Column.VectorKind.Scalar, _imageType, false);
+ }
+
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
index de0aa98124..0bd5bf7879 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
@@ -3,8 +3,11 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
using System.Drawing;
+using System.Linq;
using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
@@ -13,19 +16,24 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
-[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform),
ImagePixelExtractorTransform.UserName, "ImagePixelExtractorTransform", "ImagePixelExtractor")]
-[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform),
+ ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadModel),
+ ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadRowMapper),
ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)]
namespace Microsoft.ML.Runtime.ImageAnalytics
{
- // REVIEW: Rewrite as LambdaTransform to simplify.
///
/// Transform which takes one or many columns of and convert them into vector representation.
///
- public sealed class ImagePixelExtractorTransform : OneToOneTransformBase
+ public sealed class ImagePixelExtractorTransform : OneToOneTransformerBase
{
public class Column : OneToOneColumn
{
@@ -110,24 +118,28 @@ public class Arguments : TransformInputBase
/// Which color channels are extracted. Note that these values are serialized so should not be modified.
///
[Flags]
- private enum ColorBits : byte
+ public enum ColorBits : byte
{
Alpha = 0x01,
Red = 0x02,
Green = 0x04,
Blue = 0x08,
+ Rgb = Red | Green | Blue,
All = Alpha | Red | Green | Blue
}
- private sealed class ColInfoEx
+ public sealed class ColumnInfo
{
+ public readonly string Input;
+ public readonly string Output;
+
public readonly ColorBits Colors;
public readonly byte Planes;
public readonly bool Convert;
- public readonly Single Offset;
- public readonly Single Scale;
+ public readonly float Offset;
+ public readonly float Scale;
public readonly bool Interleave;
public bool Alpha { get { return (Colors & ColorBits.Alpha) != 0; } }
@@ -135,8 +147,14 @@ private sealed class ColInfoEx
public bool Green { get { return (Colors & ColorBits.Green) != 0; } }
public bool Blue { get { return (Colors & ColorBits.Blue) != 0; } }
- public ColInfoEx(Column item, Arguments args)
+ internal ColumnInfo(Column item, Arguments args)
{
+ Contracts.CheckValue(item, nameof(item));
+ Contracts.CheckValue(args, nameof(args));
+
+ Input = item.Source ?? item.Name;
+ Output = item.Name;
+
if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; }
if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; }
if (item.UseGreen ?? args.UseGreen) { Colors |= ColorBits.Green; Planes++; }
@@ -160,10 +178,57 @@ public ColInfoEx(Column item, Arguments args)
}
}
- public ColInfoEx(ModelLoadContext ctx)
+ public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false)
+ : this(input, output, colors, interleave, false, 1f, 0f)
+ {
+ }
+
+ public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false, float scale = 1f, float offset = 0f)
+ : this(input, output, colors, interleave, true, scale, offset)
+ {
+ }
+
+ private ColumnInfo(string input, string output, ColorBits colors, bool interleave, bool convert, float scale, float offset)
+ {
+ Contracts.CheckNonEmpty(input, nameof(input));
+ Contracts.CheckNonEmpty(output, nameof(output));
+
+ Input = input;
+ Output = output;
+ Colors = colors;
+
+ if ((Colors & ColorBits.Alpha) == ColorBits.Alpha) Planes++;
+ if ((Colors & ColorBits.Red) == ColorBits.Red) Planes++;
+ if ((Colors & ColorBits.Green) == ColorBits.Green) Planes++;
+ if ((Colors & ColorBits.Blue) == ColorBits.Blue) Planes++;
+ Contracts.CheckParam(Planes > 0, nameof(colors), "Need to use at least one color plane");
+
+ Interleave = interleave;
+
+ Convert = convert;
+ if (!Convert)
+ {
+ Offset = 0;
+ Scale = 1;
+ }
+ else
+ {
+ Offset = offset;
+ Scale = scale;
+ Contracts.CheckParam(FloatUtils.IsFinite(Offset), nameof(offset));
+ Contracts.CheckParam(FloatUtils.IsFiniteNonZero(Scale), nameof(scale));
+ }
+ }
+
+ internal ColumnInfo(string input, string output, ModelLoadContext ctx)
{
+ Contracts.AssertNonEmpty(input);
+ Contracts.AssertNonEmpty(output);
Contracts.AssertValue(ctx);
+ Input = input;
+ Output = output;
+
// *** Binary format ***
// byte: colors
// byte: convert
@@ -193,7 +258,6 @@ public ColInfoEx(ModelLoadContext ctx)
public void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
-
#if DEBUG
// This code is used in deserialization - assert that it matches what we computed above.
int planes = (int)Colors;
@@ -237,305 +301,368 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "ImagePixelExtractor";
- private readonly ColInfoEx[] _exes;
- private readonly VectorType[] _types;
+ private readonly ColumnInfo[] _columns;
+
+ public IReadOnlyCollection Columns => _columns.AsReadOnly();
+
+ public ImagePixelExtractorTransform(IHostEnvironment env, string inputColumn, string outputColumn,
+ ColorBits colors = ColorBits.Rgb, bool interleave = false)
+ : this(env, new ColumnInfo(inputColumn, outputColumn, colors, interleave))
+ {
+ }
+
+ public ImagePixelExtractorTransform(IHostEnvironment env, params ColumnInfo[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
+ {
+ _columns = columns.ToArray();
+ }
- // Public constructor corresponding to SignatureDataTransform.
- public ImagePixelExtractorTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input,
- t => t is ImageType ? null : "Expected Image type")
+ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
{
- Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Column));
+ Contracts.CheckValue(columns, nameof(columns));
+ return columns.Select(x => (x.Input, x.Output)).ToArray();
+ }
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
+ // SignatureDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ env.CheckValue(args.Column, nameof(args.Column));
+
+ var columns = new ColumnInfo[args.Column.Length];
+ for (int i = 0; i < columns.Length; i++)
{
var item = args.Column[i];
- _exes[i] = new ColInfoEx(item, args);
+ columns[i] = new ColumnInfo(item, args);
}
- _types = ConstructTypes(true);
+ var transformer = new ImagePixelExtractorTransform(env, columns);
+ return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema));
}
- private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
- Host.AssertValue(ctx);
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(RegistrationName);
+ host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
- // *** Binary format ***
- //
- //
- // foreach added column
- // ColInfoEx
- Host.AssertNonEmpty(Infos);
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
- _exes[i] = new ColInfoEx(ctx);
-
- _types = ConstructTypes(false);
+ return new ImagePixelExtractorTransform(host, ctx);
}
- public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
{
- Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
+ // *** Binary format ***
+ //
- return h.Apply("Loading Model",
- ch =>
- {
- // *** Binary format ***
- // int: sizeof(Float)
- //
- int cbFloat = ctx.Reader.ReadInt32();
- ch.CheckDecode(cbFloat == sizeof(Single));
- return new ImagePixelExtractorTransform(h, ctx, input);
- });
+ // for each added column
+ // ColumnInfo
+
+ _columns = new ColumnInfo[ColumnPairs.Length];
+ for (int i = 0; i < _columns.Length; i++)
+ _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx);
}
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
- // int: sizeof(Float)
//
- // foreach added column
- // ColInfoEx
- ctx.Writer.Write(sizeof(Single));
- SaveBase(ctx);
-
- Host.Assert(_exes.Length == Infos.Length);
- for (int i = 0; i < _exes.Length; i++)
- _exes[i].Save(ctx);
- }
- private VectorType[] ConstructTypes(bool user)
- {
- var types = new VectorType[Infos.Length];
- for (int i = 0; i < Infos.Length; i++)
- {
- var info = Infos[i];
- var ex = _exes[i];
- Host.Assert(ex.Planes > 0);
+ // for each added column
+ // ColumnInfo
- var type = Source.Schema.GetColumnType(info.Source) as ImageType;
- Host.Assert(type != null);
- if (type.Height <= 0 || type.Width <= 0)
- {
- // REVIEW: Could support this case by making the destination column be variable sized.
- // However, there's no mechanism to communicate the dimensions through with the pixel data.
- string name = Source.Schema.GetColumnName(info.Source);
- throw user ?
- Host.ExceptUserArg(nameof(Arguments.Column), "Column '{0}' does not have known size", name) :
- Host.Except("Column '{0}' does not have known size", name);
- }
- int height = type.Height;
- int width = type.Width;
- Host.Assert(height > 0);
- Host.Assert(width > 0);
- Host.Assert((long)height * width <= int.MaxValue / 4);
-
- if (ex.Interleave)
- types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, height, width, ex.Planes);
- else
- types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, ex.Planes, height, width);
- }
- Metadata.Seal();
- return types;
+ base.SaveColumns(ctx);
+
+ foreach (ColumnInfo info in _columns)
+ info.Save(ctx);
}
- protected override ColumnType GetColumnTypeCore(int iinfo)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Host.Assert(0 <= iinfo & iinfo < Infos.Length);
- return _types[iinfo];
+ var inputColName = _columns[col].Input;
+ var imageType = inputSchema.GetColumnType(srcCol) as ImageType;
+ if (imageType == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "image", inputSchema.GetColumnType(srcCol).ToString());
+ if (imageType.Height <= 0 || imageType.Width <= 0)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "known-size image", "unknown-size image");
+ if ((long)imageType.Height * imageType.Width > int.MaxValue / 4)
+ throw Host.Except("Image dimensions are too large");
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ private sealed class Mapper : MapperBase
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+ private readonly ImagePixelExtractorTransform _parent;
+ private readonly VectorType[] _types;
- if (_exes[iinfo].Convert)
- return GetGetterCore(input, iinfo, out disposer);
- return GetGetterCore(input, iinfo, out disposer);
- }
+ public Mapper(ImagePixelExtractorTransform parent, ISchema inputSchema)
+ : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _parent = parent;
+ _types = ConstructTypes();
+ }
- //REVIEW Rewrite it to where TValue : IConvertible
- private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer)
- {
- var type = _types[iinfo];
- Host.Assert(type.DimCount == 3);
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent._columns.Select((x, idx) => new RowMapperColumnInfo(x.Output, _types[idx], null)).ToArray();
- var ex = _exes[iinfo];
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length);
- int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0);
- int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1);
- int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2);
+ if (_parent._columns[iinfo].Convert)
+ return GetGetterCore(input, iinfo, out disposer);
+ return GetGetterCore(input, iinfo, out disposer);
+ }
- int size = type.ValueCount;
- Host.Assert(size > 0);
- Host.Assert(size == planes * height * width);
- int cpix = height * width;
+ //REVIEW Rewrite it to where TValue : IConvertible
+ private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer)
+ {
+ var type = _types[iinfo];
+ Contracts.Assert(type.DimCount == 3);
- var getSrc = GetSrcGetter(input, iinfo);
- var src = default(Bitmap);
+ var ex = _parent._columns[iinfo];
- disposer =
- () =>
- {
- if (src != null)
- {
- src.Dispose();
- src = null;
- }
- };
+ int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0);
+ int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1);
+ int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2);
- return
- (ref VBuffer dst) =>
- {
- getSrc(ref src);
- Contracts.AssertValueOrNull(src);
+ int size = type.ValueCount;
+ Contracts.Assert(size > 0);
+ Contracts.Assert(size == planes * height * width);
+ int cpix = height * width;
+
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+ var src = default(Bitmap);
- if (src == null)
+ disposer =
+ () =>
{
- dst = new VBuffer(size, 0, dst.Values, dst.Indices);
- return;
- }
+ if (src != null)
+ {
+ src.Dispose();
+ src = null;
+ }
+ };
- Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb);
- Host.Check(src.Height == height && src.Width == width);
+ return
+ (ref VBuffer dst) =>
+ {
+ getSrc(ref src);
+ Contracts.AssertValueOrNull(src);
- var values = dst.Values;
- if (Utils.Size(values) < size)
- values = new TValue[size];
+ if (src == null)
+ {
+ dst = new VBuffer(size, 0, dst.Values, dst.Indices);
+ return;
+ }
- Single offset = ex.Offset;
- Single scale = ex.Scale;
- Host.Assert(scale != 0);
+ Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb);
+ Host.Check(src.Height == height && src.Width == width);
- var vf = values as Single[];
- var vb = values as byte[];
- Host.Assert(vf != null || vb != null);
- bool needScale = offset != 0 || scale != 1;
- Host.Assert(!needScale || vf != null);
+ var values = dst.Values;
+ if (Utils.Size(values) < size)
+ values = new TValue[size];
- bool a = ex.Alpha;
- bool r = ex.Red;
- bool g = ex.Green;
- bool b = ex.Blue;
+ float offset = ex.Offset;
+ float scale = ex.Scale;
+ Contracts.Assert(scale != 0);
- int h = height;
- int w = width;
+ var vf = values as float[];
+ var vb = values as byte[];
+ Contracts.Assert(vf != null || vb != null);
+ bool needScale = offset != 0 || scale != 1;
+ Contracts.Assert(!needScale || vf != null);
- if (ex.Interleave)
- {
- int idst = 0;
- for (int y = 0; y < h; ++y)
- for (int x = 0; x < w; x++)
- {
- var pb = src.GetPixel(y, x);
- if (vb != null)
+ bool a = ex.Alpha;
+ bool r = ex.Red;
+ bool g = ex.Green;
+ bool b = ex.Blue;
+
+ int h = height;
+ int w = width;
+
+ if (ex.Interleave)
+ {
+ int idst = 0;
+ for (int y = 0; y < h; ++y)
+ for (int x = 0; x < w; x++)
{
- if (a) { vb[idst++] = (byte)0; }
- if (r) { vb[idst++] = pb.R; }
- if (g) { vb[idst++] = pb.G; }
- if (b) { vb[idst++] = pb.B; }
+ var pb = src.GetPixel(y, x);
+ if (vb != null)
+ {
+ if (a) { vb[idst++] = (byte)0; }
+ if (r) { vb[idst++] = pb.R; }
+ if (g) { vb[idst++] = pb.G; }
+ if (b) { vb[idst++] = pb.B; }
+ }
+ else if (!needScale)
+ {
+ if (a) { vf[idst++] = 0.0f; }
+ if (r) { vf[idst++] = pb.R; }
+ if (g) { vf[idst++] = pb.G; }
+ if (b) { vf[idst++] = pb.B; }
+ }
+ else
+ {
+ if (a) { vf[idst++] = 0.0f; }
+ if (r) { vf[idst++] = (pb.R - offset) * scale; }
+ if (g) { vf[idst++] = (pb.B - offset) * scale; }
+ if (b) { vf[idst++] = (pb.G - offset) * scale; }
+ }
}
- else if (!needScale)
+ Contracts.Assert(idst == size);
+ }
+ else
+ {
+ int idstMin = 0;
+ if (ex.Alpha)
+ {
+ // The image only has rgb but we need to supply alpha as well, so fake it up,
+ // assuming that it is 0xFF.
+ if (vf != null)
{
- if (a) { vf[idst++] = 0.0f; }
- if (r) { vf[idst++] = pb.R; }
- if (g) { vf[idst++] = pb.G; }
- if (b) { vf[idst++] = pb.B; }
+ Single v = (0xFF - offset) * scale;
+ for (int i = 0; i < cpix; i++)
+ vf[i] = v;
}
else
{
- if (a) { vf[idst++] = 0.0f; }
- if (r) { vf[idst++] = (pb.R - offset) * scale; }
- if (g) { vf[idst++] = (pb.B - offset) * scale; }
- if (b) { vf[idst++] = (pb.G - offset) * scale; }
+ for (int i = 0; i < cpix; i++)
+ vb[i] = 0xFF;
}
- }
- Host.Assert(idst == size);
- }
- else
- {
- int idstMin = 0;
- if (ex.Alpha)
- {
- // The image only has rgb but we need to supply alpha as well, so fake it up,
- // assuming that it is 0xFF.
- if (vf != null)
- {
- Single v = (0xFF - offset) * scale;
- for (int i = 0; i < cpix; i++)
- vf[i] = v;
- }
- else
- {
- for (int i = 0; i < cpix; i++)
- vb[i] = 0xFF;
- }
- idstMin = cpix;
+ idstMin = cpix;
- // We've preprocessed alpha, avoid it in the
- // scan operation below.
- a = false;
- }
-
- for (int y = 0; y < h; ++y)
- {
- int idstBase = idstMin + y * w;
+ // We've preprocessed alpha, avoid it in the
+ // scan operation below.
+ a = false;
+ }
- // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB.
- if (vb != null)
+ for (int y = 0; y < h; ++y)
{
- for (int x = 0; x < w; x++, idstBase++)
+ int idstBase = idstMin + y * w;
+
+ // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB.
+ if (vb != null)
{
- var pb = src.GetPixel(x, y);
- int idst = idstBase;
- if (a) { vb[idst] = pb.A; idst += cpix; }
- if (r) { vb[idst] = pb.R; idst += cpix; }
- if (g) { vb[idst] = pb.G; idst += cpix; }
- if (b) { vb[idst] = pb.B; idst += cpix; }
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vb[idst] = pb.A; idst += cpix; }
+ if (r) { vb[idst] = pb.R; idst += cpix; }
+ if (g) { vb[idst] = pb.G; idst += cpix; }
+ if (b) { vb[idst] = pb.B; idst += cpix; }
+ }
}
- }
- else if (!needScale)
- {
- for (int x = 0; x < w; x++, idstBase++)
+ else if (!needScale)
{
- var pb = src.GetPixel(x, y);
- int idst = idstBase;
- if (a) { vf[idst] = pb.A; idst += cpix; }
- if (r) { vf[idst] = pb.R; idst += cpix; }
- if (g) { vf[idst] = pb.G; idst += cpix; }
- if (b) { vf[idst] = pb.B; idst += cpix; }
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vf[idst] = pb.A; idst += cpix; }
+ if (r) { vf[idst] = pb.R; idst += cpix; }
+ if (g) { vf[idst] = pb.G; idst += cpix; }
+ if (b) { vf[idst] = pb.B; idst += cpix; }
+ }
}
- }
- else
- {
- for (int x = 0; x < w; x++, idstBase++)
+ else
{
- var pb = src.GetPixel(x, y);
- int idst = idstBase;
- if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; }
- if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; }
- if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; }
- if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; }
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; }
+ if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; }
+ if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; }
+ if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; }
+ }
}
}
}
- }
- dst = new VBuffer(size, values, dst.Indices);
- };
+ dst = new VBuffer(size, values, dst.Indices);
+ };
+ }
+
+ private VectorType[] ConstructTypes()
+ {
+ var types = new VectorType[_parent._columns.Length];
+ for (int i = 0; i < _parent._columns.Length; i++)
+ {
+ var column = _parent._columns[i];
+ Contracts.Assert(column.Planes > 0);
+
+ var type = InputSchema.GetColumnType(ColMapNewToOld[i]) as ImageType;
+ Contracts.Assert(type != null);
+
+ int height = type.Height;
+ int width = type.Width;
+ Contracts.Assert(height > 0);
+ Contracts.Assert(width > 0);
+ Contracts.Assert((long)height * width <= int.MaxValue / 4);
+
+ if (column.Interleave)
+ types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, height, width, column.Planes);
+ else
+ types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, column.Planes, height, width);
+ }
+ return types;
+ }
+ }
+ }
+
+ public sealed class ImagePixelExtractorEstimator : TrivialEstimator
+ {
+ public ImagePixelExtractorEstimator(IHostEnvironment env, string inputColumn, string outputColumn,
+ ImagePixelExtractorTransform.ColorBits colors = ImagePixelExtractorTransform.ColorBits.Rgb, bool interleave = false)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, inputColumn, outputColumn, colors, interleave))
+ {
+ }
+
+ public ImagePixelExtractorEstimator(IHostEnvironment env, params ImagePixelExtractorTransform.ColumnInfo[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, columns))
+ {
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(colInfo.Input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
+ if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString());
+
+ var itemType = colInfo.Convert ? NumberType.R4 : NumberType.U1;
+ result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, itemType, false);
+ }
+
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
index dd1abc9181..ac11c7fa8d 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
@@ -3,8 +3,11 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
using System.Drawing;
+using System.Linq;
using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
@@ -20,13 +23,19 @@
[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform),
ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
+[assembly: LoadableClass(typeof(ImageResizerTransform), null, typeof(SignatureLoadModel),
+ ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageResizerTransform), null, typeof(SignatureLoadRowMapper),
+ ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.ImageAnalytics
{
// REVIEW: Rewrite as LambdaTransform to simplify.
///
/// Transform which takes one or many columns of and resize them to provided height and width.
///
- public sealed class ImageResizerTransform : OneToOneTransformBase
+ public sealed class ImageResizerTransform : OneToOneTransformerBase
{
public enum ResizingKind : byte
{
@@ -98,23 +107,30 @@ public class Arguments : TransformInputBase
}
///
- /// Extra information for each column (in addition to ColumnInfo).
+ /// Information for each column pair.
///
- private sealed class ColInfoEx
+ public sealed class ColumnInfo
{
+ public readonly string Input;
+ public readonly string Output;
+
public readonly int Width;
public readonly int Height;
public readonly ResizingKind Scale;
public readonly Anchor Anchor;
public readonly ColumnType Type;
- public ColInfoEx(int width, int height, ResizingKind scale, Anchor anchor)
+ public ColumnInfo(string input, string output, int width, int height, ResizingKind scale, Anchor anchor)
{
+ Contracts.CheckNonEmpty(input, nameof(input));
+ Contracts.CheckNonEmpty(output, nameof(output));
Contracts.CheckUserArg(width > 0, nameof(Column.ImageWidth));
Contracts.CheckUserArg(height > 0, nameof(Column.ImageHeight));
Contracts.CheckUserArg(Enum.IsDefined(typeof(ResizingKind), scale), nameof(Column.Resizing));
Contracts.CheckUserArg(Enum.IsDefined(typeof(Anchor), anchor), nameof(Column.CropAnchor));
+ Input = input;
+ Output = output;
Width = width;
Height = height;
Scale = scale;
@@ -133,53 +149,87 @@ private static VersionInfo GetVersionInfo()
return new VersionInfo(
modelSignature: "IMGSCALF",
//verWrittenCur: 0x00010001, // Initial
- verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
- verReadableCur: 0x00010002,
- verWeCanReadBack: 0x00010002,
+ //verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verWrittenCur: 0x00010003, // No more sizeof(float)
+ verReadableCur: 0x00010003,
+ verWeCanReadBack: 0x00010003,
loaderSignature: LoaderSignature);
}
private const string RegistrationName = "ImageScaler";
- // This is parallel to Infos.
- private readonly ColInfoEx[] _exes;
+ private readonly ColumnInfo[] _columns;
+
+ public IReadOnlyCollection Columns => _columns.AsReadOnly();
+
+ public ImageResizerTransform(IHostEnvironment env, string inputColumn, string outputColumn,
+ int imageWidth, int imageHeight, ResizingKind resizing = ResizingKind.IsoCrop, Anchor cropAnchor = Anchor.Center)
+ : this(env, new ColumnInfo(inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor))
+ {
+ }
+
+ public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
+ {
+ _columns = columns.ToArray();
+ }
+
+ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
+ {
+ Contracts.CheckValue(columns, nameof(columns));
+ return columns.Select(x => (x.Input, x.Output)).ToArray();
+ }
- // Public constructor corresponding to SignatureDataTransform.
- public ImageResizerTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type")
+ // Factory method for SignatureDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
- Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Column));
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ env.CheckValue(args.Column, nameof(args.Column));
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
+ var cols = new ColumnInfo[args.Column.Length];
+ for (int i = 0; i < cols.Length; i++)
{
var item = args.Column[i];
- _exes[i] = new ColInfoEx(
+ cols[i] = new ColumnInfo(
+ item.Source ?? item.Name,
+ item.Name,
item.ImageWidth ?? args.ImageWidth,
item.ImageHeight ?? args.ImageHeight,
item.Resizing ?? args.Resizing,
item.CropAnchor ?? args.CropAnchor);
}
- Metadata.Seal();
+
+ return new ImageResizerTransform(env, cols).MakeDataTransform(input);
}
- private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
- Host.AssertValue(ctx);
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(RegistrationName);
+ host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return new ImageResizerTransform(host, ctx);
+ }
+
+ private ImageResizerTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
+ {
// *** Binary format ***
- //
//
+
// for each added column
// int: width
// int: height
// byte: scaling kind
- Host.AssertNonEmpty(Infos);
+ // byte: anchor
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
+ _columns = new ColumnInfo[ColumnPairs.Length];
+ for (int i = 0; i < ColumnPairs.Length; i++)
{
int width = ctx.Reader.ReadInt32();
Host.CheckDecode(width > 0);
@@ -189,182 +239,224 @@ private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input)
Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale));
var anchor = (Anchor)ctx.Reader.ReadByte();
Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor));
- _exes[i] = new ColInfoEx(width, height, scale, anchor);
+ _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, width, height, scale, anchor);
}
- Metadata.Seal();
}
- public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
- {
- Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
- return h.Apply("Loading Model",
- ch =>
- {
- // *** Binary format ***
- // int: sizeof(Float)
- //
- int cbFloat = ctx.Reader.ReadInt32();
- ch.CheckDecode(cbFloat == sizeof(Single));
- return new ImageResizerTransform(h, ctx, input);
- });
- }
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
- // int: sizeof(Float)
//
+
// for each added column
// int: width
// int: height
// byte: scaling kind
- ctx.Writer.Write(sizeof(Single));
- SaveBase(ctx);
+ // byte: anchor
+
+ base.SaveColumns(ctx);
- Host.Assert(_exes.Length == Infos.Length);
- for (int i = 0; i < _exes.Length; i++)
+ foreach (var col in _columns)
{
- var ex = _exes[i];
- ctx.Writer.Write(ex.Width);
- ctx.Writer.Write(ex.Height);
- Host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale);
- ctx.Writer.Write((byte)ex.Scale);
- Host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor);
- ctx.Writer.Write((byte)ex.Anchor);
+ ctx.Writer.Write(col.Width);
+ ctx.Writer.Write(col.Height);
+ Contracts.Assert((ResizingKind)(byte)col.Scale == col.Scale);
+ ctx.Writer.Write((byte)col.Scale);
+ Contracts.Assert((Anchor)(byte)col.Anchor == col.Anchor);
+ ctx.Writer.Write((byte)col.Anchor);
}
}
- protected override ColumnType GetColumnTypeCore(int iinfo)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Host.Check(0 <= iinfo && iinfo < Infos.Length);
- return _exes[iinfo].Type;
+ if (!(inputSchema.GetColumnType(srcCol) is ImageType))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].Input, "image", inputSchema.GetColumnType(srcCol).ToString());
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ private sealed class Mapper : MapperBase
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
-
- var src = default(Bitmap);
- var getSrc = GetSrcGetter(input, iinfo);
- var ex = _exes[iinfo];
-
- disposer =
- () =>
- {
- if (src != null)
- {
- src.Dispose();
- src = null;
- }
- };
-
- ValueGetter del =
- (ref Bitmap dst) =>
- {
- if (dst != null)
- dst.Dispose();
-
- getSrc(ref src);
- if (src == null || src.Height <= 0 || src.Width <= 0)
- return;
- if (src.Height == ex.Height && src.Width == ex.Width)
- {
- dst = src;
- return;
- }
-
- int sourceWidth = src.Width;
- int sourceHeight = src.Height;
- int sourceX = 0;
- int sourceY = 0;
- int destX = 0;
- int destY = 0;
- int destWidth = 0;
- int destHeight = 0;
- float aspect = 0;
- float widthAspect = 0;
- float heightAspect = 0;
-
- widthAspect = (float)ex.Width / sourceWidth;
- heightAspect = (float)ex.Height / sourceHeight;
-
- if (ex.Scale == ResizingKind.IsoPad)
+ private readonly ImageResizerTransform _parent;
+
+ public Mapper(ImageResizerTransform parent, ISchema inputSchema)
+ :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _parent = parent;
+ }
+
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent._columns.Select(x => new RowMapperColumnInfo(x.Output, x.Type, null)).ToArray();
+
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length);
+
+ var src = default(Bitmap);
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+ var info = _parent._columns[iinfo];
+
+ disposer =
+ () =>
{
- widthAspect = (float)ex.Width / sourceWidth;
- heightAspect = (float)ex.Height / sourceHeight;
- if (heightAspect < widthAspect)
+ if (src != null)
{
- aspect = heightAspect;
- destX = (int)((ex.Width - (sourceWidth * aspect)) / 2);
+ src.Dispose();
+ src = null;
}
- else
+ };
+
+ ValueGetter del =
+ (ref Bitmap dst) =>
+ {
+ if (dst != null)
+ dst.Dispose();
+
+ getSrc(ref src);
+ if (src == null || src.Height <= 0 || src.Width <= 0)
+ return;
+ if (src.Height == info.Height && src.Width == info.Width)
{
- aspect = widthAspect;
- destY = (int)((ex.Height - (sourceHeight * aspect)) / 2);
+ dst = src;
+ return;
}
- destWidth = (int)(sourceWidth * aspect);
- destHeight = (int)(sourceHeight * aspect);
- }
- else
- {
- if (heightAspect < widthAspect)
+ int sourceWidth = src.Width;
+ int sourceHeight = src.Height;
+ int sourceX = 0;
+ int sourceY = 0;
+ int destX = 0;
+ int destY = 0;
+ int destWidth = 0;
+ int destHeight = 0;
+ float aspect = 0;
+ float widthAspect = 0;
+ float heightAspect = 0;
+
+ widthAspect = (float)info.Width / sourceWidth;
+ heightAspect = (float)info.Height / sourceHeight;
+
+ if (info.Scale == ResizingKind.IsoPad)
{
- aspect = widthAspect;
- switch (ex.Anchor)
+ widthAspect = (float)info.Width / sourceWidth;
+ heightAspect = (float)info.Height / sourceHeight;
+ if (heightAspect < widthAspect)
{
- case Anchor.Top:
- destY = 0;
- break;
- case Anchor.Bottom:
- destY = (int)(ex.Height - (sourceHeight * aspect));
- break;
- default:
- destY = (int)((ex.Height - (sourceHeight * aspect)) / 2);
- break;
+ aspect = heightAspect;
+ destX = (int)((info.Width - (sourceWidth * aspect)) / 2);
}
+ else
+ {
+ aspect = widthAspect;
+ destY = (int)((info.Height - (sourceHeight * aspect)) / 2);
+ }
+
+ destWidth = (int)(sourceWidth * aspect);
+ destHeight = (int)(sourceHeight * aspect);
}
else
{
- aspect = heightAspect;
- switch (ex.Anchor)
+ if (heightAspect < widthAspect)
+ {
+ aspect = widthAspect;
+ switch (info.Anchor)
+ {
+ case Anchor.Top:
+ destY = 0;
+ break;
+ case Anchor.Bottom:
+ destY = (int)(info.Height - (sourceHeight * aspect));
+ break;
+ default:
+ destY = (int)((info.Height - (sourceHeight * aspect)) / 2);
+ break;
+ }
+ }
+ else
{
- case Anchor.Left:
- destX = 0;
- break;
- case Anchor.Right:
- destX = (int)(ex.Width - (sourceWidth * aspect));
- break;
- default:
- destX = (int)((ex.Width - (sourceWidth * aspect)) / 2);
- break;
+ aspect = heightAspect;
+ switch (info.Anchor)
+ {
+ case Anchor.Left:
+ destX = 0;
+ break;
+ case Anchor.Right:
+ destX = (int)(info.Width - (sourceWidth * aspect));
+ break;
+ default:
+ destX = (int)((info.Width - (sourceWidth * aspect)) / 2);
+ break;
+ }
}
+
+ destWidth = (int)(sourceWidth * aspect);
+ destHeight = (int)(sourceHeight * aspect);
+ }
+ dst = new Bitmap(info.Width, info.Height);
+ var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight);
+ var destRectangle = new Rectangle(destX, destY, destWidth, destHeight);
+ using (var g = Graphics.FromImage(dst))
+ {
+ g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel);
}
+ Contracts.Assert(dst.Width == info.Width && dst.Height == info.Height);
+ };
- destWidth = (int)(sourceWidth * aspect);
- destHeight = (int)(sourceHeight * aspect);
- }
- dst = new Bitmap(ex.Width, ex.Height);
- var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight);
- var destRectangle = new Rectangle(destX, destY, destWidth, destHeight);
- using (var g = Graphics.FromImage(dst))
- {
- g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel);
- }
- Host.Assert(dst.Width == ex.Width && dst.Height == ex.Height);
- };
+ return del;
+ }
+ }
+ }
+
+ public sealed class ImageResizerEstimator : TrivialEstimator
+ {
+ public ImageResizerEstimator(IHostEnvironment env, string inputColumn, string outputColumn,
+ int imageWidth, int imageHeight, ImageResizerTransform.ResizingKind resizing = ImageResizerTransform.ResizingKind.IsoCrop, ImageResizerTransform.Anchor cropAnchor = ImageResizerTransform.Anchor.Center)
+ : this(env, new ImageResizerTransform(env, inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor))
+ {
+ }
+
+ public ImageResizerEstimator(IHostEnvironment env, params ImageResizerTransform.ColumnInfo[] columns)
+ : this(env, new ImageResizerTransform(env, columns))
+ {
+ }
+
+ public ImageResizerEstimator(IHostEnvironment env, ImageResizerTransform transformer)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageResizerEstimator)), transformer)
+ {
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(colInfo.Input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
+ if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString());
+
+ result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Scalar, colInfo.Type, false);
+ }
- return del;
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs
index 852ea09d9d..fd31302808 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs
@@ -35,7 +35,7 @@ public override bool Equals(ColumnType other)
return false;
if (Height != tmp.Height)
return false;
- return Width != tmp.Width;
+ return Width == tmp.Width;
}
public override bool Equals(object other)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
index 2da2075728..52d2d3aef0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
@@ -1407,8 +1407,8 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args,
_positiveInstanceWeight = _args.PositiveInstanceWeight;
OutputColumns = new[]
{
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false),
- new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false)
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}
@@ -1426,7 +1426,8 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();
- if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8 && labelCol.ItemKind != DataKind.BL)
+
+ if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !labelCol.ItemType.IsBool)
error();
}
@@ -1434,17 +1435,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
- return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
- return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false);
+ return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
- return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false);
+ return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}
protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
index f775be92bd..c49593332b 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
@@ -57,8 +57,8 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args,
_args = args;
OutputColumns = new[]
{
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false),
- new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true)
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true)
};
}
@@ -76,7 +76,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();
- if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8)
+ if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8)
error();
}
@@ -84,17 +84,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
- return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
- return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true);
+ return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
}
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
- return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false);
+ return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}
///
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
index 163da10e7b..0b620959c3 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
@@ -61,7 +61,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur
_args = args;
OutputColumns = new[]
{
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false)
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false)
};
}
@@ -73,17 +73,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
- return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
- return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
- return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false);
+ return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}
protected override LinearRegressionPredictor CreatePredictor(VBuffer[] weights, Float[] bias)
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
index 760cf9c910..d189e627f5 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
@@ -5,6 +5,8 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
@@ -17,6 +19,108 @@ namespace Microsoft.ML.Runtime.RunTests
{
public abstract partial class TestDataPipeBase : TestDataViewBase
{
+ ///
+ /// 'Workout test' for an estimator.
+ /// Checks the following traits:
+ /// - the estimator is applicable to the validFitInput, and not applicable to validTransformInput and invalidInput;
+ /// - the fitted transformer is applicable to validFitInput and validTransformInput, and not applicable to invalidInput;
+ /// - fitted transformer can be saved and re-loaded into the transformer with the same behavior.
+ /// - schema propagation for fitted transformer conforms to schema propagation of estimator.
+ ///
+ protected void TestEstimatorCore(IEstimator estimator,
+ IDataView validFitInput, IDataView validTransformInput = null, IDataView invalidInput = null)
+ {
+ Contracts.AssertValue(estimator);
+ Contracts.AssertValue(validFitInput);
+ Contracts.AssertValueOrNull(validTransformInput);
+ Contracts.AssertValueOrNull(invalidInput);
+ Action mustFail = (Action action) =>
+ {
+ try
+ {
+ action();
+ Assert.False(true);
+ }
+ catch (ArgumentOutOfRangeException) { }
+ catch (InvalidOperationException) { }
+ };
+
+ // Schema propagation tests for estimator.
+ var outSchemaShape = estimator.GetOutputSchema(SchemaShape.Create(validFitInput.Schema));
+ if (validTransformInput != null)
+ {
+ mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(validTransformInput.Schema)));
+ mustFail(() => estimator.Fit(validTransformInput));
+ }
+
+ if (invalidInput != null)
+ {
+ mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(invalidInput.Schema)));
+ mustFail(() => estimator.Fit(invalidInput));
+ }
+
+ var transformer = estimator.Fit(validFitInput);
+ // Save and reload.
+ string modelPath = GetOutputPath(TestName + "-model.zip");
+ using (var fs = File.Create(modelPath))
+ transformer.SaveTo(Env, fs);
+
+ ITransformer loadedTransformer;
+ using (var fs = File.OpenRead(modelPath))
+ loadedTransformer = TransformerChain.LoadFrom(Env, fs);
+ DeleteOutputPath(modelPath);
+
+ // Run on train data.
+ Action checkOnData = (IDataView data) =>
+ {
+ var schema = transformer.GetOutputSchema(data.Schema);
+
+ // Loaded transformer needs to have the same schema propagation.
+ CheckSameSchemas(schema, loadedTransformer.GetOutputSchema(data.Schema));
+
+ var scoredTrain = transformer.Transform(data);
+ var scoredTrain2 = loadedTransformer.Transform(data);
+
+ // The schema of the transformed data must match the schema provided by schema propagation.
+ CheckSameSchemas(schema, scoredTrain.Schema);
+
+ // The schema and data of scored dataset must be identical between loaded
+ // and original transformer.
+ // This in turn means that the schema of loaded transformer matches for
+ // Transform and GetOutputSchema calls.
+ CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema);
+ CheckSameValues(scoredTrain, scoredTrain2);
+ };
+
+ checkOnData(validFitInput);
+
+ if (validTransformInput != null)
+ checkOnData(validTransformInput);
+
+ if (invalidInput != null)
+ {
+ mustFail(() => transformer.GetOutputSchema(invalidInput.Schema));
+ mustFail(() => transformer.Transform(invalidInput));
+ mustFail(() => loadedTransformer.GetOutputSchema(invalidInput.Schema));
+ mustFail(() => loadedTransformer.Transform(invalidInput));
+ }
+
+ // Schema verification between estimator and transformer.
+ var scoredTrainSchemaShape = SchemaShape.Create(transformer.GetOutputSchema(validFitInput.Schema));
+ CheckSameSchemaShape(outSchemaShape, scoredTrainSchemaShape);
+ }
+
+ private void CheckSameSchemaShape(SchemaShape first, SchemaShape second)
+ {
+ Assert.True(first.Columns.Length == second.Columns.Length);
+ var sortedCols1 = first.Columns.OrderBy(x => x.Name);
+ var sortedCols2 = second.Columns.OrderBy(x => x.Name);
+
+ Assert.True(sortedCols1.Zip(sortedCols2,
+ (x, y) => x.IsCompatibleWith(y) && y.IsCompatibleWith(x))
+ .All(x => x));
+ }
+
// REVIEW: incorporate the testing for re-apply logic here?
///
/// Create PipeDataLoader from the given args, save it, re-load it, verify that the data of
@@ -878,41 +982,44 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ
{
switch (type.RawKind)
{
- case DataKind.I1:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U1:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.I2:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U2:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.I4:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U4:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.I8:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U8:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.R4:
- return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
- case DataKind.R8:
- if (exactDoubles)
- return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
- else
- return GetComparerOne(r1, r2, col, EqualWithEps);
- case DataKind.Text:
- return GetComparerOne(r1, r2, col, DvText.Identical);
- case DataKind.Bool:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.TimeSpan:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.DT:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.DZ:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.UG:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ case DataKind.I1:
+ return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U1:
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ case DataKind.I2:
+ return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U2:
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ case DataKind.I4:
+ return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U4:
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ case DataKind.I8:
+ return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U8:
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ case DataKind.R4:
+ return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ case DataKind.R8:
+ if (exactDoubles)
+ return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ else
+ return GetComparerOne(r1, r2, col, EqualWithEps);
+ case DataKind.Text:
+ return GetComparerOne(r1, r2, col, DvText.Identical);
+ case DataKind.Bool:
+ return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ case DataKind.TimeSpan:
+ return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ case DataKind.DT:
+ return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ case DataKind.DZ:
+ return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ case DataKind.UG:
+ return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ case (DataKind)0:
+ // We cannot compare custom types (including image).
+ return () => true;
}
}
else
@@ -921,41 +1028,41 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ
Contracts.Assert(size >= 0);
switch (type.ItemType.RawKind)
{
- case DataKind.I1:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U1:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
- case DataKind.I2:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U2:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
- case DataKind.I4:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U4:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
- case DataKind.I8:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U8:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
- case DataKind.R4:
- return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
- case DataKind.R8:
- if (exactDoubles)
- return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
- else
- return GetComparerVec(r1, r2, col, size, EqualWithEps);
- case DataKind.Text:
- return GetComparerVec(r1, r2, col, size, DvText.Identical);
- case DataKind.Bool:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
- case DataKind.TimeSpan:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
- case DataKind.DT:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
- case DataKind.DZ:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
- case DataKind.UG:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
+ case DataKind.I1:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U1:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ case DataKind.I2:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U2:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ case DataKind.I4:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U4:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ case DataKind.I8:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U8:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ case DataKind.R4:
+ return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ case DataKind.R8:
+ if (exactDoubles)
+ return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ else
+ return GetComparerVec(r1, r2, col, size, EqualWithEps);
+ case DataKind.Text:
+ return GetComparerVec(r1, r2, col, size, DvText.Identical);
+ case DataKind.Bool:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
+ case DataKind.TimeSpan:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
+ case DataKind.DT:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
+ case DataKind.DZ:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
+ case DataKind.UG:
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
}
}
diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs
index a12032400a..39087b9606 100644
--- a/test/Microsoft.ML.Tests/ImagesTests.cs
+++ b/test/Microsoft.ML.Tests/ImagesTests.cs
@@ -5,20 +5,76 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.ImageAnalytics;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.TestFramework;
using System.Drawing;
using System.IO;
+using System.Linq;
using Xunit;
using Xunit.Abstractions;
namespace Microsoft.ML.Tests
{
- public class ImageTests : BaseTestClass
+ public class ImageTests : TestDataPipeBase
{
public ImageTests(ITestOutputHelper output) : base(output)
{
}
+ [Fact]
+ public void TestEstimatorChain()
+ {
+ using (var env = new TlcEnvironment())
+ {
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var invalidData = env.CreateLoader("Text{col=ImagePath:R4:0}", new MultiFileSource(dataFile));
+
+ var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
+ .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100))
+ .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels"))
+ .Append(new ImageGrayscaleEstimator(env, ("ImageReal", "ImageGray")));
+
+ TestEstimatorCore(pipe, data, null, invalidData);
+ }
+ Done();
+ }
+
+ [Fact]
+ public void TestEstimatorSaveLoad()
+ {
+ using (var env = new TlcEnvironment())
+ {
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+
+ var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
+ .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100))
+ .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels"))
+ .Append(new ImageGrayscaleEstimator(env, ("ImageReal", "ImageGray")));
+
+ pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema));
+ var model = pipe.Fit(data);
+
+ using (var file = env.CreateTempFile())
+ {
+ using (var fs = file.CreateWriteStream())
+ model.SaveTo(env, fs);
+ var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream());
+
+ var newCols = ((ImageLoaderTransform)model2.First()).Columns;
+ var oldCols = ((ImageLoaderTransform)model.First()).Columns;
+ Assert.True(newCols
+ .Zip(oldCols, (x, y) => x == y)
+ .All(x => x));
+ }
+ }
+ Done();
+ }
+
[Fact]
public void TestSaveImages()
{
@@ -27,7 +83,7 @@ public void TestSaveImages()
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
- var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments()
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
{
@@ -36,7 +92,7 @@ public void TestSaveImages()
ImageFolder = imageFolder
}, data);
- var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments()
+ IDataView cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
{
Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =100, ImageWidth = 100, Resizing = ImageResizerTransform.ResizingKind.IsoPad}
@@ -61,6 +117,7 @@ public void TestSaveImages()
}
}
}
+ Done();
}
[Fact]
@@ -73,7 +130,7 @@ public void TestGreyscaleTransformImages()
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
- var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments()
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
{
@@ -81,20 +138,29 @@ public void TestGreyscaleTransformImages()
},
ImageFolder = imageFolder
}, data);
- var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments()
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
{
Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
}, images);
- var grey = new ImageGrayscaleTransform(env, new ImageGrayscaleTransform.Arguments()
+ IDataView grey = ImageGrayscaleTransform.Create(env, new ImageGrayscaleTransform.Arguments()
{
Column = new ImageGrayscaleTransform.Column[1]{
new ImageGrayscaleTransform.Column() { Name= "ImageGrey", Source = "ImageCropped"}
}
}, cropped);
+ var fname = nameof(TestGreyscaleTransformImages) + "_model.zip";
+
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(grey));
+
+ grey = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
+
grey.Schema.TryGetColumnIndex("ImageGrey", out int greyColumn);
using (var cursor = grey.GetRowCursor((x) => true))
{
@@ -114,6 +180,7 @@ public void TestGreyscaleTransformImages()
}
}
}
+ Done();
}
[Fact]
@@ -126,7 +193,7 @@ public void TestBackAndForthConversion()
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
- var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments()
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
{
@@ -134,27 +201,37 @@ public void TestBackAndForthConversion()
},
ImageFolder = imageFolder
}, data);
- var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments()
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
{
Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
}, images);
- var pixels = new ImagePixelExtractorTransform(env, new ImagePixelExtractorTransform.Arguments()
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
{
Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true}
}
}, cropped);
- var backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
{
Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true}
}
}, pixels);
+ var fname = nameof(TestBackAndForthConversion) + "_model.zip";
+
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
+
+
backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
using (var cursor = backToBitmaps.GetRowCursor((x) => true))
@@ -178,6 +255,7 @@ public void TestBackAndForthConversion()
}
}
}
+ Done();
}
}
}
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index ed450cd8c0..c04b35669c 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -185,7 +185,7 @@ public void TensorFlowTransformCifar()
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
- var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments()
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
{
@@ -193,14 +193,14 @@ public void TensorFlowTransformCifar()
},
ImageFolder = imageFolder
}, data);
- var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments()
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
{
Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
}, images);
- var pixels = new ImagePixelExtractorTransform(env, new ImagePixelExtractorTransform.Arguments()
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
{
Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true}