diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 7a6f2e87f3..10fced0601 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -102,8 +102,8 @@ internal sealed class Options : TransformInputBase [Argument(ArgumentType.AtMostOnce, HelpText = "Reset the random number generator for each document", ShortName = "reset")] public bool ResetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator; - [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the topic-word summary in text format", ShortName = "summary")] - public bool OutputTopicWordSummary; + [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the topic-word summary in text format when saving the model to disk", ShortName = "summary")] + public bool OutputTopicWordSummary = LatentDirichletAllocationEstimator.Defaults.OutputTopicWordSummary; } internal sealed class Column : OneToOneColumn @@ -141,6 +141,9 @@ internal sealed class Column : OneToOneColumn [Argument(ArgumentType.AtMostOnce, HelpText = "Reset the random number generator for each document", ShortName = "reset")] public bool? ResetRandomGenerator; + [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the topic-word summary in text format when saving the model to disk", ShortName = "summary")] + public bool? OutputTopicWordSummary; + internal static Column Parse(string str) { Contracts.AssertNonEmpty(str); @@ -206,13 +209,17 @@ internal ModelParameters(IReadOnlyList> wordScoresP } } - [BestFriend] - internal ModelParameters GetLdaDetails(int iinfo) + /// + /// Method to provide details about the topics discovered by LightLDA + /// + /// index of column options pair + /// + public ModelParameters GetLdaDetails(int columnIndex) { - Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length); + Contracts.Assert(0 <= columnIndex && columnIndex < _ldas.Length); - var ldaState = _ldas[iinfo]; - var mapping = _columnMappings[iinfo]; + var ldaState = _ldas[columnIndex]; + var mapping = _columnMappings[columnIndex]; return ldaState.GetLdaSummary(mapping); } @@ -630,7 +637,7 @@ private static VersionInfo GetVersionInfo() private readonly List>> _columnMappings; private const string RegistrationName = "LightLda"; - private const string WordTopicModelFilename = "word_topic_summary.txt"; + private const string WordTopicModelFilename = "word_topic_summary-{0}.txt"; internal const string Summary = "The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation."; internal const string UserName = "Latent Dirichlet Allocation Transform"; internal const string ShortName = "LightLda"; @@ -760,9 +767,50 @@ private protected override void SaveModel(ModelSaveContext ctx) for (int i = 0; i < _ldas.Length; i++) { _ldas[i].Save(ctx); + + if(_columns[i].OutputTopicWordSummary) + SaveTopicWordSummary(ctx, i); } } + private void SaveTopicWordSummary(ModelSaveContext ctx, int i) + { + var summary = GetLdaDetails(i); + + var columnName = _columns[i].Name; + + ctx.SaveTextStream(String.Format(WordTopicModelFilename, columnName), writer => + { + if (summary.WordScoresPerTopic != null) + { + int topId = 0; + foreach (var wordScores in summary.WordScoresPerTopic) + { + foreach (var wordScore in wordScores) + { + writer.WriteLine($"Topic[{topId}]: {wordScore.Word}\t{wordScore.Score}"); + } + + topId++; + } + } + + if (summary.ItemScoresPerTopic != null) + { + int topId = 0; + foreach (var itemScores in summary.ItemScoresPerTopic) + { + foreach (var itemScore in itemScores) + { + writer.WriteLine($"Topic[{topId}]: {itemScore.Item}\t{itemScore.Score}"); + } + + topId++; + } + } + }); + } + private static int GetFrequency(double value) { int result = (int)value; @@ -994,6 +1042,7 @@ internal static class Defaults public const int NumberOfSummaryTermsPerTopic = 10; public const int NumberOfBurninIterations = 10; public const bool ResetRandomGenerator = false; + public const bool OutputTopicWordSummary = false; } private readonly IHost _host; @@ -1014,6 +1063,7 @@ internal static class Defaults /// Compute log likelihood over local dataset on this iteration interval. /// The number of burn-in iterations. /// Reset the random number generator for each document. + /// Whether to output the topic-word summary in text format when saving the model to disk. internal LatentDirichletAllocationEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int numberOfTopics = Defaults.NumberOfTopics, @@ -1026,10 +1076,11 @@ internal LatentDirichletAllocationEstimator(IHostEnvironment env, int numberOfSummaryTermsPerTopic = Defaults.NumberOfSummaryTermsPerTopic, int likelihoodInterval = Defaults.LikelihoodInterval, int numberOfBurninIterations = Defaults.NumberOfBurninIterations, - bool resetRandomGenerator = Defaults.ResetRandomGenerator) + bool resetRandomGenerator = Defaults.ResetRandomGenerator, + bool outputTopicWordSummary = Defaults.OutputTopicWordSummary) : this(env, new[] { new ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName, numberOfTopics, alphaSum, beta, samplingStepCount, maximumNumberOfIterations, likelihoodInterval, numberOfThreads, maximumTokenCountPerDocument, - numberOfSummaryTermsPerTopic, numberOfBurninIterations, resetRandomGenerator) }) + numberOfSummaryTermsPerTopic, numberOfBurninIterations, resetRandomGenerator, outputTopicWordSummary) }) { } /// @@ -1100,6 +1151,10 @@ internal sealed class ColumnOptions /// Reset the random number generator for each document. /// public readonly bool ResetRandomGenerator; + /// + /// Whether to output the topic-word summary in text format when saving the model to disk. + /// + public readonly bool OutputTopicWordSummary; /// /// Describes how the transformer handles one column pair. @@ -1117,6 +1172,7 @@ internal sealed class ColumnOptions /// The number of words to summarize the topic. /// The number of burn-in iterations. /// Reset the random number generator for each document. + /// Whether to output the topic-word summary in text format when saving the model to disk. public ColumnOptions(string name, string inputColumnName = null, int numberOfTopics = LatentDirichletAllocationEstimator.Defaults.NumberOfTopics, @@ -1129,7 +1185,8 @@ public ColumnOptions(string name, int maximumTokenCountPerDocument = LatentDirichletAllocationEstimator.Defaults.MaximumTokenCountPerDocument, int numberOfSummaryTermsPerTopic = LatentDirichletAllocationEstimator.Defaults.NumberOfSummaryTermsPerTopic, int numberOfBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumberOfBurninIterations, - bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator, + bool outputTopicWordSummary = LatentDirichletAllocationEstimator.Defaults.OutputTopicWordSummary) { Contracts.CheckValue(name, nameof(name)); Contracts.CheckValueOrNull(inputColumnName); @@ -1155,6 +1212,7 @@ public ColumnOptions(string name, NumberOfSummaryTermsPerTopic = numberOfSummaryTermsPerTopic; NumberOfBurninIterations = numberOfBurninIterations; ResetRandomGenerator = resetRandomGenerator; + OutputTopicWordSummary = outputTopicWordSummary; } internal ColumnOptions(LatentDirichletAllocationTransformer.Column item, LatentDirichletAllocationTransformer.Options options) : @@ -1170,7 +1228,8 @@ internal ColumnOptions(LatentDirichletAllocationTransformer.Column item, LatentD item.NumMaxDocToken ?? options.NumMaxDocToken, item.NumSummaryTermPerTopic ?? options.NumSummaryTermPerTopic, item.NumBurninIterations ?? options.NumBurninIterations, - item.ResetRandomGenerator ?? options.ResetRandomGenerator) + item.ResetRandomGenerator ?? options.ResetRandomGenerator, + item.OutputTopicWordSummary ?? options.OutputTopicWordSummary) { } diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 3780432a02..441cf257a6 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -20900,6 +20900,18 @@ "IsNullable": true, "Default": null }, + { + "Name": "OutputTopicWordSummary", + "Type": "Bool", + "Desc": "Whether to output the topic-word summary in text format when saving the model to disk", + "Aliases": [ + "summary" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, { "Name": "Name", "Type": "String", @@ -21112,7 +21124,7 @@ { "Name": "OutputTopicWordSummary", "Type": "Bool", - "Desc": "Whether to output the topic-word summary in text format", + "Desc": "Whether to output the topic-word summary in text format when saving the model to disk", "Aliases": [ "summary" ], diff --git a/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs b/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs index 31fa7277d1..eb4361f953 100644 --- a/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs @@ -182,15 +182,24 @@ public void InspectLdaModelParameters() // Define the pipeline. var pipeline = mlContext.Transforms.Text.ProduceWordBags("SentimentBag", "SentimentText") - .Append(mlContext.Transforms.Text.LatentDirichletAllocation("Features", "SentimentBag", numberOfTopics: numTopics, maximumNumberOfIterations: 10)); + .Append(mlContext.Transforms.Text.LatentDirichletAllocation("Features", "SentimentBag", + numberOfTopics: numTopics, maximumNumberOfIterations: 10)); // Fit the pipeline. var model = pipeline.Fit(data); // Get the trained LDA model. - // TODO #2197: Get the topics and summaries from the model. var ldaTransform = model.LastTransformer; + // Get the topics and summaries from the model. + var ldaDetails = ldaTransform.GetLdaDetails(0); + Assert.False(ldaDetails.ItemScoresPerTopic == null && ldaDetails.WordScoresPerTopic == null); + if(ldaDetails.ItemScoresPerTopic != null) + Assert.Equal(numTopics, ldaDetails.ItemScoresPerTopic.Count); + if (ldaDetails.WordScoresPerTopic != null) + Assert.Equal(numTopics, ldaDetails.WordScoresPerTopic.Count); + + // Transform the data. var transformedData = model.Transform(data); diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index e32eeb8297..277d7b0651 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -5,6 +5,8 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Compression; +using System.Linq; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; @@ -1047,6 +1049,18 @@ public void SavePipeLda() "loader=Text{col=F1V:Num:0-2}", "xf=Lda{col={name=Result src=F1V numtopic=3 alphasum=3 ns=3 reset=+ t=1} summary=+}", }, forceDense: true); + + // topic summary text file saved inside the model.zip file. + string name = TestName + ".zip"; + string modelPath = GetOutputPath("SavePipe", name); + using (var file = Env.OpenInputFile(modelPath)) + using (var strm = file.OpenReadStream()) + using (var zip = new ZipArchive(strm, ZipArchiveMode.Read)) + { + var entry = zip.Entries.First(source => source.Name == "word_topic_summary-Result.txt"); + Assert.True(entry != null); + } + Done(); }