From c1bb9f4ebfb2e1b360c59bd7913b8756e767783d Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Mon, 6 Aug 2018 16:46:36 -0500 Subject: [PATCH 1/6] Move Scorers and Calibrators to use IComponentFactory. Also, PartitionedFileLoader is now SubComponent-free. --- .../CommandLine/CmdParser.cs | 57 ++++++-- .../EntryPoints/ComponentFactory.cs | 28 ++++ .../Commands/CrossValidationCommand.cs | 20 +-- .../Commands/ScoreCommand.cs | 124 ++++++++++++++---- src/Microsoft.ML.Data/Commands/TestCommand.cs | 5 +- .../Commands/TrainCommand.cs | 9 +- .../Commands/TrainTestCommand.cs | 9 +- .../DataLoadSave/PartitionedFileLoader.cs | 8 +- .../EntryPoints/ScoreModel.cs | 12 +- .../Prediction/Calibrator.cs | 6 +- .../Transforms/TrainAndScoreTransform.cs | 12 +- src/Microsoft.ML.Ensemble/PipelineEnsemble.cs | 2 +- .../TreeEnsembleFeaturizer.cs | 8 +- .../MultiClass/MetaMulticlassTrainer.cs | 2 +- 14 files changed, 221 insertions(+), 81 deletions(-) diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index 191321213f..f91e9c90ba 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -1856,17 +1856,13 @@ public static IComponentFactory CreateComponentFactory( factoryType.IsGenericType); Type componentFactoryType; - if (factoryType.GenericTypeArguments.Length == 1) + switch (factoryType.GenericTypeArguments.Length) { - componentFactoryType = typeof(ComponentFactory<>); - } - else if (factoryType.GenericTypeArguments.Length == 2) - { - componentFactoryType = typeof(ComponentFactory<,>); - } - else - { - throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create components with 1 or 2 type args."); + case 1: componentFactoryType = typeof(ComponentFactory<>); break; + case 2: componentFactoryType = typeof(ComponentFactory<,>); break; + case 3: componentFactoryType = typeof(ComponentFactory<,,>); break; + case 4: componentFactoryType = typeof(ComponentFactory<,,,>); break; + default: throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create component factories with 4 or less type args."); } return (IComponentFactory)Activator.CreateInstance( @@ -1950,6 +1946,47 @@ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) argument1); } } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString(), + argument1, + argument2); + } + } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString(), + argument1, + argument2, + argument3); + } + } } private bool ReportMissingRequiredArgument(CmdParser owner, ArgValue val) diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs index d69a9d0b93..2c8aee9fb4 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs @@ -64,4 +64,32 @@ public interface IComponentFactory : ICompon { TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2); } + + /// + /// An interface for creating a component when we take three extra parameters (and an ). + /// + public interface IComponentFactory : IComponentFactory + { + TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3); + } + + /// + /// A class for creating a component when we take three extra parameters + /// (and an ) that simply wraps a delegate which + /// creates the component. + /// + public class SimpleComponentFactory : IComponentFactory + { + private Func _factory; + + public SimpleComponentFactory(Func factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) + { + return _factory(env, argument1, argument2, argument3); + } + } } diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index affd949064..39f2f78173 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -28,8 +28,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr")] public SubComponent Trainer = new SubComponent("AveragedPerceptron"); - [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SortOrder = 101)] - public SubComponent Scorer; + [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SortOrder = 101, SignatureType = typeof(SignatureDataScorer))] + public IComponentFactory Scorer; [Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", NullName = "", SortOrder = 102)] public SubComponent Evaluator; @@ -76,8 +76,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")] public string ValidationFile; - [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "")] - public SubComponent Calibrator = new SubComponent("PlattCalibration"); + [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "", SignatureType = typeof(SignatureCalibrator))] + public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")] public int MaxCalibrationExamples = 1000000000; @@ -383,9 +383,9 @@ public FoldResult(Dictionary metrics, ISchema scoreSchema, Ro private readonly string _splitColumn; private readonly int _numFolds; private readonly SubComponent _trainer; - private readonly SubComponent _scorer; + private readonly IComponentFactory _scorer; private readonly SubComponent _evaluator; - private readonly SubComponent _calibrator; + private readonly IComponentFactory _calibrator; private readonly int _maxCalibrationExamples; private readonly bool _useThreads; private readonly bool? _cacheData; @@ -423,7 +423,7 @@ public FoldHelper( Arguments args, Func createExamples, Func applyTransformsToTestData, - SubComponent scorer, + IComponentFactory scorer, SubComponent evaluator, Func getValidationDataView = null, Func applyTransformsToValidationData = null, @@ -559,11 +559,11 @@ private FoldResult RunFold(int fold) // Score. ch.Trace("Scoring and evaluating"); - var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer); + var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); var mapper = bindable.Bind(host, testData.Schema); - var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper); - IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema); + var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(mapper); + IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema); // Save per-fold model. string modelFileName = ConstructPerFoldName(_outputModelFile, fold); diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index f69c35231d..76fd587fe1 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -51,8 +51,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase ShortName = "col", SortOrder = 10)] public KeyValuePair[] CustomColumn; - [Argument(ArgumentType.Multiple, HelpText = "Scorer to use")] - public SubComponent Scorer; + [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", SignatureType = typeof(SignatureDataScorer))] + public IComponentFactory Scorer; [Argument(ArgumentType.Multiple, HelpText = "The data saver to use")] public SubComponent Saver; @@ -105,7 +105,7 @@ private void RunCore(IChannel ch) ch.Trace("Creating pipeline"); var scorer = Args.Scorer; - var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorer); + var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); // REVIEW: We probably ought to prefer role mappings from the training schema. @@ -117,11 +117,11 @@ private void RunCore(IChannel ch) var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true); var mapper = bindable.Bind(Host, schema); - if (!scorer.IsGood()) + if (scorer == null) scorer = ScoreUtils.GetScorerComponent(mapper); loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(), - (env, view) => scorer.CreateInstance(env, view, mapper, trainSchema)); + (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema)); loader = CompositeDataLoader.Create(Host, loader, Args.PostTransform); @@ -226,12 +226,18 @@ public static class ScoreUtils public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema) { var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out var mapper); - return sc.CreateInstance(env, data.Data, mapper, trainSchema); + return sc.CreateComponent(env, data.Data, mapper, trainSchema); } - public static IDataScorerTransform GetScorer(SubComponent scorer, - IPredictor predictor, IDataView input, string featureColName, string groupColName, - IEnumerable> customColumns, IHostEnvironment env, RoleMappedSchema trainSchema) + public static IDataScorerTransform GetScorer( + IComponentFactory scorer, + IPredictor predictor, + IDataView input, + string featureColName, + string groupColName, + IEnumerable> customColumns, + IHostEnvironment env, + RoleMappedSchema trainSchema) { Contracts.CheckValue(env, nameof(env)); env.CheckValueOrNull(scorer); @@ -244,23 +250,23 @@ public static IDataScorerTransform GetScorer(SubComponent - /// Determines the scorer subcomponent (if the given one is null or empty), and creates the schema bound mapper. + /// Determines the scorer component factory (if the given one is null or empty), and creates the schema bound mapper. /// - private static SubComponent GetScorerComponentAndMapper( - IPredictor predictor, SubComponent scorer, + private static IComponentFactory GetScorerComponentAndMapper( + IPredictor predictor, IComponentFactory scorerFactory, RoleMappedSchema schema, IHostEnvironment env, out ISchemaBoundMapper mapper) { Contracts.AssertValue(env); - var bindable = GetSchemaBindableMapper(env, predictor, scorer); + var bindable = GetSchemaBindableMapper(env, predictor, scorerFactory as ICommandLineComponentFactory); env.AssertValue(bindable); mapper = bindable.Bind(env, schema); - if (scorer.IsGood()) - return scorer; + if (scorerFactory != null) + return scorerFactory; return GetScorerComponent(mapper); } @@ -269,24 +275,50 @@ private static SubComponent GetScorer /// metadata on the first column of the mapper. If that text is found and maps to a scorer loadable class, /// that component is used. Otherwise, the GenericScorer is used. /// - public static SubComponent GetScorerComponent(ISchemaBoundMapper mapper) + /// The schema bound mapper to get the default scorer.. + /// An optional suffix to append to the default column names. + public static IComponentFactory GetScorerComponent( + ISchemaBoundMapper mapper, + string suffix = null) { Contracts.AssertValue(mapper); - string loadName = null; + ComponentCatalog.LoadableClassInfo info = null; DvText scoreKind = default; if (mapper.OutputSchema.ColumnCount > 0 && mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) && scoreKind.HasChars) { - loadName = scoreKind.ToString(); - var info = ComponentCatalog.GetLoadableClassInfo(loadName); + var loadName = scoreKind.ToString(); + info = ComponentCatalog.GetLoadableClassInfo(loadName); if (info == null || !typeof(IDataScorerTransform).IsAssignableFrom(info.Type)) - loadName = null; + info = null; } - if (loadName == null) - loadName = GenericScorer.LoadName; - return new SubComponent(loadName); + return new SimpleComponentFactory( + (env, data, innerMapper, trainSchema) => + { + if (info == null) + { + return new GenericScorer( + env, + new GenericScorer.Arguments() { Suffix = suffix }, + data, + innerMapper, + trainSchema); + } + else + { + object args = info.CreateArguments(); + if (args is ScorerArgumentsBase scorerArgs) + { + scorerArgs.Suffix = suffix; + } + return (IDataScorerTransform)info.CreateInstance( + env, + args, + new object[] { data, innerMapper, trainSchema }); + } + }); } /// @@ -321,6 +353,38 @@ public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env return new SchemaBindablePredictorWrapper(predictor); } + /// + /// Given a predictor and an optional scorer factory settings, produces a compatible ISchemaBindableMapper. + /// First, it tries to instantiate the bindable mapper using the + /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one + /// of the scorer). + /// If the above fails, it checks whether the predictor implements + /// directly. + /// If this also isn't true, it will create a 'matching' standard mapper. + /// + public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env, IPredictor predictor, + ICommandLineComponentFactory scorerFactorySettings) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(predictor, nameof(predictor)); + env.CheckValueOrNull(scorerFactorySettings); + + // See if we can instantiate a mapper using scorer arguments. + if (scorerFactorySettings != null && TryCreateBindableFromScorer(env, predictor, scorerFactorySettings, out var bindable)) + return bindable; + + // The easy case is that the predictor implements the interface. + bindable = predictor as ISchemaBindableMapper; + if (bindable != null) + return bindable; + + // Use one of the standard wrappers. + if (predictor is IValueMapperDist) + return new SchemaBindableBinaryPredictorWrapper(predictor); + + return new SchemaBindablePredictorWrapper(predictor); + } + private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor, SubComponent scorerSettings, out ISchemaBindableMapper bindable) { @@ -332,5 +396,17 @@ private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor var mapperComponent = new SubComponent(scorerSettings.Kind, scorerSettings.Settings); return ComponentCatalog.TryCreateInstance(env, out bindable, mapperComponent, predictor); } + + private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor, + ICommandLineComponentFactory scorerSettings, out ISchemaBindableMapper bindable) + { + Contracts.AssertValue(env); + env.AssertValue(predictor); + env.AssertValue(scorerSettings); + + // Try to find a mapper factory method with the same loadname as the scorer settings. + return ComponentCatalog.TryCreateInstance( + env, out bindable, scorerSettings.Name, scorerSettings.GetSettingsString(), predictor); + } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index d0ebbd5a05..a487204c1b 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; [assembly: LoadableClass(TestCommand.Summary, typeof(TestCommand), typeof(TestCommand.Arguments), typeof(SignatureCommand), @@ -37,8 +38,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.LastOccurenceWins, HelpText = "Columns with custom kinds declared through key assignments, e.g., col[Kind]=Name to assign column named 'Name' kind 'Kind'", ShortName = "col", SortOrder = 10)] public KeyValuePair[] CustomColumn; - [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SortOrder = 101)] - public SubComponent Scorer; + [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SortOrder = 101, SignatureType = typeof(SignatureDataScorer))] + public IComponentFactory Scorer; [Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", NullName = "", SortOrder = 102)] public SubComponent Evaluator; diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index 69370ad3ef..0cb5bfb4e9 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -68,8 +69,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether we should cache input training data", ShortName = "cache")] public bool? CacheData; - [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "")] - public SubComponent Calibrator = new SubComponent("PlattCalibration"); + [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "", SignatureType = typeof(SignatureCalibrator))] + public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")] public int MaxCalibrationExamples = 1000000000; @@ -235,9 +236,9 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData } public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, - SubComponent calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null) + IComponentFactory calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null) { - ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env); + ICalibratorTrainer caliTrainer = calibrator?.CreateComponent(env); return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor); } diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index 03ee7cdf12..aea57b8882 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -26,8 +27,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr")] public SubComponent Trainer = new SubComponent("AveragedPerceptron"); - [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SortOrder = 101)] - public SubComponent Scorer; + [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SortOrder = 101, SignatureType = typeof(SignatureDataScorer))] + public IComponentFactory Scorer; [Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", NullName = "", SortOrder = 102)] public SubComponent Evaluator; @@ -62,8 +63,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether we should cache input training data", ShortName = "cache")] public bool? CacheData; - [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "")] - public SubComponent Calibrator = new SubComponent("PlattCalibration"); + [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "", SignatureType = typeof(SignatureCalibrator))] + public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")] public int MaxCalibrationExamples = 1000000000; diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs index 10bf816dc1..70be2bcb92 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.Data.Conversion; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Data.Utilities; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -74,8 +75,8 @@ public class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Path parser to extract column name/value pairs from the file path.", ShortName = "parser")] public IPartitionedPathParserFactory PathParserFactory = new ParquetPartitionedPathParserFactory(); - [Argument(ArgumentType.Multiple, HelpText = "The data loader.")] - public SubComponent Loader; + [Argument(ArgumentType.Multiple, HelpText = "The data loader.", SignatureType = typeof(SignatureDataLoader))] + public IComponentFactory Loader; } public sealed class Column @@ -173,6 +174,7 @@ public PartitionedFileLoader(IHostEnvironment env, Arguments args, IMultiStreamS Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); _host.CheckValue(args, nameof(args)); + _host.CheckValue(args.Loader, nameof(args.Loader)); _host.CheckValue(files, nameof(files)); _pathParser = args.PathParserFactory.CreateComponent(_host); @@ -180,7 +182,7 @@ public PartitionedFileLoader(IHostEnvironment env, Arguments args, IMultiStreamS _files = files; - var subLoader = args.Loader.CreateInstance(_host, _files); + var subLoader = args.Loader.CreateComponent(_host, _files); _subLoaderBytes = SaveLoaderToBytes(subLoader); string relativePath = GetRelativePath(args.BasePath, files); diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs index 312a92bccc..bd50dd2cc1 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs @@ -79,14 +79,12 @@ public static Output Score(IHostEnvironment env, Input input) using (var ch = host.Start("Creating scoring pipeline")) { ch.Trace("Creating pipeline"); - var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerSettings: null); + var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: null); ch.AssertValue(bindable); var mapper = bindable.Bind(host, data.Schema); - var scorer = ScoreUtils.GetScorerComponent(mapper); - Contracts.Assert(string.IsNullOrEmpty(scorer.SubComponentSettings)); - scorer.SubComponentSettings = string.Format("suffix={{{0}}}", input.Suffix); - scoredPipe = scorer.CreateInstance(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host)); + var scorer = ScoreUtils.GetScorerComponent(mapper, input.Suffix); + scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host)); ch.Done(); } @@ -132,12 +130,12 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input using (var ch = host.Start("Creating scoring pipeline")) { ch.Trace("Creating pipeline"); - var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerSettings: null); + var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: null); ch.AssertValue(bindable); var mapper = bindable.Bind(host, data.Schema); var scorer = ScoreUtils.GetScorerComponent(mapper); - scoredPipe = scorer.CreateInstance(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host)); + scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host)); ch.Done(); } diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 237afb400e..93441dc74c 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -622,14 +622,14 @@ private static VersionInfo GetVersionInfo() public SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) : base(env, LoaderSignature, predictor, calibrator) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor, null); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor, scorerFactorySettings: null); _whatTheFeature = SubPredictor as IWhatTheFeatureValueMapper; } private SchemaBindableCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, GetPredictor(env, ctx), GetCalibrator(env, ctx)) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor, null); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor, scorerFactorySettings: null); _whatTheFeature = SubPredictor as IWhatTheFeatureValueMapper; } @@ -717,7 +717,7 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat return false; } - var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, null); + var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, scorerFactorySettings: null); var bound = bindable.Bind(env, schema); var outputSchema = bound.OutputSchema; int scoreCol; diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index 458e83aaed..6af5216d93 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -36,8 +36,8 @@ public sealed class Arguments : TransformInputBase ShortName = "col", SortOrder = 101, Purpose = SpecialPurpose.ColumnSelector)] public KeyValuePair[] CustomColumn; - [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "")] - public SubComponent Scorer; + [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SignatureType = typeof(SignatureDataScorer))] + public IComponentFactory Scorer; [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Predictor model file used in scoring", ShortName = "in", SortOrder = 2)] @@ -145,14 +145,14 @@ public abstract class ArgumentsBase : ArgumentsBase public sealed class Arguments : ArgumentsBase { - [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "")] - public SubComponent Calibrator = new SubComponent("PlattCalibration"); + [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "", SignatureType = typeof(SignatureCalibrator))] + public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")] public int MaxCalibrationExamples = 1000000000; - [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "")] - public SubComponent Scorer; + [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "", SignatureType = typeof(SignatureDataScorer))] + public IComponentFactory Scorer; } internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data."; diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 3ac78ed91e..b482b874b6 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -67,7 +67,7 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch Parent.PredictorModels[i].PrepareData(Parent.Host, emptyDv, out RoleMappedData rmd, out IPredictor predictor); // Get the predictor as a bindable mapper, and bind it to the RoleMappedSchema found above. - var bindable = ScoreUtils.GetSchemaBindableMapper(Parent.Host, Parent.PredictorModels[i].Predictor, null); + var bindable = ScoreUtils.GetSchemaBindableMapper(Parent.Host, Parent.PredictorModels[i].Predictor, scorerFactorySettings: null); Mappers[i] = bindable.Bind(Parent.Host, rmd.Schema) as ISchemaBoundRowMapper; if (Mappers[i] == null) throw Parent.Host.Except("Predictor {0} is not a row to row mapper", i); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index f404f3ae95..1223fb2313 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -666,11 +666,6 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV ch.Assert(args.Trainer.IsGood()); ch.Trace("Creating TrainAndScoreTransform"); - string scorerSettings = CmdParser.GetSettings(ch, scorerArgs, - new TreeEnsembleFeaturizerBindableMapper.Arguments()); - var scorer = - new SubComponent( - TreeEnsembleFeaturizerBindableMapper.LoadNameShort, scorerSettings); var trainScoreArgs = new TrainAndScoreTransform.Arguments(); args.CopyTo(trainScoreArgs); @@ -678,7 +673,8 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV args.Trainer.Settings); var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); - trainScoreArgs.Scorer = scorer; + trainScoreArgs.Scorer = new SimpleComponentFactory( + (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); var scoreXf = TrainAndScoreTransform.Create(host, trainScoreArgs, labelInput); if (input == labelInput) return scoreXf; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 23a81f78ee..05aee405ca 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -7,10 +7,10 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Runtime.EntryPoints; namespace Microsoft.ML.Runtime.Learners { From 92f5b7a89381df80c17992ddbefd37f7da7554f6 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Thu, 9 Aug 2018 15:50:07 -0500 Subject: [PATCH 2/6] fix test issue --- .../Commands/CrossValidationCommand.cs | 3 +- .../Commands/ScoreCommand.cs | 38 +++++++++++++------ src/Microsoft.ML.Data/Commands/TestCommand.cs | 1 + .../Commands/TrainTestCommand.cs | 1 + .../EntryPoints/ScoreModel.cs | 4 +- .../Prediction/Calibrator.cs | 6 +-- .../Transforms/TrainAndScoreTransform.cs | 10 ++--- src/Microsoft.ML.Ensemble/PipelineEnsemble.cs | 2 +- .../TreeEnsembleFeaturizer.cs | 9 ++++- 9 files changed, 49 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 39f2f78173..3055afe917 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -559,7 +559,8 @@ private FoldResult RunFold(int fold) // Score. ch.Trace("Scoring and evaluating"); - var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer as ICommandLineComponentFactory); + ch.Assert(_scorer == null || _scorer is ICommandLineComponentFactory, "CrossValidationCommand should only be used from the command line."); + var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); var mapper = bindable.Bind(host, testData.Schema); var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(mapper); diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 76fd587fe1..322aa06833 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -105,7 +105,8 @@ private void RunCore(IChannel ch) ch.Trace("Creating pipeline"); var scorer = Args.Scorer; - var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorer as ICommandLineComponentFactory); + ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line."); + var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); // REVIEW: We probably ought to prefer role mappings from the training schema. @@ -225,7 +226,7 @@ public static class ScoreUtils { public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema) { - var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, out var mapper); + var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, null, out var mapper); return sc.CreateComponent(env, data.Data, mapper, trainSchema); } @@ -237,7 +238,8 @@ public static IDataScorerTransform GetScorer( string groupColName, IEnumerable> customColumns, IHostEnvironment env, - RoleMappedSchema trainSchema) + RoleMappedSchema trainSchema, + IComponentFactory mapperFactory = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValueOrNull(scorer); @@ -249,7 +251,7 @@ public static IDataScorerTransform GetScorer( env.CheckValueOrNull(trainSchema); var schema = new RoleMappedSchema(input.Schema, label: null, feature: featureColName, group: groupColName, custom: customColumns, opt: true); - var sc = GetScorerComponentAndMapper(predictor, scorer, schema, env, out var mapper); + var sc = GetScorerComponentAndMapper(predictor, scorer, schema, env, mapperFactory, out var mapper); return sc.CreateComponent(env, input, mapper, trainSchema); } @@ -257,12 +259,16 @@ public static IDataScorerTransform GetScorer( /// Determines the scorer component factory (if the given one is null or empty), and creates the schema bound mapper. /// private static IComponentFactory GetScorerComponentAndMapper( - IPredictor predictor, IComponentFactory scorerFactory, - RoleMappedSchema schema, IHostEnvironment env, out ISchemaBoundMapper mapper) + IPredictor predictor, + IComponentFactory scorerFactory, + RoleMappedSchema schema, + IHostEnvironment env, + IComponentFactory mapperFactory, + out ISchemaBoundMapper mapper) { Contracts.AssertValue(env); - var bindable = GetSchemaBindableMapper(env, predictor, scorerFactory as ICommandLineComponentFactory); + var bindable = GetSchemaBindableMapper(env, predictor, mapperFactory, scorerFactory as ICommandLineComponentFactory); env.AssertValue(bindable); mapper = bindable.Bind(env, schema); if (scorerFactory != null) @@ -354,21 +360,31 @@ public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env } /// - /// Given a predictor and an optional scorer factory settings, produces a compatible ISchemaBindableMapper. - /// First, it tries to instantiate the bindable mapper using the + /// Given a predictor, an optional mapper factory, and an optional scorer factory settings, + /// produces a compatible ISchemaBindableMapper. + /// First, it tries to instantiate the bindable mapper using the mapper factory. + /// Next, it tries to instantiate the bindable mapper using the /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one /// of the scorer). /// If the above fails, it checks whether the predictor implements /// directly. /// If this also isn't true, it will create a 'matching' standard mapper. /// - public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env, IPredictor predictor, - ICommandLineComponentFactory scorerFactorySettings) + public static ISchemaBindableMapper GetSchemaBindableMapper( + IHostEnvironment env, + IPredictor predictor, + IComponentFactory mapperFactory = null, + ICommandLineComponentFactory scorerFactorySettings = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(predictor, nameof(predictor)); + env.CheckValueOrNull(mapperFactory); env.CheckValueOrNull(scorerFactorySettings); + // if the mapperFactory was supplied, use it + if (mapperFactory != null) + return mapperFactory.CreateComponent(env, predictor); + // See if we can instantiate a mapper using scorer arguments. if (scorerFactorySettings != null && TryCreateBindableFromScorer(env, predictor, scorerFactorySettings, out var bindable)) return bindable; diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index a487204c1b..5a4b7671aa 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -108,6 +108,7 @@ private void RunCore(IChannel ch) // Score. ch.Trace("Scoring and evaluating"); + ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, loader, features, group, customCols, Host, trainSchema); // Evaluate. diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index aea57b8882..bc99d60eb3 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -183,6 +183,7 @@ private void RunCore(IChannel ch, string cmd) // Score. ch.Trace("Scoring and evaluating"); + ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs index bd50dd2cc1..260ff37f26 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs @@ -79,7 +79,7 @@ public static Output Score(IHostEnvironment env, Input input) using (var ch = host.Start("Creating scoring pipeline")) { ch.Trace("Creating pipeline"); - var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: null); + var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor); ch.AssertValue(bindable); var mapper = bindable.Bind(host, data.Schema); @@ -130,7 +130,7 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input using (var ch = host.Start("Creating scoring pipeline")) { ch.Trace("Creating pipeline"); - var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: null); + var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor); ch.AssertValue(bindable); var mapper = bindable.Bind(host, data.Schema); diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 93441dc74c..0e4b9f07e9 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -622,14 +622,14 @@ private static VersionInfo GetVersionInfo() public SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) : base(env, LoaderSignature, predictor, calibrator) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor, scorerFactorySettings: null); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor); _whatTheFeature = SubPredictor as IWhatTheFeatureValueMapper; } private SchemaBindableCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, GetPredictor(env, ctx), GetCalibrator(env, ctx)) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor, scorerFactorySettings: null); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubPredictor); _whatTheFeature = SubPredictor as IWhatTheFeatureValueMapper; } @@ -717,7 +717,7 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat return false; } - var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, scorerFactorySettings: null); + var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor); var bound = bindable.Bind(env, schema); var outputSchema = bound.OutputSchema; int scoreCol; diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index 6af5216d93..142fd2d97c 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -191,10 +191,10 @@ public static IDataTransform Create(IHostEnvironment env, GroupColumn = groupColumn }; - return Create(env, args, trainer, input); + return Create(env, args, trainer, input, null); } - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, IComponentFactory mapperFactory = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); @@ -202,10 +202,10 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); env.CheckValue(input, nameof(input)); - return Create(env, args, args.Trainer.CreateInstance(env), input); + return Create(env, args, args.Trainer.CreateInstance(env), input, mapperFactory); } - private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input, IComponentFactory mapperFactory) { Contracts.AssertValue(env, nameof(env)); env.AssertValue(args, nameof(args)); @@ -226,7 +226,7 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrai ch.Done(); - return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, data.Schema); + return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, data.Schema, mapperFactory); } } diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index b482b874b6..e4d6cb0d9e 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -67,7 +67,7 @@ public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema sch Parent.PredictorModels[i].PrepareData(Parent.Host, emptyDv, out RoleMappedData rmd, out IPredictor predictor); // Get the predictor as a bindable mapper, and bind it to the RoleMappedSchema found above. - var bindable = ScoreUtils.GetSchemaBindableMapper(Parent.Host, Parent.PredictorModels[i].Predictor, scorerFactorySettings: null); + var bindable = ScoreUtils.GetSchemaBindableMapper(Parent.Host, Parent.PredictorModels[i].Predictor); Mappers[i] = bindable.Bind(Parent.Host, rmd.Schema) as ISchemaBoundRowMapper; if (Mappers[i] == null) throw Parent.Host.Except("Predictor {0} is not a row to row mapper", i); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 1223fb2313..88ff3eaf62 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -672,10 +672,15 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV trainScoreArgs.Trainer = new SubComponent(args.Trainer.Kind, args.Trainer.Settings); - var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); trainScoreArgs.Scorer = new SimpleComponentFactory( (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); - var scoreXf = TrainAndScoreTransform.Create(host, trainScoreArgs, labelInput); + + var mapperFactory = new SimpleComponentFactory( + (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor)); + + var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); + var scoreXf = TrainAndScoreTransform.Create(host, trainScoreArgs, labelInput, mapperFactory); + if (input == labelInput) return scoreXf; return (IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput); From 446c72cba14b6a12a969cf2561d9be7f5d9371cc Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Thu, 9 Aug 2018 17:15:00 -0500 Subject: [PATCH 3/6] Remove last SubComponent usage from ScoreCommand. --- src/Microsoft.ML.Api/ComponentCreation.cs | 13 +- .../CommandLine/CmdParser.cs | 333 ++++++++++-------- .../Commands/ScoreCommand.cs | 44 --- 3 files changed, 195 insertions(+), 195 deletions(-) diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs index 3080a8197c..0a1e1cd605 100644 --- a/src/Microsoft.ML.Api/ComponentCreation.cs +++ b/src/Microsoft.ML.Api/ComponentCreation.cs @@ -6,6 +6,7 @@ using System.IO; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; namespace Microsoft.ML.Runtime.Api @@ -304,12 +305,20 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin env.CheckValue(predictor, nameof(predictor)); env.CheckValueOrNull(trainSchema); - var subComponent = SubComponent.Parse(settings); - var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, subComponent); + ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings); + var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings); var mapper = bindable.Bind(env, data.Schema); return CreateCore(env, settings, data.Data, mapper, trainSchema); } + private static ICommandLineComponentFactory ParseScorerSettings(string settings) + { + return CmdParser.CreateComponentFactory( + typeof(IComponentFactory), + typeof(SignatureDataScorer), + settings); + } + /// /// Creates a default data scorer appropriate to the predictor's prediction kind. /// diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index f91e9c90ba..b5634c4e4f 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -1297,6 +1297,41 @@ private static bool IsValidItemType(Type type) typeBase.IsEnum; } + /// + /// Creates an ICommandLineComponentFactory given the factory type, signature type, + /// and a command line string. + /// + public static ICommandLineComponentFactory CreateComponentFactory( + Type factoryType, + Type signatureType, + string settings) + { + ParseComponentStrings(settings, out string name, out string args); + + string[] argsArray = string.IsNullOrEmpty(args) ? Array.Empty() : new string[] { args }; + + return ComponentFactoryFactory.CreateComponentFactory(factoryType, signatureType, name, argsArray); + } + + private static void ParseComponentStrings(string str, out string kind, out string args) + { + kind = args = null; + if (string.IsNullOrWhiteSpace(str)) + return; + str = str.Trim(); + int ich = str.IndexOf('{'); + if (ich < 0) + { + kind = str; + return; + } + if (ich == 0 || str[str.Length - 1] != '}') + throw Contracts.Except("Invalid Component string: mismatched braces, or empty component name."); + + kind = str.Substring(0, ich); + args = CmdLexer.UnquoteValue(str.Substring(ich)); + } + private sealed class ArgValue { public readonly string FirstValue; @@ -1840,155 +1875,6 @@ public bool Finish(CmdParser owner, ArgValue val, object destination) return error; } - /// - /// A factory class for creating IComponentFactory instances. - /// - private static class ComponentFactoryFactory - { - public static IComponentFactory CreateComponentFactory( - Type factoryType, - Type signatureType, - string name, - string[] settings) - { - Contracts.Check(factoryType != null && - typeof(IComponentFactory).IsAssignableFrom(factoryType) && - factoryType.IsGenericType); - - Type componentFactoryType; - switch (factoryType.GenericTypeArguments.Length) - { - case 1: componentFactoryType = typeof(ComponentFactory<>); break; - case 2: componentFactoryType = typeof(ComponentFactory<,>); break; - case 3: componentFactoryType = typeof(ComponentFactory<,,>); break; - case 4: componentFactoryType = typeof(ComponentFactory<,,,>); break; - default: throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create component factories with 4 or less type args."); - } - - return (IComponentFactory)Activator.CreateInstance( - componentFactoryType.MakeGenericType(factoryType.GenericTypeArguments), - signatureType, - name, - settings); - } - - private abstract class ComponentFactory : ICommandLineComponentFactory - { - public Type SignatureType { get; } - public string Name { get; } - private string[] Settings { get; } - - protected ComponentFactory(Type signatureType, string name, string[] settings) - { - SignatureType = signatureType; - Name = name; - - if (settings == null || (settings.Length == 1 && string.IsNullOrEmpty(settings[0]))) - { - settings = Array.Empty(); - } - Settings = settings; - } - - public string GetSettingsString() - { - return CombineSettings(Settings); - } - - public override string ToString() - { - if (string.IsNullOrEmpty(Name) && Settings.Length == 0) - return "{}"; - - if (Settings.Length == 0) - return Name; - - string str = CombineSettings(Settings); - StringBuilder sb = new StringBuilder(); - CmdQuoter.QuoteValue(str, sb, true); - return Name + sb.ToString(); - } - } - - private class ComponentFactory : ComponentFactory, IComponentFactory - where TComponent : class - { - public ComponentFactory(Type signatureType, string name, string[] settings) - : base(signatureType, name, settings) - { - } - - public TComponent CreateComponent(IHostEnvironment env) - { - return ComponentCatalog.CreateInstance( - env, - SignatureType, - Name, - GetSettingsString()); - } - } - - private class ComponentFactory : ComponentFactory, IComponentFactory - where TComponent : class - { - public ComponentFactory(Type signatureType, string name, string[] settings) - : base(signatureType, name, settings) - { - } - - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) - { - return ComponentCatalog.CreateInstance( - env, - SignatureType, - Name, - GetSettingsString(), - argument1); - } - } - - private class ComponentFactory : ComponentFactory, IComponentFactory - where TComponent : class - { - public ComponentFactory(Type signatureType, string name, string[] settings) - : base(signatureType, name, settings) - { - } - - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2) - { - return ComponentCatalog.CreateInstance( - env, - SignatureType, - Name, - GetSettingsString(), - argument1, - argument2); - } - } - - private class ComponentFactory : ComponentFactory, IComponentFactory - where TComponent : class - { - public ComponentFactory(Type signatureType, string name, string[] settings) - : base(signatureType, name, settings) - { - } - - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) - { - return ComponentCatalog.CreateInstance( - env, - SignatureType, - Name, - GetSettingsString(), - argument1, - argument2, - argument3); - } - } - } - private bool ReportMissingRequiredArgument(CmdParser owner, ArgValue val) { if (!IsRequired || val != null) @@ -2624,5 +2510,154 @@ public bool IsCustomItemType { get { return _infoCustom != null; } } } + + /// + /// A factory class for creating IComponentFactory instances. + /// + private static class ComponentFactoryFactory + { + public static ICommandLineComponentFactory CreateComponentFactory( + Type factoryType, + Type signatureType, + string name, + string[] settings) + { + Contracts.Check(factoryType != null && + typeof(IComponentFactory).IsAssignableFrom(factoryType) && + factoryType.IsGenericType); + + Type componentFactoryType; + switch (factoryType.GenericTypeArguments.Length) + { + case 1: componentFactoryType = typeof(ComponentFactory<>); break; + case 2: componentFactoryType = typeof(ComponentFactory<,>); break; + case 3: componentFactoryType = typeof(ComponentFactory<,,>); break; + case 4: componentFactoryType = typeof(ComponentFactory<,,,>); break; + default: throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create component factories with 4 or less type args."); + } + + return (ICommandLineComponentFactory)Activator.CreateInstance( + componentFactoryType.MakeGenericType(factoryType.GenericTypeArguments), + signatureType, + name, + settings); + } + + private abstract class ComponentFactory : ICommandLineComponentFactory + { + public Type SignatureType { get; } + public string Name { get; } + private string[] Settings { get; } + + protected ComponentFactory(Type signatureType, string name, string[] settings) + { + SignatureType = signatureType; + Name = name; + + if (settings == null || (settings.Length == 1 && string.IsNullOrEmpty(settings[0]))) + { + settings = Array.Empty(); + } + Settings = settings; + } + + public string GetSettingsString() + { + return CombineSettings(Settings); + } + + public override string ToString() + { + if (string.IsNullOrEmpty(Name) && Settings.Length == 0) + return "{}"; + + if (Settings.Length == 0) + return Name; + + string str = CombineSettings(Settings); + StringBuilder sb = new StringBuilder(); + CmdQuoter.QuoteValue(str, sb, true); + return Name + sb.ToString(); + } + } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString()); + } + } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString(), + argument1); + } + } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString(), + argument1, + argument2); + } + } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString(), + argument1, + argument2, + argument3); + } + } + } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 322aa06833..b0d629a002 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -327,38 +327,6 @@ public static IComponentFactory - /// Given a predictor and an optional scorer SubComponent, produces a compatible ISchemaBindableMapper. - /// First, it tries to instantiate the bindable mapper using the - /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one - /// of the scorer). - /// If the above fails, it checks whether the predictor implements - /// directly. - /// If this also isn't true, it will create a 'matching' standard mapper. - /// - public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env, IPredictor predictor, - SubComponent scorerSettings) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(predictor, nameof(predictor)); - env.CheckValueOrNull(scorerSettings); - - // See if we can instantiate a mapper using scorer arguments. - if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out var bindable)) - return bindable; - - // The easy case is that the predictor implements the interface. - bindable = predictor as ISchemaBindableMapper; - if (bindable != null) - return bindable; - - // Use one of the standard wrappers. - if (predictor is IValueMapperDist) - return new SchemaBindableBinaryPredictorWrapper(predictor); - - return new SchemaBindablePredictorWrapper(predictor); - } - /// /// Given a predictor, an optional mapper factory, and an optional scorer factory settings, /// produces a compatible ISchemaBindableMapper. @@ -401,18 +369,6 @@ public static ISchemaBindableMapper GetSchemaBindableMapper( return new SchemaBindablePredictorWrapper(predictor); } - private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor, - SubComponent scorerSettings, out ISchemaBindableMapper bindable) - { - Contracts.AssertValue(env); - env.AssertValue(predictor); - env.Assert(scorerSettings.IsGood()); - - // Try to find a mapper factory method with the same loadname as the scorer settings. - var mapperComponent = new SubComponent(scorerSettings.Kind, scorerSettings.Settings); - return ComponentCatalog.TryCreateInstance(env, out bindable, mapperComponent, predictor); - } - private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor, ICommandLineComponentFactory scorerSettings, out ISchemaBindableMapper bindable) { From c0f1c3b2853b5dc16b44f643a1cbea122277f52e Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Thu, 9 Aug 2018 18:20:12 -0500 Subject: [PATCH 4/6] Keep the Create method's signature so DI can find it. --- .../Transforms/TrainAndScoreTransform.cs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index 142fd2d97c..f4a8a5722f 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -194,13 +194,25 @@ public static IDataTransform Create(IHostEnvironment env, return Create(env, args, trainer, input, null); } - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, IComponentFactory mapperFactory = null) + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer), + "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); + env.CheckValue(input, nameof(input)); + + return Create(env, args, args.Trainer.CreateInstance(env), input, null); + } + + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, IComponentFactory mapperFactory) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer), "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); env.CheckValue(input, nameof(input)); + env.CheckValueOrNull(mapperFactory); return Create(env, args, args.Trainer.CreateInstance(env), input, mapperFactory); } From f32af763c00f7ede931e369c404a3f3217bb87ca Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Tue, 14 Aug 2018 17:13:34 -0500 Subject: [PATCH 5/6] Respond to PR feedback. --- .../Commands/ScoreCommand.cs | 60 ++++++++++--------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index b0d629a002..635d751d2e 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -20,6 +20,8 @@ namespace Microsoft.ML.Runtime.Data { + using TScorerFactory = IComponentFactory; + public interface IDataScorerTransform : IDataTransform, ITransformTemplate { } @@ -52,7 +54,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase public KeyValuePair[] CustomColumn; [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", SignatureType = typeof(SignatureDataScorer))] - public IComponentFactory Scorer; + public TScorerFactory Scorer; [Argument(ArgumentType.Multiple, HelpText = "The data saver to use")] public SubComponent Saver; @@ -231,7 +233,7 @@ public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedDat } public static IDataScorerTransform GetScorer( - IComponentFactory scorer, + TScorerFactory scorer, IPredictor predictor, IDataView input, string featureColName, @@ -258,9 +260,9 @@ public static IDataScorerTransform GetScorer( /// /// Determines the scorer component factory (if the given one is null or empty), and creates the schema bound mapper. /// - private static IComponentFactory GetScorerComponentAndMapper( + private static TScorerFactory GetScorerComponentAndMapper( IPredictor predictor, - IComponentFactory scorerFactory, + TScorerFactory scorerFactory, RoleMappedSchema schema, IHostEnvironment env, IComponentFactory mapperFactory, @@ -283,7 +285,7 @@ private static IComponentFactory /// The schema bound mapper to get the default scorer.. /// An optional suffix to append to the default column names. - public static IComponentFactory GetScorerComponent( + public static TScorerFactory GetScorerComponent( ISchemaBoundMapper mapper, string suffix = null) { @@ -300,31 +302,35 @@ public static IComponentFactory( - (env, data, innerMapper, trainSchema) => + + Func factoryFunc; + if (info == null) + { + factoryFunc = (env, data, innerMapper, trainSchema) => + new GenericScorer( + env, + new GenericScorer.Arguments() { Suffix = suffix }, + data, + innerMapper, + trainSchema); + } + else + { + factoryFunc = (env, data, innerMapper, trainSchema) => { - if (info == null) + object args = info.CreateArguments(); + if (args is ScorerArgumentsBase scorerArgs) { - return new GenericScorer( - env, - new GenericScorer.Arguments() { Suffix = suffix }, - data, - innerMapper, - trainSchema); + scorerArgs.Suffix = suffix; } - else - { - object args = info.CreateArguments(); - if (args is ScorerArgumentsBase scorerArgs) - { - scorerArgs.Suffix = suffix; - } - return (IDataScorerTransform)info.CreateInstance( - env, - args, - new object[] { data, innerMapper, trainSchema }); - } - }); + return (IDataScorerTransform)info.CreateInstance( + env, + args, + new object[] { data, innerMapper, trainSchema }); + }; + } + + return new SimpleComponentFactory(factoryFunc); } /// From 36888625b6fefd89adfd2f900d704c60cffba348 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Tue, 14 Aug 2018 17:13:58 -0500 Subject: [PATCH 6/6] Change CmdParser ComponentFactoryFactory to not throw an exception during parsing. --- .../CommandLine/CmdParser.cs | 75 +++++++++++++++---- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index b5634c4e4f..bd37a96f7b 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -1743,12 +1743,20 @@ public bool Finish(CmdParser owner, ArgValue val, object destination) settings = values.Select(x => (string)x.Value).ToArray(); Contracts.Check(_signatureType != null, "ComponentFactory Arguments need a SignatureType set."); - var factory = ComponentFactoryFactory.CreateComponentFactory( + if (ComponentFactoryFactory.TryCreateComponentFactory( ItemType, _signatureType, name, - settings); - Field.SetValue(destination, factory); + settings, + out ICommandLineComponentFactory factory)) + { + Field.SetValue(destination, factory); + } + else + { + owner.Report("There was an error creating the ComponentFactory. Ensure '{0}' is configured correctly.", LongName); + error = true; + } } else if (IsMultiSubComponent) { @@ -1813,12 +1821,20 @@ public bool Finish(CmdParser owner, ArgValue val, object destination) string[] settings = null; if (i < values.Count && IsCurlyGroup((string)values[i].Value) && string.IsNullOrEmpty(values[i].Key)) settings = new string[] { (string)values[i++].Value }; - var factory = ComponentFactoryFactory.CreateComponentFactory( + if (ComponentFactoryFactory.TryCreateComponentFactory( ItemValueType, _signatureType, name, - settings); - comList.Add(new KeyValuePair(tag, factory)); + settings, + out ICommandLineComponentFactory factory)) + { + comList.Add(new KeyValuePair(tag, factory)); + } + else + { + owner.Report("There was an error creating the ComponentFactory. Ensure '{0}' is configured correctly.", LongName); + error = true; + } } var arr = Array.CreateInstance(ItemType, comList.Count); @@ -1840,12 +1856,20 @@ public bool Finish(CmdParser owner, ArgValue val, object destination) string[] settings = null; if (i < values.Count && IsCurlyGroup((string)values[i].Value)) settings = new string[] { (string)values[i++].Value }; - var factory = ComponentFactoryFactory.CreateComponentFactory( + if (ComponentFactoryFactory.TryCreateComponentFactory( ItemValueType, _signatureType, name, - settings); - comList.Add(factory); + settings, + out ICommandLineComponentFactory factory)) + { + comList.Add(factory); + } + else + { + owner.Report("There was an error creating the ComponentFactory. Ensure '{0}' is configured correctly.", LongName); + error = true; + } } var arr = Array.CreateInstance(ItemValueType, comList.Count); @@ -2522,9 +2546,29 @@ public static ICommandLineComponentFactory CreateComponentFactory( string name, string[] settings) { - Contracts.Check(factoryType != null && - typeof(IComponentFactory).IsAssignableFrom(factoryType) && - factoryType.IsGenericType); + if (!TryCreateComponentFactory(factoryType, signatureType, name, settings, out ICommandLineComponentFactory factory)) + { + throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create IComponentFactory<> types with 4 or less type args."); + } + + return factory; + } + + public static bool TryCreateComponentFactory( + Type factoryType, + Type signatureType, + string name, + string[] settings, + out ICommandLineComponentFactory factory) + { + + if (factoryType == null || + !typeof(IComponentFactory).IsAssignableFrom(factoryType) || + !factoryType.IsGenericType) + { + factory = null; + return false; + } Type componentFactoryType; switch (factoryType.GenericTypeArguments.Length) @@ -2533,14 +2577,17 @@ public static ICommandLineComponentFactory CreateComponentFactory( case 2: componentFactoryType = typeof(ComponentFactory<,>); break; case 3: componentFactoryType = typeof(ComponentFactory<,,>); break; case 4: componentFactoryType = typeof(ComponentFactory<,,,>); break; - default: throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create component factories with 4 or less type args."); + default: + factory = null; + return false; } - return (ICommandLineComponentFactory)Activator.CreateInstance( + factory = (ICommandLineComponentFactory)Activator.CreateInstance( componentFactoryType.MakeGenericType(factoryType.GenericTypeArguments), signatureType, name, settings); + return true; } private abstract class ComponentFactory : ICommandLineComponentFactory