diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index f211a87cb9..ad4cb08b05 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -68,7 +68,8 @@ private protected override TreeLearner ConstructTreeLearner(IChannel ch) FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.FilterZeroLambdas, FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode, FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, BsrMaxTreeOutput(), ParallelTraining, - FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias); + FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, + FastTreeTrainerOptions.Bias, Host); } private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 895e79c466..cdd6084a5b 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1334,6 +1334,7 @@ private ValueMapper, VBuffer> GetCopier(DataViewType ite private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxBins, IParallelTraining parallelTraining) { + Host.CheckAlive(); Host.AssertValue(examples); Host.Assert(examples.Schema.Feature.HasValue); @@ -1414,6 +1415,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB pch.SetHeader(new ProgressHeader("features"), e => e.SetProgress(0, iFeature, features.Length)); while (cursor.MoveNext()) { + Host.CheckAlive(); iFeature = cursor.SlotIndex; if (!localConstructBinFeatures[iFeature]) continue; @@ -1489,6 +1491,8 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB int catRangeIndex = 0; for (iFeature = 0; iFeature < NumFeatures;) { + Host.CheckAlive(); + if (catRangeIndex < CategoricalFeatureIndices.Length && CategoricalFeatureIndices[catRangeIndex] == iFeature) { @@ -1565,6 +1569,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB { for (int i = 0; i < NumFeatures; i++) { + Host.CheckAlive(); GetFeatureValues(cursor, i, getter, ref temp, ref doubleTemp, copier); double[] upperBounds = BinUpperBounds[i]; Host.AssertValue(upperBounds); @@ -1919,6 +1924,7 @@ private void InitializeBins(int maxBins, IParallelTraining parallelTraining) List trivialFeatures = new List(); for (iFeature = 0; iFeature < NumFeatures; iFeature++) { + Host.CheckAlive(); if (!localConstructBinFeatures[iFeature]) continue; // The following strange call will actually sparsify. @@ -2230,6 +2236,7 @@ private IEnumerable CreateFlocksCore(IChannel ch, IProgressCha for (; iFeature < featureLim; ++iFeature) { + Host.CheckAlive(); double[] bup = BinUpperBounds[iFeature]; Contracts.Assert(Utils.Size(bup) > 0); if (bup.Length == 1) diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index d4ddc3b1a9..029f0bf5f9 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -68,7 +68,8 @@ private protected override TreeLearner ConstructTreeLearner(IChannel ch) FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode, FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, _quantileEnabled, FastTreeTrainerOptions.NumberOfQuantileSamples, ParallelTraining, - FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias); + FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, + FastTreeTrainerOptions.Bias, Host); } internal abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs index 81c4063729..9c12888f23 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs @@ -15,10 +15,10 @@ internal class RandomForestLeastSquaresTreeLearner : LeastSquaresRegressionTreeL public RandomForestLeastSquaresTreeLearner(Dataset trainData, int numLeaves, int minDocsInLeaf, Double entropyCoefficient, Double featureFirstUsePenalty, Double featureReusePenalty, Double softmaxTemperature, int histogramPoolSize, int randomSeed, Double splitFraction, bool allowEmptyTrees, Double gainConfidenceLevel, int maxCategoricalGroupsPerNode, int maxCategoricalSplitPointsPerNode, bool quantileEnabled, int quantileSampleCount, IParallelTraining parallelTraining, - double minDocsPercentageForCategoricalSplit, Bundle bundling, int minDocsForCategoricalSplit, double bias) + double minDocsPercentageForCategoricalSplit, Bundle bundling, int minDocsForCategoricalSplit, double bias, IHost host) : base(trainData, numLeaves, minDocsInLeaf, entropyCoefficient, featureFirstUsePenalty, featureReusePenalty, softmaxTemperature, histogramPoolSize, randomSeed, splitFraction, false, allowEmptyTrees, gainConfidenceLevel, maxCategoricalGroupsPerNode, maxCategoricalSplitPointsPerNode, -1, parallelTraining, - minDocsPercentageForCategoricalSplit, bundling, minDocsForCategoricalSplit, bias) + minDocsPercentageForCategoricalSplit, bundling, minDocsForCategoricalSplit, bias, host) { _quantileSampleCount = quantileSampleCount; _quantileEnabled = quantileEnabled; diff --git a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs index cf382ee46f..7853ee88bc 100644 --- a/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs +++ b/src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs @@ -69,6 +69,8 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis protected readonly bool FilterZeros; protected readonly double BsrMaxTreeOutput; + protected readonly IHost Host; + // size of reserved memory private readonly long _sizeOfReservedMemory; @@ -114,12 +116,13 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis /// /// /// + /// Host public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int minDocsInLeaf, double entropyCoefficient, double featureFirstUsePenalty, double featureReusePenalty, double softmaxTemperature, int histogramPoolSize, int randomSeed, double splitFraction, bool filterZeros, bool allowEmptyTrees, double gainConfidenceLevel, int maxCategoricalGroupsPerNode, int maxCategoricalSplitPointPerNode, double bsrMaxTreeOutput, IParallelTraining parallelTraining, double minDocsPercentageForCategoricalSplit, - Bundle bundling, int minDocsForCategoricalSplit, double bias) + Bundle bundling, int minDocsForCategoricalSplit, double bias, IHost host) : base(trainData, numLeaves) { MinDocsInLeaf = minDocsInLeaf; @@ -135,6 +138,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m MinDocsForCategoricalSplit = minDocsForCategoricalSplit; Bundling = bundling; Bias = bias; + Host = host; _calculateLeafSplitCandidates = ThreadTaskManager.MakeTask( FindBestThresholdForFlockThreadWorker, TrainData.NumFlocks); @@ -148,6 +152,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m histogramPool[i] = new SufficientStatsBase[TrainData.NumFlocks]; for (int j = 0; j < TrainData.NumFlocks; j++) { + Host.CheckAlive(); var ss = histogramPool[i][j] = TrainData.Flocks[j].CreateSufficientStats(HasWeights); _sizeOfReservedMemory += ss.SizeInBytes(); } @@ -498,6 +503,7 @@ protected virtual void SetBestFeatureForLeaf(LeafSplitCandidates leafSplitCandid /// private void FindBestThresholdForFlockThreadWorker(int flock) { + Host.CheckAlive(); int featureMin = TrainData.FlockToFirstFeature(flock); int featureLim = featureMin + TrainData.Flocks[flock].Count; // Check if any feature is active. @@ -649,6 +655,8 @@ public double CalculateSplittedLeafOutput(int count, double sumTargets, double s protected virtual void FindBestThresholdFromHistogram(SufficientStatsBase histogram, LeafSplitCandidates leafSplitCandidates, int flock) { + Host.CheckAlive(); + // Cache histograms for the parallel interface. int featureMin = TrainData.FlockToFirstFeature(flock); int featureLim = featureMin + TrainData.Flocks[flock].Count;