Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/Microsoft.ML.Api/ComponentCreation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;

namespace Microsoft.ML.Runtime.Api
Expand Down Expand Up @@ -304,12 +305,20 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin
env.CheckValue(predictor, nameof(predictor));
env.CheckValueOrNull(trainSchema);

var subComponent = SubComponent.Parse<IDataScorerTransform, SignatureDataScorer>(settings);
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, subComponent);
ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings);
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
var mapper = bindable.Bind(env, data.Schema);
return CreateCore<IDataScorerTransform, SignatureDataScorer>(env, settings, data.Data, mapper, trainSchema);
}

private static ICommandLineComponentFactory ParseScorerSettings(string settings)
{
return CmdParser.CreateComponentFactory(
typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>),
typeof(SignatureDataScorer),
settings);
}

/// <summary>
/// Creates a default data scorer appropriate to the predictor's prediction kind.
/// </summary>
Expand Down
361 changes: 240 additions & 121 deletions src/Microsoft.ML.Core/CommandLine/CmdParser.cs

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,32 @@ public interface IComponentFactory<in TArg1, in TArg2, out TComponent> : ICompon
{
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2);
}

/// <summary>
/// An interface for creating a component when we take three extra parameters (and an <see cref="IHostEnvironment"/>).
/// </summary>
public interface IComponentFactory<in TArg1, in TArg2, in TArg3, out TComponent> : IComponentFactory
{
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3);
}

/// <summary>
/// A class for creating a component when we take three extra parameters
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
/// creates the component.
/// </summary>
public class SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent> : IComponentFactory<TArg1, TArg2, TArg3, TComponent>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimpleComponentFactory [](start = 17, length = 22)

Thanks @eerhardt. As we discussed in #681 perhaps we can hide these implementations of the interface, but we'll do that after we merge this I suppose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will take care of all the SimpleComponentFactory classes being hidden in #681.

{
private Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3)
{
return _factory(env, argument1, argument2, argument3);
}
}
}
21 changes: 11 additions & 10 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr")]
public SubComponent<ITrainer, SignatureTrainer> Trainer = new SubComponent<ITrainer, SignatureTrainer>("AveragedPerceptron");

[Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "<Auto>", SortOrder = 101)]
public SubComponent<IDataScorerTransform, SignatureDataScorer> Scorer;
[Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "<Auto>", SortOrder = 101, SignatureType = typeof(SignatureDataScorer))]
Copy link
Contributor

@Zruty0 Zruty0 Aug 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SignatureDataScorer [](start = 134, length = 19)

I wonder if the information contained in SignatureType is identical to the information contained in typeof(Scorer) ?
As in, aren't these the type parameters you want to invoke with? #Closed

Copy link
Member Author

@eerhardt eerhardt Aug 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter information definitely overlaps, but this isn't duplicate information.

For example, if there are two signature delegates:

public delegate void SignatureSweeper();
public delegate void SignatureTrainer();

We can't discern which one we are trying to load from the component catalog solely on signature (both these signature delegates don't take any more parameters other than the Environment). #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. So the signature is needed for subselecting a list.


In reply to: 210097300 [](ancestors = 210097300)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature is needed in order to find the component in the ComponentCatalog. We can have multiple components with the same name, but are different based on their "signature type". So the components are keyed off both a name and a signature type:

/// <summary>
/// Used for dictionary lookup based on signature and name.
/// </summary>
internal struct Key : IEquatable<Key>
{
public readonly string Name;
public readonly Type Signature;

private static LoadableClassInfo FindClassCore(LoadableClassInfo.Key key)

public IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> Scorer;

[Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", NullName = "<Auto>", SortOrder = 102)]
public SubComponent<IMamlEvaluator, SignatureMamlEvaluator> Evaluator;
Expand Down Expand Up @@ -76,8 +76,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")]
public string ValidationFile;

[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>")]
public SubComponent<ICalibratorTrainer, SignatureCalibrator> Calibrator = new SubComponent<ICalibratorTrainer, SignatureCalibrator>("PlattCalibration");
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory();

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
public int MaxCalibrationExamples = 1000000000;
Expand Down Expand Up @@ -383,9 +383,9 @@ public FoldResult(Dictionary<string, IDataView> metrics, ISchema scoreSchema, Ro
private readonly string _splitColumn;
private readonly int _numFolds;
private readonly SubComponent<ITrainer, SignatureTrainer> _trainer;
private readonly SubComponent<IDataScorerTransform, SignatureDataScorer> _scorer;
private readonly IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> _scorer;
private readonly SubComponent<IMamlEvaluator, SignatureMamlEvaluator> _evaluator;
private readonly SubComponent<ICalibratorTrainer, SignatureCalibrator> _calibrator;
private readonly IComponentFactory<ICalibratorTrainer> _calibrator;
private readonly int _maxCalibrationExamples;
private readonly bool _useThreads;
private readonly bool? _cacheData;
Expand Down Expand Up @@ -423,7 +423,7 @@ public FoldHelper(
Arguments args,
Func<IHostEnvironment, IChannel, IDataView, ITrainer, RoleMappedData> createExamples,
Func<IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToTestData,
SubComponent<IDataScorerTransform, SignatureDataScorer> scorer,
IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> scorer,
SubComponent<IMamlEvaluator, SignatureMamlEvaluator> evaluator,
Func<IDataView> getValidationDataView = null,
Func<IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToValidationData = null,
Expand Down Expand Up @@ -559,11 +559,12 @@ private FoldResult RunFold(int fold)

// Score.
ch.Trace("Scoring and evaluating");
var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer);
ch.Assert(_scorer == null || _scorer is ICommandLineComponentFactory, "CrossValidationCommand should only be used from the command line.");
var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory);
ch.AssertValue(bindable);
var mapper = bindable.Bind(host, testData.Schema);
var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper);
IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema);
var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(mapper);
IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema);

// Save per-fold model.
string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
Expand Down
Loading