From 73c2aa851ab24d3a8ea156f37c81a7a6cbe0437f Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 17 Sep 2018 13:04:12 -0700 Subject: [PATCH 1/6] Transform wrappers and a reference implementation for tokenizers --- .../DataLoadSave/FakeSchema.cs | 107 +++++++++ .../DataLoadSave/TransformWrapper.cs | 152 ++++++++++++ .../Text/WrappedTextTransformers.cs | 127 ++++++++++ .../SingleDebug/Text/tokenized.tsv | 12 + .../SingleRelease/Text/tokenized.tsv | 12 + .../Scenarios/Api/Estimators/Wrappers.cs | 219 ------------------ .../Transformers/TextFeaturizerTests.cs | 35 +++ 7 files changed, 445 insertions(+), 219 deletions(-) create mode 100644 src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs create mode 100644 src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs create mode 100644 src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs create mode 100644 test/BaselineOutput/SingleDebug/Text/tokenized.tsv create mode 100644 test/BaselineOutput/SingleRelease/Text/tokenized.tsv diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs new file mode 100644 index 0000000000..e930002c33 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Utilities; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Data.DataLoadSave +{ + + /// + /// A fake schema that is manufactured out of a SchemaShape. + /// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10, + /// and all values are defaults (for metadata). + /// + internal sealed class FakeSchema : ISchema + { + private readonly IHostEnvironment _env; + private readonly SchemaShape _shape; + private readonly Dictionary _colMap; + + public FakeSchema(IHostEnvironment env, SchemaShape inputShape) + { + _env = env; + _shape = inputShape; + _colMap = Enumerable.Range(0, _shape.Columns.Length) + .ToDictionary(idx => _shape.Columns[idx].Name, idx => idx); + } + + public int ColumnCount => _shape.Columns.Length; + + public string GetColumnName(int col) + { + _env.Check(0 <= col && col < ColumnCount); + return _shape.Columns[col].Name; + } + + public ColumnType GetColumnType(int col) + { + _env.Check(0 <= col && col < ColumnCount); + var inputCol = _shape.Columns[col]; + return MakeColumnType(inputCol); + } + + public bool TryGetColumnIndex(string name, out int col) => _colMap.TryGetValue(name, out col); + + private static ColumnType MakeColumnType(SchemaShape.Column inputCol) + { + ColumnType curType = inputCol.ItemType; + if (inputCol.IsKey) + curType = new KeyType(curType.AsPrimitive.RawKind, 0, 10); + if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector) + curType = new VectorType(curType.AsPrimitive, 0); + else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector) + curType = new VectorType(curType.AsPrimitive, 10); + return curType; + } + + public void GetMetadata(string kind, int col, ref TValue value) + { + _env.Check(0 <= col && col < ColumnCount); + var inputCol = _shape.Columns[col]; + var metaShape = inputCol.Metadata; + if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) + throw _env.ExceptGetMetadata(); + + var colType = MakeColumnType(metaColumn); + _env.Check(colType.RawType.Equals(typeof(TValue))); + + if (colType.IsVector) + { + // This as an atypical use of VBuffer: we create it in GetMetadataVec, and then pass through + // via boxing to be returned out of this method. This is intentional. + value = (TValue)Utils.MarshalInvoke(GetMetadataVec, colType.ItemType.RawType); + } + else + value = default; + } + + private object GetMetadataVec() => new VBuffer(10, 0, null, null); + + public ColumnType GetMetadataTypeOrNull(string kind, int col) + { + _env.Check(0 <= col && col < ColumnCount); + var inputCol = _shape.Columns[col]; + var metaShape = inputCol.Metadata; + if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) + return null; + return MakeColumnType(metaColumn); + } + + public IEnumerable> GetMetadataTypes(int col) + { + _env.Check(0 <= col && col < ColumnCount); + var inputCol = _shape.Columns[col]; + var metaShape = inputCol.Metadata; + if (metaShape == null) + return Enumerable.Empty>(); + + return metaShape.Columns.Select(c => new KeyValuePair(c.Name, MakeColumnType(c))); + } + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs new file mode 100644 index 0000000000..64130ed80e --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -0,0 +1,152 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.DataLoadSave; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Model; +using System.Collections.Generic; + +[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel), + "Transform wrapper", TransformWrapper.LoaderSignature)] + +namespace Microsoft.ML.Runtime.Data +{ + // REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it. + // It needs to become internal. + public sealed class TransformWrapper : ITransformer, ICanSaveModel + { + public const string LoaderSignature = "TransformWrapper"; + private const string TransformDirTemplate = "Step_{0:000}"; + + private readonly IHost _host; + private readonly IDataView _xf; + + public TransformWrapper(IHostEnvironment env, IDataView xf) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(TransformWrapper)); + _host.CheckValue(xf, nameof(xf)); + _xf = xf; + } + + public ISchema GetOutputSchema(ISchema inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + + var dv = new EmptyDataView(_host, inputSchema); + var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv); + return output.Schema; + } + + public void Save(ModelSaveContext ctx) + { + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + var dataPipe = _xf; + var transforms = new List(); + while (dataPipe is IDataTransform xf) + { + // REVIEW: a malicious user could construct a loop in the Source chain, that would + // cause this method to iterate forever (and throw something when the list overflows). There's + // no way to insulate from ALL malicious behavior. + transforms.Add(xf); + dataPipe = xf.Source; + Contracts.AssertValue(dataPipe); + } + transforms.Reverse(); + + ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema)); + + ctx.Writer.Write(transforms.Count); + for (int i = 0; i < transforms.Count; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.SaveModel(transforms[i], dirName); + } + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "XF WRPR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + // Factory for SignatureLoadModel. + public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(TransformWrapper)); + _host.CheckValue(ctx, nameof(ctx)); + + ctx.CheckAtModel(GetVersionInfo()); + int n = ctx.Reader.ReadInt32(); + _host.CheckDecode(n >= 0); + + ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); + + IDataView data = loader; + for (int i = 0; i < n; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.LoadModel(env, out var xf, dirName, data); + data = xf; + } + + _xf = data; + } + + public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); + } + + /// + /// Estimator for trained wrapped transformers. + /// + internal abstract class TrainedWrapperEstimatorBase : IEstimator + { + private readonly IHost _host; + + protected TrainedWrapperEstimatorBase(IHost host) + { + Contracts.CheckValue(host, nameof(host)); + _host = host; + } + + public abstract TransformWrapper Fit(IDataView input); + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + + var fakeSchema = new FakeSchema(_host, inputSchema); + var transformer = Fit(new EmptyDataView(_host, fakeSchema)); + return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema)); + } + } + + /// + /// Estimator for untrained wrapped transformers. + /// + public abstract class TrivialWrapperEstimator : TrivialEstimator + { + protected TrivialWrapperEstimator(IHost host, TransformWrapper transformer) + : base(host, transformer) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var fakeSchema = new FakeSchema(Host, inputSchema); + return SchemaShape.Create(Transformer.GetOutputSchema(fakeSchema)); + } + } +} diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs new file mode 100644 index 0000000000..96b9797439 --- /dev/null +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -0,0 +1,127 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.TextAnalytics; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.ML.Transforms +{ + /// + /// Word tokenizer splits text into tokens using the delimiter. + /// For each text input, the output column is a variable vector of text. + /// + public sealed class WordTokenizer : TrivialWrapperEstimator + { + /// + /// Tokenize incoming text in and output the tokens as . + /// + /// The environment. + /// The column containing text to tokenize. + /// The column containing output tokens. Null means is replaced. + /// Any advanced settings to be applied. + public WordTokenizer(IHostEnvironment env, string inputColumn, string outputColumn = null, + Action advancedSettings = null) + : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, advancedSettings) + { + } + + /// + /// Tokenize incoming text in input columns and output the tokens as output columns. + /// + /// The environment. + /// Pairs of columns to run the tokenization on. + /// Any advanced settings to be applied. + public WordTokenizer(IHostEnvironment env, (string input, string output)[] columns, + Action advancedSettings = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizer)), MakeTransformer(env, columns, advancedSettings)) + { + } + + private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, Action advancedSettings) + { + Contracts.AssertValue(env); + env.CheckNonEmpty(columns, nameof(columns)); + env.CheckValueOrNull(advancedSettings); + foreach (var (input, output) in columns) + { + env.CheckValue(input, nameof(input)); + env.CheckValue(output, nameof(input)); + } + + // Create arguments. + var args = new DelimitedTokenizeTransform.Arguments + { + Column = columns.Select(x => new DelimitedTokenizeTransform.Column { Source = x.input, Name = x.output }).ToArray() + }; + advancedSettings?.Invoke(args); + + // Create a valid instance of data. + var schema = new SimpleSchema(env, columns.Select(x => new KeyValuePair(x.input, TextType.Instance)).ToArray()); + var emptyData = new EmptyDataView(env, schema); + + return new TransformWrapper(env, new DelimitedTokenizeTransform(env, args, emptyData)); + } + } + + /// + /// Character tokenizer splits text into sequences of characters using a sliding window. + /// + public sealed class CharacterTokenizer: TrivialWrapperEstimator + { + /// + /// Tokenize incoming text in and output the tokens as . + /// + /// The environment. + /// The column containing text to tokenize. + /// The column containing output tokens. Null means is replaced. + /// Any advanced settings to be applied. + public CharacterTokenizer(IHostEnvironment env, string inputColumn, string outputColumn = null, + Action advancedSettings = null) + : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, advancedSettings) + { + } + + /// + /// Tokenize incoming text in input columns and output the tokens as output columns. + /// + /// The environment. + /// Pairs of columns to run the tokenization on. + /// Any advanced settings to be applied. + public CharacterTokenizer(IHostEnvironment env, (string input, string output)[] columns, + Action advancedSettings = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizer)), MakeTransformer(env, columns, advancedSettings)) + { + } + + private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, Action advancedSettings) + { + Contracts.AssertValue(env); + env.CheckNonEmpty(columns, nameof(columns)); + env.CheckValueOrNull(advancedSettings); + foreach (var (input, output) in columns) + { + env.CheckValue(input, nameof(input)); + env.CheckValue(output, nameof(input)); + } + + // Create arguments. + var args = new CharTokenizeTransform.Arguments + { + Column = columns.Select(x => new CharTokenizeTransform.Column { Source = x.input, Name = x.output }).ToArray() + }; + advancedSettings?.Invoke(args); + + // Create a valid instance of data. + var schema = new SimpleSchema(env, columns.Select(x => new KeyValuePair(x.input, TextType.Instance)).ToArray()); + var emptyData = new EmptyDataView(env, schema); + + return new TransformWrapper(env, new CharTokenizeTransform(env, args, emptyData)); + } + } +} diff --git a/test/BaselineOutput/SingleDebug/Text/tokenized.tsv b/test/BaselineOutput/SingleDebug/Text/tokenized.tsv new file mode 100644 index 0000000000..eccf8ca0e6 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/Text/tokenized.tsv @@ -0,0 +1,12 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=text:TX:0 +#@ col=words:TX:1-** +#@ col={name=chars type=TX src={ min=-1 var=+}} +#@ } +text +==RUDE== Dude, you are rude upload that carl picture back, or else. ==RUDE== Dude, you are rude upload that carl picture back, or else. <␂> = = R U D E = = <␠> D u d e , <␠> y o u <␠> a r e <␠> r u d e <␠> u p l o a d <␠> t h a t <␠> c a r l <␠> p i c t u r e <␠> b a c k , <␠> o r <␠> e l s e . <␃> +== OK! == IM GOING TO VANDALIZE WILD ONES WIKI THEN!!! == OK! == IM GOING TO VANDALIZE WILD ONES WIKI THEN!!! <␂> = = <␠> O K ! <␠> = = <␠> <␠> I M <␠> G O I N G <␠> T O <␠> V A N D A L I Z E <␠> W I L D <␠> O N E S <␠> W I K I <␠> T H E N ! ! ! <␠> <␠> <␠> <␃> +Stop trolling, zapatancas, calling me a liar merely demonstartes that you arer Zapatancas. You may choose to chase every legitimate editor from this site and ignore me but I am an editor with a record that isnt 99% trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. The consensus is overwhelmingly against you and your trollin g lover Zapatancas, Stop trolling, zapatancas, calling me a liar merely demonstartes that you arer Zapatancas. You may choose to chase every legitimate editor from this site and ignore me but I am an editor with a record that isnt 99% trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. The consensus is overwhelmingly against you and your trollin g lover Zapatancas, <␂> S t o p <␠> t r o l l i n g , <␠> z a p a t a n c a s , <␠> c a l l i n g <␠> m e <␠> a <␠> l i a r <␠> m e r e l y <␠> d e m o n s t a r t e s <␠> t h a t <␠> y o u <␠> a r e r <␠> Z a p a t a n c a s . <␠> Y o u <␠> m a y <␠> c h o o s e <␠> t o <␠> c h a s e <␠> e v e r y <␠> l e g i t i m a t e <␠> e d i t o r <␠> f r o m <␠> t h i s <␠> s i t e <␠> a n d <␠> i g n o r e <␠> m e <␠> b u t <␠> I <␠> a m <␠> a n <␠> e d i t o r <␠> w i t h <␠> a <␠> r e c o r d <␠> t h a t <␠> i s n t <␠> 9 9 % <␠> t r o l l i n g <␠> a n d <␠> t h e r e f o r e <␠> m y <␠> w i s h e s <␠> a r e <␠> n o t <␠> t o <␠> b e <␠> c o m p l e t e l y <␠> i g n o r e d <␠> b y <␠> a <␠> s o c k p u p p e t <␠> l i k e <␠> y o u r s e l f . <␠> T h e <␠> c o n s e n s u s <␠> i s <␠> o v e r w h e l m i n g l y <␠> a g a i n s t <␠> y o u <␠> a n d <␠> y o u r <␠> t r o l l i n <␠> g <␠> l o v e r <␠> Z a p a t a n c a s , <␠> <␠> <␃> +==You're cool== You seem like a really cool guy... *bursts out laughing at sarcasm*. ==You're cool== You seem like a really cool guy... *bursts out laughing at sarcasm*. <␂> = = Y o u ' r e <␠> c o o l = = <␠> <␠> Y o u <␠> s e e m <␠> l i k e <␠> a <␠> r e a l l y <␠> c o o l <␠> g u y . . . <␠> * b u r s t s <␠> o u t <␠> l a u g h i n g <␠> a t <␠> s a r c a s m * . <␃> diff --git a/test/BaselineOutput/SingleRelease/Text/tokenized.tsv b/test/BaselineOutput/SingleRelease/Text/tokenized.tsv new file mode 100644 index 0000000000..eccf8ca0e6 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/Text/tokenized.tsv @@ -0,0 +1,12 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=text:TX:0 +#@ col=words:TX:1-** +#@ col={name=chars type=TX src={ min=-1 var=+}} +#@ } +text +==RUDE== Dude, you are rude upload that carl picture back, or else. ==RUDE== Dude, you are rude upload that carl picture back, or else. <␂> = = R U D E = = <␠> D u d e , <␠> y o u <␠> a r e <␠> r u d e <␠> u p l o a d <␠> t h a t <␠> c a r l <␠> p i c t u r e <␠> b a c k , <␠> o r <␠> e l s e . <␃> +== OK! == IM GOING TO VANDALIZE WILD ONES WIKI THEN!!! == OK! == IM GOING TO VANDALIZE WILD ONES WIKI THEN!!! <␂> = = <␠> O K ! <␠> = = <␠> <␠> I M <␠> G O I N G <␠> T O <␠> V A N D A L I Z E <␠> W I L D <␠> O N E S <␠> W I K I <␠> T H E N ! ! ! <␠> <␠> <␠> <␃> +Stop trolling, zapatancas, calling me a liar merely demonstartes that you arer Zapatancas. You may choose to chase every legitimate editor from this site and ignore me but I am an editor with a record that isnt 99% trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. The consensus is overwhelmingly against you and your trollin g lover Zapatancas, Stop trolling, zapatancas, calling me a liar merely demonstartes that you arer Zapatancas. You may choose to chase every legitimate editor from this site and ignore me but I am an editor with a record that isnt 99% trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. The consensus is overwhelmingly against you and your trollin g lover Zapatancas, <␂> S t o p <␠> t r o l l i n g , <␠> z a p a t a n c a s , <␠> c a l l i n g <␠> m e <␠> a <␠> l i a r <␠> m e r e l y <␠> d e m o n s t a r t e s <␠> t h a t <␠> y o u <␠> a r e r <␠> Z a p a t a n c a s . <␠> Y o u <␠> m a y <␠> c h o o s e <␠> t o <␠> c h a s e <␠> e v e r y <␠> l e g i t i m a t e <␠> e d i t o r <␠> f r o m <␠> t h i s <␠> s i t e <␠> a n d <␠> i g n o r e <␠> m e <␠> b u t <␠> I <␠> a m <␠> a n <␠> e d i t o r <␠> w i t h <␠> a <␠> r e c o r d <␠> t h a t <␠> i s n t <␠> 9 9 % <␠> t r o l l i n g <␠> a n d <␠> t h e r e f o r e <␠> m y <␠> w i s h e s <␠> a r e <␠> n o t <␠> t o <␠> b e <␠> c o m p l e t e l y <␠> i g n o r e d <␠> b y <␠> a <␠> s o c k p u p p e t <␠> l i k e <␠> y o u r s e l f . <␠> T h e <␠> c o n s e n s u s <␠> i s <␠> o v e r w h e l m i n g l y <␠> a g a i n s t <␠> y o u <␠> a n d <␠> y o u r <␠> t r o l l i n <␠> g <␠> l o v e r <␠> Z a p a t a n c a s , <␠> <␠> <␃> +==You're cool== You seem like a really cool guy... *bursts out laughing at sarcasm*. ==You're cool== You seem like a really cool guy... *bursts out laughing at sarcasm*. <␂> = = Y o u ' r e <␠> c o o l = = <␠> <␠> Y o u <␠> s e e m <␠> l i k e <␠> a <␠> r e a l l y <␠> c o o l <␠> g u y . . . <␠> * b u r s t s <␠> o u t <␠> l a u g h i n g <␠> a t <␠> s a r c a s m * . <␃> diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index 86799ce445..c5a6a40703 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -6,29 +6,16 @@ using Microsoft.ML.Legacy.Models; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Tests.Scenarios.Api; using System; using System.Collections.Generic; using System.IO; using System.Linq; -[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel), - "Transform wrapper", TransformWrapper.LoaderSignature)] -[assembly: LoadableClass(typeof(LoaderWrapper), null, typeof(SignatureLoadModel), - "Loader wrapper", LoaderWrapper.LoaderSignature)] - namespace Microsoft.ML.Tests.Scenarios.Api { - using TScalarPredictor = IPredictorProducing; - using TWeightsPredictor = IPredictorWithFeatureWeights; - public sealed class LoaderWrapper : IDataReader, ICanSaveModel { public const string LoaderSignature = "LoaderWrapper"; @@ -93,212 +80,6 @@ public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx) } } - public class TransformWrapper : ITransformer, ICanSaveModel - { - public const string LoaderSignature = "TransformWrapper"; - private const string TransformDirTemplate = "Step_{0:000}"; - - protected readonly IHostEnvironment _env; - protected readonly IDataView _xf; - - public TransformWrapper(IHostEnvironment env, IDataView xf) - { - _env = env; - _xf = xf; - } - - public ISchema GetOutputSchema(ISchema inputSchema) - { - var dv = new EmptyDataView(_env, inputSchema); - var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv); - return output.Schema; - } - - public void Save(ModelSaveContext ctx) - { - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - var dataPipe = _xf; - var transforms = new List(); - while (dataPipe is IDataTransform xf) - { - // REVIEW: a malicious user could construct a loop in the Source chain, that would - // cause this method to iterate forever (and throw something when the list overflows). There's - // no way to insulate from ALL malicious behavior. - transforms.Add(xf); - dataPipe = xf.Source; - Contracts.AssertValue(dataPipe); - } - transforms.Reverse(); - - ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema)); - - ctx.Writer.Write(transforms.Count); - for (int i = 0; i < transforms.Count; i++) - { - var dirName = string.Format(TransformDirTemplate, i); - ctx.SaveModel(transforms[i], dirName); - } - } - - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "XF WRPR", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); - } - - public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) - { - ctx.CheckAtModel(GetVersionInfo()); - int n = ctx.Reader.ReadInt32(); - - ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); - - IDataView data = loader; - for (int i = 0; i < n; i++) - { - var dirName = string.Format(TransformDirTemplate, i); - ctx.LoadModel(env, out var xf, dirName, data); - data = xf; - } - - _env = env; - _xf = data; - } - - public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); - } - - public class ScorerWrapper : TransformWrapper, IPredictionTransformer - where TModel : IPredictor - { - protected readonly string _featureColumn; - - public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string featureColumn) - : base(env, scorer) - { - _featureColumn = featureColumn; - Model = trainedModel; - } - - public TModel Model { get; } - - public string FeatureColumn => _featureColumn; - - public ColumnType FeatureColumnType => throw _env.ExceptNotSupp(); - } - - public class BinaryScorerWrapper : ScorerWrapper - where TModel : IPredictor - { - public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args) - : base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn) - { - } - - private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args) - { - var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}"; - - var scorerFactorySettings = CmdParser.CreateComponentFactory( - typeof(IComponentFactory), - typeof(SignatureDataScorer), - settings); - - var bindable = ScoreUtils.GetSchemaBindableMapper(env, model, scorerFactorySettings: scorerFactorySettings); - var edv = new EmptyDataView(env, schema); - var data = new RoleMappedData(edv, "Label", featureColumn, opt: true); - - return new BinaryClassifierScorer(env, args, data.Data, bindable.Bind(env, data.Schema), data.Schema); - } - - public BinaryScorerWrapper Clone(BinaryClassifierScorer.Arguments scorerArgs) - { - var scorer = _xf as IDataScorerTransform; - return new BinaryScorerWrapper(_env, Model, scorer.Source.Schema, _featureColumn, scorerArgs); - } - } - - public abstract class TrainerBase : ITrainerEstimator - where TTransformer : ScorerWrapper - where TModel : IPredictor - { - protected readonly IHostEnvironment _env; - protected readonly string _featureCol; - protected readonly string _labelCol; - - public abstract PredictionKind PredictionKind { get; } - - public TrainerInfo Info { get; } - - protected TrainerBase(IHostEnvironment env, TrainerInfo trainerInfo, string featureColumn, string labelColumn) - { - _env = env; - _featureCol = featureColumn; - _labelCol = labelColumn; - Info = trainerInfo; - } - - public TTransformer Fit(IDataView input) - { - return TrainTransformer(input); - } - - protected TTransformer TrainTransformer(IDataView trainSet, - IDataView validationSet = null, IPredictor initPredictor = null) - { - var cachedTrain = Info.WantCaching ? new CacheDataView(_env, trainSet, prefetch: null) : trainSet; - - var trainRoles = new RoleMappedData(cachedTrain, label: _labelCol, feature: _featureCol); - var emptyData = new EmptyDataView(_env, trainSet.Schema); - IDataView normalizer = emptyData; - - if (Info.NeedNormalization && trainRoles.Schema.FeaturesAreNormalized() == false) - { - var view = NormalizeTransform.CreateMinMaxNormalizer(_env, trainRoles.Data, name: trainRoles.Schema.Feature.Name); - normalizer = ApplyTransformUtils.ApplyAllTransformsToData(_env, view, emptyData, cachedTrain); - - trainRoles = new RoleMappedData(view, trainRoles.Schema.GetColumnRoleNames()); - } - - RoleMappedData validRoles; - - if (validationSet == null) - validRoles = null; - else - { - var cachedValid = Info.WantCaching ? new CacheDataView(_env, validationSet, prefetch: null) : validationSet; - cachedValid = ApplyTransformUtils.ApplyAllTransformsToData(_env, normalizer, cachedValid); - validRoles = new RoleMappedData(cachedValid, label: _labelCol, feature: _featureCol); - } - - var pred = TrainCore(new TrainContext(trainRoles, validRoles, initPredictor)); - - var scoreRoles = new RoleMappedData(normalizer, label: _labelCol, feature: _featureCol); - return MakeScorer(pred, scoreRoles); - } - - public SchemaShape GetOutputSchema(SchemaShape inputSchema) - { - throw new NotImplementedException(); - } - - protected abstract TModel TrainCore(TrainContext trainContext); - - protected abstract TTransformer MakeScorer(TModel predictor, RoleMappedData data); - - protected ScorerWrapper MakeScorerBasic(TModel predictor, RoleMappedData data) - { - var scorer = ScoreUtils.GetScorer(predictor, data, _env, data.Schema); - return (TTransformer)(new ScorerWrapper(_env, scorer, predictor, data.Schema.Feature.Name)); - } - } - public sealed class MyBinaryClassifierEvaluator { private readonly IHostEnvironment _env; diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 74e0fe39ec..9946996bb3 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Transforms; using System.IO; using Xunit; using Xunit.Abstractions; @@ -53,5 +54,39 @@ public void TextFeaturizerWorkout() CheckEquality("Text", "featurized.tsv"); Done(); } + + [Fact] + public void TextTokenizationWorkout() + { + string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); + var data = TextLoader.CreateReader(Env, ctx => ( + label: ctx.LoadBool(0), + text: ctx.LoadText(1)), hasHeader: true) + .Read(new MultiFileSource(sentimentDataPath)); + + var invalidData = TextLoader.CreateReader(Env, ctx => ( + label: ctx.LoadBool(0), + text: ctx.LoadFloat(1)), hasHeader: true) + .Read(new MultiFileSource(sentimentDataPath)); + + var est = new WordTokenizer(Env, "text", "words") + .Append(new CharacterTokenizer(Env, "text", "chars")) + .Append(new KeyToValueEstimator(Env, "chars")); + TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); + + var outputPath = GetOutputPath("Text", "tokenized.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); + savedData = new ChooseColumnsTransform(Env, savedData, "text", "words", "chars"); + + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + } + + CheckEquality("Text", "tokenized.tsv"); + Done(); + } } } From d145d079615544f032088565e610e0759b3a5e5d Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 17 Sep 2018 14:15:48 -0700 Subject: [PATCH 2/6] Added pigsty extensions --- .../Text/TextStaticExtensions.cs | 116 ++++++++++++++++++ .../Text/WrappedTextTransformers.cs | 45 +++---- 2 files changed, 136 insertions(+), 25 deletions(-) create mode 100644 src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs diff --git a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs new file mode 100644 index 0000000000..75d4d097b4 --- /dev/null +++ b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs @@ -0,0 +1,116 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Transforms.Text +{ + /// + /// Extensions for statically typed word tokenizer. + /// + public static class WordTokenizerExtensions + { + private sealed class OutPipelineColumn : VarVector + { + public readonly Scalar Input; + + public OutPipelineColumn(Scalar input, string separators) + : base(new Reconciler(separators), input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + private readonly string _separators; + + public Reconciler(string separators) + { + _separators = separators; + } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Contracts.Assert(toOutput.Length == 1); + + var pairs = new List<(string input, string output)>(); + foreach (var outCol in toOutput) + pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + + return new WordTokenizer(env, pairs.ToArray(), _separators); + } + } + + /// + /// Tokenize incoming text using and output the tokens. + /// + /// The column to apply to. + /// The separators to use (comma separated). + public static VarVector TokenizeText(this Scalar input, string separators = "space") => new OutPipelineColumn(input, separators); + } + + /// + /// Extensions for statically typed character tokenizer. + /// + public static class CharacterTokenizerExtensions + { + private sealed class OutPipelineColumn : VarVector + { + public readonly Scalar Input; + + public OutPipelineColumn(Scalar input, bool useMarkerChars) + : base(new Reconciler(useMarkerChars), input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler, IEquatable + { + private readonly bool _useMarker; + + public Reconciler(bool useMarkerChars) + { + _useMarker = useMarkerChars; + } + + public bool Equals(Reconciler other) + { + return _useMarker == other._useMarker; + } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Contracts.Assert(toOutput.Length == 1); + + var pairs = new List<(string input, string output)>(); + foreach (var outCol in toOutput) + pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + + return new CharacterTokenizer(env, pairs.ToArray(), _useMarker); + } + } + + /// + /// Tokenize incoming text into a sequence of characters. + /// + /// The column to apply to. + /// Whether to use marker characters to separate words. + public static VarVector TokenizeIntoCharacters(this Scalar input, bool useMarkerCharacters = true) => new OutPipelineColumn(input, useMarkerCharacters); + } +} diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index 96b9797439..654a9eb51f 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -24,10 +24,9 @@ public sealed class WordTokenizer : TrivialWrapperEstimator /// The environment. /// The column containing text to tokenize. /// The column containing output tokens. Null means is replaced. - /// Any advanced settings to be applied. - public WordTokenizer(IHostEnvironment env, string inputColumn, string outputColumn = null, - Action advancedSettings = null) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, advancedSettings) + /// The separators to use (comma separated). + public WordTokenizer(IHostEnvironment env, string inputColumn, string outputColumn = null, string separators = "space") + : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, separators) { } @@ -36,18 +35,16 @@ public WordTokenizer(IHostEnvironment env, string inputColumn, string outputColu /// /// The environment. /// Pairs of columns to run the tokenization on. - /// Any advanced settings to be applied. - public WordTokenizer(IHostEnvironment env, (string input, string output)[] columns, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizer)), MakeTransformer(env, columns, advancedSettings)) + /// The separators to use (comma separated). + public WordTokenizer(IHostEnvironment env, (string input, string output)[] columns, string separators = "space") + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizer)), MakeTransformer(env, columns, separators)) { } - private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, Action advancedSettings) + private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, string separators) { Contracts.AssertValue(env); env.CheckNonEmpty(columns, nameof(columns)); - env.CheckValueOrNull(advancedSettings); foreach (var (input, output) in columns) { env.CheckValue(input, nameof(input)); @@ -55,11 +52,12 @@ private static TransformWrapper MakeTransformer(IHostEnvironment env, (string in } // Create arguments. + // REVIEW: enable multiple separators via something other than parsing strings. var args = new DelimitedTokenizeTransform.Arguments { - Column = columns.Select(x => new DelimitedTokenizeTransform.Column { Source = x.input, Name = x.output }).ToArray() + Column = columns.Select(x => new DelimitedTokenizeTransform.Column { Source = x.input, Name = x.output }).ToArray(), + TermSeparators = separators }; - advancedSettings?.Invoke(args); // Create a valid instance of data. var schema = new SimpleSchema(env, columns.Select(x => new KeyValuePair(x.input, TextType.Instance)).ToArray()); @@ -72,7 +70,7 @@ private static TransformWrapper MakeTransformer(IHostEnvironment env, (string in /// /// Character tokenizer splits text into sequences of characters using a sliding window. /// - public sealed class CharacterTokenizer: TrivialWrapperEstimator + public sealed class CharacterTokenizer : TrivialWrapperEstimator { /// /// Tokenize incoming text in and output the tokens as . @@ -80,10 +78,9 @@ public sealed class CharacterTokenizer: TrivialWrapperEstimator /// The environment. /// The column containing text to tokenize. /// The column containing output tokens. Null means is replaced. - /// Any advanced settings to be applied. - public CharacterTokenizer(IHostEnvironment env, string inputColumn, string outputColumn = null, - Action advancedSettings = null) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, advancedSettings) + /// Whether to use marker characters to separate words. + public CharacterTokenizer(IHostEnvironment env, string inputColumn, string outputColumn = null, bool useMarkerCharacters = true) + : this (env, new[] { (inputColumn, outputColumn ?? inputColumn) }, useMarkerCharacters) { } @@ -92,18 +89,16 @@ public CharacterTokenizer(IHostEnvironment env, string inputColumn, string outpu /// /// The environment. /// Pairs of columns to run the tokenization on. - /// Any advanced settings to be applied. - public CharacterTokenizer(IHostEnvironment env, (string input, string output)[] columns, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizer)), MakeTransformer(env, columns, advancedSettings)) + /// Whether to use marker characters to separate words. + public CharacterTokenizer(IHostEnvironment env, (string input, string output)[] columns, bool useMarkerCharacters = true) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizer)), MakeTransformer(env, columns, useMarkerCharacters)) { } - private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, Action advancedSettings) + private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, bool useMarkerChars) { Contracts.AssertValue(env); env.CheckNonEmpty(columns, nameof(columns)); - env.CheckValueOrNull(advancedSettings); foreach (var (input, output) in columns) { env.CheckValue(input, nameof(input)); @@ -113,9 +108,9 @@ private static TransformWrapper MakeTransformer(IHostEnvironment env, (string in // Create arguments. var args = new CharTokenizeTransform.Arguments { - Column = columns.Select(x => new CharTokenizeTransform.Column { Source = x.input, Name = x.output }).ToArray() + Column = columns.Select(x => new CharTokenizeTransform.Column { Source = x.input, Name = x.output }).ToArray(), + UseMarkerChars = useMarkerChars }; - advancedSettings?.Invoke(args); // Create a valid instance of data. var schema = new SimpleSchema(env, columns.Select(x => new KeyValuePair(x.input, TextType.Instance)).ToArray()); From e8ef7ef562b6d2a6b53efd6c4d224966cf3f8f25 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 17 Sep 2018 14:26:52 -0700 Subject: [PATCH 3/6] Added pigsty test --- .../Text/TextStaticExtensions.cs | 4 +-- .../StaticPipeTests.cs | 31 +++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs index 75d4d097b4..4e0821241b 100644 --- a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs @@ -65,7 +65,7 @@ public override IEstimator Reconcile(IHostEnvironment env, /// public static class CharacterTokenizerExtensions { - private sealed class OutPipelineColumn : VarVector + private sealed class OutPipelineColumn : VarVector> { public readonly Scalar Input; @@ -111,6 +111,6 @@ public override IEstimator Reconcile(IHostEnvironment env, /// /// The column to apply to. /// Whether to use marker characters to separate words. - public static VarVector TokenizeIntoCharacters(this Scalar input, bool useMarkerCharacters = true) => new OutPipelineColumn(input, useMarkerCharacters); + public static VarVector> TokenizeIntoCharacters(this Scalar input, bool useMarkerCharacters = true) => new OutPipelineColumn(input, useMarkerCharacters); } } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 0095d2f4cd..3efb70c029 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Transforms.Text; using Microsoft.ML.Data.StaticPipe; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; @@ -403,5 +404,35 @@ public void ConcatWith() Assert.Equal(NumberType.Float, types[2].ItemType); Assert.Equal(NumberType.Float, types[3].ItemType); } + + [Fact] + public void Tokenize() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); + var reader = TextLoader.CreateReader(env, ctx => ( + label: ctx.LoadBool(0), + text: ctx.LoadText(1)), hasHeader: true); + var dataSource = new MultiFileSource(dataPath); + var data = reader.Read(dataSource); + + var est = data.MakeNewEstimator() + .Append(r => ( + r.label, + tokens: r.text.TokenizeText(), + chars: r.text.TokenizeIntoCharacters())); + + var tdata = est.Fit(data).Transform(data); + var schema = tdata.AsDynamic.Schema; + + Assert.True(schema.TryGetColumnIndex("tokens", out int tokensCol)); + var type = schema.GetColumnType(tokensCol); + Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsText); + + Assert.True(schema.TryGetColumnIndex("chars", out int charsCol)); + type = schema.GetColumnType(charsCol); + Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsKey); + Assert.True(type.ItemType.AsKey.RawKind == DataKind.U2); + } } } From 13505875c4380e29a63d17e0cdfebad5f27a59a8 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Mon, 17 Sep 2018 20:24:45 -0700 Subject: [PATCH 4/6] Fixed most important PR comments --- test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 3efb70c029..4b8194afe1 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -2,13 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Transforms.Text; using Microsoft.ML.Data.StaticPipe; -using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.TestFramework; +using Microsoft.ML.Transforms.Text; using System; using System.Collections.Generic; using System.Collections.Immutable; From 450efdad48ddb43b9471df73271e84480af6b29b Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Tue, 18 Sep 2018 10:12:52 -0700 Subject: [PATCH 5/6] PR comments --- src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs index e930002c33..d94219a453 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs @@ -19,6 +19,9 @@ namespace Microsoft.ML.Data.DataLoadSave /// internal sealed class FakeSchema : ISchema { + private const int AllVectorSizes = 10; + private const int AllKeySizes = 10; + private readonly IHostEnvironment _env; private readonly SchemaShape _shape; private readonly Dictionary _colMap; @@ -52,11 +55,11 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol) { ColumnType curType = inputCol.ItemType; if (inputCol.IsKey) - curType = new KeyType(curType.AsPrimitive.RawKind, 0, 10); + curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes); if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector) curType = new VectorType(curType.AsPrimitive, 0); else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector) - curType = new VectorType(curType.AsPrimitive, 10); + curType = new VectorType(curType.AsPrimitive, AllVectorSizes); return curType; } @@ -81,7 +84,7 @@ public void GetMetadata(string kind, int col, ref TValue value) value = default; } - private object GetMetadataVec() => new VBuffer(10, 0, null, null); + private object GetMetadataVec() => new VBuffer(AllVectorSizes, 0, null, null); public ColumnType GetMetadataTypeOrNull(string kind, int col) { From 4acbd6b53b2d3e7c47d8c400c1f7db1b4b8c3ff0 Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Tue, 18 Sep 2018 14:24:37 -0700 Subject: [PATCH 6/6] PR comments --- src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index 654a9eb51f..8d4330212d 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -91,7 +91,7 @@ public CharacterTokenizer(IHostEnvironment env, string inputColumn, string outpu /// Pairs of columns to run the tokenization on. /// Whether to use marker characters to separate words. public CharacterTokenizer(IHostEnvironment env, (string input, string output)[] columns, bool useMarkerCharacters = true) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizer)), MakeTransformer(env, columns, useMarkerCharacters)) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CharacterTokenizer)), MakeTransformer(env, columns, useMarkerCharacters)) { }