-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Image transforms become Estimators #753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
70d7611
Converted ImageLoaderTransform to be an Estimator/Transformer pair
1eeda48
Temp commit
5e32910
Another temp checkin
f0469e1
WIP 3
d2073eb
Added image resizers
f05ec52
Pixel extractor
a03994b
Minor fix
ef5cb62
Merge remote-tracking branch 'upstream/master' into feature/image-est…
55b33af
Fixed build
ac46be9
PR comments
013bfa4
Merge
11aa1fc
Added grayscale transform
281e731
wip one to one
baa5c34
Finished the mockup of one-to-one transformer base class.
245c344
Converted to inherit from base class
e7dc348
Workout test
d8c0648
Minor changes
3e845fd
PR comments
fe5374f
Merge remote-tracking branch 'upstream/master' into feature/image-est…
e660eeb
Fixed build
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using Microsoft.ML.Core.Data; | ||
|
|
||
| namespace Microsoft.ML.Runtime.Data | ||
| { | ||
| /// <summary> | ||
| /// The trivial implementation of <see cref="IEstimator{TTransformer}"/> that already has | ||
| /// the transformer and returns it on every call to <see cref="Fit(IDataView)"/>. | ||
| /// | ||
| /// Concrete implementations still have to provide the schema propagation mechanism, since | ||
| /// there is no easy way to infer it from the transformer. | ||
| /// </summary> | ||
| public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer> | ||
| where TTransformer : class, ITransformer | ||
| { | ||
| protected readonly IHost Host; | ||
| protected readonly TTransformer Transformer; | ||
|
|
||
| protected TrivialEstimator(IHost host, TTransformer transformer) | ||
| { | ||
| Contracts.AssertValue(host); | ||
|
|
||
| Host = host; | ||
| Host.CheckValue(transformer, nameof(transformer)); | ||
| Transformer = transformer; | ||
| } | ||
|
|
||
| public TTransformer Fit(IDataView input) => Transformer; | ||
|
|
||
| public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
177 changes: 177 additions & 0 deletions
177
src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Linq; | ||
| using Microsoft.ML.Core.Data; | ||
| using Microsoft.ML.Runtime.Model; | ||
|
|
||
| namespace Microsoft.ML.Runtime.Data | ||
| { | ||
| public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel | ||
| { | ||
| protected readonly IHost Host; | ||
| protected readonly (string input, string output)[] ColumnPairs; | ||
|
|
||
| protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns) | ||
| { | ||
| Contracts.AssertValue(host); | ||
| host.CheckValue(columns, nameof(columns)); | ||
|
|
||
| var newNames = new HashSet<string>(); | ||
| foreach (var column in columns) | ||
| { | ||
| host.CheckNonEmpty(column.input, nameof(columns)); | ||
| host.CheckNonEmpty(column.output, nameof(columns)); | ||
|
|
||
| if (!newNames.Add(column.output)) | ||
| throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times"); | ||
| } | ||
|
|
||
| Host = host; | ||
| ColumnPairs = columns; | ||
| } | ||
|
|
||
| protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) | ||
| { | ||
| Host = host; | ||
| // *** Binary format *** | ||
| // int: number of added columns | ||
| // for each added column | ||
| // int: id of output column name | ||
| // int: id of input column name | ||
|
|
||
| int n = ctx.Reader.ReadInt32(); | ||
| ColumnPairs = new (string input, string output)[n]; | ||
| for (int i = 0; i < n; i++) | ||
| { | ||
| string output = ctx.LoadNonEmptyString(); | ||
| string input = ctx.LoadNonEmptyString(); | ||
| ColumnPairs[i] = (input, output); | ||
| } | ||
| } | ||
|
|
||
| public abstract void Save(ModelSaveContext ctx); | ||
|
|
||
| protected void SaveColumns(ModelSaveContext ctx) | ||
| { | ||
| Host.CheckValue(ctx, nameof(ctx)); | ||
|
|
||
| // *** Binary format *** | ||
| // int: number of added columns | ||
| // for each added column | ||
| // int: id of output column name | ||
| // int: id of input column name | ||
|
|
||
| ctx.Writer.Write(ColumnPairs.Length); | ||
| for (int i = 0; i < ColumnPairs.Length; i++) | ||
| { | ||
| ctx.SaveNonEmptyString(ColumnPairs[i].output); | ||
| ctx.SaveNonEmptyString(ColumnPairs[i].input); | ||
| } | ||
| } | ||
|
|
||
| private void CheckInput(ISchema inputSchema, int col, out int srcCol) | ||
| { | ||
| Contracts.AssertValue(inputSchema); | ||
| Contracts.Assert(0 <= col && col < ColumnPairs.Length); | ||
|
|
||
| if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol)) | ||
| throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input); | ||
| CheckInputColumn(inputSchema, col, srcCol); | ||
| } | ||
|
|
||
| protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol) | ||
| { | ||
| // By default, there are no extra checks. | ||
| } | ||
|
|
||
| protected abstract IRowMapper MakeRowMapper(ISchema schema); | ||
|
|
||
| public ISchema GetOutputSchema(ISchema inputSchema) | ||
| { | ||
| Host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
|
||
| // Check that all the input columns are present and correct. | ||
| for (int i = 0; i < ColumnPairs.Length; i++) | ||
| CheckInput(inputSchema, i, out int col); | ||
|
|
||
| return Transform(new EmptyDataView(Host, inputSchema)).Schema; | ||
| } | ||
|
|
||
| public IDataView Transform(IDataView input) => MakeDataTransform(input); | ||
|
|
||
| protected RowToRowMapperTransform MakeDataTransform(IDataView input) | ||
| { | ||
| Host.CheckValue(input, nameof(input)); | ||
| return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema)); | ||
| } | ||
|
|
||
| protected abstract class MapperBase : IRowMapper | ||
| { | ||
| protected readonly IHost Host; | ||
| protected readonly Dictionary<int, int> ColMapNewToOld; | ||
| protected readonly ISchema InputSchema; | ||
| private readonly OneToOneTransformerBase _parent; | ||
|
|
||
| protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema) | ||
| { | ||
| Contracts.AssertValue(host); | ||
| Contracts.AssertValue(parent); | ||
| Contracts.AssertValue(inputSchema); | ||
|
|
||
| Host = host; | ||
| _parent = parent; | ||
|
|
||
| ColMapNewToOld = new Dictionary<int, int>(); | ||
| for (int i = 0; i < _parent.ColumnPairs.Length; i++) | ||
| { | ||
| _parent.CheckInput(inputSchema, i, out int srcCol); | ||
| ColMapNewToOld.Add(i, srcCol); | ||
| } | ||
| InputSchema = inputSchema; | ||
| } | ||
| public Func<int, bool> GetDependencies(Func<int, bool> activeOutput) | ||
| { | ||
| var active = new bool[InputSchema.ColumnCount]; | ||
| foreach (var pair in ColMapNewToOld) | ||
| if (activeOutput(pair.Key)) | ||
| active[pair.Value] = true; | ||
| return col => active[col]; | ||
| } | ||
|
|
||
| public abstract RowMapperColumnInfo[] GetOutputColumns(); | ||
|
|
||
| public void Save(ModelSaveContext ctx) => _parent.Save(ctx); | ||
|
|
||
| public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Action disposer) | ||
| { | ||
| Contracts.Assert(input.Schema == InputSchema); | ||
| var result = new Delegate[_parent.ColumnPairs.Length]; | ||
| var disposers = new Action[_parent.ColumnPairs.Length]; | ||
| for (int i = 0; i < _parent.ColumnPairs.Length; i++) | ||
| { | ||
| if (!activeOutput(i)) | ||
| continue; | ||
| int srcCol = ColMapNewToOld[i]; | ||
| result[i] = MakeGetter(input, i, out disposers[i]); | ||
| } | ||
| if (disposers.Any(x => x != null)) | ||
| { | ||
| disposer = () => | ||
| { | ||
| foreach (var act in disposers) | ||
| act(); | ||
| }; | ||
| } | ||
| else | ||
| disposer = null; | ||
| return result; | ||
| } | ||
|
|
||
| protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer); | ||
| } | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this public and call it in GetOutputSchema in each estimator. #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot do this, because GetOutputSchema in the estimator operates over
SchemaShapeIn reply to: 213863333 [](ancestors = 213863333)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if only ISchema and SchemaShape were relatives....
In reply to: 213863552 [](ancestors = 213863552,213863333)