diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 74957f9aea..62809c333d 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -265,7 +265,15 @@ public static extern unsafe int DecisionForestClassificationCompute( [BestFriend] private bool IsDispatchingToOneDalEnabled() { - return OneDalUtils.IsDispatchingEnabled(); + try + { + return OneDalUtils.IsDispatchingEnabled(); + } + catch (Exception) + { + // Bail to default implementation upon encountering any situation where dispatch failed + return false; + } } [BestFriend] diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 4598b13aae..f1969f2cb2 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -398,7 +398,15 @@ public static extern unsafe int DecisionForestRegressionCompute( [BestFriend] private bool IsDispatchingToOneDalEnabled() { - return OneDalUtils.IsDispatchingEnabled(); + try + { + return OneDalUtils.IsDispatchingEnabled(); + } + catch (Exception) + { + // fall back to original implementation for any circumstance that prevents dispatching + return false; + } } [BestFriend] diff --git a/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs b/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs index 57d32ff80b..6f4f721121 100644 --- a/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs +++ b/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs @@ -14,9 +14,9 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.OneDal; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers; -using Microsoft.ML.OneDal; [assembly: LoadableClass(OlsTrainer.Summary, typeof(OlsTrainer), typeof(OlsTrainer.Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -409,7 +409,15 @@ private void ComputeMklRegression(IChannel ch, FloatLabelCursor.Factory cursorFa [BestFriend] private bool IsDispatchingToOneDalEnabled() { - return OneDalUtils.IsDispatchingEnabled(); + try + { + return OneDalUtils.IsDispatchingEnabled(); + } + catch (Exception) + { + // Bail to default implementation upon any situation that prevents dispatching + return false; + } } private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) diff --git a/src/Microsoft.ML.OneDal/OneDalUtils.cs b/src/Microsoft.ML.OneDal/OneDalUtils.cs index c9d347f56b..061c4a3a88 100644 --- a/src/Microsoft.ML.OneDal/OneDalUtils.cs +++ b/src/Microsoft.ML.OneDal/OneDalUtils.cs @@ -3,8 +3,8 @@ // See the LICENSE file in the project root for more information. using System; -using System.IO; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Runtime.InteropServices; using Microsoft.ML.Internal.Utilities; @@ -17,14 +17,6 @@ namespace Microsoft.ML.OneDal internal static class OneDalUtils { -#if false - [BestFriend] - internal static bool IsDispatchingEnabled() - { - return Environment.GetEnvironmentVariable("MLNET_BACKEND") == "ONEDAL" && - System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture == System.Runtime.InteropServices.Architecture.X64; - } -#else [BestFriend] internal static bool IsDispatchingEnabled() { @@ -47,7 +39,6 @@ internal static bool IsDispatchingEnabled() } return false; } -#endif [BestFriend] internal static long GetTrainData(IChannel channel, FloatLabelCursor.Factory cursorFactory, ref List featuresList, ref List labelsList, int numberOfFeatures)