diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md
index 716e8fac12..727ed415e2 100644
--- a/docs/code/MlNetCookBook.md
+++ b/docs/code/MlNetCookBook.md
@@ -383,19 +383,18 @@ var metrics = mlContext.Regression.Evaluate(model.Transform(testData), labelColu
Assuming that the model metrics look good to you, it's time to 'operationalize' the model. This is where ML.NET really shines: the `model` object you just built is ready for immediate consumption, it will apply all the same steps that it has 'learned' during training, and it can be persisted and reused in different environments.
-Here's what you do to save the model to a file, and reload it (potentially in a different context).
+Here's what you do to save the model as well as its input schema to a file, and reload it (potentially in a different context).
```csharp
-using (var stream = File.Create(modelPath))
-{
- mlContext.Model.Save(model, stream);
-}
+// Saving and loading happens to transformers. We save the input schema with this model.
+mlContext.Model.Save(model, trainData.Schema, modelPath);
// Potentially, the lines below can be in a different process altogether.
-ITransformer loadedModel;
-using (var stream = File.OpenRead(modelPath))
- loadedModel = mlContext.Model.Load(stream);
+// When you load the model, it's a non-specific ITransformer. We also recover
+// the original schema.
+ITransformer loadedModel = mlContext.Model.Load(modelPath, out var schema);
```
+
## How do I use the model to make one prediction?
Since any ML.NET model is a transformer, you can of course use `model.Transform` to apply the model to the 'data view' and obtain predictions this way.
@@ -1018,7 +1017,5 @@ using (var fs = File.Create(modelPath))
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
// Now we can load the model.
-ITransformer loadedModel;
-using (var fs = File.OpenRead(modelPath))
- loadedModel = newContext.Model.Load(fs);
+ITransformer loadedModel = newContext.Model.Load(modelPath, out var schema);
```
diff --git a/docs/code/experimental/MlNetCookBookStaticApi.md b/docs/code/experimental/MlNetCookBookStaticApi.md
index fcdb2c45ae..086e3b8e3b 100644
--- a/docs/code/experimental/MlNetCookBookStaticApi.md
+++ b/docs/code/experimental/MlNetCookBookStaticApi.md
@@ -396,18 +396,13 @@ This is where ML.NET really shines: the `model` object you just built is ready f
Here's what you do to save the model to a file, and reload it (potentially in a different context).
```csharp
-using (var stream = File.Create(modelPath))
-{
- // Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
- mlContext.Model.Save(model.AsDynamic, stream);
-}
+// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
+mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, modelPath);
// Potentially, the lines below can be in a different process altogether.
// When you load the model, it's a 'dynamic' transformer.
-ITransformer loadedModel;
-using (var stream = File.OpenRead(modelPath))
- loadedModel = mlContext.Model.Load(stream);
+ITransformer loadedModel = mlContext.Model.Load(modelPath, out var schema);
```
## How do I use the model to make one prediction?
diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
index dc4cb551d0..311fad8888 100644
--- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
+++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
@@ -29,63 +29,60 @@ internal ModelOperationsCatalog(IHostEnvironment env)
}
///
- /// Save the model to the stream.
+ /// Save a transformer model and the loader used to create its input data to the stream.
///
- /// The trained model to be saved.
+ /// The trained model to be saved. Note that this can be , as a shorthand
+ /// for an empty transformer chain. Upon loading with
+ /// the returned value will be an empty .
+ /// The loader that was used to create data to train the model.
/// A writeable, seekable stream to save to.
- public void Save(IDataLoader model, Stream stream)
+ public void Save(ITransformer model, IDataLoader loader, Stream stream)
{
- _env.CheckValue(model, nameof(model));
+ _env.CheckValue(loader, nameof(loader));
+ _env.CheckValueOrNull(model);
_env.CheckValue(stream, nameof(stream));
+ // For the sake of consistency of this API specifically, when called upon we save any transformer
+ // in a single element transformer chain.
+ var chainedModel = model == null ? null : new TransformerChain(model);
+ var compositeLoader = new CompositeDataLoader(loader, chainedModel);
+
using (var rep = RepositoryWriter.CreateNew(stream))
{
- ModelSaveContext.SaveModel(rep, model, null);
+ ModelSaveContext.SaveModel(rep, compositeLoader, null);
rep.Commit();
}
}
- ///
- /// Save the model to the file.
- ///
- /// The trained model to be saved.
- /// Path where model should be saved.
- public void Save(IDataLoader model, string filePath)
- {
- using (var stream = File.Create(filePath))
- Save(model, stream);
- }
-
- ///
- /// Save a transformer model and the loader used to create its input data to the stream.
- ///
- /// The loader that was used to create data to train the model
- /// The trained model to be saved
- /// A writeable, seekable stream to save to.
- public void Save(IDataLoader loader, ITransformer model, Stream stream) =>
- Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream);
-
///
/// Save a transformer model and the loader used to create its input data to the file.
///
- /// The loader that was used to create data to train the model
- /// The trained model to be saved
+ /// The trained model to be saved. Note that this can be , as a shorthand
+ /// for an empty transformer chain. Upon loading with
+ /// the returned value will be an empty .
+ /// The loader that was used to create data to train the model.
/// Path where model should be saved.
- public void Save(IDataLoader loader, ITransformer model, string filePath)
+ public void Save(ITransformer model, IDataLoader loader, string filePath)
{
+ _env.CheckValueOrNull(model);
+ _env.CheckValue(loader, nameof(loader));
+ _env.CheckNonEmpty(filePath, nameof(filePath));
+
using (var stream = File.Create(filePath))
- Save(loader, model, stream);
+ Save(model, loader, stream);
}
///
/// Save a transformer model and the schema of the data that was used to train it to the stream.
///
- /// The trained model to be saved.
- /// The schema of the input to the transformer. This can be null.
+ /// The trained model to be saved. Note that this can be , as a shorthand
+ /// for an empty transformer chain. Upon loading with the returned value will
+ /// be an empty .
+ /// The schema of the input to the transformer. This can be .
/// A writeable, seekable stream to save to.
public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
{
- _env.CheckValue(model, nameof(model));
+ _env.CheckValueOrNull(model);
_env.CheckValueOrNull(inputSchema);
_env.CheckValue(stream, nameof(stream));
@@ -100,11 +97,17 @@ public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
///
/// Save a transformer model and the schema of the data that was used to train it to the file.
///
- /// The trained model to be saved.
- /// The schema of the input to the transformer. This can be null.
+ /// The trained model to be saved. Note that this can be , as a shorthand
+ /// for an empty transformer chain. Upon loading with the returned value will
+ /// be an empty .
+ /// The schema of the input to the transformer. This can be .
/// Path where model should be saved.
public void Save(ITransformer model, DataViewSchema inputSchema, string filePath)
{
+ _env.CheckValueOrNull(model);
+ _env.CheckValueOrNull(inputSchema);
+ _env.CheckNonEmpty(filePath, nameof(filePath));
+
using (var stream = File.Create(filePath))
Save(model, inputSchema, stream);
}
@@ -126,11 +129,11 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
}
///
- /// Load the model and its input schema from the stream.
+ /// Load the model and its input schema from a stream.
///
/// A readable, seekable stream to load from.
- /// Will contain the input schema for the model. If the model was saved using older APIs
- /// it may not contain an input schema, in this case will be null.
+ /// Will contain the input schema for the model. If the model was saved without
+ /// any description of the input, there will be no input schema. In this case this can be .
/// The loaded model.
public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
{
@@ -171,23 +174,67 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
throw _env.Except(ex, "Could not load legacy format model");
}
}
- if (dataLoader is CompositeDataLoader composite)
- {
- inputSchema = composite.Loader.GetOutputSchema();
- return composite.Transformer;
- }
+ var transformer = DecomposeLoader(ref dataLoader);
inputSchema = dataLoader.GetOutputSchema();
- return new TransformerChain();
+ return transformer;
+ }
+ }
+
+ ///
+ /// Load the model and its input schema from a file.
+ ///
+ /// Path to a file where the model should be read from.
+ /// Will contain the input schema for the model. If the model was saved without
+ /// any description of the input, there will be no input schema. In this case this can be .
+ /// The loaded model.
+ public ITransformer Load(string filePath, out DataViewSchema inputSchema)
+ {
+ _env.CheckNonEmpty(filePath, nameof(filePath));
+
+ using (var stream = File.OpenRead(filePath))
+ return Load(stream, out inputSchema);
+ }
+
+ ///
+ /// Given a loader, test try to "decompose" it into a source loader, and its transform if any.
+ /// If necessary an empty chain will be created to stand in for the trivial transformation; it
+ /// should never return .
+ ///
+ private ITransformer DecomposeLoader(ref IDataLoader loader)
+ {
+ _env.AssertValue(loader);
+
+ if (loader is CompositeDataLoader composite)
+ {
+ loader = composite.Loader;
+ var chain = composite.Transformer;
+ // The save method corresponding to this load method encapsulates the input ITransformer
+ // into a single-element transformer chain. If it is that sort, we guess that it is in fact
+ // that sort, and so return it.
+ var accessor = (ITransformerChainAccessor)chain;
+ if (accessor.Transformers.Length == 1)
+ return accessor.Transformers[0];
+ // If it is some other length than 1 due to, say, some legacy model saving, just return that
+ // chain. Using the above API this is not possible, since the chain saved will always be of length
+ // one, but older APIs behaved differently so we should retain flexibility with those schemes.
+ // (Those schemes are BTW by no means incorrect, they just aren't what the API in this particular
+ // class will specifically do.)
+ return chain;
}
+ // Maybe we have no transformer stored. Rather than return null, we prefer to return the
+ // empty "trivial" transformer chain.
+ return new TransformerChain();
}
///
- /// Load the model and its input schema from the stream.
+ /// Load a transformer model and a data loader model from a stream.
///
/// A readable, seekable stream to load from.
- /// A model of type containing the loader
- /// and the transformer chain.
- public IDataLoader Load(Stream stream)
+ /// The data loader from the model stream. Note that if there is no data loader,
+ /// this method will throw an exception. The scenario where no loader is stored in the stream should
+ /// be handled instead using the method.
+ /// The transformer model from the model stream.
+ public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader loader)
{
_env.CheckValue(stream, nameof(stream));
@@ -195,33 +242,32 @@ public IDataLoader Load(Stream stream)
{
try
{
- ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, null);
- return model;
+ ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out loader, rep, null);
+ return DecomposeLoader(ref loader);
}
catch (Exception ex)
{
- throw _env.Except(ex, "Model does not contain an IDataLoader");
+ throw _env.Except(ex, "Model does not contain an " + nameof(IDataLoader) +
+ ". Perhaps this was saved with an " + nameof(DataViewSchema) + ", or even no information on its input at all. " +
+ "Consider using the " + nameof(Load) + " method instead.");
}
}
}
///
- /// Load a transformer model and a data loader model from the stream.
+ /// Load a transformer model and a data loader model from a file.
///
- /// A readable, seekable stream to load from.
- /// The data loader from the model stream.
- /// The transformer model from the model stream.
- public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader loader)
+ /// Path to a file where the model should be read from.
+ /// The data loader from the model stream. Note that if there is no data loader,
+ /// this method will throw an exception. The scenario where no loader is stored in the stream should
+ /// be handled instead using the method.
+ /// The transformer model from the model file.
+ public ITransformer LoadWithDataLoader(string filePath, out IDataLoader loader)
{
- _env.CheckValue(stream, nameof(stream));
+ _env.CheckNonEmpty(filePath, nameof(filePath));
- loader = Load(stream);
- if (loader is CompositeDataLoader composite)
- {
- loader = composite.Loader;
- return composite.Transformer;
- }
- return new TransformerChain();
+ using (var stream = File.OpenRead(filePath))
+ return LoadWithDataLoader(stream, out loader);
}
///
diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs
index 193ddedad5..4810ad2a09 100644
--- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs
+++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs
@@ -8,7 +8,6 @@
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
-using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Transforms;
using Xunit;
@@ -16,22 +15,12 @@
namespace Microsoft.ML.Functional.Tests
{
- public partial class ModelLoadingTests : BaseTestClass
+ public partial class ModelLoadingTests : TestDataPipeBase
{
- private MLContext _ml;
-
public ModelLoadingTests(ITestOutputHelper output) : base(output)
{
}
- protected override void Initialize()
- {
- base.Initialize();
-
- _ml = new MLContext(42);
- _ml.AddStandardComponents();
- }
-
private class InputData
{
[LoadColumn(0)]
@@ -45,104 +34,133 @@ private class InputData
public void LoadModelAndExtractPredictor()
{
var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename));
- var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file);
+ var loader = ML.Data.CreateTextLoader(hasHeader: true, dataSample: file);
var data = loader.Load(file);
// Pipeline.
- var pipeline = _ml.BinaryClassification.Trainers.Gam();
+ var pipeline = ML.BinaryClassification.Trainers.Gam();
// Define the same pipeline starting with the loader.
- var pipeline1 = loader.Append(_ml.BinaryClassification.Trainers.Gam());
-
+ var pipeline1 = loader.Append(ML.BinaryClassification.Trainers.Gam());
+
// Train.
var transformerModel = pipeline.Fit(data);
var compositeLoaderModel = pipeline1.Fit(file);
- // Save and reload.
+ // Save and reload the "same" model with some differences in structure.
+
+ // In this case we are saving the transformer model, but *not* the loader, just the schema from that loader.
string modelAndSchemaPath = GetOutputPath(FullTestName + "-model-schema.zip");
- _ml.Model.Save(transformerModel, data.Schema, modelAndSchemaPath);
+ ML.Model.Save(transformerModel, data.Schema, modelAndSchemaPath);
+
+ // In this case we have combined the loader with the transformer model to form a "composite" loader, and are just
+ // saving that one loader to this file.
string compositeLoaderModelPath = GetOutputPath(FullTestName + "-composite-model.zip");
- _ml.Model.Save(compositeLoaderModel, compositeLoaderModelPath);
+ ML.Model.Save(null, compositeLoaderModel, compositeLoaderModelPath);
+
+ // In this case we are saving the transformer model, as well as the associated data loader.
string loaderAndTransformerModelPath = GetOutputPath(FullTestName + "-loader-transformer.zip");
- _ml.Model.Save(loader, transformerModel, loaderAndTransformerModelPath);
+ ML.Model.Save(transformerModel, loader, loaderAndTransformerModelPath);
ITransformer loadedTransformerModel;
IDataLoader loadedCompositeLoader;
ITransformer loadedTransformerModel1;
using (var fs = File.OpenRead(modelAndSchemaPath))
- loadedTransformerModel = _ml.Model.Load(fs, out var loadedSchema);
+ loadedTransformerModel = ML.Model.Load(fs, out var loadedSchema);
using (var fs = File.OpenRead(compositeLoaderModelPath))
{
// This model can be loaded either as a composite data loader,
// a transformer model + an input schema, or a transformer model + a data loader.
- var t = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l);
- var t1 = _ml.Model.Load(fs, out var s);
- loadedCompositeLoader = _ml.Model.Load(fs);
+ var t = ML.Model.LoadWithDataLoader(fs, out loadedCompositeLoader);
+ // This is a bit strange, as it seems to test that it can reload from the same
+ // stream twice opened only once, which as far as I know is not really a requirement
+ // of the design or API, but we are nonetheless testing it. If this winds up failing,
+ // I'm not sure we should really insist on this as a design requirement.
+ var t1 = ML.Model.Load(fs, out var s);
+
+ CheckSameSchemas(loadedCompositeLoader.GetOutputSchema(), s);
+ // We combined the GAM with the loader, so the remaining chain should just be empty.
+ Assert.Empty(Assert.IsType>(t));
+ Assert.Empty(Assert.IsType>(t1));
}
using (var fs = File.OpenRead(loaderAndTransformerModelPath))
{
// This model can be loaded either as a composite data loader,
// a transformer model + an input schema, or a transformer model + a data loader.
- var t = _ml.Model.Load(fs, out var s);
- var c = _ml.Model.Load(fs);
- loadedTransformerModel1 = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l);
+ var t = ML.Model.Load(fs, out var s);
+ CheckSameSchemas(loader.GetOutputSchema(), s);
+
+ loadedTransformerModel1 = ML.Model.LoadWithDataLoader(fs, out var l);
+ }
+
+ void AssertIsGam(ITransformer trans)
+ {
+ Assert.IsType(
+ Assert.IsAssignableFrom(
+ Assert.IsAssignableFrom>(trans).Model).SubModel);
}
- var gam = ((loadedTransformerModel as ISingleFeaturePredictionTransformer