diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index dc4cb551d0..144c4b56f3 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -33,7 +33,7 @@ internal ModelOperationsCatalog(IHostEnvironment env) /// /// The trained model to be saved. /// A writeable, seekable stream to save to. - public void Save(IDataLoader model, Stream stream) + public void SaveDataLoader(IDataLoader model, Stream stream) { _env.CheckValue(model, nameof(model)); _env.CheckValue(stream, nameof(stream)); @@ -50,10 +50,10 @@ public void Save(IDataLoader model, Stream stream) /// /// The trained model to be saved. /// Path where model should be saved. - public void Save(IDataLoader model, string filePath) + public void SaveDataLoader(IDataLoader model, string filePath) { using (var stream = File.Create(filePath)) - Save(model, stream); + SaveDataLoader(model, stream); } /// @@ -63,7 +63,7 @@ public void Save(IDataLoader model, string filePath) /// The trained model to be saved /// A writeable, seekable stream to save to. public void Save(IDataLoader loader, ITransformer model, Stream stream) => - Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream); + SaveDataLoader(new CompositeDataLoader(loader, new TransformerChain(model)), stream); /// /// Save a transformer model and the loader used to create its input data to the file. @@ -182,12 +182,25 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema) } /// - /// Load the model and its input schema from the stream. + /// Load the model and its input schema from the file. + /// + /// Path to model. + /// Will contain the input schema for the model. If the model was saved using older APIs + /// it may not contain an input schema, in this case will be null. + /// The loaded model. + public ITransformer Load(string modelPath, out DataViewSchema inputSchema) + { + using (var stream = File.OpenRead(modelPath)) + return Load(stream, out inputSchema); + } + + /// + /// Load the model from the stream. /// /// A readable, seekable stream to load from. /// A model of type containing the loader /// and the transformer chain. - public IDataLoader Load(Stream stream) + public IDataLoader LoadDataLoader(Stream stream) { _env.CheckValue(stream, nameof(stream)); @@ -205,6 +218,18 @@ public IDataLoader Load(Stream stream) } } + /// + /// Load the model from the file. + /// + /// Path to model. + /// A model of type containing the loader + /// and the transformer chain. + public IDataLoader LoadDataLoader(string modelPath) + { + using (var stream = File.OpenRead(modelPath)) + return LoadDataLoader(stream); + } + /// /// Load a transformer model and a data loader model from the stream. /// @@ -215,7 +240,7 @@ public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader composite) { loader = composite.Loader; @@ -224,6 +249,18 @@ public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader(); } + /// + /// Load a transformer model and a data loader model from the file. + /// + /// Path to model. + /// The data loader from the model stream. + /// The transformer model from the model stream. + public ITransformer LoadWithDataLoader(string modelPath, out IDataLoader loader) + { + using (var stream = File.OpenRead(modelPath)) + return LoadWithDataLoader(stream, out loader); + } + /// /// Create a prediction engine for one-time prediction. /// diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index 193ddedad5..8549c5eb29 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -61,31 +61,26 @@ public void LoadModelAndExtractPredictor() string modelAndSchemaPath = GetOutputPath(FullTestName + "-model-schema.zip"); _ml.Model.Save(transformerModel, data.Schema, modelAndSchemaPath); string compositeLoaderModelPath = GetOutputPath(FullTestName + "-composite-model.zip"); - _ml.Model.Save(compositeLoaderModel, compositeLoaderModelPath); + _ml.Model.SaveDataLoader(compositeLoaderModel, compositeLoaderModelPath); string loaderAndTransformerModelPath = GetOutputPath(FullTestName + "-loader-transformer.zip"); _ml.Model.Save(loader, transformerModel, loaderAndTransformerModelPath); ITransformer loadedTransformerModel; IDataLoader loadedCompositeLoader; ITransformer loadedTransformerModel1; - using (var fs = File.OpenRead(modelAndSchemaPath)) - loadedTransformerModel = _ml.Model.Load(fs, out var loadedSchema); - using (var fs = File.OpenRead(compositeLoaderModelPath)) - { - // This model can be loaded either as a composite data loader, - // a transformer model + an input schema, or a transformer model + a data loader. - var t = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l); - var t1 = _ml.Model.Load(fs, out var s); - loadedCompositeLoader = _ml.Model.Load(fs); - } - using (var fs = File.OpenRead(loaderAndTransformerModelPath)) - { - // This model can be loaded either as a composite data loader, - // a transformer model + an input schema, or a transformer model + a data loader. - var t = _ml.Model.Load(fs, out var s); - var c = _ml.Model.Load(fs); - loadedTransformerModel1 = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l); - } + loadedTransformerModel = _ml.Model.Load(modelAndSchemaPath, out var loadedSchema); + + // This model can be loaded either as a composite data loader, + // a transformer model + an input schema, or a transformer model + a data loader. + var t = _ml.Model.LoadWithDataLoader(compositeLoaderModelPath, out IDataLoader l); + var t1 = _ml.Model.Load(compositeLoaderModelPath, out var s); + loadedCompositeLoader = _ml.Model.LoadDataLoader(compositeLoaderModelPath); + + // This model can be loaded either as a composite data loader, + // a transformer model + an input schema, or a transformer model + a data loader. + var tt = _ml.Model.Load(loaderAndTransformerModelPath, out var ss); + var c = _ml.Model.LoadDataLoader(loaderAndTransformerModelPath); + loadedTransformerModel1 = _ml.Model.LoadWithDataLoader(loaderAndTransformerModelPath, out IDataLoader ll); var gam = ((loadedTransformerModel as ISingleFeaturePredictionTransformer).Model as CalibratedModelParametersBase).SubModel @@ -125,11 +120,8 @@ public void SaveAndLoadModelWithLoader() IDataLoader loadedModel; ITransformer loadedModelWithoutLoader; DataViewSchema loadedSchema; - using (var fs = File.OpenRead(modelPath)) - { - loadedModel = _ml.Model.Load(fs); - loadedModelWithoutLoader = _ml.Model.Load(fs, out loadedSchema); - } + loadedModel = _ml.Model.LoadDataLoader(modelPath); + loadedModelWithoutLoader = _ml.Model.Load(modelPath, out loadedSchema); // Without deserializing the loader from the model we lose the slot names. data = _ml.Data.LoadFromEnumerable(new[] { new InputData() }); @@ -171,8 +163,7 @@ public void LoadSchemaAndCreateNewData() ITransformer loadedModel; DataViewSchema loadedSchema; - using (var fs = File.OpenRead(modelPath)) - loadedModel = _ml.Model.Load(fs, out loadedSchema); + loadedModel = _ml.Model.Load(modelPath, out loadedSchema); // Without using the schema from the model we lose the slot names. data = _ml.Data.LoadFromEnumerable(new[] { new InputData() }); @@ -190,7 +181,7 @@ public void SaveTextLoaderAndLoad() var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); string modelPath = GetOutputPath(FullTestName + "-model.zip"); - _ml.Model.Save(loader, modelPath); + _ml.Model.SaveDataLoader(loader, modelPath); Load(modelPath, out var loadedWithSchema, out var loadedSchema, out var loadedLoader, out var loadedWithLoader, out var loadedLoaderWithTransformer); @@ -220,7 +211,7 @@ public void SaveCompositeLoaderAndLoad() var model = composite.Fit(file); string modelPath = GetOutputPath(FullTestName + "-model.zip"); - _ml.Model.Save(model, modelPath); + _ml.Model.SaveDataLoader(model, modelPath); Load(modelPath, out var loadedWithSchema, out var loadedSchema, out var loadedLoader, out var loadedWithLoader, out var loadedLoaderWithTransformer); @@ -298,26 +289,24 @@ private void Load(string filename, out ITransformer loadedWithSchema, out DataVi out IDataLoader loadedLoader, out ITransformer loadedWithLoader, out IDataLoader loadedLoaderWithTransformer) { - using (var fs = File.OpenRead(filename)) + + try + { + loadedLoader = _ml.Model.LoadDataLoader(filename); + } + catch (Exception) + { + loadedLoader = null; + } + loadedWithSchema = _ml.Model.Load(filename, out loadedSchema); + try + { + loadedWithLoader = _ml.Model.LoadWithDataLoader(filename, out loadedLoaderWithTransformer); + } + catch (Exception) { - try - { - loadedLoader = _ml.Model.Load(fs); - } - catch (Exception) - { - loadedLoader = null; - } - loadedWithSchema = _ml.Model.Load(fs, out loadedSchema); - try - { - loadedWithLoader = _ml.Model.LoadWithDataLoader(fs, out loadedLoaderWithTransformer); - } - catch (Exception) - { - loadedWithLoader = null; - loadedLoaderWithTransformer = null; - } + loadedWithLoader = null; + loadedLoaderWithTransformer = null; } }