diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
index dc4cb551d0..144c4b56f3 100644
--- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
+++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
@@ -33,7 +33,7 @@ internal ModelOperationsCatalog(IHostEnvironment env)
///
/// The trained model to be saved.
/// A writeable, seekable stream to save to.
- public void Save(IDataLoader model, Stream stream)
+ public void SaveDataLoader(IDataLoader model, Stream stream)
{
_env.CheckValue(model, nameof(model));
_env.CheckValue(stream, nameof(stream));
@@ -50,10 +50,10 @@ public void Save(IDataLoader model, Stream stream)
///
/// The trained model to be saved.
/// Path where model should be saved.
- public void Save(IDataLoader model, string filePath)
+ public void SaveDataLoader(IDataLoader model, string filePath)
{
using (var stream = File.Create(filePath))
- Save(model, stream);
+ SaveDataLoader(model, stream);
}
///
@@ -63,7 +63,7 @@ public void Save(IDataLoader model, string filePath)
/// The trained model to be saved
/// A writeable, seekable stream to save to.
public void Save(IDataLoader loader, ITransformer model, Stream stream) =>
- Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream);
+ SaveDataLoader(new CompositeDataLoader(loader, new TransformerChain(model)), stream);
///
/// Save a transformer model and the loader used to create its input data to the file.
@@ -182,12 +182,25 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
}
///
- /// Load the model and its input schema from the stream.
+ /// Load the model and its input schema from the file.
+ ///
+ /// Path to model.
+ /// Will contain the input schema for the model. If the model was saved using older APIs
+ /// it may not contain an input schema, in this case will be null.
+ /// The loaded model.
+ public ITransformer Load(string modelPath, out DataViewSchema inputSchema)
+ {
+ using (var stream = File.OpenRead(modelPath))
+ return Load(stream, out inputSchema);
+ }
+
+ ///
+ /// Load the model from the stream.
///
/// A readable, seekable stream to load from.
/// A model of type containing the loader
/// and the transformer chain.
- public IDataLoader Load(Stream stream)
+ public IDataLoader LoadDataLoader(Stream stream)
{
_env.CheckValue(stream, nameof(stream));
@@ -205,6 +218,18 @@ public IDataLoader Load(Stream stream)
}
}
+ ///
+ /// Load the model from the file.
+ ///
+ /// Path to model.
+ /// A model of type containing the loader
+ /// and the transformer chain.
+ public IDataLoader LoadDataLoader(string modelPath)
+ {
+ using (var stream = File.OpenRead(modelPath))
+ return LoadDataLoader(stream);
+ }
+
///
/// Load a transformer model and a data loader model from the stream.
///
@@ -215,7 +240,7 @@ public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader composite)
{
loader = composite.Loader;
@@ -224,6 +249,18 @@ public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader();
}
+ ///
+ /// Load a transformer model and a data loader model from the file.
+ ///
+ /// Path to model.
+ /// The data loader from the model stream.
+ /// The transformer model from the model stream.
+ public ITransformer LoadWithDataLoader(string modelPath, out IDataLoader loader)
+ {
+ using (var stream = File.OpenRead(modelPath))
+ return LoadWithDataLoader(stream, out loader);
+ }
+
///
/// Create a prediction engine for one-time prediction.
///
diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs
index 193ddedad5..8549c5eb29 100644
--- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs
+++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs
@@ -61,31 +61,26 @@ public void LoadModelAndExtractPredictor()
string modelAndSchemaPath = GetOutputPath(FullTestName + "-model-schema.zip");
_ml.Model.Save(transformerModel, data.Schema, modelAndSchemaPath);
string compositeLoaderModelPath = GetOutputPath(FullTestName + "-composite-model.zip");
- _ml.Model.Save(compositeLoaderModel, compositeLoaderModelPath);
+ _ml.Model.SaveDataLoader(compositeLoaderModel, compositeLoaderModelPath);
string loaderAndTransformerModelPath = GetOutputPath(FullTestName + "-loader-transformer.zip");
_ml.Model.Save(loader, transformerModel, loaderAndTransformerModelPath);
ITransformer loadedTransformerModel;
IDataLoader loadedCompositeLoader;
ITransformer loadedTransformerModel1;
- using (var fs = File.OpenRead(modelAndSchemaPath))
- loadedTransformerModel = _ml.Model.Load(fs, out var loadedSchema);
- using (var fs = File.OpenRead(compositeLoaderModelPath))
- {
- // This model can be loaded either as a composite data loader,
- // a transformer model + an input schema, or a transformer model + a data loader.
- var t = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l);
- var t1 = _ml.Model.Load(fs, out var s);
- loadedCompositeLoader = _ml.Model.Load(fs);
- }
- using (var fs = File.OpenRead(loaderAndTransformerModelPath))
- {
- // This model can be loaded either as a composite data loader,
- // a transformer model + an input schema, or a transformer model + a data loader.
- var t = _ml.Model.Load(fs, out var s);
- var c = _ml.Model.Load(fs);
- loadedTransformerModel1 = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l);
- }
+ loadedTransformerModel = _ml.Model.Load(modelAndSchemaPath, out var loadedSchema);
+
+ // This model can be loaded either as a composite data loader,
+ // a transformer model + an input schema, or a transformer model + a data loader.
+ var t = _ml.Model.LoadWithDataLoader(compositeLoaderModelPath, out IDataLoader l);
+ var t1 = _ml.Model.Load(compositeLoaderModelPath, out var s);
+ loadedCompositeLoader = _ml.Model.LoadDataLoader(compositeLoaderModelPath);
+
+ // This model can be loaded either as a composite data loader,
+ // a transformer model + an input schema, or a transformer model + a data loader.
+ var tt = _ml.Model.Load(loaderAndTransformerModelPath, out var ss);
+ var c = _ml.Model.LoadDataLoader(loaderAndTransformerModelPath);
+ loadedTransformerModel1 = _ml.Model.LoadWithDataLoader(loaderAndTransformerModelPath, out IDataLoader ll);
var gam = ((loadedTransformerModel as ISingleFeaturePredictionTransformer