From da06e4d5aec595fd5e65bb8ff33b8df4884c6005 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Wed, 11 Jul 2018 15:55:44 -0700 Subject: [PATCH 01/13] Conversion of ITrainer.Train returns predictor, accepts TrainContext * `ITrainer.Train` returns a predictor. There is no `CreatePredictor` method on the interface. * `ITrainer.Train` always accepts a `TrainContext`. Dataset type is no longer a generic parameter. This context object replaces the functionality previously offered by the combination of `ITrainer`, `IValidatingTrainer`, `IIncrementalTrainer`, and `IIncrementalValidatingTrainer`, which is now captured in one `ITrainer.Train` method with differently configured contexts. * All trainers updated to these two new idioms. Many trainers correspondingly improved to no longer be stateful objects. (The exceptions are those that are just too far gone to be done with less than herculean effort at refactoring them to no longer use instance fields for their computation. Most notably, LBFGS and FastTree based trainers.) * Utility code meant to deal with the complexity of the aforementioned `IT/IVT/IIT/IIVT` idiom reduced considerably. * Opportunistic improvements to `ITrainer` implementors where observed. --- src/Microsoft.ML.Core/Prediction/ITrainer.cs | 115 +++---- .../Prediction/TrainContext.cs | 50 +++ .../Commands/TrainCommand.cs | 70 +--- src/Microsoft.ML.Data/Training/TrainerBase.cs | 44 +-- .../OutputCombiners/BaseStacking.cs | 7 +- .../OutputCombiners/MultiStacking.cs | 2 +- .../OutputCombiners/RegressionStacking.cs | 2 +- .../OutputCombiners/Stacking.cs | 2 +- .../Trainer/Binary/EnsembleTrainer.cs | 25 +- .../Trainer/EnsembleTrainerBase.cs | 85 +++-- .../MulticlassDataPartitionEnsembleTrainer.cs | 9 +- .../Regression/RegressionEnsembleTrainer.cs | 11 +- src/Microsoft.ML.FastTree/FastTree.cs | 15 +- .../FastTreeClassification.cs | 14 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 11 +- .../FastTreeRegression.cs | 22 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 12 +- src/Microsoft.ML.FastTree/GamTrainer.cs | 22 +- .../RandomForestClassification.cs | 21 +- .../RandomForestRegression.cs | 20 +- .../KMeansPlusPlusTrainer.cs | 57 ++-- .../LightGbmBinaryTrainer.cs | 8 +- .../LightGbmMulticlassTrainer.cs | 2 +- .../LightGbmRankingTrainer.cs | 10 +- .../LightGbmRegressionTrainer.cs | 8 +- .../LightGbmTrainerBase.cs | 105 +++--- src/Microsoft.ML.PCA/PcaTrainer.cs | 63 ++-- .../FactorizationMachineTrainer.cs | 61 +--- .../Standard/LinearClassificationTrainer.cs | 302 ++++++++---------- .../LogisticRegression/LbfgsPredictorBase.cs | 28 +- .../LogisticRegression/LogisticRegression.cs | 6 +- .../MulticlassLogisticRegression.cs | 8 +- .../MultiClass/MetaMulticlassTrainer.cs | 22 +- .../MultiClass/MultiClassNaiveBayesTrainer.cs | 18 +- .../Standard/MultiClass/Ova.cs | 7 +- .../Standard/MultiClass/Pkpd.cs | 5 +- .../Standard/OlsLinearRegression.cs | 103 +++--- .../Standard/Online/AveragedPerceptron.cs | 2 +- .../Standard/Online/LinearSvm.cs | 2 +- .../Standard/Online/OnlineGradientDescent.cs | 2 +- .../Standard/Online/OnlineLinear.cs | 52 +-- .../PoissonRegression/PoissonRegression.cs | 6 +- .../Standard/SdcaMultiClass.cs | 78 ++--- .../Standard/SdcaRegression.cs | 35 +- .../Standard/Simple/SimpleTrainers.cs | 91 ++---- .../Algorithms/SmacSweeper.cs | 5 +- .../LearnerFeatureSelection.cs | 4 +- .../UnitTests/TestEntryPoints.cs | 6 +- .../IrisPlantClassificationTests.cs | 3 +- .../SentimentPredictionTests.cs | 3 +- 50 files changed, 654 insertions(+), 1007 deletions(-) create mode 100644 src/Microsoft.ML.Core/Prediction/TrainContext.cs diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs index 6e04a30e6f..972e6fb3a8 100644 --- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs +++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs @@ -2,9 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; -using System.IO; +using Microsoft.ML.Runtime.Data; namespace Microsoft.ML.Runtime { @@ -56,15 +55,9 @@ public interface ITrainerEx : ITrainer /// Whether this trainer could benefit from a cached view of the data. /// bool WantCaching { get; } - } - - public interface ITrainerHost - { - Random Rand { get; } - int Verbosity { get; } - TextWriter StdOut { get; } - TextWriter StdErr { get; } + bool SupportsValidation { get; } + bool SupportsIncrementalTraining { get; } } // The Trainer (of Factory) can optionally implement this. @@ -77,7 +70,8 @@ public interface IModelCombiner public delegate void SignatureModelCombiner(PredictionKind kind); /// - /// Weakly typed interface for a trainer "session" that produces a predictor. + /// The base interface for a trainers. Implementors should not implement this interface directly, + /// but rather implement the more specific . /// public interface ITrainer { @@ -87,91 +81,60 @@ public interface ITrainer PredictionKind PredictionKind { get; } /// - /// Returns the trained predictor. - /// REVIEW: Consider removing this. - /// - IPredictor CreatePredictor(); - } - - /// - /// Interface implemented by the MetalinearLearners base class. - /// Used to distinguish the MetaLinear Learners from the other learners - /// - public interface IMetaLinearTrainer - { - - } - - public interface ITrainer : ITrainer - { - /// - /// Trains a predictor using the specified dataset. + /// Trains a predictor. /// - /// Training dataset - void Train(TDataSet data); + /// A context containing at least the training data + /// The trained predictor + /// + IPredictor Train(TrainContext context); } /// - /// Strongly typed generic interface for a trainer. A trainer object takes - /// supervision data and produces a predictor. + /// Strongly typed generic interface for a trainer. A trainer object takes training data + /// and produces a predictor. /// - /// Type of the training dataset /// Type of predictor produced - public interface ITrainer : ITrainer + public interface ITrainer : ITrainer where TPredictor : IPredictor { /// - /// Returns the trained predictor. + /// Trains a predictor. /// - /// Trained predictor ready to make predictions - new TPredictor CreatePredictor(); + /// A context containing at least the training data + /// The trained predictor + new TPredictor Train(TrainContext context); } - /// - /// Trainers that want data to do their own validation implement this interface. - /// - public interface IValidatingTrainer : ITrainer + public static class TrainerExtensions { /// - /// Trains a predictor using the specified dataset. + /// Convenience train extension for the case where one has only a training set with no auxiliary information. + /// Equivalent to calling + /// on a constructed with . /// - /// Training dataset - /// Validation dataset - void Train(TDataSet data, TDataSet validData); - } + /// The trainer + /// The training data. + /// The trained predictor + public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData) + => trainer.Train(new TrainContext(trainData)); - public interface IIncrementalTrainer : ITrainer - { - /// - /// Trains a predictor using the specified dataset and a trained predictor. - /// - /// Training dataset - /// A trained predictor - void Train(TDataSet data, TPredictor predictor); - } - - public interface IIncrementalValidatingTrainer : ITrainer - { /// - /// Trains a predictor using the specified dataset and a trained predictor. + /// Convenience train extension for the case where one has only a training set with no auxiliary information. + /// Equivalent to calling + /// on a constructed with . /// - /// Training dataset - /// Validation dataset - /// A trained predictor - void Train(TDataSet data, TDataSet validData, TPredictor predictor); + /// The trainer + /// The training data. + /// The trained predictor + public static TPredictor Train(this ITrainer trainer, RoleMappedData trainData) where TPredictor : IPredictor + => trainer.Train(new TrainContext(trainData)); } -#if FUTURE - public interface IMultiTrainer : - IMultiTrainer - { - } - - public interface IMultiTrainer : - ITrainer + /// + /// Interface implemented by the MetalinearLearners base class. + /// Used to distinguish the MetaLinear Learners from the other learners + /// + public interface IMetaLinearTrainer { - void UpdatePredictor(TDataBatch trainInstance); - IPredictor GetCurrentPredictor(); } -#endif } diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs new file mode 100644 index 0000000000..a85e1f8fdf --- /dev/null +++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Runtime +{ + /// + /// Instances of this class are meant to be constructed and passed to trainers. + /// + public sealed class TrainContext + { + /// + /// The training set. Cannot be null. + /// + public RoleMappedData Train { get; } + + /// + /// The validation set. Can be null. + /// + public RoleMappedData Validation { get; } + + /// + /// The initial + /// + public IPredictor InitialPredictor { get; } + + + /// + /// Constructor, given a training set and optional other arguments. + /// + /// Will be set to , must be specified + /// Will be set to if specified + /// Will be set to if specified + public TrainContext(RoleMappedData train, RoleMappedData valid = null, IPredictor initPredictor = null) + { + Contracts.CheckValue(train, nameof(train)); + Contracts.CheckValueOrNull(valid); + Contracts.CheckValueOrNull(initPredictor); + + // REVIEW: Should there be code here to ensure that the role mappings between the two are compatible? + // That is, all the role mappings are the same and the columns between them have identical types? + + Train = train; + Validation = valid; + InitialPredictor = initPredictor; + } + } +} diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index b5e3964157..86c8d63397 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -252,77 +252,27 @@ private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappe ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inpPredictor); - var trainerRmd = trainer as ITrainer; - if (trainerRmd == null) - throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name); - - Action, object, object, object> trainCoreAction = TrainCore; - IPredictor predictor; AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); - var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod( - typeof(RoleMappedData), - inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor)); - Action trainExam = trainerRmd.Train; - genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor }); - - ch.Trace("Constructing predictor"); - predictor = trainerRmd.CreatePredictor(); + var trainerEx = trainer as ITrainerEx; + if (inpPredictor != null && trainerEx?.SupportsIncrementalTraining != true) + { + ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) + + ": Trainer does not support incremental training."); + inpPredictor = null; + } + ch.Assert(validData == null || CanUseValidationData(trainer)); + var predictor = trainer.Train(new TrainContext(data, validData, inpPredictor)); return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data); } public static bool CanUseValidationData(ITrainer trainer) { Contracts.CheckValue(trainer, nameof(trainer)); - - if (trainer is ITrainer) - return trainer is IValidatingTrainer; - - return false; - } - - private static void TrainCore(IChannel ch, ITrainer trainer, Action train, TDataSet data, TDataSet validData = null, TPredictor predictor = null) - where TDataSet : class - where TPredictor : class - { - const string inputModelArg = nameof(TrainCommand.Arguments.InputModelFile); - if (validData != null) - { - if (predictor != null) - { - var incValidTrainer = trainer as IIncrementalValidatingTrainer; - if (incValidTrainer != null) - { - incValidTrainer.Train(data, validData, predictor); - return; - } - - ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer."); - } - - var validTrainer = trainer as IValidatingTrainer; - ch.AssertValue(validTrainer); - validTrainer.Train(data, validData); - } - else - { - if (predictor != null) - { - var incTrainer = trainer as IIncrementalTrainer; - if (incTrainer != null) - { - incTrainer.Train(data, predictor); - return; - } - - ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer."); - } - - train(data); - } + return (trainer as ITrainerEx)?.SupportsValidation ?? false; } public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor) diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs index 90f8b64a7c..f24f77292c 100644 --- a/src/Microsoft.ML.Data/Training/TrainerBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs @@ -4,11 +4,12 @@ namespace Microsoft.ML.Runtime.Training { - public abstract class TrainerBase : ITrainer, ITrainerEx + public abstract class TrainerBase : ITrainer, ITrainerEx + where TPredictor : IPredictor { public const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features."; - protected readonly IHost Host; + protected IHost Host { get; } public string Name { get; } public abstract PredictionKind PredictionKind { get; } @@ -16,47 +17,20 @@ public abstract class TrainerBase : ITrainer, ITrainerEx public abstract bool NeedCalibration { get; } public abstract bool WantCaching { get; } + public virtual bool SupportsValidation => false; + public virtual bool SupportsIncrementalTraining => false; + protected TrainerBase(IHostEnvironment env, string name) { Contracts.CheckValue(env, nameof(env)); - Contracts.CheckNonEmpty(name, nameof(name)); + env.CheckNonEmpty(name, nameof(name)); Name = name; Host = env.Register(name); } - IPredictor ITrainer.CreatePredictor() - { - return CreatePredictorCore(); - } - - protected abstract IPredictor CreatePredictorCore(); - } - - public abstract class TrainerBase : TrainerBase - where TPredictor : IPredictor - { - protected TrainerBase(IHostEnvironment env, string name) - : base(env, name) - { - } - - public abstract TPredictor CreatePredictor(); - - protected sealed override IPredictor CreatePredictorCore() - { - return CreatePredictor(); - } - } - - public abstract class TrainerBase : TrainerBase, ITrainer - where TPredictor : IPredictor - { - protected TrainerBase(IHostEnvironment env, string name) - : base(env, name) - { - } + IPredictor ITrainer.Train(TrainContext context) => Train(context); - public abstract void Train(TDataSet data); + public abstract TPredictor Train(TrainContext context); } } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index f49e3af81c..f38aa2329f 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -27,10 +27,10 @@ public abstract class ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] [TGUI(Label = "Base predictor")] - public SubComponent>, TSigBase> BasePredictorType; + public SubComponent>, TSigBase> BasePredictorType; } - protected readonly SubComponent>, TSigBase> BasePredictorType; + protected readonly SubComponent>, TSigBase> BasePredictorType; protected readonly IHost Host; protected IPredictorProducing Meta; @@ -190,8 +190,7 @@ public void Train(List>> models, var trainer = BasePredictorType.CreateInstance(host); if (trainer is ITrainerEx ex && ex.NeedNormalization) ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this."); - trainer.Train(rmd); - Meta = trainer.CreatePredictor(); + Meta = trainer.Train(rmd); CheckMeta(); ch.Done(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 588dd89508..2ef74c8169 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -43,7 +43,7 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF public Arguments() { // REVIEW: Perhaps we can have a better non-parametetric learner. - BasePredictorType = new SubComponent, SignatureMultiClassClassifierTrainer>( + BasePredictorType = new SubComponent, SignatureMultiClassClassifierTrainer>( "OVA", "p=FastTreeBinaryClassification"); } } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index aeb011a51b..0b5f8e6057 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -39,7 +39,7 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF { public Arguments() { - BasePredictorType = new SubComponent, SignatureRegressorTrainer>("FastTreeRegression"); + BasePredictorType = new SubComponent, SignatureRegressorTrainer>("FastTreeRegression"); } public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index afd6e3f958..f3481e9936 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -37,7 +37,7 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto { public Arguments() { - BasePredictorType = new SubComponent, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification"); + BasePredictorType = new SubComponent, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification"); } public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 80fe4cbdee..7b3d7397e8 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -46,7 +46,7 @@ public sealed class Arguments : ArgumentsBase public Arguments() { - BasePredictors = new[] { new SubComponent, SignatureBinaryClassifierTrainer>("LinearSVM") }; + BasePredictors = new[] { new SubComponent, SignatureBinaryClassifierTrainer>("LinearSVM") }; } } @@ -60,16 +60,13 @@ public EnsembleTrainer(IHostEnvironment env, Arguments args) Combiner = args.OutputCombiner.CreateComponent(Host); } - public override PredictionKind PredictionKind - { - get { return PredictionKind.BinaryClassification; } - } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override TScalarPredictor CreatePredictor() + protected internal override TScalarPredictor CreatePredictor(List> models) { - if (Models.All(m => m.Predictor is TDistPredictor)) - return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(), Combiner); - return new EnsemblePredictor(Host, PredictionKind, CreateModels(), Combiner); + if (models.All(m => m.Predictor is TDistPredictor)) + return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(models), Combiner); + return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner); } public TScalarPredictor CombineModels(IEnumerable models) @@ -77,19 +74,13 @@ public TScalarPredictor CombineModels(IEnumerable models) var combiner = _outputCombiner.CreateComponent(Host); var p = models.First(); - TScalarPredictor predictor = null; if (p is TDistPredictor) { - predictor = new EnsembleDistributionPredictor(Host, p.PredictionKind, + return new EnsembleDistributionPredictor(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel((TDistPredictor)k)).ToArray(), combiner); } - else - { - predictor = new EnsemblePredictor(Host, p.PredictionKind, + return new EnsemblePredictor(Host, p.PredictionKind, models.Select(k => new FeatureSubsetModel(k)).ToArray(), combiner); - } - - return predictor; } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 776b1f5f53..810eada855 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using Stopwatch = System.Diagnostics.Stopwatch; - public abstract class EnsembleTrainerBase : TrainerBase + public abstract class EnsembleTrainerBase : TrainerBase where TPredictor : class, IPredictorProducing where TSelector : class, ISubModelSelector where TCombiner : class, IOutputCombiner @@ -54,27 +54,22 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel public bool ShowMetrics; [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - public SubComponent>, TSig>[] BasePredictors; + public SubComponent>, TSig>[] BasePredictors; } private const int DefaultNumModels = 50; /// Command-line arguments - protected readonly ArgumentsBase Args; - protected readonly int NumModels; + protected internal readonly ArgumentsBase Args; + protected internal readonly int NumModels; /// Ensemble members - protected readonly ITrainer>[] Trainers; + protected internal readonly ITrainer>[] Trainers; private readonly ISubsetSelector _subsetSelector; - protected ISubModelSelector SubModelSelector; - protected IOutputCombiner Combiner; + protected internal ISubModelSelector SubModelSelector; + protected internal IOutputCombiner Combiner; - protected List>> Models; - - private readonly bool _needNorm; - private readonly bool _needCalibration; - - internal EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) + protected internal EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) : base(env, name) { Args = args; @@ -93,41 +88,36 @@ internal EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string na _subsetSelector = Args.SamplingType.CreateComponent(Host); - Trainers = new ITrainer>[NumModels]; + Trainers = new ITrainer>[NumModels]; for (int i = 0; i < Trainers.Length; i++) Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host); - _needNorm = Trainers.Any( - t => - { - return t is ITrainerEx nn && nn.NeedNormalization; - }); - _needCalibration = Trainers.Any( - t => - { - return t is ITrainerEx nn && nn.NeedCalibration; - }); + NeedNormalization = Trainers.Any(t => t is ITrainerEx nn && nn.NeedNormalization); + NeedCalibration = Trainers.Any(t => t is ITrainerEx nn && nn.NeedCalibration); ch.Done(); } } - public override bool NeedNormalization => _needNorm; + public override bool NeedNormalization { get; } - public override bool NeedCalibration => _needCalibration; + public override bool NeedCalibration { get; } // No matter the internal predictors, we are performing multiple passes over the data // so it is probably appropriate to always cache. public override bool WantCaching => true; - public override void Train(RoleMappedData data) + public sealed override TPredictor Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); + using (var ch = Host.Start("Training")) { - TrainCore(ch, data); + var pred = TrainCore(ch, context.Train); ch.Done(); + return pred; } } - private void TrainCore(IChannel ch, RoleMappedData data) + private TPredictor TrainCore(IChannel ch, RoleMappedData data) { Host.AssertValue(ch); ch.AssertValue(data); @@ -143,6 +133,7 @@ private void TrainCore(IChannel ch, RoleMappedData data) validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion); var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager; + var Models = new List>>(); _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion); int batchNumber = 1; @@ -150,7 +141,7 @@ private void TrainCore(IChannel ch, RoleMappedData data) { // 2. Core train ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++); - var models = new FeatureSubsetModel>[Trainers.Length]; + var batchModels = new FeatureSubsetModel>[Trainers.Length]; Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand), new ParallelOptions() { MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1 }, @@ -162,26 +153,24 @@ private void TrainCore(IChannel ch, RoleMappedData data) { if (EnsureMinimumFeaturesSelected(subset)) { - Trainers[(int)index].Train(subset.Data); - var model = new FeatureSubsetModel>( - Trainers[(int)index].CreatePredictor(), + Trainers[(int)index].Train(subset.Data), subset.SelectedFeatures, null); SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics); - models[(int)index] = model; + batchModels[(int)index] = model; } } catch (Exception ex) { - ch.Assert(models[(int)index] == null); + ch.Assert(batchModels[(int)index] == null); ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.", index + 1, Trainers.Length, ex.Message); } ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed); }); - var modelsList = models.Where(m => m != null).ToList(); + var modelsList = batchModels.Where(m => m != null).ToList(); if (Args.ShowMetrics) PrintMetrics(ch, modelsList); @@ -190,15 +179,17 @@ private void TrainCore(IChannel ch, RoleMappedData data) if (stackingTrainer != null) stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host); - foreach (var model in modelsList) - Utils.Add(ref Models, model); + Models.AddRange(modelsList); int modelSize = Utils.Size(Models); if (modelSize < Utils.Size(Trainers)) ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers)); ch.Check(modelSize > 0, "Ensemble training resulted in no valid models."); } + return CreatePredictor(Models); } + protected internal abstract TPredictor CreatePredictor(List>> models); + private bool EnsureMinimumFeaturesSelected(Subset subset) { if (subset.SelectedFeatures == null) @@ -212,7 +203,7 @@ private bool EnsureMinimumFeaturesSelected(Subset subset) return false; } - protected virtual void PrintMetrics(IChannel ch, List>> models) + protected internal virtual void PrintMetrics(IChannel ch, List>> models) { // REVIEW: The formatting of this method is bizarre and seemingly not even self-consistent // w.r.t. its usage of |. Is this intentional? @@ -225,17 +216,17 @@ protected virtual void PrintMetrics(IChannel ch, List string.Format("| {0} |", m.Value))), model.Predictor.GetType().Name); } - protected FeatureSubsetModel[] CreateModels() where T : IPredictor + protected internal static FeatureSubsetModel[] CreateModels(List>> models) where T : IPredictor { - var models = new FeatureSubsetModel[Models.Count]; - for (int i = 0; i < Models.Count; i++) + var subsetModels = new FeatureSubsetModel[models.Count]; + for (int i = 0; i < models.Count; i++) { - models[i] = new FeatureSubsetModel( - (T)Models[i].Predictor, - Models[i].SelectedFeatures, - Models[i].Metrics); + subsetModels[i] = new FeatureSubsetModel( + (T)models[i].Predictor, + models[i].SelectedFeatures, + models[i].Metrics); } - return models; + return subsetModels; } } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 0e6b4f6a53..0e4f0a043e 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -47,7 +47,7 @@ public sealed class Arguments : ArgumentsBase public Arguments() { - BasePredictors = new[] { new SubComponent, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") }; + BasePredictors = new[] { new SubComponent, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") }; } } @@ -61,12 +61,11 @@ public MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments ar Combiner = args.OutputCombiner.CreateComponent(Host); } - public override PredictionKind PredictionKind { get { return PredictionKind.MultiClassClassification; } } + public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public override EnsembleMultiClassPredictor CreatePredictor() + protected internal override EnsembleMultiClassPredictor CreatePredictor(List> models) { - var combiner = Combiner; - return new EnsembleMultiClassPredictor(Host, CreateModels(), combiner as IMultiClassOutputCombiner); + return new EnsembleMultiClassPredictor(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner); } public TVectorPredictor CombineModels(IEnumerable models) diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 322c1e02a1..bf7671d60d 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -41,7 +41,7 @@ public sealed class Arguments : ArgumentsBase public Arguments() { - BasePredictors = new[] { new SubComponent, SignatureRegressorTrainer>("OnlineGradientDescent") }; + BasePredictors = new[] { new SubComponent, SignatureRegressorTrainer>("OnlineGradientDescent") }; } } @@ -55,14 +55,11 @@ public RegressionEnsembleTrainer(IHostEnvironment env, Arguments args) Combiner = args.OutputCombiner.CreateComponent(Host); } - public override PredictionKind PredictionKind - { - get { return PredictionKind.Regression; } - } + public override PredictionKind PredictionKind => PredictionKind.Regression; - public override TScalarPredictor CreatePredictor() + protected internal override TScalarPredictor CreatePredictor(List> models) { - return new EnsemblePredictor(Host, PredictionKind, CreateModels(), Combiner); + return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner); } public TScalarPredictor CombineModels(IEnumerable models) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index e7432139c5..25f44712e6 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -44,8 +44,7 @@ internal static class FastTreeShared } public abstract class FastTreeTrainerBase : - TrainerBase, - IValidatingTrainer + TrainerBase where TArgs : TreeArgs, new() where TPredictor : IPredictorProducing { @@ -88,6 +87,8 @@ public abstract class FastTreeTrainerBase : public bool HasCategoricalFeatures => Utils.Size(CategoricalFeatures) > 0; + public override bool SupportsValidation => true; + protected internal FastTreeTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) { @@ -125,14 +126,6 @@ protected internal FastTreeTrainerBase(IHostEnvironment env, TArgs args) protected abstract ObjectiveFunctionBase ConstructObjFunc(IChannel ch); - public void Train(RoleMappedData trainData, RoleMappedData validationData) - { - // REVIEW: Idiotic. This should be reversed... the other train method should - // be put in here, rather than having this "hidden argument" through an instance field. - ValidData = validationData; - Train(trainData); - } - protected virtual Float GetMaxLabel() { return Float.PositiveInfinity; @@ -1887,7 +1880,7 @@ private void MakeBoundariesAndCheckLabels(out long missingInstances, out long to missingInstances = cursor.BadFeaturesRowCount; } - ch.Check(totalInstances > 0, TrainerBase.NoTrainingInstancesMessage); + ch.Check(totalInstances > 0, TrainerBase.NoTrainingInstancesMessage); if (missingInstances > 0) ch.Warning("Skipped {0} instances with missing features during training", missingInstances); diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 18f61e6dbe..db266ac71c 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -118,10 +118,14 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) public override bool NeedCalibration => false; - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override void Train(RoleMappedData trainData) + public override IPredictorWithFeatureWeights Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); + var trainData = context.Train; + ValidData = context.Validation; + using (var ch = Host.Start("Training")) { ch.CheckValue(trainData, nameof(trainData)); @@ -133,12 +137,6 @@ public override void Train(RoleMappedData trainData) TrainCore(ch); ch.Done(); } - } - - public override IPredictorWithFeatureWeights CreatePredictor() - { - Host.Check(TrainedEnsemble != null, - "The predictor cannot be created before training is complete"); // The FastTree binary classification boosting is naturally calibrated to // output probabilities when transformed using a scaled logistic function, diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index e0b6933726..620a6dedb2 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -65,8 +65,12 @@ protected override float GetMaxLabel() return GetLabelGains().Length - 1; } - public override void Train(RoleMappedData trainData) + public override FastTreeRankingPredictor Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); + var trainData = context.Train; + ValidData = context.Validation; + using (var ch = Host.Start("Training")) { var maxLabel = GetLabelGains().Length - 1; @@ -75,11 +79,6 @@ public override void Train(RoleMappedData trainData) FeatureCount = trainData.Schema.Feature.Type.ValueCount; ch.Done(); } - } - - public override FastTreeRankingPredictor CreatePredictor() - { - Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); return new FastTreeRankingPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 308437440a..719bdc781c 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -43,23 +43,23 @@ public sealed partial class FastTreeRegressionTrainer : BoostingFastTreeTrainerB private Test _trainRegressionTest; private Test _testRegressionTest; + public override bool NeedCalibration => false; + + public override PredictionKind PredictionKind => PredictionKind.Regression; + public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) : base(env, args) { } - public override bool NeedCalibration + public override FastTreeRegressionPredictor Train(TrainContext context) { - get { return false; } - } - - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + Host.CheckValue(context, nameof(context)); + var trainData = context.Train; + ValidData = context.Validation; - public override void Train(RoleMappedData trainData) - { using (var ch = Host.Start("Training")) { - ch.CheckValue(trainData, nameof(trainData)); trainData.CheckRegressionLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); @@ -68,12 +68,6 @@ public override void Train(RoleMappedData trainData) TrainCore(ch); ch.Done(); } - } - - public override FastTreeRegressionPredictor CreatePredictor() - { - Host.Check(TrainedEnsemble != null, - "The predictor cannot be created before training is complete"); return new FastTreeRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index b43c499a44..d26c02adc1 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -55,8 +55,12 @@ public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) Host.CheckUserArg(1 <= Args.Index && Args.Index <= 2, nameof(Args.Index), "Must be in the range [1, 2]"); } - public override void Train(RoleMappedData trainData) + public override FastTreeTweediePredictor Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); + var trainData = context.Train; + ValidData = context.Validation; + using (var ch = Host.Start("Training")) { ch.CheckValue(trainData, nameof(trainData)); @@ -68,12 +72,6 @@ public override void Train(RoleMappedData trainData) TrainCore(ch); ch.Done(); } - } - - public override FastTreeTweediePredictor CreatePredictor() - { - Host.Check(TrainedEnsemble != null, - "The predictor cannot be created before training is complete"); return new FastTreeTweediePredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); } diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 931dd2335b..dac3eb11d5 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -75,7 +75,7 @@ internal override void CheckLabel(RoleMappedData data) data.CheckRegressionLabel(); } - public override RegressionGamPredictor CreatePredictor() + protected internal override RegressionGamPredictor CreatePredictor() { return new RegressionGamPredictor(Host, InputLength, TrainSet, BinEffects, FeatureMap); } @@ -137,7 +137,7 @@ private bool[] ConvertTargetsToBool(double[] targets) return boolArray; } - public override BinaryClassGamPredictor CreatePredictor() + protected internal override BinaryClassGamPredictor CreatePredictor() { return new BinaryClassGamPredictor(Host, InputLength, TrainSet, BinEffects, FeatureMap); } @@ -152,9 +152,7 @@ protected override ObjectiveFunctionBase CreateObjectiveFunction() /// /// Generalized Additive Model Learner. /// - public abstract partial class GamTrainerBase : - TrainerBase, - ITrainer + public abstract partial class GamTrainerBase : TrainerBase where TArgs : GamTrainerBase.ArgumentsBase, new() where TPredictor : GamPredictorBase { @@ -233,7 +231,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight public override bool WantCaching => false; - public GamTrainerBase(IHostEnvironment env, TArgs args) + protected internal GamTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) { Contracts.CheckValue(env, nameof(env)); @@ -264,18 +262,22 @@ public GamTrainerBase(IHostEnvironment env, TArgs args) InitializeThreads(numThreads); } - public override void Train(RoleMappedData trainData) + public sealed override TPredictor Train(TrainContext context) { using (var ch = Host.Start("Training")) { - ch.CheckValue(trainData, nameof(trainData)); - ConvertData(trainData); - InputLength = trainData.Schema.Feature.Type.ValueCount; + ch.CheckValue(context, nameof(context)); + ConvertData(context.Train); + InputLength = context.Train.Schema.Feature.Type.ValueCount; TrainCore(ch); + var pred = CreatePredictor(); ch.Done(); + return pred; } } + protected internal abstract TPredictor CreatePredictor(); + internal abstract void CheckLabel(RoleMappedData data); private void ConvertData(RoleMappedData trainData) diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 8cd62ceb77..8affec5a2e 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -134,15 +134,15 @@ public FastForestClassification(IHostEnvironment env, Arguments args) { } - public override bool NeedCalibration - { - get { return true; } - } + public override bool NeedCalibration => true; + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } - - public override void Train(RoleMappedData trainData) + public override IPredictorWithFeatureWeights Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); + var trainData = context.Train; + ValidData = context.Validation; + using (var ch = Host.Start("Training")) { ch.CheckValue(trainData, nameof(trainData)); @@ -154,13 +154,6 @@ public override void Train(RoleMappedData trainData) TrainCore(ch); ch.Done(); } - } - - public override IPredictorWithFeatureWeights CreatePredictor() - { - Host.Check(TrainedEnsemble != null, - "The predictor cannot be created before training is complete"); - // LogitBoost is naturally calibrated to // output probabilities when transformed using // the logistic function, so if we have trained no diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index f501037df3..d1265a06bd 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -158,15 +158,16 @@ public FastForestRegression(IHostEnvironment env, Arguments args) { } - public override bool NeedCalibration - { - get { return false; } - } + public override bool NeedCalibration => false; - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; - public override void Train(RoleMappedData trainData) + public override FastForestRegressionPredictor Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); + var trainData = context.Train; + ValidData = context.Validation; + using (var ch = Host.Start("Training")) { ch.CheckValue(trainData, nameof(trainData)); @@ -178,13 +179,6 @@ public override void Train(RoleMappedData trainData) TrainCore(ch); ch.Done(); } - } - - public override FastForestRegressionPredictor CreatePredictor() - { - Host.Check(TrainedEnsemble != null, - "The predictor cannot be created before training is complete"); - return new FastForestRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount); } diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index dce7be48d2..8571ca10b7 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -29,7 +29,7 @@ namespace Microsoft.ML.Runtime.KMeans { /// - public class KMeansPlusPlusTrainer : TrainerBase + public class KMeansPlusPlusTrainer : TrainerBase { public const string LoadNameValue = "KMeansPlusPlus"; internal const string UserNameValue = "KMeans++ Clustering"; @@ -74,11 +74,6 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight } private readonly int _k; - private int _dimensionality; - - // The coordinates of the final centroids at the end of the training. During training - // it holds the centroids of the previous iteration. - private readonly VBuffer[] _centroids; private readonly int _maxIterations; // max number of iterations to train private readonly Float _convergenceThreshold; // convergence thresholds @@ -101,8 +96,6 @@ public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) Host.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive"); _convergenceThreshold = args.OptTol; - _centroids = new VBuffer[_k]; - Host.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Must be positive"); _accelMemBudgetMb = args.AccelMemBudgetMb; @@ -118,29 +111,35 @@ public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) public override bool WantCaching => true; public override PredictionKind PredictionKind => PredictionKind.Clustering; - public override void Train(RoleMappedData data) + public override KMeansPredictor Train(TrainContext context) { - Host.CheckValue(data, nameof(data)); + Host.CheckValue(context, nameof(context)); + var data = context.Train; - data.CheckFeatureFloatVector(out _dimensionality); - Contracts.Assert(_dimensionality > 0); + data.CheckFeatureFloatVector(out int dimensionality); + Contracts.Assert(dimensionality > 0); using (var ch = Host.Start("Training")) { - TrainCore(ch, data); + var pred = TrainCore(ch, data, dimensionality); ch.Done(); + return pred; } } - private void TrainCore(IChannel ch, RoleMappedData data) + private KMeansPredictor TrainCore(IChannel ch, RoleMappedData data, int dimensionality) { Host.AssertValue(ch); ch.AssertValue(data); - // REVIEW: In high-dimensionality cases this is less than ideal - // and we should consider using sparse buffers. + // REVIEW: In high-dimensionality cases this is less than ideal and we should consider + // using sparse buffers for the centroids. + + // The coordinates of the final centroids at the end of the training. During training + // it holds the centroids of the previous iteration. + var centroids = new VBuffer[_k]; for (int i = 0; i < _k; i++) - _centroids[i] = VBufferUtils.CreateDense(_dimensionality); + centroids[i] = VBufferUtils.CreateDense(dimensionality); ch.Info("Initializing centroids"); long missingFeatureCount; @@ -154,29 +153,29 @@ private void TrainCore(IChannel ch, RoleMappedData data) // pay attention to their incoming set of centroids and incrementally train. if (_initAlgorithm == InitAlgorithm.KMeansPlusPlus) { - KMeansPlusPlusInit.Initialize(Host, _numThreads, ch, cursorFactory, _k, _dimensionality, - _centroids, out missingFeatureCount, out totalTrainingInstances); + KMeansPlusPlusInit.Initialize(Host, _numThreads, ch, cursorFactory, _k, dimensionality, + centroids, out missingFeatureCount, out totalTrainingInstances); } else if (_initAlgorithm == InitAlgorithm.Random) { KMeansRandomInit.Initialize(Host, _numThreads, ch, cursorFactory, _k, - _centroids, out missingFeatureCount, out totalTrainingInstances); + centroids, out missingFeatureCount, out totalTrainingInstances); } else { // Defaulting to KMeans|| initialization. - KMeansBarBarInitialization.Initialize(Host, _numThreads, ch, cursorFactory, _k, _dimensionality, - _centroids, _accelMemBudgetMb, out missingFeatureCount, out totalTrainingInstances); + KMeansBarBarInitialization.Initialize(Host, _numThreads, ch, cursorFactory, _k, dimensionality, + centroids, _accelMemBudgetMb, out missingFeatureCount, out totalTrainingInstances); } - KMeansUtils.VerifyModelConsistency(_centroids); + KMeansUtils.VerifyModelConsistency(centroids); ch.Info("Centroids initialized, starting main trainer"); KMeansLloydsYinYangTrain.Train( - Host, _numThreads, ch, cursorFactory, totalTrainingInstances, _k, _dimensionality, _maxIterations, - _accelMemBudgetMb, _convergenceThreshold, _centroids); + Host, _numThreads, ch, cursorFactory, totalTrainingInstances, _k, dimensionality, _maxIterations, + _accelMemBudgetMb, _convergenceThreshold, centroids); - KMeansUtils.VerifyModelConsistency(_centroids); + KMeansUtils.VerifyModelConsistency(centroids); ch.Info("Model trained successfully on {0} instances", totalTrainingInstances); if (missingFeatureCount > 0) { @@ -184,11 +183,7 @@ private void TrainCore(IChannel ch, RoleMappedData data) "{0} instances with missing features detected and ignored. Consider using MissingHandler.", missingFeatureCount); } - } - - public override KMeansPredictor CreatePredictor() - { - return new KMeansPredictor(Host, _k, _centroids, copyIn: true); + return new KMeansPredictor(Host, _k, centroids, copyIn: true); } private static int ComputeNumThreads(IHost host, int? argNumThreads) diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 54cd523e72..4d778cf1a2 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -43,11 +43,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } } + protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized { get { return 0x00010004; } } + protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } } + protected override uint VerCategoricalSplitSerialized => 0x00010005; internal LightGbmBinaryPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -94,7 +94,7 @@ public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) { } - public override IPredictorWithFeatureWeights CreatePredictor() + protected internal override IPredictorWithFeatureWeights CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 2a84bad0e8..2f82d505a8 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -53,7 +53,7 @@ private LightGbmBinaryPredictor CreateBinaryPredictor(int classID, string innerA return new LightGbmBinaryPredictor(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs); } - public override OvaPredictor CreatePredictor() + protected internal override OvaPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete."); diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 4a1d1634a8..64c837fd9c 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -41,11 +41,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } } + protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized { get { return 0x00010004; } } + protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } } + protected override uint VerCategoricalSplitSerialized => 0x00010005; internal LightGbmRankingPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -68,7 +68,7 @@ public static LightGbmRankingPredictor Create(IHostEnvironment env, ModelLoadCon return new LightGbmRankingPredictor(env, ctx); } - public override PredictionKind PredictionKind { get { return PredictionKind.Ranking; } } + public override PredictionKind PredictionKind => PredictionKind.Ranking; } /// @@ -103,7 +103,7 @@ protected override void CheckDataValid(IChannel ch, RoleMappedData data) } } - public override LightGbmRankingPredictor CreatePredictor() + protected internal override LightGbmRankingPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 6ae3da792a..120bb1fd69 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -41,11 +41,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } } + protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized { get { return 0x00010004; } } + protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } } + protected override uint VerCategoricalSplitSerialized => 0x00010005; internal LightGbmRegressionPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -86,7 +86,7 @@ public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) { } - public override LightGbmRegressionPredictor CreatePredictor() + protected internal override LightGbmRegressionPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index c778c4ee23..0d9a9005ea 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -24,9 +24,7 @@ internal static class LightGbmShared /// /// Base class for all training with LightGBM. /// - public abstract class LightGbmTrainerBase : - ITrainer, - IValidatingTrainer + public abstract class LightGbmTrainerBase : TrainerBase where TPredictor : IPredictorProducing { private sealed class CategoricalMetaData @@ -39,86 +37,70 @@ private sealed class CategoricalMetaData public bool[] IsCategoricalFeature; } - #region members - private readonly IHostEnvironment _env; - private readonly PredictionKind _predictionKind; - - protected readonly IHost Host; - protected readonly LightGbmArguments Args; + protected internal readonly LightGbmArguments Args; /// /// Stores argumments as objects to convert them to invariant string type in the end so that /// the code is culture agnostic. When retrieving key value from this dictionary as string /// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]). /// - protected readonly Dictionary Options; - protected readonly IParallel ParallelTraining; + protected internal readonly Dictionary Options; + protected internal readonly IParallel ParallelTraining; // Store _featureCount and _trainedEnsemble to construct predictor. - protected int FeatureCount; - protected FastTree.Internal.Ensemble TrainedEnsemble; + protected internal int FeatureCount; + protected internal FastTree.Internal.Ensemble TrainedEnsemble; - #endregion + public override bool NeedNormalization => false; + public override bool NeedCalibration => false; + public override bool WantCaching => false; + public override bool SupportsValidation => true; - protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, PredictionKind predictionKind, string name) + protected internal LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, PredictionKind predictionKind, string name) + : base(env, name) { - Contracts.CheckValue(env, nameof(env)); - env.CheckNonWhiteSpace(name, nameof(name)); - - Host = env.Register(name); Host.CheckValue(args, nameof(args)); Args = args; Options = Args.ToDictionary(Host); - _predictionKind = predictionKind; - _env = env; + PredictionKind = predictionKind; ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); InitParallelTraining(); } - public void Train(RoleMappedData data) + public override TPredictor Train(TrainContext context) { - Dataset dtrain; - CategoricalMetaData catMetaData; - using (var ch = Host.Start("Loading data for LightGBM")) - { - using (var pch = Host.StartProgressChannel("Loading data for LightGBM")) - dtrain = LoadTrainingData(ch, data, out catMetaData); - ch.Done(); - } - using (var ch = Host.Start("Training with LightGBM")) - { - using (var pch = Host.StartProgressChannel("Training with LightGBM")) - TrainCore(ch, pch, dtrain, catMetaData); - ch.Done(); - } - dtrain.Dispose(); - DisposeParallelTraining(); - } + Host.CheckValue(context, nameof(context)); - public void Train(RoleMappedData data, RoleMappedData validData) - { - Dataset dtrain; - Dataset dvalid; + Dataset dtrain = null; + Dataset dvalid = null; CategoricalMetaData catMetaData; - using (var ch = Host.Start("Loading data for LightGBM")) + try { - using (var pch = Host.StartProgressChannel("Loading data for LightGBM")) + using (var ch = Host.Start("Loading data for LightGBM")) + { + using (var pch = Host.StartProgressChannel("Loading data for LightGBM")) + { + dtrain = LoadTrainingData(ch, context.Train, out catMetaData); + if (context.Validation != null) + dvalid = LoadValidationData(ch, dtrain, context.Validation, catMetaData); + } + ch.Done(); + } + using (var ch = Host.Start("Training with LightGBM")) { - dtrain = LoadTrainingData(ch, data, out catMetaData); - dvalid = LoadValidationData(ch, dtrain, validData, catMetaData); + using (var pch = Host.StartProgressChannel("Training with LightGBM")) + TrainCore(ch, pch, dtrain, catMetaData, dvalid); + ch.Done(); } - ch.Done(); } - using (var ch = Host.Start("Training with LightGBM")) + finally { - using (var pch = Host.StartProgressChannel("Training with LightGBM")) - TrainCore(ch, pch, dtrain, catMetaData, dvalid); - ch.Done(); + dtrain?.Dispose(); + dvalid?.Dispose(); + DisposeParallelTraining(); } - dtrain.Dispose(); - dvalid.Dispose(); - DisposeParallelTraining(); + return CreatePredictor(); } private void InitParallelTraining() @@ -178,7 +160,7 @@ protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCat private FloatLabelCursor.Factory CreateCursorFactory(RoleMappedData data) { var loadFlags = CursOpt.AllLabels | CursOpt.AllWeights | CursOpt.Features; - if (_predictionKind == PredictionKind.Ranking) + if (PredictionKind == PredictionKind.Ranking) loadFlags |= CursOpt.Group; var factory = new FloatLabelCursor.Factory(data, loadFlags); @@ -392,7 +374,7 @@ private void GetMetainfo(IChannel ch, FloatLabelCursor.Factory factory, List labelList = new List(); bool hasWeights = factory.Data.Schema.Weight != null; bool hasGroup = false; - if (_predictionKind == PredictionKind.Ranking) + if (PredictionKind == PredictionKind.Ranking) { ch.Check(factory.Data.Schema != null, "The data for ranking task should have group field."); hasGroup = true; @@ -870,14 +852,9 @@ private static int GetNumSampleRow(int numRow, int numCol) return ret; } - public PredictionKind PredictionKind => _predictionKind; - - IPredictor ITrainer.CreatePredictor() - { - return CreatePredictor(); - } + public override PredictionKind PredictionKind { get; } - public abstract TPredictor CreatePredictor(); + protected internal abstract TPredictor CreatePredictor(); /// /// This function will be called before training. It will check the label/group and add parameters for specific applications. diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index 23e7351a86..c8a2f6a105 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -41,7 +41,7 @@ namespace Microsoft.ML.Runtime.PCA /// /// This PCA can be made into Kernel PCA by using Random Fourier Features transform /// - public sealed class RandomizedPcaTrainer : TrainerBase + public sealed class RandomizedPcaTrainer : TrainerBase { public const string LoadNameValue = "pcaAnomaly"; internal const string UserNameValue = "PCA Anomaly Detector"; @@ -69,13 +69,10 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight public int? Seed; } - private int _dimension; private readonly int _rank; private readonly int _oversampling; private readonly bool _center; private readonly int _seed; - private VBuffer[] _eigenvectors; // top eigenvectors of the covariance matrix - private VBuffer _mean; public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) @@ -90,65 +87,52 @@ public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) _seed = args.Seed ?? Host.Rand.Next(); } - public override bool NeedNormalization - { - get { return true; } - } - - public override bool NeedCalibration - { - get { return false; } - } + public override bool NeedNormalization => true; - public override bool WantCaching - { - // Two passes, only. Probably not worth caching. - get { return false; } - } + public override bool NeedCalibration => false; - public override PcaPredictor CreatePredictor() - { - return new PcaPredictor(Host, _rank, _eigenvectors, ref _mean); - } + // Two passes, only. Probably not worth caching. + public override bool WantCaching => false; - public override PredictionKind PredictionKind { get { return PredictionKind.AnomalyDetection; } } + public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection; //Note: the notations used here are the same as in http://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9) - public override void Train(RoleMappedData data) + public override PcaPredictor Train(TrainContext context) { - Host.CheckValue(data, nameof(data)); + Host.CheckValue(context, nameof(context)); - data.CheckFeatureFloatVector(out _dimension); + context.Train.CheckFeatureFloatVector(out int dimension); using (var ch = Host.Start("Training")) { - TrainCore(ch, data); + var pred = TrainCore(ch, context.Train, dimension); ch.Done(); + return pred; } } - private void TrainCore(IChannel ch, RoleMappedData data) + private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension) { Host.AssertValue(ch); ch.AssertValue(data); - if (_rank > _dimension) - throw ch.Except("Rank ({0}) cannot be larger than the original dimension ({1})", _rank, _dimension); - int oversampledRank = Math.Min(_rank + _oversampling, _dimension); + if (_rank > dimension) + throw ch.Except("Rank ({0}) cannot be larger than the original dimension ({1})", _rank, dimension); + int oversampledRank = Math.Min(_rank + _oversampling, dimension); //exact: (size of the 2 big matrices + other minor allocations) / (2^30) - Double memoryUsageEstimate = 2.0 * _dimension * oversampledRank * sizeof(Float) / 1e9; + Double memoryUsageEstimate = 2.0 * dimension * oversampledRank * sizeof(Float) / 1e9; if (memoryUsageEstimate > 2) ch.Info("Estimate memory usage: {0:G2} GB. If running out of memory, reduce rank and oversampling factor.", memoryUsageEstimate); - var y = Zeros(oversampledRank, _dimension); - _mean = _center ? VBufferUtils.CreateDense(_dimension) : VBufferUtils.CreateEmpty(_dimension); + var y = Zeros(oversampledRank, dimension); + var mean = _center ? VBufferUtils.CreateDense(dimension) : VBufferUtils.CreateEmpty(dimension); - var omega = GaussianMatrix(oversampledRank, _dimension, _seed); + var omega = GaussianMatrix(oversampledRank, dimension, _seed); var cursorFactory = new FeatureFloatVectorCursor.Factory(data, CursOpt.Features | CursOpt.Weight); long numBad; - Project(Host, cursorFactory, ref _mean, omega, y, out numBad); + Project(Host, cursorFactory, ref mean, omega, y, out numBad); if (numBad > 0) ch.Warning("Skipped {0} instances with missing features/weights during training", numBad); @@ -166,7 +150,7 @@ private void TrainCore(IChannel ch, RoleMappedData data) var q = y; // q in QR decomposition. var b = omega; // reuse the memory allocated by Omega. - Project(Host, cursorFactory, ref _mean, q, b, out numBad); + Project(Host, cursorFactory, ref mean, q, b, out numBad); //Compute B2 = B' * B var b2 = new Float[oversampledRank * oversampledRank]; @@ -179,8 +163,9 @@ private void TrainCore(IChannel ch, RoleMappedData data) Float[] smallEigenvalues;// eigenvectors and eigenvalues of the small matrix B2. Float[] smallEigenvectors; EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors); - PostProcess(b, smallEigenvalues, smallEigenvectors, _dimension, oversampledRank); - _eigenvectors = b; + PostProcess(b, smallEigenvalues, smallEigenvectors, dimension, oversampledRank); + + return new PcaPredictor(Host, _rank, b, ref mean); } private static VBuffer[] Zeros(int k, int d) diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index b967bd9f95..290c269b8b 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -30,9 +30,7 @@ namespace Microsoft.ML.Runtime.FactorizationMachine [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf */ /// - public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase, - IIncrementalTrainer, IValidatingTrainer, - IIncrementalValidatingTrainer + public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase { public const string Summary = "Train a field-aware factorization machine for binary classification"; public const string UserName = "Field-aware Factorization Machine"; @@ -89,7 +87,6 @@ public sealed class Arguments : LearnerInputBaseWithLabel private readonly bool _shuffle; private readonly bool _verbose; private readonly float _radius; - private FieldAwareFactorizationMachinePredictor _pred; public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) : base(env, LoadName) { @@ -216,7 +213,7 @@ private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool no return loss / exampleCount; } - private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, RoleMappedData validData, FieldAwareFactorizationMachinePredictor predictor) + private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, RoleMappedData validData, FieldAwareFactorizationMachinePredictor predictor) { Host.AssertValue(ch); Host.AssertValue(pch); @@ -346,63 +343,25 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set"); if (validBadExampleCount != 0) ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set"); - _pred = new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); + return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); } - public override void Train(RoleMappedData data) + public override FieldAwareFactorizationMachinePredictor Train(TrainContext context) { - Host.CheckValue(data, nameof(data)); - using (var ch = Host.Start("Training")) - using (var pch = Host.StartProgressChannel("Training")) - { - TrainCore(ch, pch, data, null, null); - ch.Done(); - } - } - - public void Train(RoleMappedData data, RoleMappedData validData) - { - Host.CheckValue(data, nameof(data)); - Host.CheckValue(validData, nameof(validData)); - using (var ch = Host.Start("Training")) - using (var pch = Host.StartProgressChannel("Training")) - { - TrainCore(ch, pch, data, validData, null); - ch.Done(); - } - } + Host.CheckValue(context, nameof(context)); + var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor; + Host.CheckParam(context.InitialPredictor == null || initPredictor != null, nameof(context), + "Initial predictor should have been " + nameof(FieldAwareFactorizationMachinePredictor)); - public void Train(RoleMappedData data, FieldAwareFactorizationMachinePredictor predictor) - { - Host.CheckValue(data, nameof(data)); - Host.CheckValue(predictor, nameof(predictor)); using (var ch = Host.Start("Training")) using (var pch = Host.StartProgressChannel("Training")) { - TrainCore(ch, pch, data, null, predictor); + var pred = TrainCore(ch, pch, context.Train, context.Validation, initPredictor); ch.Done(); + return pred; } } - public void Train(RoleMappedData data, RoleMappedData validData, FieldAwareFactorizationMachinePredictor predictor) - { - Host.CheckValue(data, nameof(data)); - Host.CheckValue(data, nameof(validData)); - Host.CheckValue(predictor, nameof(predictor)); - using (var ch = Host.Start("Training")) - using (var pch = Host.StartProgressChannel("Training")) - { - TrainCore(ch, pch, data, validData, predictor); - ch.Done(); - } - } - - public override FieldAwareFactorizationMachinePredictor CreatePredictor() - { - Host.Check(_pred != null, nameof(Train) + " has not yet been called"); - return _pred; - } - [TlcModule.EntryPoint(Name = "Trainers.FieldAwareFactorizationMachineBinaryClassifier", Desc = Summary, UserName = UserName, diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 554babf1ce..29dbbb3447 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -44,12 +44,9 @@ namespace Microsoft.ML.Runtime.Learners using Stopwatch = System.Diagnostics.Stopwatch; using TScalarPredictor = IPredictorWithFeatureWeights; - public abstract class LinearTrainerBase : TrainerBase + public abstract class LinearTrainerBase : TrainerBase where TPredictor : IPredictor { - protected int NumFeatures; - protected VBuffer[] Weights; - protected Float[] Bias; protected bool NeedShuffle; public override bool NeedNormalization => true; @@ -66,40 +63,39 @@ protected LinearTrainerBase(IHostEnvironment env, string name) { } - protected void TrainEx(RoleMappedData data, LinearPredictor predictor) + public override TPredictor Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); + TPredictor pred; using (var ch = Host.Start("Training")) { - ch.AssertValue(data, nameof(data)); - ch.AssertValueOrNull(predictor); - var preparedData = PrepareDataFromTrainingExamples(ch, data); - TrainCore(ch, preparedData, predictor); + var preparedData = PrepareDataFromTrainingExamples(ch, context.Train, out int weightSetCount); + var initPred = context.InitialPredictor; + var linInitPred = (initPred as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; + linInitPred = linInitPred ?? initPred as LinearPredictor; + Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context), + "Initial predictor was not a linear predictor."); + pred = TrainCore(ch, preparedData, linInitPred, weightSetCount); ch.Done(); } + return pred; } - public override void Train(RoleMappedData examples) - { - Host.CheckValue(examples, nameof(examples)); - TrainEx(examples, null); - } - - protected abstract void TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor); - - /// - /// Gets the size of weights and bias array. For binary classification and regression, this is 1. - /// For multi-class classification, this equals the number of classes. - /// - protected abstract int WeightArraySize { get; } + protected abstract TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount); /// /// This method ensures that the data meets the requirements of this trainer and its /// subclasses, injects necessary transforms, and throws if it couldn't meet them. /// - protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples) + /// The channel + /// The training examples + /// Gets the length of weights and bias array. For binary classification and regression, + /// this is 1. For multi-class classification, this equals the number of classes on the label. + /// A potentially modified version of + protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount) { ch.AssertValue(examples); - CheckLabel(examples); + CheckLabel(examples, out weightSetCount); examples.CheckFeatureFloatVector(); var idvToShuffle = examples.Data; IDataView idvToFeedTrain; @@ -120,17 +116,17 @@ protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMapped var roles = examples.Schema.GetColumnRoleNames(); var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles); - ch.Assert(examplesToFeedTrain.Schema.Label != null); - ch.Assert(examplesToFeedTrain.Schema.Feature != null); + ch.AssertValue(examplesToFeedTrain.Schema.Label); + ch.AssertValue(examplesToFeedTrain.Schema.Feature); if (examples.Schema.Weight != null) - ch.Assert(examplesToFeedTrain.Schema.Weight != null); + ch.AssertValue(examplesToFeedTrain.Schema.Weight); - NumFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize; - ch.Check(NumFeatures > 0, "Training set has 0 instances, aborting training."); + int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize; + ch.Check(numFeatures > 0, "Training set has no features, aborting training."); return examplesToFeedTrain; } - protected abstract void CheckLabel(RoleMappedData examples); + protected abstract void CheckLabel(RoleMappedData examples, out int weightSetCount); protected Float WDot(ref VBuffer features, ref VBuffer weights, Float bias) { @@ -165,13 +161,13 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel { [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant. By default the l2 constant is automatically inferred based on data set.", NullName = "", ShortName = "l2", SortOrder = 1)] [TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = ",1e-7,1e-6,1e-5,1e-4,1e-3,1e-2")] - [TlcModule.SweepableDiscreteParamAttribute("L2Const", new object[] { "", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f })] + [TlcModule.SweepableDiscreteParam("L2Const", new object[] { "", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f })] public Float? L2Const; // REVIEW: make the default positive when we know how to consume a sparse model [Argument(ArgumentType.AtMostOnce, HelpText = "L1 soft threshold (L1/L2). Note that it is easier to control and sweep using the threshold parameter than the raw L1-regularizer constant. By default the l1 threshold is automatically inferred based on data set.", NullName = "", ShortName = "l1", SortOrder = 2)] [TGUI(Label = "L1 Soft Threshold", SuggestedSweeps = ",0,0.25,0.5,0.75,1")] - [TlcModule.SweepableDiscreteParamAttribute("L1Threshold", new object[] { "", 0f, 0.25f, 0.5f, 0.75f, 1f })] + [TlcModule.SweepableDiscreteParam("L1Threshold", new object[] { "", 0f, 0.25f, 0.5f, 0.75f, 1f })] public Float? L1Threshold; [Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.", NullName = "", ShortName = "nt,t,threads", SortOrder = 50)] @@ -180,16 +176,16 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel [Argument(ArgumentType.AtMostOnce, HelpText = "The tolerance for the ratio between duality gap and primal loss for convergence checking.", ShortName = "tol")] [TGUI(SuggestedSweeps = "0.001, 0.01, 0.1, 0.2")] - [TlcModule.SweepableDiscreteParamAttribute("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f })] + [TlcModule.SweepableDiscreteParam("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f })] public Float ConvergenceTolerance = 0.1f; [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations; set to 1 to simulate online learning. Defaults to automatic.", NullName = "", ShortName = "iter")] [TGUI(Label = "Max number of iterations", SuggestedSweeps = ",10,20,100")] - [TlcModule.SweepableDiscreteParamAttribute("MaxIterations", new object[] { "", 10, 20, 100 })] + [TlcModule.SweepableDiscreteParam("MaxIterations", new object[] { "", 10, 20, 100 })] public int? MaxIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data every epoch?", ShortName = "shuf")] - [TlcModule.SweepableDiscreteParamAttribute("Shuffle", null, isBool: true)] + [TlcModule.SweepableDiscreteParam("Shuffle", null, isBool: true)] public bool Shuffle = true; [Argument(ArgumentType.AtMostOnce, HelpText = "Convergence check frequency (in terms of number of iterations). Set as negative or zero for not checking at all. If left blank, it defaults to check after every 'numThreads' iterations.", NullName = "", ShortName = "checkFreq")] @@ -197,7 +193,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel [Argument(ArgumentType.AtMostOnce, HelpText = "The learning rate for adjusting bias from being regularized.", ShortName = "blr")] [TGUI(SuggestedSweeps = "0, 0.01, 0.1, 1")] - [TlcModule.SweepableDiscreteParamAttribute("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f })] + [TlcModule.SweepableDiscreteParam("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f })] public Float BiasLearningRate = 0; internal virtual void Check(IHostEnvironment env) @@ -217,6 +213,7 @@ internal virtual void Check(IHostEnvironment env) "could drastically slow down the convergence. So using l2Const = {1} instead.", L2Const); L2Const = L2LowerBound; + ch.Done(); } } } @@ -243,12 +240,9 @@ protected enum MetricKind private readonly ArgumentsBase _args; protected ISupportSdcaLoss Loss; - public override bool NeedNormalization - { - get { return true; } - } + public override bool NeedNormalization => true; - protected override bool ShuffleData { get { return _args.Shuffle; } } + protected override bool ShuffleData => _args.Shuffle; protected SdcaTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) : base(env, name) @@ -257,13 +251,13 @@ protected SdcaTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) _args.Check(env); } - protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor) + protected sealed override TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount) { Contracts.Assert(predictor == null, "SDCA based trainers don't support continuous training."); - Contracts.Assert(NumFeatures > 0, "Number of features must be assigned prior to passing into TrainCore."); - int weightArraySize = WeightArraySize; - Contracts.Assert(weightArraySize >= 1); - long maxTrainingExamples = MaxDualTableSize / weightArraySize; + Contracts.Assert(weightSetCount >= 1); + + int numFeatures = data.Schema.Feature.Type.VectorSize; + long maxTrainingExamples = MaxDualTableSize / weightSetCount; var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight | CursOpt.Id); int numThreads; if (_args.NumThreads.HasValue) @@ -301,8 +295,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic ch.Assert(checkFrequency > 0); var pOptions = new ParallelOptions { MaxDegreeOfParallelism = numThreads }; - var converged = false; - var watch = new Stopwatch(); + bool converged = false; // Getting the total count of rows in data. Ignore rows with bad label and feature values. long count = 0; @@ -398,24 +391,24 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic Contracts.Assert(_args.L2Const.HasValue); if (_args.L1Threshold == null) - _args.L1Threshold = TuneDefaultL1(ch, NumFeatures); + _args.L1Threshold = TuneDefaultL1(ch, numFeatures); ch.Assert(_args.L1Threshold.HasValue); var l1Threshold = _args.L1Threshold.Value; var l1ThresholdZero = l1Threshold == 0; - VBuffer[] weights = new VBuffer[weightArraySize]; - VBuffer[] bestWeights = new VBuffer[weightArraySize]; - VBuffer[] l1IntermediateWeights = l1ThresholdZero ? null : new VBuffer[weightArraySize]; - Float[] biasReg = new Float[weightArraySize]; - Float[] bestBiasReg = new Float[weightArraySize]; - Float[] biasUnreg = new Float[weightArraySize]; - Float[] bestBiasUnreg = new Float[weightArraySize]; - Float[] l1IntermediateBias = l1ThresholdZero ? null : new Float[weightArraySize]; - - for (int i = 0; i < weightArraySize; i++) + var weights = new VBuffer[weightSetCount]; + var bestWeights = new VBuffer[weightSetCount]; + var l1IntermediateWeights = l1ThresholdZero ? null : new VBuffer[weightSetCount]; + var biasReg = new Float[weightSetCount]; + var bestBiasReg = new Float[weightSetCount]; + var biasUnreg = new Float[weightSetCount]; + var bestBiasUnreg = new Float[weightSetCount]; + var l1IntermediateBias = l1ThresholdZero ? null : new Float[weightSetCount]; + + for (int i = 0; i < weightSetCount; i++) { - weights[i] = VBufferUtils.CreateDense(NumFeatures); - bestWeights[i] = VBufferUtils.CreateDense(NumFeatures); + weights[i] = VBufferUtils.CreateDense(numFeatures); + bestWeights[i] = VBufferUtils.CreateDense(numFeatures); biasReg[i] = 0; bestBiasReg[i] = 0; biasUnreg[i] = 0; @@ -423,7 +416,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic if (!l1ThresholdZero) { - l1IntermediateWeights[i] = VBufferUtils.CreateDense(NumFeatures); + l1IntermediateWeights[i] = VBufferUtils.CreateDense(numFeatures); l1IntermediateBias[i] = 0; } } @@ -441,7 +434,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic if (idToIdx == null) { Contracts.Assert(!needLookup); - long dualsLength = ((long)idLoMax + 1) * WeightArraySize; + long dualsLength = ((long)idLoMax + 1) * weightSetCount; if (dualsLength <= Utils.ArrayMaxSize) { // The dual variables fit into a standard float[]. @@ -465,7 +458,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic { // Similar logic as above when using the id-to-index lookup. Contracts.Assert(needLookup); - long dualsLength = count * WeightArraySize; + long dualsLength = count * weightSetCount; if (dualsLength <= Utils.ArrayMaxSize) { duals = new StandardArrayDualsTable((int)dualsLength); @@ -497,8 +490,6 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic ch.Assert(_args.MaxIterations.HasValue); var maxIterations = _args.MaxIterations.Value; - watch.Start(); - var rands = new IRandom[maxIterations]; for (int i = 0; i < maxIterations; i++) rands[i] = RandomUtils.Create(Host.Rand.Next()); @@ -506,9 +497,9 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic // If we favor storing the invariants, precompute the invariants now. if (invariants != null) { - Contracts.Assert((idToIdx == null & ((long)idLoMax + 1) * WeightArraySize <= Utils.ArrayMaxSize) | (idToIdx != null & count * WeightArraySize <= Utils.ArrayMaxSize)); - Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx); - int invariantCoeff = WeightArraySize == 1 ? 1 : 2; + Contracts.Assert((idToIdx == null & ((long)idLoMax + 1) * weightSetCount <= Utils.ArrayMaxSize) | (idToIdx != null & count * weightSetCount <= Utils.ArrayMaxSize)); + Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length); + int invariantCoeff = weightSetCount == 1 ? 1 : 2; using (var cursor = cursorFactory.Create()) using (var pch = Host.StartProgressChannel("SDCA invariants initialization")) { @@ -599,23 +590,25 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic } } - Bias = new Float[weightArraySize]; + var bias = new Float[weightSetCount]; if (bestIter > 0) { ch.Info("Using best model from iteration {0}.", bestIter); - Weights = bestWeights; - for (int i = 0; i < weightArraySize; i++) - Bias[i] = bestBiasReg[i] + bestBiasUnreg[i]; + weights = bestWeights; + for (int i = 0; i < weightSetCount; i++) + bias[i] = bestBiasReg[i] + bestBiasUnreg[i]; } else { ch.Info("Using model from last iteration."); - Weights = weights; - for (int i = 0; i < weightArraySize; i++) - Bias[i] = biasReg[i] + biasUnreg[i]; + for (int i = 0; i < weightSetCount; i++) + bias[i] = biasReg[i] + biasUnreg[i]; } + return CreatePredictor(weights, bias); } + protected abstract TPredictor CreatePredictor(VBuffer[] weights, Float[] bias); + // Assign an upper bound for number of iterations based on data set size first. // This ensures SDCA will not run forever... // Based on empirical estimation of max iterations needed. @@ -746,7 +739,7 @@ protected virtual void TrainWithoutLock(IProgressChannelProvider progress, Float if (pch != null) pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount)); - Func getIndexFromId = GetIndexFromIdGetter(idToIdx); + Func getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length); while (cursor.MoveNext()) { long idx = getIndexFromId(cursor.Id); @@ -901,7 +894,7 @@ protected virtual bool CheckConvergence( using (var cursor = cursorFactory.Create()) { long row = 0; - Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx); + Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length); // Iterates through data to compute loss function. while (cursor.MoveNext()) { @@ -992,8 +985,7 @@ public StandardArrayDualsTable(int length) _duals = new Float[length]; } - public override Float this[long index] - { + public override Float this[long index] { get { return _duals[(int)index]; } set { _duals[(int)index] = value; } } @@ -1019,14 +1011,11 @@ public BigArrayDualsTable(long length) _duals = new BigArray(length); } - public override Float this[long index] - { - get - { + public override Float this[long index] { + get { return _duals[index]; } - set - { + set { _duals[index] = value; } } @@ -1042,10 +1031,10 @@ public override void ApplyAt(long index, Visitor manip) /// Returns a function delegate to retrieve index from id. /// This is to avoid redundant conditional branches in the tight loop of training. /// - protected Func GetIndexFromIdGetter(IdToIdxLookup idToIdx) + protected Func GetIndexFromIdGetter(IdToIdxLookup idToIdx, int biasLength) { Contracts.AssertValueOrNull(idToIdx); - long maxTrainingExamples = MaxDualTableSize / WeightArraySize; + long maxTrainingExamples = MaxDualTableSize / biasLength; if (idToIdx == null) { return (UInt128 id) => @@ -1073,10 +1062,10 @@ protected Func GetIndexFromIdGetter(IdToIdxLookup idToIdx) /// Only works if the cursor is not shuffled. /// This is to avoid redundant conditional branches in the tight loop of training. /// - protected Func GetIndexFromIdAndRowGetter(IdToIdxLookup idToIdx) + protected Func GetIndexFromIdAndRowGetter(IdToIdxLookup idToIdx, int biasLength) { Contracts.AssertValueOrNull(idToIdx); - long maxTrainingExamples = MaxDualTableSize / WeightArraySize; + long maxTrainingExamples = MaxDualTableSize / biasLength; if (idToIdx == null) { return (UInt128 id, long row) => @@ -1368,7 +1357,7 @@ public void Add(Double summand) } } - public sealed class LinearClassificationTrainer : SdcaTrainerBase, ITrainer, ITrainerEx + public sealed class LinearClassificationTrainer : SdcaTrainerBase, ITrainerEx { public const string LoadNameValue = "SDCA"; public const string UserNameValue = "Fast Linear (SA-SDCA)"; @@ -1404,8 +1393,6 @@ internal override void Check(IHostEnvironment env) public override bool NeedCalibration => !(_loss is LogLoss); - protected override int WeightArraySize => 1; - public LinearClassificationTrainer(IHostEnvironment env, Arguments args) : base(args, env, LoadNameValue) { @@ -1416,48 +1403,42 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args) _positiveInstanceWeight = _args.PositiveInstanceWeight; } - public override IPredictor CreatePredictor() + protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias) { - Contracts.Assert(WeightArraySize == 1); - Contracts.Assert(Utils.Size(Weights) == 1); - Contracts.Assert(Utils.Size(Bias) == 1); - Host.Check(Weights[0].Length > 0); - VBuffer maybeSparseWeights = VBufferUtils.CreateEmpty(Weights[0].Length); - VBufferUtils.CreateMaybeSparseCopy(ref Weights[0], ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float)); - var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, Bias[0]); + Host.CheckParam(Utils.Size(weights) == 1, nameof(weights)); + Host.CheckParam(Utils.Size(bias) == 1, nameof(bias)); + Host.CheckParam(weights[0].Length > 0, nameof(weights)); + + VBuffer maybeSparseWeights = default; + VBufferUtils.CreateMaybeSparseCopy(ref weights[0], ref maybeSparseWeights, + Conversions.Instance.GetIsDefaultPredicate(NumberType.Float)); + var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, bias[0]); if (!(_loss is LogLoss)) return predictor; return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0)); } - TScalarPredictor ITrainer.CreatePredictor() - { - var predictor = CreatePredictor() as TScalarPredictor; - Contracts.AssertValue(predictor); - return predictor; - } - protected override Float GetInstanceWeight(FloatLabelCursor cursor) { return cursor.Label > 0 ? cursor.Weight * _positiveInstanceWeight : cursor.Weight; } - protected override void CheckLabel(RoleMappedData examples) + protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) { examples.CheckBinaryLabel(); + weightSetCount = 1; } } public sealed class StochasticGradientDescentClassificationTrainer : - LinearTrainerBase, - IIncrementalTrainer, - ITrainer, - ITrainerEx + LinearTrainerBase { public const string LoadNameValue = "BinarySGD"; public const string UserNameValue = "Hogwild SGD (binary)"; public const string ShortName = "HogwildSGD"; + public override bool SupportsIncrementalTraining => true; + public sealed class Arguments : LearnerInputBaseWithWeight { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] @@ -1465,7 +1446,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant", ShortName = "l2", SortOrder = 50)] [TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = "1e-7,5e-7,1e-6,5e-6,1e-5")] - [TlcModule.SweepableDiscreteParamAttribute("L2Const", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f })] + [TlcModule.SweepableDiscreteParam("L2Const", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f })] public Float L2Const = (Float)1e-6; [Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic depending on data sparseness. Determinism not guaranteed.", ShortName = "nt,t,threads", SortOrder = 50)] @@ -1474,12 +1455,12 @@ public sealed class Arguments : LearnerInputBaseWithWeight [Argument(ArgumentType.AtMostOnce, HelpText = "Exponential moving averaged improvement tolerance for convergence", ShortName = "tol")] [TGUI(SuggestedSweeps = "1e-2,1e-3,1e-4,1e-5")] - [TlcModule.SweepableDiscreteParamAttribute("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f })] + [TlcModule.SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f })] public Double ConvergenceTolerance = 1e-4; [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations; set to 1 to simulate online learning.", ShortName = "iter")] [TGUI(Label = "Max number of iterations", SuggestedSweeps = "1,5,10,20")] - [TlcModule.SweepableDiscreteParamAttribute("MaxIterations", new object[] { 1, 5, 10, 20 })] + [TlcModule.SweepableDiscreteParam("MaxIterations", new object[] { 1, 5, 10, 20 })] public int MaxIterations = 20; [Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate (only used by SGD)", ShortName = "ilr,lr")] @@ -1487,7 +1468,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight public Double InitLearningRate = 0.01; [Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data every epoch?", ShortName = "shuf")] - [TlcModule.SweepableDiscreteParamAttribute("Shuffle", null, isBool: true)] + [TlcModule.SweepableDiscreteParam("Shuffle", null, isBool: true)] public bool Shuffle = true; [Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")] @@ -1502,15 +1483,23 @@ public sealed class Arguments : LearnerInputBaseWithWeight [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public int MaxCalibrationExamples = 1000000; - public void Check(ITrainerHost host) + internal void Check(IHostEnvironment env) { - Contracts.CheckUserArg(L2Const >= 0, nameof(L2Const), "L2 constant must be non-negative."); - Contracts.CheckUserArg(InitLearningRate > 0, nameof(InitLearningRate), "Initial learning rate must be positive."); - Contracts.CheckUserArg(MaxIterations > 0, nameof(MaxIterations), "Max number of iterations must be positive."); - Contracts.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Weight for positive instances must be positive"); + Contracts.CheckValue(env, nameof(env)); + env.CheckUserArg(L2Const >= 0, nameof(L2Const), "Must be non-negative."); + env.CheckUserArg(InitLearningRate > 0, nameof(InitLearningRate), "Must be positive."); + env.CheckUserArg(MaxIterations > 0, nameof(MaxIterations), "Must be positive."); + env.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Must be positive"); if (InitLearningRate * L2Const >= 1) - host.StdOut.WriteLine("Learning rate {0} set too high; reducing to {1}", InitLearningRate, InitLearningRate = (Float)0.5 / L2Const); + { + using (var ch = env.Start("Argument Adjustment")) + { + ch.Warning("{0} {1} set too high; reducing to {1}", nameof(InitLearningRate), + InitLearningRate, InitLearningRate = (Float)0.5 / L2Const); + ch.Done(); + } + } if (ConvergenceTolerance <= 0) ConvergenceTolerance = Float.Epsilon; @@ -1520,63 +1509,33 @@ public void Check(ITrainerHost host) private readonly IClassificationLoss _loss; private readonly Arguments _args; - protected override bool ShuffleData { get { return _args.Shuffle; } } - - protected override int WeightArraySize { get { return 1; } } + protected override bool ShuffleData => _args.Shuffle; - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override bool NeedCalibration - { - get { return !(_loss is LogLoss); } - } + public override bool NeedCalibration => !(_loss is LogLoss); public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { + args.Check(env); _loss = args.LossFunction.CreateComponent(env); NeedShuffle = args.Shuffle; _args = args; } - public override IPredictor CreatePredictor() - { - Contracts.Assert(WeightArraySize == 1); - Contracts.Assert(Utils.Size(Weights) == 1); - Contracts.Assert(Utils.Size(Bias) == 1); - Host.Check(Weights[0].Length > 0); - VBuffer maybeSparseWeights = VBufferUtils.CreateEmpty(Weights[0].Length); - VBufferUtils.CreateMaybeSparseCopy(ref Weights[0], ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float)); - var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, Bias[0]); - if (!(_loss is LogLoss)) - return predictor; - return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0)); - } - - TScalarPredictor ITrainer.CreatePredictor() - { - var predictor = CreatePredictor() as TScalarPredictor; - Contracts.AssertValue(predictor); - return predictor; - } - - public void Train(RoleMappedData data, IPredictor predictor) - { - Host.CheckValue(data, nameof(data)); - Host.CheckValue(predictor, nameof(predictor)); - LinearPredictor pred = (predictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; - pred = pred ?? predictor as LinearPredictor; - Host.CheckParam(pred != null, nameof(predictor), "Not a linear predictor."); - TrainEx(data, pred); - } - //For complexity analysis, we assume that // - The number of features is N // - Average number of non-zero per instance is k - protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor) + protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount) { - ch.Assert(NumFeatures > 0, "Number of features must be assigned prior to passing into TrainCore."); + Contracts.AssertValue(data); + Contracts.Assert(weightSetCount == 1); + Contracts.AssertValueOrNull(predictor); + + int numFeatures = data.Schema.Feature.Type.VectorSize; var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight); + int numThreads; if (_args.NumThreads.HasValue) { @@ -1603,7 +1562,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic bias = predictor.Bias; } else - weights = VBufferUtils.CreateDense(NumFeatures); + weights = VBufferUtils.CreateDense(numFeatures); var weightsSync = new object(); double weightScaling = 1; @@ -1742,15 +1701,18 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic VectorUtils.ScaleBy(ref weights, (Float)weightScaling); // restore the true weights - Weights = new VBuffer[1]; - Bias = new Float[1]; - Weights[0] = weights; - Bias[0] = bias; + VBuffer maybeSparseWeights = default; + VBufferUtils.CreateMaybeSparseCopy(ref weights, ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float)); + var pred = new LinearBinaryPredictor(Host, ref maybeSparseWeights, bias); + if (!(_loss is LogLoss)) + return pred; + return new ParameterMixingCalibratedPredictor(Host, pred, new PlattCalibrator(Host, -1, 0)); } - protected override void CheckLabel(RoleMappedData examples) + protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) { examples.CheckBinaryLabel(); + weightSetCount = 1; } [TlcModule.EntryPoint(Name = "Trainers.StochasticGradientDescentBinaryClassifier", Desc = "Train an Hogwild SGD binary model.", UserName = UserNameValue, ShortName = ShortName)] diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 89f4866228..4c37c4b191 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -17,9 +17,7 @@ namespace Microsoft.ML.Runtime.Learners { - public abstract class LbfgsTrainerBase : - TrainerBase, - IIncrementalTrainer + public abstract class LbfgsTrainerBase : TrainerBase where TPredictor : class, IPredictorProducing { public abstract class ArgumentsBase : LearnerInputBaseWithWeight @@ -181,6 +179,7 @@ internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, protected virtual int ClassCount => 1; protected int BiasCount => ClassCount; protected int WeightCount => ClassCount * NumFeatures; + public sealed override bool SupportsIncrementalTraining => true; protected virtual Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory, out VBuffer init, out ITerminationCriterion terminationCriterion) @@ -289,28 +288,23 @@ protected virtual VBuffer InitializeWeightsSgd(IChannel ch, FloatLabelCur protected abstract VBuffer InitializeWeightsFromPredictor(TPredictor srcPredictor); - public void Train(RoleMappedData data, TPredictor predictor) - { - Contracts.CheckValue(data, nameof(data)); - Contracts.CheckValue(predictor, nameof(predictor)); - - _srcPredictor = predictor; - Train(data); - } - protected abstract void CheckLabel(RoleMappedData data); protected virtual void PreTrainingProcessInstance(Float label, ref VBuffer feat, Float weight) { } + protected abstract TPredictor CreatePredictor(); + /// /// The basic training calls the optimizer /// - public override void Train(RoleMappedData data) + public override TPredictor Train(TrainContext context) { - Contracts.CheckValue(data, nameof(data)); + Contracts.CheckValue(context, nameof(context)); + var data = context.Train; + _srcPredictor = context.Train as TPredictor; data.CheckFeatureFloatVector(out NumFeatures); CheckLabel(data); data.CheckOptFloatWeight(); @@ -318,16 +312,18 @@ public override void Train(RoleMappedData data) if (NumFeatures >= Utils.ArrayMaxSize / ClassCount) { throw Contracts.ExceptParam(nameof(data), - String.Format("The number of model parameters which is equal to ('# of features' + 1) * '# of classes' should be less than or equal to {0}.", Utils.ArrayMaxSize)); + "The number of model parameters which is equal to ('# of features' + 1) * '# of classes' should be less than or equal to {0}.", Utils.ArrayMaxSize); } using (var ch = Host.Start("Training")) { TrainCore(ch, data); + var pred = CreatePredictor(); ch.Done(); + return pred; } } - + private void TrainCore(IChannel ch, RoleMappedData data) { Host.AssertValue(ch); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 3cf97ea801..994e5f278c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -55,9 +55,9 @@ public LogisticRegression(IHostEnvironment env, Arguments args) _posWeight = 0; } - public override bool NeedCalibration { get { return false; } } + public override bool NeedCalibration => false; - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; protected override void CheckLabel(RoleMappedData data) { @@ -373,7 +373,7 @@ protected override VBuffer InitializeWeightsFromPredictor(ParameterMixing return InitializeWeights(pred.Weights2, new[] { pred.Bias }); } - public override ParameterMixingCalibratedPredictor CreatePredictor() + protected override ParameterMixingCalibratedPredictor CreatePredictor() { // Logistic regression is naturally calibrated to // output probabilities when transformed using diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 66c1d41084..cceb8ea408 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -67,16 +67,16 @@ public sealed class Arguments : ArgumentsBase private LinearModelStatistics _stats; - protected override int ClassCount { get { return _numClasses; } } + protected override int ClassCount => _numClasses; public MulticlassLogisticRegression(IHostEnvironment env, Arguments args) : base(args, env, LoadNameValue, Contracts.CheckRef(args, nameof(args)).ShowTrainingStats) { } - public override bool NeedCalibration { get { return false; } } + public override bool NeedCalibration => false; - public override PredictionKind PredictionKind { get { return PredictionKind.MultiClassClassification; } } + public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; protected override void CheckLabel(RoleMappedData data) { @@ -203,7 +203,7 @@ protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogis return InitializeWeights(srcPredictor.DenseWeightsEnumerable(), srcPredictor.BiasesEnumerable()); } - public override MulticlassLogisticRegressionPredictor CreatePredictor() + protected override MulticlassLogisticRegressionPredictor CreatePredictor() { if (_numClasses < 1) throw Contracts.Except("Cannot create a multiclass predictor with {0} classes", _numClasses); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 9a3552f74b..2bfdedfe6c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -13,9 +13,9 @@ namespace Microsoft.ML.Runtime.Learners { - using TScalarTrainer = ITrainer>; + using TScalarTrainer = ITrainer>; - public abstract class MetaMulticlassTrainer : TrainerBase + public abstract class MetaMulticlassTrainer : TrainerBase where TPred : IPredictor where TArgs : MetaMulticlassTrainer.ArgumentsBase { @@ -38,7 +38,6 @@ public abstract class ArgumentsBase protected readonly TArgs Args; private TScalarTrainer _trainer; - private TPred _pred; public sealed override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; public sealed override bool NeedNormalization { get; } @@ -96,9 +95,11 @@ protected TScalarTrainer GetTrainer() protected abstract TPred TrainCore(IChannel ch, RoleMappedData data, int count); - public override void Train(RoleMappedData data) + public override TPred Train(TrainContext context) { - Host.CheckValue(data, nameof(data)); + Host.CheckValue(context, nameof(context)); + var data = context.Train; + data.CheckFeatureFloatVector(); int count; @@ -107,16 +108,11 @@ public override void Train(RoleMappedData data) using (var ch = Host.Start("Training")) { - _pred = TrainCore(ch, data, count); - ch.Check(_pred != null, "Training did not result in a predictor"); + var pred = TrainCore(ch, data, count); + ch.Check(pred != null, "Training did not result in a predictor"); ch.Done(); + return pred; } } - - public override TPred CreatePredictor() - { - Host.Check(_pred != null, nameof(CreatePredictor) + " called before " + nameof(Train)); - return _pred; - } } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 94efc8cf05..a0d92194d2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners { - public sealed class MultiClassNaiveBayesTrainer : TrainerBase + public sealed class MultiClassNaiveBayesTrainer : TrainerBase { public const string LoadName = "MultiClassNaiveBayes"; internal const string UserName = "Multiclass Naive Bayes"; @@ -43,8 +43,6 @@ public sealed class Arguments : LearnerInputBaseWithLabel { } - private MultiClassNaiveBayesPredictor _predictor; - public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; public override bool NeedNormalization => false; @@ -58,9 +56,10 @@ public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) { } - public override void Train(RoleMappedData data) + public override MultiClassNaiveBayesPredictor Train(TrainContext context) { - Host.CheckValue(data, nameof(data)); + Host.CheckValue(context, nameof(context)); + var data = context.Train; Host.Check(data.Schema.Label != null, "Missing Label column"); Host.Check(data.Schema.Label.Type == NumberType.Float || data.Schema.Label.Type is KeyType, "Invalid type for Label column, only floats and known-size keys are supported"); @@ -89,6 +88,7 @@ public override void Train(RoleMappedData data) if (cursor.Row.Position > int.MaxValue) { ch.Warning("Stopping training because maximum number of rows have been traversed"); + ch.Done(); break; } @@ -118,16 +118,12 @@ public override void Train(RoleMappedData data) examplesProcessed += 1; } + ch.Done(); } Array.Resize(ref labelHistogram, labelCount); Array.Resize(ref featureHistogram, labelCount); - _predictor = new MultiClassNaiveBayesPredictor(Host, labelHistogram, featureHistogram, featureCount); - } - - public override MultiClassNaiveBayesPredictor CreatePredictor() - { - return _predictor; + return new MultiClassNaiveBayesPredictor(Host, labelHistogram, featureHistogram, featureCount); } [TlcModule.EntryPoint(Name = "Trainers.NaiveBayesClassifier", diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index 7b5fcc8a93..62e3a79631 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.Learners { using CR = RoleMappedSchema.ColumnRole; using TScalarPredictor = IPredictorProducing; - using TScalarTrainer = ITrainer>; + using TScalarTrainer = ITrainer>; public sealed class Ova : MetaMulticlassTrainer { @@ -81,9 +81,10 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe .Prepend(CR.Label.Bind(dstName)); var td = new RoleMappedData(view, roles); - trainer.Train(td); + // REVIEW: In principle we could support validation sets and the like via the train context, but + // this is currently unsupported. + var predictor = trainer.Train(td); - var predictor = trainer.CreatePredictor(); if (Args.UseProbabilities) { ICalibratorTrainer calibrator; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index cf1e7c062b..1c4700ccdd 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners { - using TScalarTrainer = ITrainer>; + using TScalarTrainer = ITrainer>; using TScalarPredictor = IPredictorProducing; using TDistPredictor = IDistPredictorProducing; using CR = RoleMappedSchema.ColumnRole; @@ -78,14 +78,13 @@ private TDistPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedD .Prepend(CR.Label.Bind(dstName)); var td = new RoleMappedData(view, roles); - trainer.Train(td); + var predictor = trainer.Train(td); ICalibratorTrainer calibrator; if (!Args.Calibrator.IsGood()) calibrator = null; else calibrator = Args.Calibrator.CreateInstance(Host); - TScalarPredictor predictor = trainer.CreatePredictor(); var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples, trainer, predictor, td); var dist = res as TDistPredictor; diff --git a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs index db271ff858..e413fea94a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs @@ -30,7 +30,7 @@ namespace Microsoft.ML.Runtime.Learners { - public sealed class OlsLinearRegressionTrainer : TrainerBase + public sealed class OlsLinearRegressionTrainer : TrainerBase { public sealed class Arguments : LearnerInputBaseWithWeight { @@ -57,17 +57,6 @@ It assumes that the conditional mean of the dependent variable follows a linear By minimizing the squares of the difference between observed values and the predictions, the parameters of the regressor can be estimated. "; - private VBuffer _weights; - private Float _bias; - - // These have length equal to the number of model parameters, i.e., one for bias plus length of weights. - private Double[] _standardErrors; - private Double[] _tValues; - private Double[] _pValues; - - private Double _rSquared; - private Double _rSquaredAdjusted; - private readonly Float _l2Weight; private readonly bool _perParameterSignificance; @@ -80,24 +69,11 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) _perParameterSignificance = args.PerParameterSignificance; } - public override bool NeedNormalization - { - get { return true; } - } - - public override bool NeedCalibration - { - get { return false; } - } - - public override bool WantCaching - { - // Two passes, only. Probably not worth caching. - get { return false; } - } - - public override PredictionKind PredictionKind - { get { return PredictionKind.Regression; } } + public override bool NeedNormalization => true; + public override bool NeedCalibration => false; + // Two passes, only. Probably not worth caching. + public override bool WantCaching => false; + public override PredictionKind PredictionKind => PredictionKind.Regression; /// /// In several calculations, we calculate probabilities or other quantities that should range @@ -107,15 +83,14 @@ public override PredictionKind PredictionKind /// The quantity that should be clamped from 0 to 1 /// Either p, or 0 or 1 if it was outside the range 0 to 1 private static Double ProbClamp(Double p) - { - return Math.Max(0, Math.Min(p, 1)); - } + => Math.Max(0, Math.Min(p, 1)); - public override void Train(RoleMappedData examples) + public override OlsLinearRegressionPredictor Train(TrainContext context) { using (var ch = Host.Start("Training")) { - ch.CheckValue(examples, nameof(examples)); + ch.CheckValue(context, nameof(context)); + var examples = context.Train; ch.CheckParam(examples.Schema.Feature != null, nameof(examples), "Need a feature column"); ch.CheckParam(examples.Schema.Label != null, nameof(examples), "Need a label column"); @@ -133,12 +108,13 @@ public override void Train(RoleMappedData examples) var cursorFactory = new FloatLabelCursor.Factory(examples, CursOpt.Label | CursOpt.Features); - TrainCore(ch, cursorFactory, typeFeat.VectorSize); + var pred = TrainCore(ch, cursorFactory, typeFeat.VectorSize); ch.Done(); + return pred; } } - private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) + private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) { Host.AssertValue(ch); ch.AssertValue(cursorFactory); @@ -262,26 +238,21 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int var weights = VBufferUtils.CreateDense(beta.Length - 1); for (int i = 1; i < beta.Length; ++i) weights.Values[i - 1] = (Float)beta[i]; - _weights = weights; - _bias = (Float)beta[0]; - _standardErrors = _tValues = _pValues = null; + var bias = (Float)beta[0]; if (!(_l2Weight > 0) && m == n) { // We would expect the solution to the problem to be exact in this case. - _rSquared = 1; - _rSquaredAdjusted = Float.NaN; ch.Info("Number of examples equals number of parameters, solution is exact but no statistics can be derived"); - ch.Done(); - return; + return new OlsLinearRegressionPredictor(Host, ref weights, bias, null, null, null, 1, Float.NaN); } Double rss = 0; // residual sum of squares Double tss = 0; // total sum of squares using (var cursor = cursorFactory.Create()) { - var lrPredictor = new LinearRegressionPredictor(Host, ref _weights, _bias); + var lrPredictor = new LinearRegressionPredictor(Host, ref weights, bias); var lrMap = lrPredictor.GetMapper, Float>(); - Float yh = default(Float); + Float yh = default; while (cursor.MoveNext()) { var features = cursor.Features; @@ -292,27 +263,28 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int tss += ydm * ydm; } } - _rSquared = ProbClamp(1 - (rss / tss)); + var rSquared = ProbClamp(1 - (rss / tss)); // R^2 adjusted differs from the normal formula on account of the bias term, by Said's reckoning. + double rSquaredAdjusted; if (n > m) { - _rSquaredAdjusted = ProbClamp(1 - (1 - _rSquared) * (n - 1) / (n - m)); + rSquaredAdjusted = ProbClamp(1 - (1 - rSquared) * (n - 1) / (n - m)); ch.Info("Coefficient of determination R2 = {0:g}, or {1:g} (adjusted)", - _rSquared, _rSquaredAdjusted); + rSquared, rSquaredAdjusted); } else - _rSquaredAdjusted = Double.NaN; + rSquaredAdjusted = Double.NaN; // The per parameter significance is compute intensive and may not be required for all practitioners. // Also we can't estimate it, unless we can estimate the variance, which requires more examples than // parameters. if (!_perParameterSignificance || m >= n) - return; + return new OlsLinearRegressionPredictor(Host, ref weights, bias, null, null, null, rSquared, rSquaredAdjusted); - ch.Assert(!Double.IsNaN(_rSquaredAdjusted)); - _standardErrors = new Double[m]; - _tValues = new Double[m]; - _pValues = new Double[m]; + ch.Assert(!Double.IsNaN(rSquaredAdjusted)); + var standardErrors = new Double[m]; + var tValues = new Double[m]; + var pValues = new Double[m]; // Invert X'X: Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, xtx); var s2 = rss / (n - m); // estimate of variance of y @@ -320,7 +292,7 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int for (int i = 0; i < m; i++) { // Initialize with inverse Hessian. - _standardErrors[i] = (Single)xtx[i * (i + 1) / 2 + i]; + standardErrors[i] = (Single)xtx[i * (i + 1) / 2 + i]; } if (_l2Weight > 0) @@ -334,9 +306,9 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int { var entry = (Single)xtx[ioffset]; var adjustment = -reg * entry * entry; - _standardErrors[iRow] -= adjustment; + standardErrors[iRow] -= adjustment; if (0 < iCol && iCol < iRow) - _standardErrors[iCol] -= adjustment; + standardErrors[iCol] -= adjustment; ioffset++; } } @@ -347,17 +319,14 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int for (int i = 0; i < m; i++) { // sqrt of diagonal entries of s2 * inverse(X'X + reg * I) * X'X * inverse(X'X + reg * I). - _standardErrors[i] = Math.Sqrt(s2 * _standardErrors[i]); - ch.Check(FloatUtils.IsFinite(_standardErrors[i]), "Non-finite standard error detected from OLS solution"); - _tValues[i] = beta[i] / _standardErrors[i]; - _pValues[i] = (Float)MathUtils.TStatisticToPValue(_tValues[i], n - m); - ch.Check(0 <= _pValues[i] && _pValues[i] <= 1, "p-Value calculated outside expected [0,1] range"); + standardErrors[i] = Math.Sqrt(s2 * standardErrors[i]); + ch.Check(FloatUtils.IsFinite(standardErrors[i]), "Non-finite standard error detected from OLS solution"); + tValues[i] = beta[i] / standardErrors[i]; + pValues[i] = (Float)MathUtils.TStatisticToPValue(tValues[i], n - m); + ch.Check(0 <= pValues[i] && pValues[i] <= 1, "p-Value calculated outside expected [0,1] range"); } - } - public override OlsLinearRegressionPredictor CreatePredictor() - { - return new OlsLinearRegressionPredictor(Host, ref _weights, _bias, _standardErrors, _tValues, _pValues, _rSquared, _rSquaredAdjusted); + return new OlsLinearRegressionPredictor(Host, ref weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted); } internal static class Mkl diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index aa5ecb67a5..a3316eb959 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -68,7 +68,7 @@ protected override void CheckLabel(RoleMappedData data) data.CheckBinaryLabel(); } - public override LinearBinaryPredictor CreatePredictor() + protected override LinearBinaryPredictor CreatePredictor() { Contracts.Assert(WeightsScale == 1); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index d2b8f0b30f..6e31e55623 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -221,7 +221,7 @@ private void UpdateWeights(ref VBuffer weightsUpdate, Float weightsUpdate } } - public override TPredictor CreatePredictor() + protected override TPredictor CreatePredictor() { Contracts.Assert(WeightsScale == 1); return new LinearBinaryPredictor(Host, ref Weights, Bias); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index f345466e19..f0f50dded3 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -70,7 +70,7 @@ protected override void CheckLabel(RoleMappedData data) data.CheckRegressionLabel(); } - public override TPredictor CreatePredictor() + protected override TPredictor CreatePredictor() { Contracts.Assert(WeightsScale == 1); VBuffer weights = default(VBuffer); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index bcd4b33d58..62896ba789 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -21,7 +21,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel { [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter", SortOrder = 50)] [TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")] - [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize:10, isLogScale:true)] + [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)] public int NumIterations = 1; [Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")] @@ -34,16 +34,14 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel public Float InitWtsDiameter = 0; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf")] - [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] {false, true})] + [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] { false, true })] public bool Shuffle = true; [Argument(ArgumentType.AtMostOnce, HelpText = "Size of cache when trained in Scope", ShortName = "cache")] public int StreamingCacheSize = 1000000; } - public abstract class OnlineLinearTrainer : - TrainerBase, - IIncrementalTrainer + public abstract class OnlineLinearTrainer : TrainerBase where TArguments : OnlineLinearArguments where TPredictor : IPredictorProducing { @@ -83,17 +81,11 @@ protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name Args = args; } - public override bool NeedNormalization - { - get { return true; } - } + public override bool NeedNormalization => true; - public override bool WantCaching - { - // REVIEW: This could return true if there are more than 0 iterations, - // if we got around the whole shuffling issue. - get { return true; } - } + // REVIEW: This could return true if there are more than 0 iterations, + // if we got around the whole shuffling issue. + public override bool WantCaching => true; /// /// Propagates the _weightsScale to the weights vector. @@ -119,18 +111,20 @@ protected void ScaleWeightsIfNeeded() ScaleWeights(); } - private void TrainEx(RoleMappedData data, LinearPredictor predictor) + public override TPredictor Train(TrainContext context) { - Contracts.AssertValue(data, nameof(data)); - Contracts.AssertValueOrNull(predictor); + Host.CheckValue(context, nameof(context)); + var initPredictor = context.InitialPredictor; + var initLinearPred = initPredictor as LinearPredictor ?? (initPredictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; + Host.CheckParam(initPredictor == null || initLinearPred != null, nameof(context), "Not a linear predictor."); + var data = context.Train; - int numFeatures; - data.CheckFeatureFloatVector(out numFeatures); + data.CheckFeatureFloatVector(out int numFeatures); CheckLabel(data); using (var ch = Host.Start("Training")) { - InitCore(ch, numFeatures, predictor); + InitCore(ch, numFeatures, initLinearPred); // InitCore should set the number of features field. Contracts.Assert(NumFeatures > 0); @@ -150,23 +144,11 @@ private void TrainEx(RoleMappedData data, LinearPredictor predictor) ch.Done(); } - } - public override void Train(RoleMappedData data) - { - Host.CheckValue(data, nameof(data)); - TrainEx(data, null); + return CreatePredictor(); } - public void Train(RoleMappedData data, IPredictor predictor) - { - Host.CheckValue(data, nameof(data)); - Host.CheckValue(predictor, nameof(predictor)); - LinearPredictor pred = (predictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; - pred = pred ?? predictor as LinearPredictor; - Host.CheckParam(pred != null, nameof(predictor), "Not a linear predictor."); - TrainEx(data, pred); - } + protected abstract TPredictor CreatePredictor(); protected abstract void CheckLabel(RoleMappedData data); diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs index 9322c2cc75..a818165f01 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs @@ -45,9 +45,9 @@ public PoissonRegression(IHostEnvironment env, Arguments args) { } - public override bool NeedCalibration { get { return false; } } + public override bool NeedCalibration => false; - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; protected override void CheckLabel(RoleMappedData data) { @@ -106,7 +106,7 @@ protected override Float AccumulateOneGradient(ref VBuffer feat, Float la return -(y * dot - lambda) * weight; } - public override PoissonRegressionPredictor CreatePredictor() + protected override PoissonRegressionPredictor CreatePredictor() { VBuffer weights = default(VBuffer); CurrentWeights.CopyTo(ref weights, 1, CurrentWeights.Length - 1); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 20bc349a7c..2a2b01ac2c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -45,21 +45,10 @@ public sealed class Arguments : ArgumentsBase private readonly ISupportSdcaClassificationLoss _loss; private readonly Arguments _args; - private int _numClasses; - public override PredictionKind PredictionKind - { - get { return PredictionKind.MultiClassClassification; } - } + public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - protected override int WeightArraySize - { - get - { - Contracts.Assert(_numClasses > 0, "_numClasses should already have been initialized when this property is called."); - return _numClasses; - } - } + public override bool NeedCalibration => false; public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args) : base(args, env, LoadNameValue) @@ -70,8 +59,6 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args) _args = args; } - public override bool NeedCalibration { get { return false; } } - /// protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, IRandom rand, IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, Float[] biasReg, Float[] invariants, Float lambdaNInv, @@ -82,11 +69,9 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa Contracts.AssertValueOrNull(idToIdx); Contracts.AssertValueOrNull(invariants); Contracts.AssertValueOrNull(featureNormSquared); - int weightArraySize = WeightArraySize; - Contracts.Assert(weightArraySize == _numClasses); - Contracts.Assert(Utils.Size(weights) == weightArraySize); - Contracts.Assert(Utils.Size(biasReg) == weightArraySize); - Contracts.Assert(Utils.Size(biasUnreg) == weightArraySize); + int numClasses = Utils.Size(weights); + Contracts.Assert(Utils.Size(biasReg) == numClasses); + Contracts.Assert(Utils.Size(biasUnreg) == numClasses); int maxUpdateTrials = 2 * numThreads; var l1Threshold = _args.L1Threshold.Value; @@ -101,11 +86,11 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa if (pch != null) pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount)); - Func getIndexFromId = GetIndexFromIdGetter(idToIdx); + Func getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length); while (cursor.MoveNext()) { long idx = getIndexFromId(cursor.Id); - long dualIndexInitPos = idx * weightArraySize; + long dualIndexInitPos = idx * numClasses; var features = cursor.Features; var label = (int)cursor.Label; Float invariant; @@ -139,7 +124,7 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa Float labelAdjustment = 0; // Iterates through all classes. - for (int iClass = 0; iClass < _numClasses; iClass++) + for (int iClass = 0; iClass < numClasses; iClass++) { // Skip the dual/weights/bias update for label class. Will be taken care of at the end. if (iClass == label) @@ -161,9 +146,7 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa dualUpdate -= adjustment; bool success = false; duals.ApplyAt(dualIndex, (long index, ref Float value) => - { - success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual; - }); + success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual); if (success) { @@ -251,24 +234,23 @@ protected override bool CheckConvergence( { Contracts.AssertValue(weights); Contracts.AssertValue(duals); - Contracts.Assert(weights.Length == _numClasses); - Contracts.Assert(duals.Length >= _numClasses * count); + int numClasses = weights.Length; + Contracts.Assert(duals.Length >= numClasses * count); Contracts.AssertValueOrNull(idToIdx); - int weightArraySize = WeightArraySize; - Contracts.Assert(weightArraySize == _numClasses); - Contracts.Assert(Utils.Size(weights) == weightArraySize); - Contracts.Assert(Utils.Size(biasReg) == weightArraySize); - Contracts.Assert(Utils.Size(biasUnreg) == weightArraySize); + Contracts.Assert(Utils.Size(weights) == numClasses); + Contracts.Assert(Utils.Size(biasReg) == numClasses); + Contracts.Assert(Utils.Size(biasUnreg) == numClasses); Contracts.Assert(Utils.Size(metrics) == 6); var reportedValues = new Double?[metrics.Length + 1]; reportedValues[metrics.Length] = iter; var lossSum = new CompensatedSum(); var dualLossSum = new CompensatedSum(); + int numFeatures = weights[0].Length; using (var cursor = cursorFactory.Create()) { long row = 0; - Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx); + Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length); // Iterates through data to compute loss function. while (cursor.MoveNext()) { @@ -279,8 +261,8 @@ protected override bool CheckConvergence( Double subLoss = 0; Double subDualLoss = 0; long idx = getIndexFromIdAndRow(cursor.Id, row); - long dualIndex = idx * _numClasses; - for (int iClass = 0; iClass < _numClasses; iClass++) + long dualIndex = idx * numClasses; + for (int iClass = 0; iClass < numClasses; iClass++) { if (iClass == label) { @@ -290,7 +272,7 @@ protected override bool CheckConvergence( var currentClassOutput = WDot(ref features, ref weights[iClass], biasReg[iClass] + biasUnreg[iClass]); subLoss += _loss.Loss(labelOutput - currentClassOutput, 1); - Contracts.Assert(dualIndex == iClass + idx * _numClasses); + Contracts.Assert(dualIndex == iClass + idx * numClasses); var dual = duals[dualIndex++]; subDualLoss += _loss.DualLoss(1, dual); } @@ -300,7 +282,7 @@ protected override bool CheckConvergence( row++; } - Host.Assert(idToIdx == null || row * WeightArraySize == duals.Length); + Host.Assert(idToIdx == null || row * numClasses == duals.Length); } Contracts.Assert(_args.L2Const.HasValue); @@ -311,7 +293,7 @@ protected override bool CheckConvergence( Double weightsL1Norm = 0; Double weightsL2NormSquared = 0; Double biasRegularizationAdjustment = 0; - for (int iClass = 0; iClass < _numClasses; iClass++) + for (int iClass = 0; iClass < numClasses; iClass++) { weightsL1Norm += VectorUtils.L1Norm(ref weights[iClass]) + Math.Abs(biasReg[iClass]); weightsL2NormSquared += VectorUtils.NormSquared(weights[iClass]) + biasReg[iClass] * biasReg[iClass]; @@ -330,13 +312,14 @@ protected override bool CheckConvergence( metrics[(int)MetricKind.DualityGap] = dualityGap; metrics[(int)MetricKind.BiasUnreg] = biasUnreg[0]; metrics[(int)MetricKind.BiasReg] = biasReg[0]; - metrics[(int)MetricKind.L1Sparsity] = _args.L1Threshold == 0 ? 1 : (Double)weights.Sum(weight => weight.Values.Count(w => w != 0)) / (_numClasses * NumFeatures); + metrics[(int)MetricKind.L1Sparsity] = _args.L1Threshold == 0 ? 1 : weights.Sum( + weight => weight.Values.Count(w => w != 0)) / (numClasses * numFeatures); bool converged = dualityGap / newLoss < _args.ConvergenceTolerance; if (metrics[(int)MetricKind.Loss] < bestPrimalLoss) { - for (int iClass = 0; iClass < _numClasses; iClass++) + for (int iClass = 0; iClass < numClasses; iClass++) { // Maintain a copy of weights and bias with best primal loss thus far. // This is some extra work and uses extra memory, but it seems worth doing it. @@ -358,14 +341,19 @@ protected override bool CheckConvergence( return converged; } - public override TVectorPredictor CreatePredictor() + protected override TVectorPredictor CreatePredictor(VBuffer[] weights, Float[] bias) { - return new MulticlassLogisticRegressionPredictor(Host, Weights, Bias, _numClasses, NumFeatures, null, stats: null); + Host.CheckValue(weights, nameof(weights)); + Host.CheckValue(bias, nameof(bias)); + Host.CheckParam(weights.Length > 0, nameof(weights)); + Host.CheckParam(weights.Length == bias.Length, nameof(weights)); + + return new MulticlassLogisticRegressionPredictor(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null); } - protected override void CheckLabel(RoleMappedData examples) + protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) { - examples.CheckMultiClassLabel(out _numClasses); + examples.CheckMultiClassLabel(out weightSetCount); } protected override Float[] InitializeFeatureNormSquared(int length) diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 512818bba7..e5ab48c0af 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners using TScalarPredictor = IPredictorWithFeatureWeights; /// - public sealed class SdcaRegressionTrainer : SdcaTrainerBase, ITrainer, ITrainerEx + public sealed class SdcaRegressionTrainer : SdcaTrainerBase { public const string LoadNameValue = "SDCAR"; public const string UserNameValue = "Fast Linear Regression (SA-SDCA)"; @@ -51,11 +51,9 @@ public Arguments() private readonly ISupportSdcaRegressionLoss _loss; private readonly Arguments _args; - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; - public override bool NeedCalibration { get { return false; } } - - protected override int WeightArraySize { get { return 1; } } + public override bool NeedCalibration => false; public SdcaRegressionTrainer(IHostEnvironment env, Arguments args) : base(args, env, LoadNameValue) @@ -66,22 +64,16 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args) _args = args; } - public override IPredictor CreatePredictor() - { - Contracts.Assert(WeightArraySize == 1); - Contracts.Assert(Utils.Size(Weights) == 1); - Contracts.Assert(Utils.Size(Bias) == 1); - Host.Check(Weights[0].Length > 0); - VBuffer maybeSparseWeights = VBufferUtils.CreateEmpty(Weights[0].Length); - VBufferUtils.CreateMaybeSparseCopy(ref Weights[0], ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float)); - return new LinearRegressionPredictor(Host, ref maybeSparseWeights, Bias[0]); - } - - TScalarPredictor ITrainer.CreatePredictor() + protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias) { - var predictor = CreatePredictor() as TScalarPredictor; - Contracts.AssertValue(predictor); - return predictor; + Host.CheckParam(Utils.Size(weights) == 1, nameof(weights)); + Host.CheckParam(Utils.Size(bias) == 1, nameof(bias)); + Host.CheckParam(weights[0].Length > 0, nameof(weights)); + + VBuffer maybeSparseWeights = default; + VBufferUtils.CreateMaybeSparseCopy(ref weights[0], ref maybeSparseWeights, + Conversions.Instance.GetIsDefaultPredicate(NumberType.Float)); + return new LinearRegressionPredictor(Host, ref maybeSparseWeights, bias[0]); } protected override Float GetInstanceWeight(FloatLabelCursor cursor) @@ -89,9 +81,10 @@ protected override Float GetInstanceWeight(FloatLabelCursor cursor) return cursor.Weight; } - protected override void CheckLabel(RoleMappedData examples) + protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) { examples.CheckRegressionLabel(); + weightSetCount = 1; } // REVIEW: No extra benefits from using more threads in training. diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 364e29b877..0c00b4a9c8 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -38,7 +38,7 @@ namespace Microsoft.ML.Runtime.Learners /// /// A trainer that trains a predictor that returns random values /// - public sealed class RandomTrainer : TrainerBase + public sealed class RandomTrainer : TrainerBase { internal const string LoadNameValue = "RandomPredictor"; internal const string UserNameValue = "Random Predictor"; @@ -54,29 +54,19 @@ public class Arguments public bool BooleanArg = false; } - private Arguments _args; - public RandomTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { - _args = args; } - public override PredictionKind PredictionKind - { get { return PredictionKind.BinaryClassification; } } - public override bool NeedNormalization - { get { return false; } } - public override bool NeedCalibration - { get { return false; } } - public override bool WantCaching - { get { return false; } } - - public override void Train(RoleMappedData data) - { - } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + public override bool NeedNormalization => false; + public override bool NeedCalibration => false; + public override bool WantCaching => false; - public override RandomPredictor CreatePredictor() + public override RandomPredictor Train(TrainContext context) { + Host.CheckValue(context, nameof(context)); return new RandomPredictor(Host, Host.Rand.Next()); } } @@ -107,16 +97,10 @@ private static VersionInfo GetVersionInfo() private readonly object _instanceLock; private readonly Random _random; - private readonly ColumnType _inputType; - - public override PredictionKind PredictionKind - { get { return PredictionKind.BinaryClassification; } } - public ColumnType InputType - { get { return _inputType; } } - public ColumnType OutputType - { get { return NumberType.Float; } } - public ColumnType DistType - { get { return NumberType.Float; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + public ColumnType InputType { get; } + public ColumnType OutputType => NumberType.Float; + public ColumnType DistType => NumberType.Float; public RandomPredictor(IHostEnvironment env, int seed) : base(env, LoaderSignature) @@ -126,7 +110,7 @@ public RandomPredictor(IHostEnvironment env, int seed) _instanceLock = new object(); _random = new Random(_seed); - _inputType = new VectorType(NumberType.Float); + InputType = new VectorType(NumberType.Float); } /// @@ -211,7 +195,7 @@ private void MapDist(ref VBuffer src, ref Float score, ref Float prob) } // Learns the prior distribution for 0/1 class labels and just outputs that. - public sealed class PriorTrainer : TrainerBase + public sealed class PriorTrainer : TrainerBase { internal const string LoadNameValue = "PriorPredictor"; internal const string UserNameValue = "Prior Predictor"; @@ -220,26 +204,21 @@ public sealed class Arguments { } - private Float _prob; - - public override PredictionKind PredictionKind - { get { return PredictionKind.BinaryClassification; } } - public override bool NeedNormalization - { get { return false; } } - public override bool NeedCalibration - { get { return false; } } - public override bool WantCaching - { get { return false; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + public override bool NeedNormalization => false; + public override bool NeedCalibration => false; + public override bool WantCaching => false; public PriorTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { - _prob = Float.NaN; } - public override void Train(RoleMappedData data) + public override PriorPredictor Train(TrainContext context) { - Contracts.CheckValue(data, nameof(data)); + Contracts.CheckValue(context, nameof(context)); + var data = context.Train; + data.CheckBinaryLabel(); Contracts.CheckParam(data.Schema.Label != null, nameof(data), "Missing Label column"); Contracts.CheckParam(data.Schema.Label.Type == NumberType.Float, nameof(data), "Invalid type for Label column"); @@ -248,11 +227,11 @@ public override void Train(RoleMappedData data) int col = data.Schema.Label.Index; int colWeight = -1; - if (data.Schema.Weight != null && data.Schema.Weight.Type == NumberType.Float) + if (data.Schema.Weight?.Type == NumberType.Float) colWeight = data.Schema.Weight.Index; using (var cursor = data.Data.GetRowCursor(c => c == col || c == colWeight)) { - var getLab = cursor.GetGetter(col); + var getLab = cursor.GetLabelFloatGetter(data); var getWeight = colWeight >= 0 ? cursor.GetGetter(colWeight) : null; Float lab = default(Float); Float weight = 1; @@ -274,13 +253,8 @@ public override void Train(RoleMappedData data) } } - if (pos + neg > 0) - _prob = (Float)(pos / (pos + neg)); - } - - public override PriorPredictor CreatePredictor() - { - return new PriorPredictor(Host, _prob); + Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN; + return new PriorPredictor(Host, prob); } } @@ -304,8 +278,6 @@ private static VersionInfo GetVersionInfo() private readonly Float _prob; private readonly Float _raw; - private readonly ColumnType _inputType; - public PriorPredictor(IHostEnvironment env, Float prob) : base(env, LoaderSignature) { @@ -314,7 +286,7 @@ public PriorPredictor(IHostEnvironment env, Float prob) _prob = prob; _raw = 2 * _prob - 1; // This could be other functions -- logodds for instance - _inputType = new VectorType(NumberType.Float); + InputType = new VectorType(NumberType.Float); } private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx) @@ -328,7 +300,7 @@ private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx) _raw = 2 * _prob - 1; - _inputType = new VectorType(NumberType.Float); + InputType = new VectorType(NumberType.Float); } public static PriorPredictor Create(IHostEnvironment env, ModelLoadContext ctx) @@ -353,12 +325,9 @@ protected override void SaveCore(ModelSaveContext ctx) public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } - public ColumnType InputType - { get { return _inputType; } } - public ColumnType OutputType - { get { return NumberType.Float; } } - public ColumnType DistType - { get { return NumberType.Float; } } + public ColumnType InputType { get; } + public ColumnType OutputType => NumberType.Float; + public ColumnType DistType => NumberType.Float; public ValueMapper GetMapper() { diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index 391b102cb0..2351454709 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -142,9 +142,8 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR args.MinDocumentsInLeafs = _args.NMinForSplit; // Train random forest. - FastForestRegression trainer = new FastForestRegression(_host, args); - trainer.Train(data); - FastForestRegressionPredictor predictor = trainer.CreatePredictor(); + var trainer = new FastForestRegression(_host, args); + var predictor = trainer.Train(data); // Return random forest predictor. ch.Done(); diff --git a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs index 637d75250b..c2b2bead79 100644 --- a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs @@ -33,8 +33,8 @@ public sealed class Arguments public int? NumSlotsToKeep; [Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1)] - public SubComponent>, SignatureFeatureScorerTrainer> Filter = - new SubComponent>, SignatureFeatureScorerTrainer>("SDCA"); + public SubComponent>, SignatureFeatureScorerTrainer> Filter = + new SubComponent>, SignatureFeatureScorerTrainer>("SDCA"); [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for features", ShortName = "feat,col", SortOrder = 3, Purpose = SpecialPurpose.ColumnName)] public string FeatureColumn = DefaultColumnNames.Features; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index afd488915e..b0bc269164 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -685,8 +685,7 @@ public void EntryPointCalibrate() // This tests that the SchemaBindableCalibratedPredictor doesn't get confused if its sub-predictor is already calibrated. var fastForest = new FastForestClassification(Env, new FastForestClassification.Arguments()); var rmd = new RoleMappedData(splitOutput.TrainData[0], "Label", "Features"); - fastForest.Train(rmd); - var ffModel = new PredictorModel(Env, rmd, splitOutput.TrainData[0], fastForest.CreatePredictor()); + var ffModel = new PredictorModel(Env, rmd, splitOutput.TrainData[0], fastForest.Train(rmd)); var calibratedFfModel = Calibrate.Platt(Env, new Calibrate.NoArgumentsInput() { Data = splitOutput.TestData[0], UncalibratedPredictorModel = ffModel }).PredictorModel; var twiceCalibratedFfModel = Calibrate.Platt(Env, @@ -1219,9 +1218,8 @@ public void EntryPointMulticlassPipelineEnsemble() var mlr = new MulticlassLogisticRegression(Env, new MulticlassLogisticRegression.Arguments()); var rmd = new RoleMappedData(data, "Label", "Features"); - mlr.Train(rmd); - predictorModels[i] = new PredictorModel(Env, rmd, data, mlr.CreatePredictor()); + predictorModels[i] = new PredictorModel(Env, rmd, data, mlr.Train(rmd)); var transformModel = new TransformModel(Env, data, splitOutput.TrainData[i]); predictorModels[i] = ModelOperations.CombineTwoModels(Env, diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index 20649b5b25..95852a6e81 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -75,10 +75,9 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); - trainer.Train(trainRoles); + var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data - var pred = trainer.CreatePredictor(); IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); var metrics = Evaluate(env, testDataScorer); CompareMatrics(metrics); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index da208cb3f0..87c3ab4beb 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -79,10 +79,9 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() }); var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); - trainer.Train(trainRoles); + var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data - var pred = trainer.CreatePredictor(); IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); var metrics = EvaluateBinary(env, testDataScorer); ValidateBinaryMetrics(metrics); From fa028207e637827fb72cf7bc9c0ed4815c336f39 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 12 Jul 2018 09:10:58 -0700 Subject: [PATCH 02/13] TrainContext changes --- .../Prediction/TrainContext.cs | 27 ++++++++++++------- .../Trainer/EnsembleTrainerBase.cs | 2 +- .../FastTreeClassification.cs | 12 ++++----- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 12 ++++----- .../FastTreeRegression.cs | 12 ++++----- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 19 ++++++------- src/Microsoft.ML.FastTree/GamTrainer.cs | 4 +-- .../RandomForestClassification.cs | 12 ++++----- .../RandomForestRegression.cs | 12 ++++----- .../KMeansPlusPlusTrainer.cs | 2 +- .../LightGbmTrainerBase.cs | 6 ++--- src/Microsoft.ML.PCA/PcaTrainer.cs | 4 +-- .../FactorizationMachineTrainer.cs | 2 +- .../Standard/LinearClassificationTrainer.cs | 2 +- .../LogisticRegression/LbfgsPredictorBase.cs | 4 +-- .../MultiClass/MetaMulticlassTrainer.cs | 2 +- .../MultiClass/MultiClassNaiveBayesTrainer.cs | 2 +- .../Standard/OlsLinearRegression.cs | 2 +- .../Standard/Online/OnlineLinear.cs | 2 +- .../Standard/Simple/SimpleTrainers.cs | 2 +- 20 files changed, 73 insertions(+), 69 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs index a85e1f8fdf..87515c85f4 100644 --- a/src/Microsoft.ML.Core/Prediction/TrainContext.cs +++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs @@ -7,22 +7,29 @@ namespace Microsoft.ML.Runtime { /// - /// Instances of this class are meant to be constructed and passed to trainers. + /// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed + /// into or . + /// This holds at least a training set, as well as optioonally a predictor. /// public sealed class TrainContext { /// /// The training set. Cannot be null. /// - public RoleMappedData Train { get; } + public RoleMappedData TrainingSet { get; } /// - /// The validation set. Can be null. + /// The validation set. Can be null. Note that passing a non-null validation set into + /// a trainer that does not support validation sets should not be considered an error condition. It + /// should simply be ignored in that case. /// - public RoleMappedData Validation { get; } + public RoleMappedData ValidationSet { get; } /// - /// The initial + /// The initial predictor, for incremental training. Note that if a implementor + /// does not support incremental training, then it can ignore it similarly to how one would ignore + /// . However, if the trainer does support incremental training and there + /// is something wrong with a non-null value of this, then the trainer ought to throw an exception. /// public IPredictor InitialPredictor { get; } @@ -30,9 +37,9 @@ public sealed class TrainContext /// /// Constructor, given a training set and optional other arguments. /// - /// Will be set to , must be specified - /// Will be set to if specified - /// Will be set to if specified + /// Will set to this value. This must be specified + /// Will set to this value if specified + /// Will set to this value if specified public TrainContext(RoleMappedData train, RoleMappedData valid = null, IPredictor initPredictor = null) { Contracts.CheckValue(train, nameof(train)); @@ -42,8 +49,8 @@ public TrainContext(RoleMappedData train, RoleMappedData valid = null, IPredicto // REVIEW: Should there be code here to ensure that the role mappings between the two are compatible? // That is, all the role mappings are the same and the columns between them have identical types? - Train = train; - Validation = valid; + TrainingSet = train; + ValidationSet = valid; InitialPredictor = initPredictor; } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 810eada855..7ab2395569 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -111,7 +111,7 @@ public sealed override TPredictor Train(TrainContext context) using (var ch = Host.Start("Training")) { - var pred = TrainCore(ch, context.Train); + var pred = TrainCore(ch, context.TrainingSet); ch.Done(); return pred; } diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index db266ac71c..bfc77b56cb 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -62,11 +62,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } } + protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized { get { return 0x00010004; } } + protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } } + protected override uint VerCategoricalSplitSerialized => 0x00010005; internal FastTreeBinaryPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -97,7 +97,7 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC return new SchemaBindableCalibratedPredictor(env, predictor, calibrator); } - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; } /// @@ -123,8 +123,8 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) public override IPredictorWithFeatureWeights Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var trainData = context.Train; - ValidData = context.Validation; + var trainData = context.TrainingSet; + ValidData = context.ValidationSet; using (var ch = Host.Start("Training")) { diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 620a6dedb2..9e0e533bfe 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -68,8 +68,8 @@ protected override float GetMaxLabel() public override FastTreeRankingPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var trainData = context.Train; - ValidData = context.Validation; + var trainData = context.TrainingSet; + ValidData = context.ValidationSet; using (var ch = Host.Start("Training")) { @@ -1062,11 +1062,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } } + protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized { get { return 0x00010004; } } + protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } } + protected override uint VerCategoricalSplitSerialized => 0x00010005; internal FastTreeRankingPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -1089,7 +1089,7 @@ public static FastTreeRankingPredictor Create(IHostEnvironment env, ModelLoadCon return new FastTreeRankingPredictor(env, ctx); } - public override PredictionKind PredictionKind { get { return PredictionKind.Ranking; } } + public override PredictionKind PredictionKind => PredictionKind.Ranking; } public static partial class FastTree diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 719bdc781c..b9d4b22ddb 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -55,8 +55,8 @@ public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) public override FastTreeRegressionPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var trainData = context.Train; - ValidData = context.Validation; + var trainData = context.TrainingSet; + ValidData = context.ValidationSet; using (var ch = Host.Start("Training")) { @@ -408,11 +408,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } } + protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized { get { return 0x00010004; } } + protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } } + protected override uint VerCategoricalSplitSerialized => 0x00010005; internal FastTreeRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -438,7 +438,7 @@ public static FastTreeRegressionPredictor Create(IHostEnvironment env, ModelLoad return new FastTreeRegressionPredictor(env, ctx); } - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; } public static partial class FastTree diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index d26c02adc1..bc6d57b1b9 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -42,12 +42,9 @@ public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase private Test _trainRegressionTest; private Test _testRegressionTest; - public override bool NeedCalibration - { - get { return false; } - } + public override bool NeedCalibration => false; - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) : base(env, args) @@ -58,8 +55,8 @@ public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) public override FastTreeTweediePredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var trainData = context.Train; - ValidData = context.Validation; + var trainData = context.TrainingSet; + ValidData = context.ValidationSet; using (var ch = Host.Start("Training")) { @@ -407,11 +404,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010001; } } + protected override uint VerNumFeaturesSerialized => 0x00010001; - protected override uint VerDefaultValueSerialized { get { return 0x00010002; } } + protected override uint VerDefaultValueSerialized => 0x00010002; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010003; } } + protected override uint VerCategoricalSplitSerialized => 0x00010003; internal FastTreeTweediePredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -450,7 +447,7 @@ protected override void Map(ref VBuffer src, ref float dst) dst = MathUtils.ExpSlow(dst); } - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; } public static partial class FastTree diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index dac3eb11d5..9b94b0f866 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -267,8 +267,8 @@ public sealed override TPredictor Train(TrainContext context) using (var ch = Host.Start("Training")) { ch.CheckValue(context, nameof(context)); - ConvertData(context.Train); - InputLength = context.Train.Schema.Feature.Type.ValueCount; + ConvertData(context.TrainingSet); + InputLength = context.TrainingSet.Schema.Feature.Type.ValueCount; TrainCore(ch); var pred = CreatePredictor(); ch.Done(); diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 8affec5a2e..31f3fd6116 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -67,13 +67,13 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010003; } } + protected override uint VerNumFeaturesSerialized => 0x00010003; - protected override uint VerDefaultValueSerialized { get { return 0x00010005; } } + protected override uint VerDefaultValueSerialized => 0x00010005; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010006; } } + protected override uint VerCategoricalSplitSerialized => 0x00010006; - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; internal FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) @@ -140,8 +140,8 @@ public FastForestClassification(IHostEnvironment env, Arguments args) public override IPredictorWithFeatureWeights Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var trainData = context.Train; - ValidData = context.Validation; + var trainData = context.TrainingSet; + ValidData = context.ValidationSet; using (var ch = Host.Start("Training")) { diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index d1265a06bd..9b96c01312 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -53,11 +53,11 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - protected override uint VerNumFeaturesSerialized { get { return 0x00010003; } } + protected override uint VerNumFeaturesSerialized => 0x00010003; - protected override uint VerDefaultValueSerialized { get { return 0x00010005; } } + protected override uint VerDefaultValueSerialized => 0x00010005; - protected override uint VerCategoricalSplitSerialized { get { return 0x00010006; } } + protected override uint VerCategoricalSplitSerialized => 0x00010006; internal FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount) @@ -99,7 +99,7 @@ public static FastForestRegressionPredictor Create(IHostEnvironment env, ModelLo return new FastForestRegressionPredictor(env, ctx); } - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } + public override PredictionKind PredictionKind => PredictionKind.Regression; protected override void Map(ref VBuffer src, ref Float dst) { @@ -165,8 +165,8 @@ public FastForestRegression(IHostEnvironment env, Arguments args) public override FastForestRegressionPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var trainData = context.Train; - ValidData = context.Validation; + var trainData = context.TrainingSet; + ValidData = context.ValidationSet; using (var ch = Host.Start("Training")) { diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 8571ca10b7..88a7f9546e 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -114,7 +114,7 @@ public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) public override KMeansPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var data = context.Train; + var data = context.TrainingSet; data.CheckFeatureFloatVector(out int dimensionality); Contracts.Assert(dimensionality > 0); diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 0d9a9005ea..6220e447ac 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -81,9 +81,9 @@ public override TPredictor Train(TrainContext context) { using (var pch = Host.StartProgressChannel("Loading data for LightGBM")) { - dtrain = LoadTrainingData(ch, context.Train, out catMetaData); - if (context.Validation != null) - dvalid = LoadValidationData(ch, dtrain, context.Validation, catMetaData); + dtrain = LoadTrainingData(ch, context.TrainingSet, out catMetaData); + if (context.ValidationSet != null) + dvalid = LoadValidationData(ch, dtrain, context.ValidationSet, catMetaData); } ch.Done(); } diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index c8a2f6a105..ddc0a09778 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -101,11 +101,11 @@ public override PcaPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - context.Train.CheckFeatureFloatVector(out int dimension); + context.TrainingSet.CheckFeatureFloatVector(out int dimension); using (var ch = Host.Start("Training")) { - var pred = TrainCore(ch, context.Train, dimension); + var pred = TrainCore(ch, context.TrainingSet, dimension); ch.Done(); return pred; } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index 290c269b8b..b616da1a70 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -356,7 +356,7 @@ public override FieldAwareFactorizationMachinePredictor Train(TrainContext conte using (var ch = Host.Start("Training")) using (var pch = Host.StartProgressChannel("Training")) { - var pred = TrainCore(ch, pch, context.Train, context.Validation, initPredictor); + var pred = TrainCore(ch, pch, context.TrainingSet, context.ValidationSet, initPredictor); ch.Done(); return pred; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 29dbbb3447..098d0dc0c9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -69,7 +69,7 @@ public override TPredictor Train(TrainContext context) TPredictor pred; using (var ch = Host.Start("Training")) { - var preparedData = PrepareDataFromTrainingExamples(ch, context.Train, out int weightSetCount); + var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); var initPred = context.InitialPredictor; var linInitPred = (initPred as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; linInitPred = linInitPred ?? initPred as LinearPredictor; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 4c37c4b191..f8b186ec4d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -303,8 +303,8 @@ public override TPredictor Train(TrainContext context) { Contracts.CheckValue(context, nameof(context)); - var data = context.Train; - _srcPredictor = context.Train as TPredictor; + var data = context.TrainingSet; + _srcPredictor = context.TrainingSet as TPredictor; data.CheckFeatureFloatVector(out NumFeatures); CheckLabel(data); data.CheckOptFloatWeight(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 2bfdedfe6c..bfd98df683 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -98,7 +98,7 @@ protected TScalarTrainer GetTrainer() public override TPred Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var data = context.Train; + var data = context.TrainingSet; data.CheckFeatureFloatVector(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index a0d92194d2..cd565d0459 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -59,7 +59,7 @@ public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) public override MultiClassNaiveBayesPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - var data = context.Train; + var data = context.TrainingSet; Host.Check(data.Schema.Label != null, "Missing Label column"); Host.Check(data.Schema.Label.Type == NumberType.Float || data.Schema.Label.Type is KeyType, "Invalid type for Label column, only floats and known-size keys are supported"); diff --git a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs index e413fea94a..ab435ad5d5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs @@ -90,7 +90,7 @@ public override OlsLinearRegressionPredictor Train(TrainContext context) using (var ch = Host.Start("Training")) { ch.CheckValue(context, nameof(context)); - var examples = context.Train; + var examples = context.TrainingSet; ch.CheckParam(examples.Schema.Feature != null, nameof(examples), "Need a feature column"); ch.CheckParam(examples.Schema.Label != null, nameof(examples), "Need a label column"); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 62896ba789..d4901a6bb4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -117,7 +117,7 @@ public override TPredictor Train(TrainContext context) var initPredictor = context.InitialPredictor; var initLinearPred = initPredictor as LinearPredictor ?? (initPredictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; Host.CheckParam(initPredictor == null || initLinearPred != null, nameof(context), "Not a linear predictor."); - var data = context.Train; + var data = context.TrainingSet; data.CheckFeatureFloatVector(out int numFeatures); CheckLabel(data); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 0c00b4a9c8..b01aea3b59 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -217,7 +217,7 @@ public PriorTrainer(IHostEnvironment env, Arguments args) public override PriorPredictor Train(TrainContext context) { Contracts.CheckValue(context, nameof(context)); - var data = context.Train; + var data = context.TrainingSet; data.CheckBinaryLabel(); Contracts.CheckParam(data.Schema.Label != null, nameof(data), "Missing Label column"); Contracts.CheckParam(data.Schema.Label.Type == NumberType.Float, nameof(data), "Invalid type for Label column"); From 0bff1ffe8d51e6e7229698f235ce03c76a6fbf64 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 12 Jul 2018 09:20:25 -0700 Subject: [PATCH 03/13] LightGBM changes --- src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs | 9 ++++----- src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs | 3 ++- src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs | 9 ++++----- src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs | 9 ++++----- src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs | 5 +---- 5 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 4d778cf1a2..2c007b19b1 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -44,10 +44,9 @@ private static VersionInfo GetVersionInfo() } protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized => 0x00010005; + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; internal LightGbmBinaryPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -77,8 +76,6 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC return predictor; return new CalibratedPredictor(env, predictor, calibrator); } - - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } } /// @@ -89,8 +86,10 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase PredictionKind.BinaryClassification; + public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, PredictionKind.BinaryClassification, "LGBBINCL") + : base(env, args, LoadNameValue) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 2f82d505a8..aaa9b39a70 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -29,9 +29,10 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MultiClassClassification; public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, PredictionKind.MultiClassClassification, "LightGBMMulticlass") + : base(env, args, LoadNameValue) { _numClass = -1; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 64c837fd9c..56d159659f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -42,10 +42,9 @@ private static VersionInfo GetVersionInfo() } protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized => 0x00010005; + public override PredictionKind PredictionKind => PredictionKind.Ranking; internal LightGbmRankingPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -67,8 +66,6 @@ public static LightGbmRankingPredictor Create(IHostEnvironment env, ModelLoadCon { return new LightGbmRankingPredictor(env, ctx); } - - public override PredictionKind PredictionKind => PredictionKind.Ranking; } /// @@ -78,8 +75,10 @@ public sealed class LightGbmRankingTrainer : LightGbmTrainerBase PredictionKind.Ranking; + public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, PredictionKind.Ranking, "LightGBMRanking") + : base(env, args, LoadNameValue) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 120bb1fd69..fba74c0c8f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -42,10 +42,9 @@ private static VersionInfo GetVersionInfo() } protected override uint VerNumFeaturesSerialized => 0x00010002; - protected override uint VerDefaultValueSerialized => 0x00010004; - protected override uint VerCategoricalSplitSerialized => 0x00010005; + public override PredictionKind PredictionKind => PredictionKind.Regression; internal LightGbmRegressionPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) @@ -70,8 +69,6 @@ public static LightGbmRegressionPredictor Create(IHostEnvironment env, ModelLoad ctx.CheckAtModel(GetVersionInfo()); return new LightGbmRegressionPredictor(env, ctx); } - - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } } public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase @@ -81,8 +78,10 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase PredictionKind.Regression; + public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) - : base(env, args, PredictionKind.Regression, "LightGBMRegressor") + : base(env, args, LoadNameValue) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 6220e447ac..a7392acb95 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -56,14 +56,13 @@ private sealed class CategoricalMetaData public override bool WantCaching => false; public override bool SupportsValidation => true; - protected internal LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, PredictionKind predictionKind, string name) + protected internal LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) : base(env, name) { Host.CheckValue(args, nameof(args)); Args = args; Options = Args.ToDictionary(Host); - PredictionKind = predictionKind; ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); InitParallelTraining(); } @@ -852,8 +851,6 @@ private static int GetNumSampleRow(int numRow, int numCol) return ret; } - public override PredictionKind PredictionKind { get; } - protected internal abstract TPredictor CreatePredictor(); /// From c16012e78db2b241d44222ee2d126d30bf48800a Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 12 Jul 2018 14:21:40 -0700 Subject: [PATCH 04/13] Rearrangements, review comments. --- src/Microsoft.ML.Core/Prediction/ITrainer.cs | 93 +++++++++++--------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs index 972e6fb3a8..0fcb04979e 100644 --- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs +++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs @@ -26,47 +26,6 @@ namespace Microsoft.ML.Runtime public delegate void SignatureSequenceTrainer(); public delegate void SignatureMatrixRecommendingTrainer(); - /// - /// Interface to provide extra information about a trainer. - /// - public interface ITrainerEx : ITrainer - { - // REVIEW: Ideally trainers should be able to communicate - // something about the type of data they are capable of being trained - // on, e.g., what ColumnKinds they want, how many of each, of what type, - // etc. This interface seems like the most natural conduit for that sort - // of extra information. - - // REVIEW: Can we please have consistent naming here? - // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to - // be 'Needs' / 'Wants' anyway. - - /// - /// Whether the trainer needs to see data in normalized form. - /// - bool NeedNormalization { get; } - - /// - /// Whether the trainer needs calibration to produce probabilities. - /// - bool NeedCalibration { get; } - - /// - /// Whether this trainer could benefit from a cached view of the data. - /// - bool WantCaching { get; } - - bool SupportsValidation { get; } - bool SupportsIncrementalTraining { get; } - } - - // The Trainer (of Factory) can optionally implement this. - public interface IModelCombiner - where TPredictor : IPredictor - { - TPredictor CombineModels(IEnumerable models); - } - public delegate void SignatureModelCombiner(PredictionKind kind); /// @@ -130,6 +89,58 @@ public static TPredictor Train(this ITrainer trainer, Ro => trainer.Train(new TrainContext(trainData)); } + /// + /// Interface to provide extra information about a trainer. + /// + public interface ITrainerEx : ITrainer + { + // REVIEW: Ideally trainers should be able to communicate + // something about the type of data they are capable of being trained + // on, e.g., what ColumnKinds they want, how many of each, of what type, + // etc. This interface seems like the most natural conduit for that sort + // of extra information. + + // REVIEW: Can we please have consistent naming here? + // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to + // be 'Needs' / 'Wants' anyway. + + /// + /// Whether the trainer needs to see data in normalized form. + /// + bool NeedNormalization { get; } + + /// + /// Whether the trainer needs calibration to produce probabilities. + /// + bool NeedCalibration { get; } + + /// + /// Whether this trainer could benefit from a cached view of the data. + /// + bool WantCaching { get; } + + /// + /// Whether the trainer supports validation sets via . + /// Not implementing this interface and returning true from this property is an indication + /// the trainer does not support that. + /// + bool SupportsValidation { get; } + + /// + /// Whether the trainer can support incremental trainers via . + /// Not implementing this interface and returning true from this property is an indication + /// the trainer does not support that. + /// + bool SupportsIncrementalTraining { get; } + } + + // A trainer can optionally implement this to indicate it can combine multiple models into a single predictor. + public interface IModelCombiner + where TPredictor : IPredictor + { + TPredictor CombineModels(IEnumerable models); + } + /// /// Interface implemented by the MetalinearLearners base class. /// Used to distinguish the MetaLinear Learners from the other learners From e864d212186e76f21e6b37060877ea863870ab1f Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 12 Jul 2018 16:01:18 -0700 Subject: [PATCH 05/13] private protected stuff, other comments --- src/Microsoft.ML.Core/Utilities/ObjectPool.cs | 2 +- .../Commands/TrainCommand.cs | 4 +-- .../Training/EarlyStoppingCriteria.cs | 4 +-- src/Microsoft.ML.Data/Training/TrainerBase.cs | 6 ++++- .../SubModelSelector/BaseDiverseSelector.cs | 2 +- .../Trainer/Binary/EnsembleTrainer.cs | 2 +- .../Trainer/EnsembleTrainerBase.cs | 26 +++++++++---------- .../MulticlassDataPartitionEnsembleTrainer.cs | 2 +- .../Regression/RegressionEnsembleTrainer.cs | 2 +- src/Microsoft.ML.FastTree/FastTree.cs | 6 ++--- src/Microsoft.ML.FastTree/GamTrainer.cs | 10 +++---- src/Microsoft.ML.FastTree/Training/Test.cs | 4 +-- .../LightGbmBinaryTrainer.cs | 2 +- .../LightGbmMulticlassTrainer.cs | 2 +- .../LightGbmRankingTrainer.cs | 2 +- .../LightGbmRegressionTrainer.cs | 2 +- .../LightGbmTrainerBase.cs | 14 +++++----- .../Standard/LinearClassificationTrainer.cs | 24 ++++++++--------- 18 files changed, 59 insertions(+), 57 deletions(-) diff --git a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs index 46486dc937..4a65286551 100644 --- a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs +++ b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs @@ -39,7 +39,7 @@ public abstract class ObjectPoolBase public int Count => _pool.Count; public int NumCreated { get { return _numCreated; } } - protected internal ObjectPoolBase() + private protected ObjectPoolBase() { _pool = new ConcurrentBag(); } diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index 86c8d63397..2debbc0c69 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -235,10 +235,10 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData } public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, - SubComponent calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inpPredictor = null) + SubComponent calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null) { ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env); - return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inpPredictor); + return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor); } private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, diff --git a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs index 13cdb126ee..285db8bfe1 100644 --- a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs @@ -139,9 +139,9 @@ public class Arguments : ArgumentsBase public int WindowSize = 5; } - protected internal Queue PastScores; + protected Queue PastScores; - internal MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) + private protected MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) : base(args, lowerIsBetter) { Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1]."); diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs index f24f77292c..f9bbefbcb0 100644 --- a/src/Microsoft.ML.Data/Training/TrainerBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs @@ -7,7 +7,11 @@ namespace Microsoft.ML.Runtime.Training public abstract class TrainerBase : ITrainer, ITrainerEx where TPredictor : IPredictor { - public const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features."; + /// + /// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid + /// instances were able to be found. + /// + protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features."; protected IHost Host { get; } diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs index 56f7e1edfc..5e31b2c8f5 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs @@ -23,7 +23,7 @@ public abstract class DiverseSelectorArguments : ArgumentsBase private readonly IComponentFactory> _diversityMetricType; private ConcurrentDictionary>, TOutput[]> _predictions; - protected internal BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name, + private protected BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name, IComponentFactory> diversityMetricType) : base(args, env, name) { diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 7b3d7397e8..ab2dff5045 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -62,7 +62,7 @@ public EnsembleTrainer(IHostEnvironment env, Arguments args) public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - protected internal override TScalarPredictor CreatePredictor(List> models) + private protected override TScalarPredictor CreatePredictor(List> models) { if (models.All(m => m.Predictor is TDistPredictor)) return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(models), Combiner); diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 7ab2395569..83ec285e96 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -59,17 +59,17 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel private const int DefaultNumModels = 50; /// Command-line arguments - protected internal readonly ArgumentsBase Args; - protected internal readonly int NumModels; + private protected readonly ArgumentsBase Args; + private protected readonly int NumModels; /// Ensemble members - protected internal readonly ITrainer>[] Trainers; + private protected readonly ITrainer>[] Trainers; private readonly ISubsetSelector _subsetSelector; - protected internal ISubModelSelector SubModelSelector; - protected internal IOutputCombiner Combiner; + private protected ISubModelSelector SubModelSelector; + private protected IOutputCombiner Combiner; - protected internal EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) + private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) : base(env, name) { Args = args; @@ -133,7 +133,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data) validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion); var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager; - var Models = new List>>(); + var models = new List>>(); _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion); int batchNumber = 1; @@ -179,16 +179,16 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data) if (stackingTrainer != null) stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host); - Models.AddRange(modelsList); - int modelSize = Utils.Size(Models); + models.AddRange(modelsList); + int modelSize = Utils.Size(models); if (modelSize < Utils.Size(Trainers)) ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers)); ch.Check(modelSize > 0, "Ensemble training resulted in no valid models."); } - return CreatePredictor(Models); + return CreatePredictor(models); } - protected internal abstract TPredictor CreatePredictor(List>> models); + private protected abstract TPredictor CreatePredictor(List>> models); private bool EnsureMinimumFeaturesSelected(Subset subset) { @@ -203,7 +203,7 @@ private bool EnsureMinimumFeaturesSelected(Subset subset) return false; } - protected internal virtual void PrintMetrics(IChannel ch, List>> models) + private protected virtual void PrintMetrics(IChannel ch, List>> models) { // REVIEW: The formatting of this method is bizarre and seemingly not even self-consistent // w.r.t. its usage of |. Is this intentional? @@ -216,7 +216,7 @@ protected internal virtual void PrintMetrics(IChannel ch, List string.Format("| {0} |", m.Value))), model.Predictor.GetType().Name); } - protected internal static FeatureSubsetModel[] CreateModels(List>> models) where T : IPredictor + private protected static FeatureSubsetModel[] CreateModels(List>> models) where T : IPredictor { var subsetModels = new FeatureSubsetModel[models.Count]; for (int i = 0; i < models.Count; i++) diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 0e4f0a043e..4421cd5838 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -63,7 +63,7 @@ public MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments ar public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - protected internal override EnsembleMultiClassPredictor CreatePredictor(List> models) + private protected override EnsembleMultiClassPredictor CreatePredictor(List> models) { return new EnsembleMultiClassPredictor(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner); } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index bf7671d60d..1cc36f20cd 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -57,7 +57,7 @@ public RegressionEnsembleTrainer(IHostEnvironment env, Arguments args) public override PredictionKind PredictionKind => PredictionKind.Regression; - protected internal override TScalarPredictor CreatePredictor(List> models) + private protected override TScalarPredictor CreatePredictor(List> models) { return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner); } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 25f44712e6..86df5b5893 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -89,7 +89,7 @@ public abstract class FastTreeTrainerBase : public override bool SupportsValidation => true; - protected internal FastTreeTrainerBase(IHostEnvironment env, TArgs args) + private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) { Host.CheckValue(args, nameof(args)); @@ -1880,7 +1880,7 @@ private void MakeBoundariesAndCheckLabels(out long missingInstances, out long to missingInstances = cursor.BadFeaturesRowCount; } - ch.Check(totalInstances > 0, TrainerBase.NoTrainingInstancesMessage); + ch.Check(totalInstances > 0, "All instances skipped due to missing features."); if (missingInstances > 0) ch.Warning("Skipped {0} instances with missing features during training", missingInstances); @@ -2806,7 +2806,7 @@ public abstract class FastTreePredictionWrapper : public bool CanSavePfa => true; public bool CanSaveOnnx => true; - protected internal FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs) + protected FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs) : base(env, name) { Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble)); diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 9b94b0f866..366e73a4f0 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -75,7 +75,7 @@ internal override void CheckLabel(RoleMappedData data) data.CheckRegressionLabel(); } - protected internal override RegressionGamPredictor CreatePredictor() + private protected override RegressionGamPredictor CreatePredictor() { return new RegressionGamPredictor(Host, InputLength, TrainSet, BinEffects, FeatureMap); } @@ -137,7 +137,7 @@ private bool[] ConvertTargetsToBool(double[] targets) return boolArray; } - protected internal override BinaryClassGamPredictor CreatePredictor() + private protected override BinaryClassGamPredictor CreatePredictor() { return new BinaryClassGamPredictor(Host, InputLength, TrainSet, BinEffects, FeatureMap); } @@ -231,7 +231,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight public override bool WantCaching => false; - protected internal GamTrainerBase(IHostEnvironment env, TArgs args) + private protected GamTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) { Contracts.CheckValue(env, nameof(env)); @@ -276,7 +276,7 @@ public sealed override TPredictor Train(TrainContext context) } } - protected internal abstract TPredictor CreatePredictor(); + private protected abstract TPredictor CreatePredictor(); internal abstract void CheckLabel(RoleMappedData data); @@ -571,7 +571,7 @@ public abstract class GamPredictorBase : PredictorBase, public ColumnType OutputType => NumberType.Float; - protected internal GamPredictorBase(IHostEnvironment env, string name, int inputLength, Dataset trainSet, double[][] binEffects, int[] featureMap) + private protected GamPredictorBase(IHostEnvironment env, string name, int inputLength, Dataset trainSet, double[][] binEffects, int[] featureMap) : base(env, name) { Host.CheckValue(trainSet, nameof(trainSet)); diff --git a/src/Microsoft.ML.FastTree/Training/Test.cs b/src/Microsoft.ML.FastTree/Training/Test.cs index f000a1fa32..f1a264a5e4 100644 --- a/src/Microsoft.ML.FastTree/Training/Test.cs +++ b/src/Microsoft.ML.FastTree/Training/Test.cs @@ -207,8 +207,8 @@ public class TestHistory : Test protected IList History; protected int Iteration { get; private set; } - public TestResult BestResult { get; protected internal set; } - public int BestIteration { get; protected internal set; } + public TestResult BestResult { get; private protected set; } + public int BestIteration { get; private protected set; } // scenarioWithoutHistory - simple test scenario we want to track the history and look for best iteration // lossIndex - index of lossFunction in case Test returns more than one loss (default should be 0) diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 2c007b19b1..c25b196b0e 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -93,7 +93,7 @@ public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) { } - protected internal override IPredictorWithFeatureWeights CreatePredictor() + private protected override IPredictorWithFeatureWeights CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index aaa9b39a70..97649feb2a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -54,7 +54,7 @@ private LightGbmBinaryPredictor CreateBinaryPredictor(int classID, string innerA return new LightGbmBinaryPredictor(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs); } - protected internal override OvaPredictor CreatePredictor() + private protected override OvaPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete."); diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 56d159659f..2f44a0ba9d 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -102,7 +102,7 @@ protected override void CheckDataValid(IChannel ch, RoleMappedData data) } } - protected internal override LightGbmRankingPredictor CreatePredictor() + private protected override LightGbmRankingPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index fba74c0c8f..db2f1f268a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -85,7 +85,7 @@ public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) { } - protected internal override LightGbmRegressionPredictor CreatePredictor() + private protected override LightGbmRegressionPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index a7392acb95..e540d8f1ae 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -37,26 +37,26 @@ private sealed class CategoricalMetaData public bool[] IsCategoricalFeature; } - protected internal readonly LightGbmArguments Args; + private protected readonly LightGbmArguments Args; /// /// Stores argumments as objects to convert them to invariant string type in the end so that /// the code is culture agnostic. When retrieving key value from this dictionary as string /// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]). /// - protected internal readonly Dictionary Options; - protected internal readonly IParallel ParallelTraining; + private protected readonly Dictionary Options; + private protected readonly IParallel ParallelTraining; // Store _featureCount and _trainedEnsemble to construct predictor. - protected internal int FeatureCount; - protected internal FastTree.Internal.Ensemble TrainedEnsemble; + private protected int FeatureCount; + private protected FastTree.Internal.Ensemble TrainedEnsemble; public override bool NeedNormalization => false; public override bool NeedCalibration => false; public override bool WantCaching => false; public override bool SupportsValidation => true; - protected internal LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) + private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) : base(env, name) { Host.CheckValue(args, nameof(args)); @@ -851,7 +851,7 @@ private static int GetNumSampleRow(int numRow, int numCol) return ret; } - protected internal abstract TPredictor CreatePredictor(); + private protected abstract TPredictor CreatePredictor(); /// /// This function will be called before training. It will check the label/group and add parameters for specific applications. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 098d0dc0c9..314af27fc6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -985,9 +985,10 @@ public StandardArrayDualsTable(int length) _duals = new Float[length]; } - public override Float this[long index] { - get { return _duals[(int)index]; } - set { _duals[(int)index] = value; } + public override Float this[long index] + { + get => _duals[(int)index]; + set => _duals[(int)index] = value; } public override void ApplyAt(long index, Visitor manip) @@ -1003,7 +1004,7 @@ private sealed class BigArrayDualsTable : DualsTableBase { private BigArray _duals; - public override long Length { get { return _duals.Length; } } + public override long Length => _duals.Length; public BigArrayDualsTable(long length) { @@ -1011,13 +1012,10 @@ public BigArrayDualsTable(long length) _duals = new BigArray(length); } - public override Float this[long index] { - get { - return _duals[index]; - } - set { - _duals[index] = value; - } + public override Float this[long index] + { + get => _duals[index]; + set => _duals[index] = value; } public override void ApplyAt(long index, Visitor manip) @@ -1104,7 +1102,7 @@ protected Func GetIndexFromIdAndRowGetter(IdToIdxLookup idT /// the table growing operation initializes a new larger bucket and rehash the existing entries to /// the new bucket. Such operation has an expected complexity proportional to the size. /// - protected internal sealed class IdToIdxLookup + protected sealed class IdToIdxLookup { // Utilizing this struct gives better cache behavior than using parallel arrays. private struct Entry @@ -1131,7 +1129,7 @@ public Entry(long itNext, UInt128 value) /// /// Gets the count of id entries. /// - public long Count { get { return _count; } } + public long Count => _count; /// /// Initializes an instance of the class with the specified size. From f5766fc5c48d1a6b3ff3b42f91baa69a9aa065df Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 12 Jul 2018 16:33:55 -0700 Subject: [PATCH 06/13] Iris test now fixed at 1 thread to try to increase stability --- .../IrisPlantClassificationTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index 95852a6e81..0535f2b15d 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -70,7 +70,7 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features"); // Train - var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments()); + var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 } ); // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto var cached = new CacheDataView(env, trans, prefetch: null); From da2acdb3af50dc548a2fffa3c51f48bec237cf55 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 13 Jul 2018 05:39:32 -0700 Subject: [PATCH 07/13] TrainerInfo introduction, ITrainerEx destruction --- src/Microsoft.ML.Core/Prediction/ITrainer.cs | 51 ++----------- .../Prediction/TrainerInfo.cs | 75 +++++++++++++++++++ .../Commands/CrossValidationCommand.cs | 2 +- .../Commands/TrainCommand.cs | 29 +++---- .../Commands/TrainTestCommand.cs | 2 +- .../EntryPoints/InputBase.cs | 3 +- .../Prediction/Calibrator.cs | 3 +- src/Microsoft.ML.Data/Training/TrainerBase.cs | 9 +-- .../Transforms/NormalizeTransform.cs | 29 ++----- .../OutputCombiners/BaseStacking.cs | 2 +- .../Trainer/EnsembleTrainerBase.cs | 17 ++--- src/Microsoft.ML.FastTree/FastTree.cs | 10 +-- .../FastTreeClassification.cs | 2 - src/Microsoft.ML.FastTree/FastTreeRanking.cs | 2 - .../FastTreeRegression.cs | 2 - src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 2 - src/Microsoft.ML.FastTree/GamTrainer.cs | 8 +- .../RandomForestClassification.cs | 6 +- .../RandomForestRegression.cs | 2 - .../KMeansPlusPlusTrainer.cs | 9 +-- .../LightGbmTrainerBase.cs | 6 +- src/Microsoft.ML.PCA/PcaTrainer.cs | 14 ++-- .../FactorizationMachineTrainer.cs | 5 +- .../Standard/LinearClassificationTrainer.cs | 19 ++--- .../LogisticRegression/LbfgsPredictorBase.cs | 13 ++-- .../LogisticRegression/LogisticRegression.cs | 2 - .../MulticlassLogisticRegression.cs | 2 - .../MultiClass/MetaMulticlassTrainer.cs | 12 +-- .../MultiClass/MultiClassNaiveBayesTrainer.cs | 8 +- .../Standard/OlsLinearRegression.cs | 11 ++- .../Standard/Online/AveragedPerceptron.cs | 7 +- .../Standard/Online/LinearSvm.cs | 7 +- .../Standard/Online/OnlineGradientDescent.cs | 5 -- .../Standard/Online/OnlineLinear.cs | 20 ++--- .../PoissonRegression/PoissonRegression.cs | 2 - .../Standard/SdcaMultiClass.cs | 4 +- .../Standard/SdcaRegression.cs | 2 - .../Standard/Simple/SimpleTrainers.cs | 16 ++-- 38 files changed, 182 insertions(+), 238 deletions(-) create mode 100644 src/Microsoft.ML.Core/Prediction/TrainerInfo.cs diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs index 0fcb04979e..31a45d7183 100644 --- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs +++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs @@ -34,6 +34,12 @@ namespace Microsoft.ML.Runtime /// public interface ITrainer { + /// + /// Auxiliary information about the trainer in terms of its capabilities + /// and requirements. + /// + TrainerInfo Info { get; } + /// /// Return the type of prediction task for the produced predictor. /// @@ -89,51 +95,6 @@ public static TPredictor Train(this ITrainer trainer, Ro => trainer.Train(new TrainContext(trainData)); } - /// - /// Interface to provide extra information about a trainer. - /// - public interface ITrainerEx : ITrainer - { - // REVIEW: Ideally trainers should be able to communicate - // something about the type of data they are capable of being trained - // on, e.g., what ColumnKinds they want, how many of each, of what type, - // etc. This interface seems like the most natural conduit for that sort - // of extra information. - - // REVIEW: Can we please have consistent naming here? - // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to - // be 'Needs' / 'Wants' anyway. - - /// - /// Whether the trainer needs to see data in normalized form. - /// - bool NeedNormalization { get; } - - /// - /// Whether the trainer needs calibration to produce probabilities. - /// - bool NeedCalibration { get; } - - /// - /// Whether this trainer could benefit from a cached view of the data. - /// - bool WantCaching { get; } - - /// - /// Whether the trainer supports validation sets via . - /// Not implementing this interface and returning true from this property is an indication - /// the trainer does not support that. - /// - bool SupportsValidation { get; } - - /// - /// Whether the trainer can support incremental trainers via . - /// Not implementing this interface and returning true from this property is an indication - /// the trainer does not support that. - /// - bool SupportsIncrementalTraining { get; } - } - // A trainer can optionally implement this to indicate it can combine multiple models into a single predictor. public interface IModelCombiner where TPredictor : IPredictor diff --git a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs new file mode 100644 index 0000000000..0514285974 --- /dev/null +++ b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Runtime +{ + /// + /// Instances of this class posses information about trainers, in terms of their requirements and capabilities. + /// The intended usage is as the value for . + /// + public sealed class TrainerInfo + { + // REVIEW: Ideally trainers should be able to communicate + // something about the type of data they are capable of being trained + // on, e.g., what ColumnKinds they want, how many of each, of what type, + // etc. This interface seems like the most natural conduit for that sort + // of extra information. + + // REVIEW: Can we please have consistent naming here? + // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to + // be 'Needs' / 'Wants' anyway. + + /// + /// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce + /// normalization here. + /// + public bool NeedNormalization { get; } + + /// + /// Whether the trainer needs calibration to produce probabilities. As a general rule only trainers that produce + /// binary classifier predictors that also do not have a natural probabilistic interpretation should have a + /// true value here. + /// + public bool NeedCalibration { get; } + + /// + /// Whether this trainer could benefit from a cached view of the data. Trainers that have few passes over the + /// data, or that need to build their own custom data structure over the data, will have a false here. + /// + public bool WantCaching { get; } + + /// + /// Whether the trainer supports validation sets via . Not implementing + /// this interface and returning true from this property is an indication the trainer does not support + /// that. + /// + public bool SupportsValidation { get; } + + /// + /// Whether the trainer can support incremental trainers via . Not + /// implementing this interface and returning true from this property is an indication the trainer does + /// not support that. + /// + public bool SupportsIncrementalTraining { get; } + + /// + /// Initializes with the given parameters. The parameters have default values for the most typical values + /// for most classical trainers. + /// + /// The value for the property + /// The value for the property + /// The value for the property + /// The value for the property + /// The value for the property + public TrainerInfo(bool normalization = true, bool calibration = false, bool caching = true, + bool supportValid = false, bool supportIncrementalTrain = false) + { + NeedNormalization = normalization; + NeedCalibration = calibration; + WantCaching = caching; + SupportsValidation = supportValid; + SupportsIncrementalTraining = supportIncrementalTrain; + } + } +} diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 23b1c601c9..26ec32d3fe 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -538,7 +538,7 @@ private FoldResult RunFold(int fold) if (_getValidationDataView != null) { ch.Assert(_applyTransformsToValidationData != null); - if (!TrainUtils.CanUseValidationData(trainer)) + if (!trainer.Info.SupportsValidation) ch.Warning("Trainer does not accept validation dataset."); else { diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index 2debbc0c69..1c25275c3e 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -163,7 +163,7 @@ private void RunCore(IChannel ch, string cmd) RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { - if (!TrainUtils.CanUseValidationData(trainer)) + if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } @@ -242,7 +242,7 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData } private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, - ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inpPredictor = null) + ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); @@ -250,31 +250,24 @@ private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappe ch.CheckValue(trainer, nameof(trainer)); ch.CheckNonEmpty(name, nameof(name)); ch.CheckValueOrNull(validData); - ch.CheckValueOrNull(inpPredictor); + ch.CheckValueOrNull(inputPredictor); AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); - var trainerEx = trainer as ITrainerEx; - if (inpPredictor != null && trainerEx?.SupportsIncrementalTraining != true) + if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining) { ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) + ": Trainer does not support incremental training."); - inpPredictor = null; + inputPredictor = null; } - ch.Assert(validData == null || CanUseValidationData(trainer)); - var predictor = trainer.Train(new TrainContext(data, validData, inpPredictor)); + ch.Assert(validData == null || trainer.Info.SupportsValidation); + var predictor = trainer.Train(new TrainContext(data, validData, inputPredictor)); return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data); } - public static bool CanUseValidationData(ITrainer trainer) - { - Contracts.CheckValue(trainer, nameof(trainer)); - return (trainer as ITrainerEx)?.SupportsValidation ?? false; - } - public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor) { Contracts.AssertValue(env); @@ -388,9 +381,8 @@ public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositor IDataView pipeStart; var xfs = BacktrackPipe(dataPipe, out pipeStart); - IDataLoader loader; Action saveAction; - if (!blankLoader && (loader = pipeStart as IDataLoader) != null) + if (!blankLoader && pipeStart is IDataLoader loader) saveAction = loader.Save; else { @@ -460,7 +452,7 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra if (autoNorm != NormalizeOption.Yes) { DvBool isNormalized = DvBool.False; - if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol)) + if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol)) { ch.Info("Not adding a normalizer."); return false; @@ -491,8 +483,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer ch.AssertValue(trainer, nameof(trainer)); ch.AssertValue(data, nameof(data)); - ITrainerEx trainerEx = trainer as ITrainerEx; - bool shouldCache = cacheData ?? (!(data.Data is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching)); + bool shouldCache = cacheData ?? !(data.Data is BinaryLoader) && trainer.Info.WantCaching; if (shouldCache) { diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index 7c4249c6ee..03ee7cdf12 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -152,7 +152,7 @@ private void RunCore(IChannel ch, string cmd) RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { - if (!TrainUtils.CanUseValidationData(trainer)) + if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs index bc45a929d4..94b67af670 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs @@ -164,9 +164,8 @@ public static TOut Train(IHost host, TArg input, } case CachingOptions.Auto: { - ITrainerEx trainerEx = trainer as ITrainerEx; // REVIEW: we should switch to hybrid caching in future. - if (!(input.TrainingData is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching)) + if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching) // default to Memory so mml is on par with maml cachingType = Cache.CachingType.Memory; break; diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 835ba5d99a..1328bb006a 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -687,8 +687,7 @@ public static class CalibratorUtils private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator, ITrainer trainer, IPredictor predictor, RoleMappedSchema schema) { - var trainerEx = trainer as ITrainerEx; - if (trainerEx == null || !trainerEx.NeedCalibration) + if (!trainer.Info.NeedCalibration) { ch.Info("Not training a calibrator because it is not needed."); return false; diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs index f9bbefbcb0..ca2f2c7b64 100644 --- a/src/Microsoft.ML.Data/Training/TrainerBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs @@ -4,7 +4,7 @@ namespace Microsoft.ML.Runtime.Training { - public abstract class TrainerBase : ITrainer, ITrainerEx + public abstract class TrainerBase : ITrainer where TPredictor : IPredictor { /// @@ -17,12 +17,7 @@ public abstract class TrainerBase : ITrainer, ITrainerEx public string Name { get; } public abstract PredictionKind PredictionKind { get; } - public abstract bool NeedNormalization { get; } - public abstract bool NeedCalibration { get; } - public abstract bool WantCaching { get; } - - public virtual bool SupportsValidation => false; - public virtual bool SupportsIncrementalTraining => false; + public abstract TrainerInfo Info { get; } protected TrainerBase(IHostEnvironment env, string name) { diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index f3c560c8dc..7a4738d5e4 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -205,9 +205,8 @@ private NormalizeTransform(IHost host, ArgumentsBase args, IDataView input, /// /// The host environment to use to potentially instantiate the transform /// The role-mapped data that is potentially going to be modified by this method. - /// The trainer to query with . - /// This method will not modify if the return from that is null or - /// false. + /// The trainer to query as to whether it wants normalization. If the + /// 's is true /// True if the normalizer was applied and was modified public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer) { @@ -215,14 +214,12 @@ public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, env.CheckValue(data, nameof(data)); env.CheckValue(trainer, nameof(trainer)); - // If this is false or null, we do not want to normalize. - if (trainer.NeedNormalization() != true) - return false; - // If this is true or null, we do not want to normalize. - if (data.Schema.FeaturesAreNormalized() != false) + // If the trainer does not need normalization, or if the features either don't exist + // or are not normalized, return false. + if (!trainer.Info.NeedNormalization || data.Schema.FeaturesAreNormalized() != false) return false; var featInfo = data.Schema.Feature; - env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value. + env.AssertValue(featInfo); // Should be defined, if FeaturesAreNormalized returned a definite value. var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name); data = new RoleMappedData(view, data.Schema.GetColumnRoleNames()); @@ -363,20 +360,6 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou public static class NormalizeUtils { - /// - /// Tells whether the trainer wants normalization. - /// - /// This method works via testing whether the trainer implements the optional interface - /// , via the Boolean property. - /// If does not implement that interface, then we return null - /// The trainer to query - /// Whether the trainer wants normalization - public static bool? NeedNormalization(this ITrainer trainer) - { - Contracts.CheckValue(trainer, nameof(trainer)); - return (trainer as ITrainerEx)?.NeedNormalization; - } - /// /// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not /// specified on the schema, then this will return null. diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index f38aa2329f..f30174a31d 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -188,7 +188,7 @@ public void Train(List>> models, var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); var trainer = BasePredictorType.CreateInstance(host); - if (trainer is ITrainerEx ex && ex.NeedNormalization) + if (trainer.Info.NeedNormalization) ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this."); Meta = trainer.Train(rmd); CheckMeta(); diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 83ec285e96..0a350ef8ee 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -69,6 +69,8 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel private protected ISubModelSelector SubModelSelector; private protected IOutputCombiner Combiner; + public override TrainerInfo Info { get; } + private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) : base(env, name) { @@ -91,20 +93,15 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, Trainers = new ITrainer>[NumModels]; for (int i = 0; i < Trainers.Length; i++) Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host); - NeedNormalization = Trainers.Any(t => t is ITrainerEx nn && nn.NeedNormalization); - NeedCalibration = Trainers.Any(t => t is ITrainerEx nn && nn.NeedCalibration); + // We infer normalization and calibration preferences from the trainers. However, even if the internal trainers + // don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache. + Info = new TrainerInfo( + normalization: Trainers.Any(t => t.Info.NeedNormalization), + calibration: Trainers.Any(t => t.Info.NeedCalibration)); ch.Done(); } } - public override bool NeedNormalization { get; } - - public override bool NeedCalibration { get; } - - // No matter the internal predictors, we are performing multiple passes over the data - // so it is probably appropriate to always cache. - public override bool WantCaching => true; - public sealed override TPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 86df5b5893..eb7c95bb54 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -81,19 +81,19 @@ public abstract class FastTreeTrainerBase : protected string InnerArgs => CmdParser.GetSettings(Host, Args, new TArgs()); - public override bool NeedNormalization => false; - - public override bool WantCaching => false; + public override TrainerInfo Info { get; } public bool HasCategoricalFeatures => Utils.Size(CategoricalFeatures) > 0; - public override bool SupportsValidation => true; - private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) { Host.CheckValue(args, nameof(args)); Args = args; + // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. + // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. + // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. + Info = new TrainerInfo(normalization: false, caching: false, calibration: this is FastForestClassification); int numThreads = Args.NumThreads ?? Environment.ProcessorCount; if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) { diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index bfc77b56cb..4eca3e15fb 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -116,8 +116,6 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) { } - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; public override IPredictorWithFeatureWeights Train(TrainContext context) diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 9e0e533bfe..0e44553dee 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -51,8 +51,6 @@ public sealed partial class FastTreeRankingTrainer : BoostingFastTreeTrainerBase private Test _specialTrainSetTest; private TestHistory _firstTestSetHistory; - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.Ranking; public FastTreeRankingTrainer(IHostEnvironment env, Arguments args) diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index b9d4b22ddb..1e78f4c473 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -43,8 +43,6 @@ public sealed partial class FastTreeRegressionTrainer : BoostingFastTreeTrainerB private Test _trainRegressionTest; private Test _testRegressionTest; - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.Regression; public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index bc6d57b1b9..6c1b56b1eb 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -42,8 +42,6 @@ public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase private Test _trainRegressionTest; private Test _testRegressionTest; - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.Regression; public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 366e73a4f0..8ae1b3a7dc 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -107,7 +107,6 @@ public sealed class Arguments : ArgumentsBase internal const string ShortName = "gam"; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override bool NeedCalibration => true; public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) : base(env, args) { } @@ -225,11 +224,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight protected double[][] BinEffects; protected int[] FeatureMap; - public override bool NeedCalibration => false; - - public override bool NeedNormalization => false; - - public override bool WantCaching => false; + public override TrainerInfo Info { get; } private protected GamTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) @@ -245,6 +240,7 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args) Host.CheckParam(0 < args.NumIterations, nameof(args.NumIterations), "Must be positive."); Args = args; + Info = new TrainerInfo(normalization: false, calibration: this is BinaryClassificationGamTrainer, caching: false); _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2); _entropyCoefficient = Args.EntropyCoefficient * 1e-6; int numThreads = args.NumThreads ?? Environment.ProcessorCount; diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 31f3fd6116..1e04fe6234 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -129,14 +129,14 @@ public sealed class Arguments : FastForestArgumentsBase private bool[] _trainSetLabels; + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + public FastForestClassification(IHostEnvironment env, Arguments args) : base(env, args) { + } - public override bool NeedCalibration => true; - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override IPredictorWithFeatureWeights Train(TrainContext context) { Host.CheckValue(context, nameof(context)); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 9b96c01312..7c45af3e54 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -158,8 +158,6 @@ public FastForestRegression(IHostEnvironment env, Arguments args) { } - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.Regression; public override FastForestRegressionPredictor Train(TrainContext context) diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 88a7f9546e..3049e6a54c 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -82,6 +82,9 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight private readonly InitAlgorithm _initAlgorithm; private readonly int _numThreads; + public override TrainerInfo Info { get; } + public override PredictionKind PredictionKind => PredictionKind.Clustering; + public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { @@ -104,13 +107,9 @@ public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) Host.CheckUserArg(!args.NumThreads.HasValue || args.NumThreads > 0, nameof(args.NumThreads), "Must be either null or a positive integer."); _numThreads = ComputeNumThreads(Host, args.NumThreads); + Info = new TrainerInfo(); } - public override bool NeedNormalization => true; - public override bool NeedCalibration => false; - public override bool WantCaching => true; - public override PredictionKind PredictionKind => PredictionKind.Clustering; - public override KMeansPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index e540d8f1ae..5f0877146d 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -51,10 +51,8 @@ private sealed class CategoricalMetaData private protected int FeatureCount; private protected FastTree.Internal.Ensemble TrainedEnsemble; - public override bool NeedNormalization => false; - public override bool NeedCalibration => false; - public override bool WantCaching => false; - public override bool SupportsValidation => true; + private static TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true); + public override TrainerInfo Info => _info; private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) : base(env, name) diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index ddc0a09778..b06aec335a 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -74,6 +74,9 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight private readonly bool _center; private readonly int _seed; + public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection; + public override TrainerInfo Info { get; } + public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { @@ -81,21 +84,14 @@ public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) Host.CheckUserArg(args.Rank > 0, nameof(args.Rank), "Rank must be positive"); Host.CheckUserArg(args.Oversampling >= 0, nameof(args.Oversampling), "Oversampling must be non-negative"); + // Two passes, only. Probably not worth caching. + Info = new TrainerInfo(caching: false); _rank = args.Rank; _center = args.Center; _oversampling = args.Oversampling; _seed = args.Seed ?? Host.Rand.Next(); } - public override bool NeedNormalization => true; - - public override bool NeedCalibration => false; - - // Two passes, only. Probably not worth caching. - public override bool WantCaching => false; - - public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection; - //Note: the notations used here are the same as in http://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9) public override PcaPredictor Train(TrainContext context) { diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index b616da1a70..58c28b8712 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -74,9 +74,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override bool NeedNormalization => true; - public override bool NeedCalibration => false; - public override bool WantCaching => true; + public override TrainerInfo Info { get; } private readonly int _latentDim; private readonly int _latentDimAligned; private readonly float _lambdaLinear; @@ -105,6 +103,7 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg _shuffle = args.Shuffle; _verbose = args.Verbose; _radius = args.Radius; + Info = new TrainerInfo(); } private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights, diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 314af27fc6..55742820e0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -49,18 +49,17 @@ public abstract class LinearTrainerBase : TrainerBase { protected bool NeedShuffle; - public override bool NeedNormalization => true; - - public override bool WantCaching => true; + public override TrainerInfo Info { get; } /// /// Whether data is to be shuffled every epoch. /// protected abstract bool ShuffleData { get; } - protected LinearTrainerBase(IHostEnvironment env, string name) + private protected LinearTrainerBase(IHostEnvironment env, string name) : base(env, name) { + Info = new TrainerInfo(); } public override TPredictor Train(TrainContext context) @@ -240,8 +239,6 @@ protected enum MetricKind private readonly ArgumentsBase _args; protected ISupportSdcaLoss Loss; - public override bool NeedNormalization => true; - protected override bool ShuffleData => _args.Shuffle; protected SdcaTrainerBase(ArgumentsBase args, IHostEnvironment env, string name) @@ -1355,7 +1352,7 @@ public void Add(Double summand) } } - public sealed class LinearClassificationTrainer : SdcaTrainerBase, ITrainerEx + public sealed class LinearClassificationTrainer : SdcaTrainerBase { public const string LoadNameValue = "SDCA"; public const string UserNameValue = "Fast Linear (SA-SDCA)"; @@ -1389,13 +1386,14 @@ internal override void Check(IHostEnvironment env) public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override bool NeedCalibration => !(_loss is LogLoss); + public override TrainerInfo Info { get; } public LinearClassificationTrainer(IHostEnvironment env, Arguments args) : base(args, env, LoadNameValue) { _loss = args.LossFunction.CreateComponent(env); base.Loss = _loss; + Info = new TrainerInfo(calibration: !(_loss is LogLoss)); NeedShuffle = args.Shuffle; _args = args; _positiveInstanceWeight = _args.PositiveInstanceWeight; @@ -1435,8 +1433,6 @@ public sealed class StochasticGradientDescentClassificationTrainer : public const string UserNameValue = "Hogwild SGD (binary)"; public const string ShortName = "HogwildSGD"; - public override bool SupportsIncrementalTraining => true; - public sealed class Arguments : LearnerInputBaseWithWeight { [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] @@ -1511,13 +1507,14 @@ internal void Check(IHostEnvironment env) public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override bool NeedCalibration => !(_loss is LogLoss); + public override TrainerInfo Info { get; } public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { args.Check(env); _loss = args.LossFunction.CreateComponent(env); + Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true); NeedShuffle = args.Shuffle; _args = args; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index f8b186ec4d..616a1cbe8d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -132,6 +132,8 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight private VBuffer[] _localGradients; private Float[] _localLosses; + public override TrainerInfo Info { get; } + internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, bool showTrainingStats = false) : base(env, name) { @@ -158,6 +160,9 @@ internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, DenseOptimizer = args.DenseOptimizer; ShowTrainingStats = showTrainingStats; EnforceNonNegativity = args.EnforceNonNegativity; + // REVIEW: It's pointless to request caching when we're going to load everything into + // memory, that is, when using multiple threads. So should caching not be requested? + Info = new TrainerInfo(caching: true, supportIncrementalTrain: true); if (EnforceNonNegativity && ShowTrainingStats) { @@ -170,17 +175,9 @@ internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, } } - public override bool NeedNormalization => true; - - // REVIEW: It's pointless to request caching when we're going to load everything into - // memory, that is, when using multiple threads. - public override bool WantCaching => true; - protected virtual int ClassCount => 1; protected int BiasCount => ClassCount; protected int WeightCount => ClassCount * NumFeatures; - public sealed override bool SupportsIncrementalTraining => true; - protected virtual Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory, out VBuffer init, out ITerminationCriterion terminationCriterion) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 994e5f278c..f720b960a4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -55,8 +55,6 @@ public LogisticRegression(IHostEnvironment env, Arguments args) _posWeight = 0; } - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; protected override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index cceb8ea408..5f7712843f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -74,8 +74,6 @@ public MulticlassLogisticRegression(IHostEnvironment env, Arguments args) { } - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; protected override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index bfd98df683..52cd025370 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -40,12 +40,7 @@ public abstract class ArgumentsBase private TScalarTrainer _trainer; public sealed override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public sealed override bool NeedNormalization { get; } - public sealed override bool NeedCalibration => false; - - // No matter what the internal predictor, we're performing many passes - // simply by virtue of this being a meta-trainer. - public sealed override bool WantCaching => true; + public override TrainerInfo Info { get; } internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name) : base(env, name) @@ -55,8 +50,9 @@ internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name) Host.CheckUserArg(Args.PredictorType.IsGood(), nameof(Args.PredictorType)); // Create the first trainer so errors in the args surface early. _trainer = Args.PredictorType.CreateInstance(Host); - var ex = _trainer as ITrainerEx; - NeedNormalization = ex != null && ex.NeedNormalization; + // Regarding caching, no matter what the internal predictor, we're performing many passes + // simply by virtue of this being a meta-trainer, so we will still cache. + Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization); } protected IDataView MapLabelsCore(ColumnType type, RefPredicate equalsTarget, RoleMappedData data, string dstName) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index cd565d0459..67762d935d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -45,15 +45,13 @@ public sealed class Arguments : LearnerInputBaseWithLabel public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public override bool NeedNormalization => false; - - public override bool NeedCalibration => false; - - public override bool WantCaching => false; + public override TrainerInfo Info { get; } public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) : base(env, LoadName) { + Host.CheckValue(args, nameof(args)); + Info = new TrainerInfo(normalization: false, caching: false); } public override MultiClassNaiveBayesPredictor Train(TrainContext context) diff --git a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs index ab435ad5d5..95005f7d39 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs @@ -60,6 +60,9 @@ It assumes that the conditional mean of the dependent variable follows a linear private readonly Float _l2Weight; private readonly bool _perParameterSignificance; + public override PredictionKind PredictionKind => PredictionKind.Regression; + public override TrainerInfo Info { get; } + public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { @@ -67,14 +70,10 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative"); _l2Weight = args.L2Weight; _perParameterSignificance = args.PerParameterSignificance; + // Two passes, only. Probably not worth caching. + Info = new TrainerInfo(caching: false); } - public override bool NeedNormalization => true; - public override bool NeedCalibration => false; - // Two passes, only. Probably not worth caching. - public override bool WantCaching => false; - public override PredictionKind PredictionKind => PredictionKind.Regression; - /// /// In several calculations, we calculate probabilities or other quantities that should range /// from 0 to 1, but because of numerical imprecision may, in entirely innocent circumstances, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index a3316eb959..a2e09b6905 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -49,17 +49,14 @@ public class Arguments : AveragedLinearArguments public int MaxCalibrationExamples = 1000000; } + protected override bool NeedCalibration => true; + public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) : base(args, env, UserNameValue) { LossFunction = Args.LossFunction.CreateComponent(env); } - public override bool NeedCalibration - { - get { return true; } - } - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } protected override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index 6e31e55623..d435539e95 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -80,6 +80,8 @@ public sealed class Arguments : OnlineLinearArguments private Float _weightsUpdateScale; private Float _biasUpdate; + protected override bool NeedCalibration => true; + public LinearSvm(IHostEnvironment env, Arguments args) : base(args, env, UserNameValue) { @@ -87,11 +89,6 @@ public LinearSvm(IHostEnvironment env, Arguments args) Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive); } - public override bool NeedCalibration - { - get { return true; } - } - public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } protected override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs index f0f50dded3..5cbdc01478 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs @@ -58,11 +58,6 @@ public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args) LossFunction = args.LossFunction.CreateComponent(env); } - public override bool NeedCalibration - { - get { return false; } - } - public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } } protected override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index d4901a6bb4..60fe7f9705 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -70,6 +70,10 @@ public abstract class OnlineLinearTrainer : TrainerBase< protected const string UserErrorPositive = "must be positive"; protected const string UserErrorNonNegative = "must be non-negative"; + public override TrainerInfo Info { get; } + + protected virtual bool NeedCalibration => false; + protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name) : base(env, name) { @@ -79,16 +83,12 @@ protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive); Args = args; + // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue. + Info = new TrainerInfo(calibration: NeedCalibration); } - public override bool NeedNormalization => true; - - // REVIEW: This could return true if there are more than 0 iterations, - // if we got around the whole shuffling issue. - public override bool WantCaching => true; - /// - /// Propagates the _weightsScale to the weights vector. + /// Propagates the to the vector. /// protected void ScaleWeights() { @@ -100,9 +100,9 @@ protected void ScaleWeights() } /// - /// Conditionally propagates the _weightsScale to the weights vector when - /// it reaches a scale where additions to weights would start dropping too much - /// precision. ("Too much" is mostly empirically defined.) + /// Conditionally propagates the to the vector + /// when it reaches a scale where additions to weights would start dropping too much precision. + /// ("Too much" is mostly empirically defined.) /// protected void ScaleWeightsIfNeeded() { diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs index a818165f01..94dbb42325 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs @@ -45,8 +45,6 @@ public PoissonRegression(IHostEnvironment env, Arguments args) { } - public override bool NeedCalibration => false; - public override PredictionKind PredictionKind => PredictionKind.Regression; protected override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 2a2b01ac2c..facceffd70 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -30,7 +30,7 @@ namespace Microsoft.ML.Runtime.Learners // SDCA linear multiclass trainer. /// - public class SdcaMultiClassTrainer : SdcaTrainerBase, ITrainerEx + public class SdcaMultiClassTrainer : SdcaTrainerBase { public const string LoadNameValue = "SDCAMC"; public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)"; @@ -48,8 +48,6 @@ public sealed class Arguments : ArgumentsBase public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public override bool NeedCalibration => false; - public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args) : base(args, env, LoadNameValue) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index e5ab48c0af..466625a679 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -53,8 +53,6 @@ public Arguments() public override PredictionKind PredictionKind => PredictionKind.Regression; - public override bool NeedCalibration => false; - public SdcaRegressionTrainer(IHostEnvironment env, Arguments args) : base(args, env, LoadNameValue) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index b01aea3b59..7e81b0ba27 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -54,16 +54,16 @@ public class Arguments public bool BooleanArg = false; } + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + public override TrainerInfo Info { get; } + public RandomTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { + Host.CheckValue(args, nameof(args)); + Info = new TrainerInfo(normalization: false, caching: false); } - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override bool NeedNormalization => false; - public override bool NeedCalibration => false; - public override bool WantCaching => false; - public override RandomPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); @@ -205,13 +205,13 @@ public sealed class Arguments } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override bool NeedNormalization => false; - public override bool NeedCalibration => false; - public override bool WantCaching => false; + public override TrainerInfo Info { get; } public PriorTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { + Host.CheckValue(args, nameof(args)); + Info = new TrainerInfo(normalization: false, caching: false); } public override PriorPredictor Train(TrainContext context) From 82f44e8a3a5fbabb724b8ba6f92e1f758b58e56e Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 13 Jul 2018 07:41:09 -0700 Subject: [PATCH 08/13] No more superclass referencing subclass --- src/Microsoft.ML.FastTree/FastTree.cs | 4 +++- src/Microsoft.ML.FastTree/GamTrainer.cs | 4 +++- src/Microsoft.ML.FastTree/RandomForestClassification.cs | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index eb7c95bb54..29fe439e0a 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -85,6 +85,8 @@ public abstract class FastTreeTrainerBase : public bool HasCategoricalFeatures => Utils.Size(CategoricalFeatures) > 0; + private protected virtual bool NeedCalibration => false; + private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) { @@ -93,7 +95,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. - Info = new TrainerInfo(normalization: false, caching: false, calibration: this is FastForestClassification); + Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration); int numThreads = Args.NumThreads ?? Environment.ProcessorCount; if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) { diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 8ae1b3a7dc..3b3ca9e92f 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -107,6 +107,7 @@ public sealed class Arguments : ArgumentsBase internal const string ShortName = "gam"; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + private protected override bool NeedCalibration => true; public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) : base(env, args) { } @@ -225,6 +226,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight protected int[] FeatureMap; public override TrainerInfo Info { get; } + private protected virtual bool NeedCalibration => false; private protected GamTrainerBase(IHostEnvironment env, TArgs args) : base(env, RegisterName) @@ -240,7 +242,7 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args) Host.CheckParam(0 < args.NumIterations, nameof(args.NumIterations), "Must be positive."); Args = args; - Info = new TrainerInfo(normalization: false, calibration: this is BinaryClassificationGamTrainer, caching: false); + Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false); _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2); _entropyCoefficient = Args.EntropyCoefficient * 1e-6; int numThreads = args.NumThreads ?? Environment.ProcessorCount; diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 1e04fe6234..612313fb93 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -130,11 +130,11 @@ public sealed class Arguments : FastForestArgumentsBase private bool[] _trainSetLabels; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + private protected override bool NeedCalibration => true; public FastForestClassification(IHostEnvironment env, Arguments args) : base(env, args) { - } public override IPredictorWithFeatureWeights Train(TrainContext context) From 390589450bc47399f5454be37d4817409144459f Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 13 Jul 2018 10:01:08 -0700 Subject: [PATCH 09/13] No abbreviations in TrainContext, use static readonly field where convenient. --- .../Prediction/TrainContext.cs | 20 +++++++++---------- .../LightGbmTrainerBase.cs | 2 +- src/Microsoft.ML.PCA/PcaTrainer.cs | 7 ++++--- .../Standard/LinearClassificationTrainer.cs | 4 ++-- .../LogisticRegression/LbfgsPredictorBase.cs | 8 ++++---- .../MultiClass/MultiClassNaiveBayesTrainer.cs | 4 ++-- .../Standard/OlsLinearRegression.cs | 7 ++++--- .../Standard/Simple/SimpleTrainers.cs | 10 ++++++---- 8 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs index 87515c85f4..3464aa4bc9 100644 --- a/src/Microsoft.ML.Core/Prediction/TrainContext.cs +++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs @@ -37,21 +37,21 @@ public sealed class TrainContext /// /// Constructor, given a training set and optional other arguments. /// - /// Will set to this value. This must be specified - /// Will set to this value if specified - /// Will set to this value if specified - public TrainContext(RoleMappedData train, RoleMappedData valid = null, IPredictor initPredictor = null) + /// Will set to this value. This must be specified + /// Will set to this value if specified + /// Will set to this value if specified + public TrainContext(RoleMappedData trainingSet, RoleMappedData validationSet = null, IPredictor initialPredictor = null) { - Contracts.CheckValue(train, nameof(train)); - Contracts.CheckValueOrNull(valid); - Contracts.CheckValueOrNull(initPredictor); + Contracts.CheckValue(trainingSet, nameof(trainingSet)); + Contracts.CheckValueOrNull(validationSet); + Contracts.CheckValueOrNull(initialPredictor); // REVIEW: Should there be code here to ensure that the role mappings between the two are compatible? // That is, all the role mappings are the same and the columns between them have identical types? - TrainingSet = train; - ValidationSet = valid; - InitialPredictor = initPredictor; + TrainingSet = trainingSet; + ValidationSet = validationSet; + InitialPredictor = initialPredictor; } } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 5f0877146d..83e0f7803b 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -51,7 +51,7 @@ private sealed class CategoricalMetaData private protected int FeatureCount; private protected FastTree.Internal.Ensemble TrainedEnsemble; - private static TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true); + private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true); public override TrainerInfo Info => _info; private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index b06aec335a..bebaa49691 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -75,7 +75,10 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight private readonly int _seed; public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection; - public override TrainerInfo Info { get; } + + // The training performs two passes, only. Probably not worth caching. + private static readonly TrainerInfo _info = new TrainerInfo(caching: false); + public override TrainerInfo Info => _info; public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) @@ -84,8 +87,6 @@ public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) Host.CheckUserArg(args.Rank > 0, nameof(args.Rank), "Rank must be positive"); Host.CheckUserArg(args.Oversampling >= 0, nameof(args.Oversampling), "Oversampling must be non-negative"); - // Two passes, only. Probably not worth caching. - Info = new TrainerInfo(caching: false); _rank = args.Rank; _center = args.Center; _oversampling = args.Oversampling; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 55742820e0..5ae866c5d7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -49,7 +49,8 @@ public abstract class LinearTrainerBase : TrainerBase { protected bool NeedShuffle; - public override TrainerInfo Info { get; } + private static readonly TrainerInfo _info = new TrainerInfo(); + public override TrainerInfo Info => _info; /// /// Whether data is to be shuffled every epoch. @@ -59,7 +60,6 @@ public abstract class LinearTrainerBase : TrainerBase private protected LinearTrainerBase(IHostEnvironment env, string name) : base(env, name) { - Info = new TrainerInfo(); } public override TPredictor Train(TrainContext context) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 616a1cbe8d..537364a56a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -132,7 +132,10 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight private VBuffer[] _localGradients; private Float[] _localLosses; - public override TrainerInfo Info { get; } + // REVIEW: It's pointless to request caching when we're going to load everything into + // memory, that is, when using multiple threads. So should caching not be requested? + private static readonly TrainerInfo _info = new TrainerInfo(caching: true, supportIncrementalTrain: true); + public override TrainerInfo Info => _info; internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, bool showTrainingStats = false) : base(env, name) @@ -160,9 +163,6 @@ internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, DenseOptimizer = args.DenseOptimizer; ShowTrainingStats = showTrainingStats; EnforceNonNegativity = args.EnforceNonNegativity; - // REVIEW: It's pointless to request caching when we're going to load everything into - // memory, that is, when using multiple threads. So should caching not be requested? - Info = new TrainerInfo(caching: true, supportIncrementalTrain: true); if (EnforceNonNegativity && ShowTrainingStats) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 67762d935d..94bac1d9ec 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -45,13 +45,13 @@ public sealed class Arguments : LearnerInputBaseWithLabel public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public override TrainerInfo Info { get; } + private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); + public override TrainerInfo Info => _info; public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) : base(env, LoadName) { Host.CheckValue(args, nameof(args)); - Info = new TrainerInfo(normalization: false, caching: false); } public override MultiClassNaiveBayesPredictor Train(TrainContext context) diff --git a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs index 95005f7d39..7f47271f68 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs @@ -61,7 +61,10 @@ It assumes that the conditional mean of the dependent variable follows a linear private readonly bool _perParameterSignificance; public override PredictionKind PredictionKind => PredictionKind.Regression; - public override TrainerInfo Info { get; } + + // The training performs two passes, only. Probably not worth caching. + private static readonly TrainerInfo _info = new TrainerInfo(caching: false); + public override TrainerInfo Info => _info; public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) @@ -70,8 +73,6 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative"); _l2Weight = args.L2Weight; _perParameterSignificance = args.PerParameterSignificance; - // Two passes, only. Probably not worth caching. - Info = new TrainerInfo(caching: false); } /// diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 7e81b0ba27..abfff554c9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -55,13 +55,14 @@ public class Arguments } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override TrainerInfo Info { get; } + + private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); + public override TrainerInfo Info => _info; public RandomTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { Host.CheckValue(args, nameof(args)); - Info = new TrainerInfo(normalization: false, caching: false); } public override RandomPredictor Train(TrainContext context) @@ -205,13 +206,14 @@ public sealed class Arguments } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override TrainerInfo Info { get; } + + private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); + public override TrainerInfo Info => _info; public PriorTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue) { Host.CheckValue(args, nameof(args)); - Info = new TrainerInfo(normalization: false, caching: false); } public override PriorPredictor Train(TrainContext context) From 9d7eb02e684c09ff5e79919562329d9356546f95 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 13 Jul 2018 10:18:46 -0700 Subject: [PATCH 10/13] Remove `IMetaLinearTrainer` --- src/Microsoft.ML.Core/Prediction/ITrainer.cs | 8 -------- .../ExperimentsGenerator.cs | 3 +-- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs index 31a45d7183..b38a742d9a 100644 --- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs +++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs @@ -101,12 +101,4 @@ public interface IModelCombiner { TPredictor CombineModels(IEnumerable models); } - - /// - /// Interface implemented by the MetalinearLearners base class. - /// Used to distinguish the MetaLinear Learners from the other learners - /// - public interface IMetaLinearTrainer - { - } } diff --git a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs index 3c7441c87f..6b184e7b61 100644 --- a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs +++ b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs @@ -110,8 +110,7 @@ public static List GenerateCandidates(IHostEnvironment env, string dataFi //get all the trainers for this task, and generate the initial set of candidates. // Exclude the hidden learners, and the metalinear learners. - var trainers = ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType) - .Where(cls => !cls.IsHidden && !typeof(IMetaLinearTrainer).IsAssignableFrom(cls.Type)); + var trainers = ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType).Where(cls => !cls.IsHidden); var loaderSubComponent = new SubComponent("TextLoader", loaderSettings); string loader = $" loader={loaderSubComponent}"; From 9d34675f71f750409804785aa179c64ff8fb8091 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 17 Jul 2018 09:50:53 -0700 Subject: [PATCH 11/13] Make FT Test.cs just a tiny bit less terrible --- src/Microsoft.ML.FastTree/Training/Test.cs | 35 +++++++++------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.FastTree/Training/Test.cs b/src/Microsoft.ML.FastTree/Training/Test.cs index f1a264a5e4..c0c54d7bb9 100644 --- a/src/Microsoft.ML.FastTree/Training/Test.cs +++ b/src/Microsoft.ML.FastTree/Training/Test.cs @@ -10,10 +10,8 @@ namespace Microsoft.ML.Runtime.FastTree.Internal { - public class TestResult : IComparable + public sealed class TestResult : IComparable { - private double _finalValue; - public enum ValueOperator : int { None = 0, // the final value will be the raw value, @@ -36,33 +34,31 @@ public enum ValueOperator : int // the raw value should be the same constant for all test results. } - public string LossFunctionName { get; private set; } + public string LossFunctionName { get; } /// /// Raw value used for calculating final test result value. /// - public double RawValue { get; private set; } + public double RawValue { get; } /// /// The factor used for calculating final test result value. /// - public double Factor { get; private set; } + public double Factor { get; } /// /// The operator used for calculating final test result value. /// Final value = Operator(RawValue, Factor) /// - public ValueOperator Operator { get; private set; } + public ValueOperator Operator { get; } /// /// Indicates that the lower value of this metric is better /// This is used for early stopping (with TestHistory and TestWindowWithTolerance) /// - public bool LowerIsBetter { get; private set; } + public bool LowerIsBetter { get; } - public double FinalValue { - get { return _finalValue; } - } + public double FinalValue { get; } public TestResult(string lossFunctionName, double rawValue, double factor, bool lowerIsBetter, ValueOperator valueOperator) { @@ -72,7 +68,7 @@ public TestResult(string lossFunctionName, double rawValue, double factor, bool Operator = valueOperator; LowerIsBetter = lowerIsBetter; - CalculateFinalValue(); + FinalValue = CalculateFinalValue(); } public int CompareTo(TestResult o) @@ -124,7 +120,7 @@ public static TestResult FromByteArray(byte[] buffer, ref int offset) (ValueOperator)valueOperator); } - private void CalculateFinalValue() + private double CalculateFinalValue() { switch (Operator) { @@ -133,14 +129,11 @@ private void CalculateFinalValue() case ValueOperator.Min: case ValueOperator.None: case ValueOperator.Sum: - _finalValue = RawValue; - break; + return RawValue; case ValueOperator.Average: - _finalValue = RawValue / Factor; - break; + return RawValue / Factor; case ValueOperator.SqrtAverage: - _finalValue = Math.Sqrt(RawValue / Factor); - break; + return Math.Sqrt(RawValue / Factor); default: throw Contracts.Except("Unsupported value operator: {0}", Operator); } @@ -157,7 +150,7 @@ public abstract class Test //The method returns one or more losses on a given Dataset public abstract IEnumerable ComputeTests(double[] scores); - public Test(ScoreTracker scoreTracker) + private protected Test(ScoreTracker scoreTracker) { ScoreTracker = scoreTracker; if (ScoreTracker != null) @@ -213,7 +206,7 @@ public class TestHistory : Test // scenarioWithoutHistory - simple test scenario we want to track the history and look for best iteration // lossIndex - index of lossFunction in case Test returns more than one loss (default should be 0) // lower is better: are we looking for minimum or maximum of loss function? - public TestHistory(Test scenarioWithoutHistory, int lossIndex) + private protected TestHistory(Test scenarioWithoutHistory, int lossIndex) : base(null) { History = new List(); From 63ec9fa3127c88de6c444842427730bfd304d74f Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 17 Jul 2018 09:51:30 -0700 Subject: [PATCH 12/13] TrainerInfo naming comment removed --- src/Microsoft.ML.Core/Prediction/TrainerInfo.cs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs index 0514285974..cce728e09a 100644 --- a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs +++ b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs @@ -16,10 +16,6 @@ public sealed class TrainerInfo // etc. This interface seems like the most natural conduit for that sort // of extra information. - // REVIEW: Can we please have consistent naming here? - // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to - // be 'Needs' / 'Wants' anyway. - /// /// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce /// normalization here. From 837733a21f0b460bfdc1c98927020057496ef279 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 17 Jul 2018 10:14:22 -0700 Subject: [PATCH 13/13] Do you even build bro --- src/Microsoft.ML.FastTree/Training/Test.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.FastTree/Training/Test.cs b/src/Microsoft.ML.FastTree/Training/Test.cs index c0c54d7bb9..4e72f46372 100644 --- a/src/Microsoft.ML.FastTree/Training/Test.cs +++ b/src/Microsoft.ML.FastTree/Training/Test.cs @@ -206,7 +206,7 @@ public class TestHistory : Test // scenarioWithoutHistory - simple test scenario we want to track the history and look for best iteration // lossIndex - index of lossFunction in case Test returns more than one loss (default should be 0) // lower is better: are we looking for minimum or maximum of loss function? - private protected TestHistory(Test scenarioWithoutHistory, int lossIndex) + internal TestHistory(Test scenarioWithoutHistory, int lossIndex) : base(null) { History = new List();