Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ private protected override TreeLearner ConstructTreeLearner(IChannel ch)
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.FilterZeroLambdas,
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, BsrMaxTreeOutput(), ParallelTraining,
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias);
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit,
FastTreeTrainerOptions.Bias, Host);
}

private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,7 @@ private ValueMapper<VBuffer<T1>, VBuffer<T2>> GetCopier<T1, T2>(DataViewType ite

private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxBins, IParallelTraining parallelTraining)
{
Host.CheckAlive();
Host.AssertValue(examples);
Host.Assert(examples.Schema.Feature.HasValue);

Expand Down Expand Up @@ -1414,6 +1415,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
pch.SetHeader(new ProgressHeader("features"), e => e.SetProgress(0, iFeature, features.Length));
while (cursor.MoveNext())
{
Host.CheckAlive();
iFeature = cursor.SlotIndex;
if (!localConstructBinFeatures[iFeature])
continue;
Expand Down Expand Up @@ -1489,6 +1491,8 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
int catRangeIndex = 0;
for (iFeature = 0; iFeature < NumFeatures;)
{
Host.CheckAlive();

if (catRangeIndex < CategoricalFeatureIndices.Length &&
CategoricalFeatureIndices[catRangeIndex] == iFeature)
{
Expand Down Expand Up @@ -1565,6 +1569,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
{
for (int i = 0; i < NumFeatures; i++)
{
Host.CheckAlive();
GetFeatureValues(cursor, i, getter, ref temp, ref doubleTemp, copier);
double[] upperBounds = BinUpperBounds[i];
Host.AssertValue(upperBounds);
Expand Down Expand Up @@ -1919,6 +1924,7 @@ private void InitializeBins(int maxBins, IParallelTraining parallelTraining)
List<int> trivialFeatures = new List<int>();
for (iFeature = 0; iFeature < NumFeatures; iFeature++)
{
Host.CheckAlive();
if (!localConstructBinFeatures[iFeature])
continue;
// The following strange call will actually sparsify.
Expand Down Expand Up @@ -2230,6 +2236,7 @@ private IEnumerable<FeatureFlockBase> CreateFlocksCore(IChannel ch, IProgressCha

for (; iFeature < featureLim; ++iFeature)
{
Host.CheckAlive();
double[] bup = BinUpperBounds[iFeature];
Contracts.Assert(Utils.Size(bup) > 0);
if (bup.Length == 1)
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.FastTree/RandomForest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ private protected override TreeLearner ConstructTreeLearner(IChannel ch)
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit,
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, _quantileEnabled, FastTreeTrainerOptions.NumberOfQuantileSamples, ParallelTraining,
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias);
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit,
FastTreeTrainerOptions.Bias, Host);
}

internal abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ internal class RandomForestLeastSquaresTreeLearner : LeastSquaresRegressionTreeL
public RandomForestLeastSquaresTreeLearner(Dataset trainData, int numLeaves, int minDocsInLeaf, Double entropyCoefficient, Double featureFirstUsePenalty,
Double featureReusePenalty, Double softmaxTemperature, int histogramPoolSize, int randomSeed, Double splitFraction, bool allowEmptyTrees,
Double gainConfidenceLevel, int maxCategoricalGroupsPerNode, int maxCategoricalSplitPointsPerNode, bool quantileEnabled, int quantileSampleCount, IParallelTraining parallelTraining,
double minDocsPercentageForCategoricalSplit, Bundle bundling, int minDocsForCategoricalSplit, double bias)
double minDocsPercentageForCategoricalSplit, Bundle bundling, int minDocsForCategoricalSplit, double bias, IHost host)
: base(trainData, numLeaves, minDocsInLeaf, entropyCoefficient, featureFirstUsePenalty, featureReusePenalty, softmaxTemperature, histogramPoolSize,
randomSeed, splitFraction, false, allowEmptyTrees, gainConfidenceLevel, maxCategoricalGroupsPerNode, maxCategoricalSplitPointsPerNode, -1, parallelTraining,
minDocsPercentageForCategoricalSplit, bundling, minDocsForCategoricalSplit, bias)
minDocsPercentageForCategoricalSplit, bundling, minDocsForCategoricalSplit, bias, host)
{
_quantileSampleCount = quantileSampleCount;
_quantileEnabled = quantileEnabled;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis
protected readonly bool FilterZeros;
protected readonly double BsrMaxTreeOutput;

protected readonly IHost Host;

// size of reserved memory
private readonly long _sizeOfReservedMemory;

Expand Down Expand Up @@ -114,12 +116,13 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis
/// <param name="bundling"></param>
/// <param name="minDocsForCategoricalSplit"></param>
/// <param name="bias"></param>
/// <param name="host">Host</param>
public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int minDocsInLeaf, double entropyCoefficient,
double featureFirstUsePenalty, double featureReusePenalty, double softmaxTemperature, int histogramPoolSize,
int randomSeed, double splitFraction, bool filterZeros, bool allowEmptyTrees, double gainConfidenceLevel,
int maxCategoricalGroupsPerNode, int maxCategoricalSplitPointPerNode,
double bsrMaxTreeOutput, IParallelTraining parallelTraining, double minDocsPercentageForCategoricalSplit,
Bundle bundling, int minDocsForCategoricalSplit, double bias)
Bundle bundling, int minDocsForCategoricalSplit, double bias, IHost host)
: base(trainData, numLeaves)
{
MinDocsInLeaf = minDocsInLeaf;
Expand All @@ -135,6 +138,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m
MinDocsForCategoricalSplit = minDocsForCategoricalSplit;
Bundling = bundling;
Bias = bias;
Host = host;

_calculateLeafSplitCandidates = ThreadTaskManager.MakeTask(
FindBestThresholdForFlockThreadWorker, TrainData.NumFlocks);
Expand All @@ -148,6 +152,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m
histogramPool[i] = new SufficientStatsBase[TrainData.NumFlocks];
for (int j = 0; j < TrainData.NumFlocks; j++)
{
Host.CheckAlive();
var ss = histogramPool[i][j] = TrainData.Flocks[j].CreateSufficientStats(HasWeights);
_sizeOfReservedMemory += ss.SizeInBytes();
}
Expand Down Expand Up @@ -498,6 +503,7 @@ protected virtual void SetBestFeatureForLeaf(LeafSplitCandidates leafSplitCandid
/// </summary>
private void FindBestThresholdForFlockThreadWorker(int flock)
{
Host.CheckAlive();
int featureMin = TrainData.FlockToFirstFeature(flock);
int featureLim = featureMin + TrainData.Flocks[flock].Count;
// Check if any feature is active.
Expand Down Expand Up @@ -649,6 +655,8 @@ public double CalculateSplittedLeafOutput(int count, double sumTargets, double s
protected virtual void FindBestThresholdFromHistogram(SufficientStatsBase histogram,
LeafSplitCandidates leafSplitCandidates, int flock)
{
Host.CheckAlive();

// Cache histograms for the parallel interface.
int featureMin = TrainData.FlockToFirstFeature(flock);
int featureLim = featureMin + TrainData.Flocks[flock].Count;
Expand Down