diff --git a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs index 07d534142f..1714265674 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; +using static Microsoft.ML.Runtime.Data.MetricKinds; namespace Microsoft.ML.Models { @@ -35,7 +36,7 @@ internal static List FromMetrics(IHostEnvironment e var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); int index = 0; - foreach(var metric in metricsEnumerable) + foreach (var metric in metricsEnumerable) { if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) @@ -57,6 +58,7 @@ internal static List FromMetrics(IHostEnvironment e Entropy = metric.Entropy, F1Score = metric.F1Score, Auprc = metric.Auprc, + RowTag = metric.RowTag, ConfusionMatrix = confusionMatrices.Current, }); @@ -162,6 +164,12 @@ internal static List FromMetrics(IHostEnvironment e /// public ConfusionMatrix ConfusionMatrix { get; private set; } + /// + /// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation. + /// For non-CV scenarios, this is equal to null + /// + public string RowTag { get; private set; } + /// /// This class contains the public fields necessary to deserialize from IDataView. /// @@ -200,6 +208,9 @@ private sealed class SerializationClass [ColumnName(BinaryClassifierEvaluator.AuPrc)] public Double Auprc; + + [ColumnName(ColumnNames.FoldIndex)] + public string RowTag; #pragma warning restore 649 // never assigned } } diff --git a/src/Microsoft.ML/Models/ClassificationMetrics.cs b/src/Microsoft.ML/Models/ClassificationMetrics.cs index 3261a0cc60..6c1c139278 100644 --- a/src/Microsoft.ML/Models/ClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/ClassificationMetrics.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using System.Collections.Generic; +using static Microsoft.ML.Runtime.Data.MetricKinds; namespace Microsoft.ML.Models { @@ -51,7 +52,8 @@ internal static List FromMetrics(IHostEnvironment env, ID LogLossReduction = metric.LogLossReduction, TopKAccuracy = metric.TopKAccuracy, PerClassLogLoss = metric.PerClassLogLoss, - ConfusionMatrix = confusionMatrices.Current + ConfusionMatrix = confusionMatrices.Current, + RowTag = metric.RowTag, }); } @@ -127,6 +129,12 @@ internal static List FromMetrics(IHostEnvironment env, ID /// public double[] PerClassLogLoss { get; private set; } + /// + /// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation. + /// For non-CV scenarios, this is equal to null + /// + public string RowTag { get; private set; } + /// /// Gets the confusion matrix, or error matrix, of the classifier. /// @@ -155,6 +163,9 @@ private sealed class SerializationClass [ColumnName(MultiClassClassifierEvaluator.PerClassLogLoss)] public double[] PerClassLogLoss; + + [ColumnName(ColumnNames.FoldIndex)] + public string RowTag; #pragma warning restore 649 // never assigned } } diff --git a/src/Microsoft.ML/Models/ClusterMetrics.cs b/src/Microsoft.ML/Models/ClusterMetrics.cs index aec0264ff0..83770389cd 100644 --- a/src/Microsoft.ML/Models/ClusterMetrics.cs +++ b/src/Microsoft.ML/Models/ClusterMetrics.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; +using static Microsoft.ML.Runtime.Data.MetricKinds; namespace Microsoft.ML.Models { @@ -38,6 +39,7 @@ internal static List FromOverallMetrics(IHostEnvironment env, ID AvgMinScore = metric.AvgMinScore, Nmi = metric.Nmi, Dbi = metric.Dbi, + RowTag = metric.RowTag, }); } @@ -73,6 +75,12 @@ internal static List FromOverallMetrics(IHostEnvironment env, ID /// public double AvgMinScore { get; private set; } + /// + /// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation. + /// For non-CV scenarios, this is equal to null + /// + public string RowTag { get; private set; } + /// /// This class contains the public fields necessary to deserialize from IDataView. /// @@ -88,6 +96,8 @@ private sealed class SerializationClass [ColumnName(Runtime.Data.ClusteringEvaluator.AvgMinScore)] public Double AvgMinScore; + [ColumnName(ColumnNames.FoldIndex)] + public string RowTag; #pragma warning restore 649 // never assigned } } diff --git a/src/Microsoft.ML/Models/RegressionMetrics.cs b/src/Microsoft.ML/Models/RegressionMetrics.cs index 64500f5e6c..bf5ba625f6 100644 --- a/src/Microsoft.ML/Models/RegressionMetrics.cs +++ b/src/Microsoft.ML/Models/RegressionMetrics.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; +using static Microsoft.ML.Runtime.Data.MetricKinds; namespace Microsoft.ML.Models { @@ -40,6 +41,7 @@ internal static List FromOverallMetrics(IHostEnvironment env, Rms = metric.Rms, LossFn = metric.LossFn, RSquared = metric.RSquared, + RowTag = metric.RowTag, }); } @@ -90,6 +92,12 @@ internal static List FromOverallMetrics(IHostEnvironment env, /// public double RSquared { get; private set; } + /// + /// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation. + /// For non-CV scenarios, this is equal to null + /// + public string RowTag { get; private set; } + /// /// This class contains the public fields necessary to deserialize from IDataView. /// @@ -110,6 +118,9 @@ private sealed class SerializationClass [ColumnName(Runtime.Data.RegressionEvaluator.RSquared)] public Double RSquared; + + [ColumnName(ColumnNames.FoldIndex)] + public string RowTag; #pragma warning restore 649 // never assigned } }