diff --git a/src/Microsoft.ML.AutoML/ColumnInference/ColumnTypeInference.cs b/src/Microsoft.ML.AutoML/ColumnInference/ColumnTypeInference.cs index 6d00054dc5..8401a4578c 100644 --- a/src/Microsoft.ML.AutoML/ColumnInference/ColumnTypeInference.cs +++ b/src/Microsoft.ML.AutoML/ColumnInference/ColumnTypeInference.cs @@ -82,7 +82,7 @@ public bool HasAllBooleanValues() bool value; // (note: Conversions.Instance.TryParse parses an empty string as a Boolean) return !string.IsNullOrEmpty(x.ToString()) && - Conversions.Instance.TryParse(in x, out value); + Conversions.DefaultInstance.TryParse(in x, out value); })) { return true; @@ -164,7 +164,7 @@ public void Apply(IntermediateColumn[] columns) col.SuggestedType = BooleanDataViewType.Instance; bool first; - col.HasHeader = !Conversions.Instance.TryParse(in col.RawData[0], out first); + col.HasHeader = !Conversions.DefaultInstance.TryParse(in col.RawData[0], out first); } } } @@ -179,7 +179,7 @@ public void Apply(IntermediateColumn[] columns) .All(x => { float value; - return Conversions.Instance.TryParse(in x, out value); + return Conversions.DefaultInstance.TryParse(in x, out value); }) ) { diff --git a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs index 22bd8ea82e..58ee99f5ef 100644 --- a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs +++ b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs @@ -12,26 +12,24 @@ namespace Microsoft.ML.Internal.Utilities [BestFriend] internal static class DoubleParser { + [BestFriend] + [Flags] + internal enum OptionFlags : uint + { + Default = 0x00, + + // If this flag is set, then a "," will be used as Decimal Marker + // (i.e., the punctuation mark that separates the integer part of + // a number and its decimal part). If this isn't set, then + // default behavior is to use "." as decimal marker. + UseCommaAsDecimalMarker = 0x01, + } + private const ulong TopBit = 0x8000000000000000UL; private const ulong TopTwoBits = 0xC000000000000000UL; private const ulong TopThreeBits = 0xE000000000000000UL; private const char InfinitySymbol = '\u221E'; - // Note for future development: DoubleParser is a static class and DecimalMarker is a - // static variable, which means only one instance of these can exist at once. As such, - // the value of DecimalMarker cannot vary when datasets with differing decimal markers - // are loaded together at once, which would result in not being able to accurately read - // the dataset with the differing decimal marker. Although this edge case where we attempt - // to load in datasets with different decimal markers at once is unlikely to occur, we - // should still be aware of this and plan to fix it in the future. - - // The decimal marker that separates the integer part from the fractional part of a number - // written in decimal from can vary across different cultures as either '.' or ','. The - // default decimal marker in ML .NET is '.', however through this static char variable, - // we allow users to specify the decimal marker used in their datasets as ',' as well. - [BestFriend] - internal static char DecimalMarker = '.'; - // REVIEW: casting ulong to Double doesn't always do the right thing, for example // with 0x84595161401484A0UL. Hence the gymnastics several places in this code. Note that // long to Double does work. The work around is: @@ -85,9 +83,9 @@ public enum Result /// /// This produces zero for an empty string. /// - public static bool TryParse(ReadOnlySpan span, out Single value) + public static bool TryParse(ReadOnlySpan span, out Single value, OptionFlags flags = OptionFlags.Default) { - var res = Parse(span, out value); + var res = Parse(span, out value, flags); Contracts.Assert(res != Result.Empty || value == 0); return res <= Result.Empty; } @@ -95,14 +93,14 @@ public static bool TryParse(ReadOnlySpan span, out Single value) /// /// This produces zero for an empty string. /// - public static bool TryParse(ReadOnlySpan span, out Double value) + public static bool TryParse(ReadOnlySpan span, out Double value, OptionFlags flags = OptionFlags.Default) { - var res = Parse(span, out value); + var res = Parse(span, out value, flags); Contracts.Assert(res != Result.Empty || value == 0); return res <= Result.Empty; } - public static Result Parse(ReadOnlySpan span, out Single value) + public static Result Parse(ReadOnlySpan span, out Single value, OptionFlags flags = OptionFlags.Default) { int ich = 0; for (; ; ich++) @@ -133,7 +131,7 @@ public static Result Parse(ReadOnlySpan span, out Single value) } int ichEnd; - if (!DoubleParser.TryParse(span.Slice(ich, span.Length - ich), out value, out ichEnd)) + if (!DoubleParser.TryParse(span.Slice(ich, span.Length - ich), out value, out ichEnd, flags)) { value = default(Single); return Result.Error; @@ -150,7 +148,7 @@ public static Result Parse(ReadOnlySpan span, out Single value) return Result.Good; } - public static Result Parse(ReadOnlySpan span, out Double value) + public static Result Parse(ReadOnlySpan span, out Double value, OptionFlags flags = OptionFlags.Default) { int ich = 0; for (; ; ich++) @@ -181,7 +179,7 @@ public static Result Parse(ReadOnlySpan span, out Double value) } int ichEnd; - if (!DoubleParser.TryParse(span.Slice(ich, span.Length - ich), out value, out ichEnd)) + if (!DoubleParser.TryParse(span.Slice(ich, span.Length - ich), out value, out ichEnd, flags)) { value = default(Double); return Result.Error; @@ -198,14 +196,14 @@ public static Result Parse(ReadOnlySpan span, out Double value) return Result.Good; } - public static bool TryParse(ReadOnlySpan span, out Single value, out int ichEnd) + public static bool TryParse(ReadOnlySpan span, out Single value, out int ichEnd, OptionFlags flags = OptionFlags.Default) { bool neg = false; ulong num = 0; long exp = 0; ichEnd = 0; - if (!TryParseCore(span, ref ichEnd, ref neg, ref num, ref exp)) + if (!TryParseCore(span, ref ichEnd, ref neg, ref num, ref exp, flags)) return TryParseSpecial(span, ref ichEnd, out value); if (num == 0) @@ -287,14 +285,14 @@ public static bool TryParse(ReadOnlySpan span, out Single value, out int i return true; } - public static bool TryParse(ReadOnlySpan span, out Double value, out int ichEnd) + public static bool TryParse(ReadOnlySpan span, out Double value, out int ichEnd, OptionFlags flags = OptionFlags.Default) { bool neg = false; ulong num = 0; long exp = 0; ichEnd = 0; - if (!TryParseCore(span, ref ichEnd, ref neg, ref num, ref exp)) + if (!TryParseCore(span, ref ichEnd, ref neg, ref num, ref exp, flags)) return TryParseSpecial(span, ref ichEnd, out value); if (num == 0) @@ -535,13 +533,19 @@ private static bool TryParseSpecial(ReadOnlySpan span, ref int ich, out Si return false; } - private static bool TryParseCore(ReadOnlySpan span, ref int ich, ref bool neg, ref ulong num, ref long exp) + private static bool TryParseCore(ReadOnlySpan span, ref int ich, ref bool neg, ref ulong num, ref long exp, OptionFlags flags = OptionFlags.Default) { Contracts.Assert(0 <= ich & ich <= span.Length); Contracts.Assert(!neg); Contracts.Assert(num == 0); Contracts.Assert(exp == 0); + char decimalMarker; + if ((flags & OptionFlags.UseCommaAsDecimalMarker) != 0) + decimalMarker = ','; + else + decimalMarker = '.'; + if (ich >= span.Length) return false; @@ -570,11 +574,11 @@ private static bool TryParseCore(ReadOnlySpan span, ref int ich, ref bool break; case '.': - if (DecimalMarker != '.') // Decimal marker was not '.', but we encountered a '.', which must be an error. + if (decimalMarker != '.') // Decimal marker was not '.', but we encountered a '.', which must be an error. return false; // Since this was an error, return false, which will later make the caller to set NaN as the out value. goto LPoint; case ',': - if (DecimalMarker != ',') // Same logic as above. + if (decimalMarker != ',') // Same logic as above. return false; goto LPoint; @@ -614,12 +618,12 @@ private static bool TryParseCore(ReadOnlySpan span, ref int ich, ref bool } Contracts.Assert(i < span.Length); - if (span[i] != DecimalMarker) + if (span[i] != decimalMarker) goto LAfterDigits; LPoint: Contracts.Assert(i < span.Length); - Contracts.Assert(span[i] == DecimalMarker); + Contracts.Assert(span[i] == decimalMarker); // Get the digits after the decimal marker, which may be '.' or ',' for (; ; ) diff --git a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs index c68746037a..bd9485d92d 100644 --- a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs @@ -234,7 +234,7 @@ private static void ShowMetadataValue(IndentedTextWriter itw, DataViewSchema Contracts.Assert(!(type is VectorDataViewType)); Contracts.Assert(type.RawType == typeof(T)); - var conv = Conversions.Instance.GetStringConversion(type); + var conv = Conversions.DefaultInstance.GetStringConversion(type); var value = default(T); var sb = default(StringBuilder); @@ -272,7 +272,7 @@ private static void ShowMetadataValueVec(IndentedTextWriter itw, DataViewSche Contracts.AssertValue(type); Contracts.Assert(type.ItemType.RawType == typeof(T)); - var conv = Conversions.Instance.GetStringConversion(type.ItemType); + var conv = Conversions.DefaultInstance.GetStringConversion(type.ItemType); var value = default(VBuffer); schema[col].Annotations.GetValue(kind, ref value); diff --git a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs index 12cb66c79c..fb100e361f 100644 --- a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs @@ -79,7 +79,7 @@ public void Run() { using (var ch = _host.Start("Run")) { - var conv = Conversions.Instance; + var conv = Conversions.DefaultInstance; var comp = new SetOfKindsComparer(); var dstToSrcMap = new Dictionary, HashSet>(comp); var srcToDstMap = new Dictionary>(); @@ -143,7 +143,7 @@ private TypeNaInfo KindReport(IChannel ch, PrimitiveDataViewType type) ch.AssertValue(type); ch.Assert(type.IsStandardScalar()); - var conv = Conversions.Instance; + var conv = Conversions.DefaultInstance; InPredicate isNaDel; bool hasNaPred = conv.TryGetIsNAPredicate(type, out isNaDel); bool defaultIsNa = false; diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 82de43bb48..322856c22f 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -53,18 +53,26 @@ private static readonly FuncInstanceMethodInfo1 _delegatesStd; @@ -92,7 +100,7 @@ public static Conversions Instance // This has TryParseMapper delegates for parsing values from text. private readonly Dictionary _tryParseDelegates; - private Conversions() + private Conversions(DoubleParser.OptionFlags doubleParserOptionFlags = DoubleParser.OptionFlags.Default) { _delegatesStd = new Dictionary<(Type src, Type dst), Delegate>(); _delegatesAll = new Dictionary<(Type src, Type dst), Delegate>(); @@ -102,6 +110,7 @@ private Conversions() _hasZeroDelegates = new Dictionary(); _getNADelegates = new Dictionary(); _tryParseDelegates = new Dictionary(); + _doubleParserOptionFlags = doubleParserOptionFlags; // !!! WARNING !!!: Do NOT add any standard conversions without clearing from the IDV Type System // design committee. Any changes also require updating the IDV Type System Specification. @@ -1333,7 +1342,7 @@ private void TryParseSigned(long max, in TX text, out long? result) public bool TryParse(in TX src, out R4 dst) { var span = src.Span; - if (DoubleParser.TryParse(span, out dst)) + if (DoubleParser.TryParse(span, out dst, _doubleParserOptionFlags)) return true; dst = R4.NaN; return IsStdMissing(ref span); @@ -1346,7 +1355,7 @@ public bool TryParse(in TX src, out R4 dst) public bool TryParse(in TX src, out R8 dst) { var span = src.Span; - if (DoubleParser.TryParse(span, out dst)) + if (DoubleParser.TryParse(span, out dst, _doubleParserOptionFlags)) return true; dst = R8.NaN; return IsStdMissing(ref span); @@ -1630,7 +1639,7 @@ public void Convert(in TX span, ref UG value) public void Convert(in TX src, ref R4 value) { var span = src.Span; - if (DoubleParser.TryParse(span, out value)) + if (DoubleParser.TryParse(span, out value, _doubleParserOptionFlags)) return; // Unparsable is mapped to NA. value = R4.NaN; @@ -1638,7 +1647,7 @@ public void Convert(in TX src, ref R4 value) public void Convert(in TX src, ref R8 value) { var span = src.Span; - if (DoubleParser.TryParse(span, out value)) + if (DoubleParser.TryParse(span, out value, _doubleParserOptionFlags)) return; // Unparsable is mapped to NA. value = R8.NaN; diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index 0524d4c728..829118a430 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -1352,7 +1352,7 @@ public static ValueGetter> GetSingleValueGetter(DataView var floatGetter = cursor.GetGetter(cursor.Schema[i]); T v = default(T); ValueMapper conversion; - if (!Conversions.Instance.TryGetStringConversion(colType, out conversion)) + if (!Conversions.DefaultInstance.TryGetStringConversion(colType, out conversion)) { var error = $"Cannot display {colType}"; conversion = (in T src, ref StringBuilder builder) => @@ -1383,7 +1383,7 @@ public static ValueGetter> GetVectorFlatteningGetter(Dat var vbuf = default(VBuffer); const int previewValues = 100; ValueMapper conversion; - Conversions.Instance.TryGetStringConversion(colType, out conversion); + Conversions.DefaultInstance.TryGetStringConversion(colType, out conversion); StringBuilder dst = null; ValueGetter> getter = (ref ReadOnlyMemory value) => diff --git a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs index a34070d791..7c5cd1e28d 100644 --- a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs +++ b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs @@ -94,7 +94,7 @@ private static ValueGetter GetGetterAsCore(DataViewType typeSr var getter = row.GetGetter(row.Schema[col]); bool identity; - var conv = Conversions.Instance.GetStandardConversion(typeSrc, typeDst, out identity); + var conv = Conversions.DefaultInstance.GetStandardConversion(typeSrc, typeDst, out identity); if (identity) { Contracts.Assert(typeof(TSrc) == typeof(TDst)); @@ -134,7 +134,7 @@ private static ValueGetter GetGetterAsStringBuilderCore(Dat Contracts.Assert(typeof(TSrc) == typeSrc.RawType); var getter = row.GetGetter(row.Schema[col]); - var conv = Conversions.Instance.GetStringConversion(typeSrc); + var conv = Conversions.DefaultInstance.GetStringConversion(typeSrc); var src = default(TSrc); return @@ -260,7 +260,7 @@ private static ValueGetter> GetVecGetterAsCore(VectorD var getter = getterFact.GetGetter>(); bool identity; - var conv = Conversions.Instance.GetStandardConversion(typeSrc.ItemType, typeDst, out identity); + var conv = Conversions.DefaultInstance.GetStandardConversion(typeSrc.ItemType, typeDst, out identity); if (identity) { Contracts.Assert(typeof(TSrc) == typeof(TDst)); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 6bc58de054..2cd1e8e8cf 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -1631,7 +1631,6 @@ public BoundLoader(TextLoader loader, IMultiStreamSource files) public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null) { _host.CheckValueOrNull(rand); - DoubleParser.DecimalMarker = _loader._decimalMarker; var active = Utils.BuildArray(_loader._bindings.OutputSchema.Count, columnsNeeded); return Cursor.Create(_loader, _files, active); } @@ -1639,7 +1638,6 @@ public DataViewRowCursor GetRowCursor(IEnumerable columns public DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) { _host.CheckValueOrNull(rand); - DoubleParser.DecimalMarker = _loader._decimalMarker; var active = Utils.BuildArray(_loader._bindings.OutputSchema.Count, columnsNeeded); return Cursor.CreateSet(_loader, _files, active, n); } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index a6f8b73ba6..897c257e2a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Reflection; @@ -30,26 +31,40 @@ private static readonly FuncInstanceMethodInfo1> _getCreatorVecCoreMethodInfo = FuncInstanceMethodInfo1>.Create(target => target.GetCreatorVecCore); - private static volatile ValueCreatorCache _instance; - public static ValueCreatorCache Instance + private static volatile ValueCreatorCache _defaultInstance; + public static ValueCreatorCache DefaultInstance { get { - return _instance ?? - Interlocked.CompareExchange(ref _instance, new ValueCreatorCache(), null) ?? - _instance; + return _defaultInstance ?? + Interlocked.CompareExchange(ref _defaultInstance, new ValueCreatorCache(), null) ?? + _defaultInstance; } } + private static readonly ConcurrentDictionary _customInstances + = new ConcurrentDictionary(); + + public static ValueCreatorCache GetInstanceWithDoubleParserOptionFlags(DoubleParser.OptionFlags doubleParserOptionFlags) + { + if (!_customInstances.ContainsKey(doubleParserOptionFlags)) + return _customInstances.GetOrAdd(doubleParserOptionFlags, new ValueCreatorCache(doubleParserOptionFlags)); + + return _customInstances[doubleParserOptionFlags]; + } + private readonly Conversions _conv; // Indexed by DataKind.ToIndex() private readonly Func[] _creatorsOne; private readonly Func[] _creatorsVec; - private ValueCreatorCache() + private ValueCreatorCache(DoubleParser.OptionFlags doubleParserOptionFlags = DoubleParser.OptionFlags.Default) { - _conv = Conversions.Instance; + if (doubleParserOptionFlags == DoubleParser.OptionFlags.Default) + _conv = Conversions.DefaultInstance; + else + _conv = Conversions.CreateInstanceWithDoubleParserOptions(doubleParserOptionFlags); _creatorsOne = new Func[InternalDataKindExtensions.KindCount]; _creatorsVec = new Func[InternalDataKindExtensions.KindCount]; @@ -243,7 +258,7 @@ public PrimitivePipe(RowSet rows, PrimitiveDataViewType type, TryParseMapper _values = new VectorValue[Rows.Count]; for (int i = 0; i < _values.Length; i++) _values[i] = new VectorValue(this); - HasNA = Conversions.Instance.TryGetIsNAPredicate(type, out var del); + HasNA = Conversions.DefaultInstance.TryGetIsNAPredicate(type, out var del); } public override void Reset(int irow, int size) @@ -650,7 +665,18 @@ public Parser(TextLoader parent) _infos = parent._bindings.Infos; _creator = new Func[_infos.Length]; - var cache = ValueCreatorCache.Instance; + + ValueCreatorCache cache; + + var doubleParserOptionFlags = DoubleParser.OptionFlags.Default; + if (parent._decimalMarker == ',') + doubleParserOptionFlags |= DoubleParser.OptionFlags.UseCommaAsDecimalMarker; + + if (doubleParserOptionFlags == DoubleParser.OptionFlags.Default) + cache = ValueCreatorCache.DefaultInstance; + else + cache = ValueCreatorCache.GetInstanceWithDoubleParserOptionFlags(doubleParserOptionFlags); + var mapOne = new Dictionary>(); var mapVec = new Dictionary>(); for (int i = 0; i < _creator.Length; i++) @@ -1017,7 +1043,7 @@ public int GatherFields(ReadOnlyMemory lineSpan, ReadOnlySpan span, var spanT = Fields.Spans[Fields.Count - 1]; int csrc; - if (!Conversions.Instance.TryParse(in spanT, out csrc) || csrc <= 0) + if (!Conversions.DefaultInstance.TryParse(in spanT, out csrc) || csrc <= 0) { _stats.LogBadFmt(ref scan, "Bad dimensionality or ambiguous sparse item. Use sparse=- for non-sparse file, and/or quote the value."); break; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index 7ad78898fc..14ce35c5a5 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -123,7 +123,7 @@ protected ValueWriterBase(PrimitiveDataViewType type, int source, char sep) Conv = (ValueMapper)(Delegate)c; } else - Conv = Conversions.Instance.GetStringConversion(type); + Conv = Conversions.DefaultInstance.GetStringConversion(type); var d = default(T); Conv(in d, ref Sb); diff --git a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs index c84b737304..f6a3726f13 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs @@ -56,7 +56,7 @@ public static IDataView Create(IHostEnvironment env, string name, ID ident = true; conv = null; } - else if (!Conversions.Instance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident)) + else if (!Conversions.DefaultInstance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident)) { throw env.ExceptParam(nameof(mapper), "The type of column '{0}', '{1}', cannot be converted to the input type of the mapper '{2}'", diff --git a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs index 3961f50da9..70e46156d1 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs @@ -47,7 +47,7 @@ public static IDataView Create(IHostEnvironment env, string name, IDataVie ident = true; conv = null; } - else if (!Conversions.Instance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident)) + else if (!Conversions.DefaultInstance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident)) { throw env.ExceptParam(nameof(predicate), "The type of column '{0}', '{1}', cannot be converted to the input type of the predicate '{2}'", diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index 77dc58275b..0d317f0887 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -370,7 +370,7 @@ protected override bool MoveNextCore() protected override ValueGetter> GetGetterCore() { - var isDefault = Conversion.Conversions.Instance.GetIsDefaultPredicate(_view.Schema[_col].Type); + var isDefault = Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(_view.Schema[_col].Type); bool valid = false; VBuffer cached = default(VBuffer); return @@ -518,7 +518,7 @@ private void EnsureValid() Ch.Assert(itemType.RawType == typeof(T)); int vecLen = type.GetValueCount(); Ch.Assert(vecLen > 0); - InPredicate isDefault = Conversion.Conversions.Instance.GetIsDefaultPredicate(itemType); + InPredicate isDefault = Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(itemType); int maxPossibleSize = _rbuff.Length * vecLen; const int sparseThresholdRatio = 5; int sparseThreshold = (maxPossibleSize + sparseThresholdRatio - 1) / sparseThresholdRatio; diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs index 314416e1ad..1cd1c1a126 100644 --- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs @@ -38,7 +38,7 @@ public static ValueMapper GetSimpleMapper(DataViewSchema sc Contracts.Assert(0 <= col && col < schema.Count); var type = schema[col].Type.GetItemType(); Contracts.Assert(type.RawType == typeof(T)); - var conv = Conversion.Conversions.Instance; + var conv = Conversion.Conversions.DefaultInstance; // First: if not key, then get the standard string conversion. if (!(type is KeyDataViewType keyType)) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index 983f9605cd..5ffc82b1a1 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -323,16 +323,16 @@ public KeyToValueMap(Mapper parent, KeyDataViewType typeKey, PrimitiveDataViewTy // REVIEW: May want to include more specific information about what the specific value is for the default. DataViewType outputItemType = TypeOutput.GetItemType(); - _na = Data.Conversion.Conversions.Instance.GetNAOrDefault(outputItemType, out _naMapsToDefault); + _na = Data.Conversion.Conversions.DefaultInstance.GetNAOrDefault(outputItemType, out _naMapsToDefault); if (_naMapsToDefault) { // Only initialize _isDefault if _defaultIsNA is true as this is the only case in which it is used. - _isDefault = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(outputItemType); + _isDefault = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(outputItemType); } bool identity; - _convertToUInt = Data.Conversion.Conversions.Instance.GetStandardConversion(typeKey, NumberDataViewType.UInt32, out identity); + _convertToUInt = Data.Conversion.Conversions.DefaultInstance.GetStandardConversion(typeKey, NumberDataViewType.UInt32, out identity); } private void MapKey(in TKey src, ref TValue dst) diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index 21a52335ac..620594fb13 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -297,7 +297,7 @@ private static ValueOne CreateOne(Cursor cursor, ColInfo info) Contracts.Assert(info.Type.RawType == typeof(T)); var getSrc = cursor.Input.GetGetter(cursor.Input.Schema[info.Index]); - var hasBad = Data.Conversion.Conversions.Instance.GetIsNAPredicate(info.Type); + var hasBad = Data.Conversion.Conversions.DefaultInstance.GetIsNAPredicate(info.Type); return new ValueOne(cursor, getSrc, hasBad); } @@ -309,7 +309,7 @@ private static ValueVec CreateVec(Cursor cursor, ColInfo info) Contracts.Assert(info.Type.RawType == typeof(VBuffer)); var getSrc = cursor.Input.GetGetter>(cursor.Input.Schema[info.Index]); - var hasBad = Data.Conversion.Conversions.Instance.GetHasMissingPredicate((VectorDataViewType)info.Type); + var hasBad = Data.Conversion.Conversions.DefaultInstance.GetHasMissingPredicate((VectorDataViewType)info.Type); return new ValueVec(cursor, getSrc, hasBad); } diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs index d9db6e1f59..a1faa00ef2 100644 --- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs @@ -430,7 +430,7 @@ public KeyRowCursor(RangeFilter parent, DataViewRowCursor input, bool[] active) dst = _value; }; bool identity; - _conv = Data.Conversion.Conversions.Instance.GetStandardConversion(Parent._type, NumberDataViewType.UInt64, out identity); + _conv = Data.Conversion.Conversions.DefaultInstance.GetStandardConversion(Parent._type, NumberDataViewType.UInt64, out identity); } protected override Delegate GetGetter() diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 1bfe7f8882..0ab4e0520a 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -421,7 +421,7 @@ private static bool CanConvertToType(IExceptionContext ectx, DataViewType srcTyp // Ensure that the conversion is legal. We don't actually cache the delegate here. It will get // re-fetched by the utils code when needed. - if (!Data.Conversion.Conversions.Instance.TryGetStandardConversion(srcType.GetItemType(), itemType, out Delegate del, out bool identity)) + if (!Data.Conversion.Conversions.DefaultInstance.TryGetStandardConversion(srcType.GetItemType(), itemType, out Delegate del, out bool identity)) return false; typeDst = itemType; @@ -626,7 +626,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!TypeConvertingTransformer.GetNewType(Host, col.ItemType, colInfo.OutputKind.ToInternalDataKind(), colInfo.OutputKeyCount, out PrimitiveDataViewType newType)) throw Host.ExceptParam(nameof(inputSchema), $"Can't convert {colInfo.InputColumnName} into {newType.ToString()}"); - if (!Data.Conversion.Conversions.Instance.TryGetStandardConversion(col.ItemType, newType, out Delegate del, out bool identity)) + if (!Data.Conversion.Conversions.DefaultInstance.TryGetStandardConversion(col.ItemType, newType, out Delegate del, out bool identity)) throw Host.ExceptParam(nameof(inputSchema), $"Don't know how to convert {colInfo.InputColumnName} into {newType.ToString()}"); var metadata = new List(); if (col.ItemType is BooleanDataViewType && newType is NumberDataViewType) diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 23f23e5cb4..7a81bee190 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -495,7 +495,7 @@ private static TextLoader.Column GenerateValueColumn(IHostEnvironment env, // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0, // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for // computing max and min. - if (Data.Conversion.Conversions.Instance.TryParseKey(in value, ulong.MaxValue - 1, out res)) + if (Data.Conversion.Conversions.DefaultInstance.TryParseKey(in value, ulong.MaxValue - 1, out res)) { if (res < keyMin && res != 0) keyMin = res; @@ -504,7 +504,7 @@ private static TextLoader.Column GenerateValueColumn(IHostEnvironment env, } // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds, // then the value is 0, and we update min accordingly. - else if (Microsoft.ML.Data.Conversion.Conversions.Instance.TryParse(in value, out res)) + else if (Microsoft.ML.Data.Conversion.Conversions.DefaultInstance.TryParse(in value, out res)) { keyMin = 0; } @@ -863,7 +863,7 @@ public override void Train(IHostEnvironment env, DataViewRowCursor cursor) // First check if there is a String->ValueType conversion method. If so, call the conversion method with an // empty string, the returned value will be the new missing value. // NOTE this will return NA for R4 and R8 types. - if (Data.Conversion.Conversions.Instance.TryGetStandardConversion, TValue>( + if (Data.Conversion.Conversions.DefaultInstance.TryGetStandardConversion, TValue>( TextDataViewType.Instance, ValueColumn.Type, out conv, diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs index 9dbe1bb2da..78c467ef15 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs @@ -67,7 +67,7 @@ private static Builder CreateCore(PrimitiveDataViewType type, bool sorted) // of building our term dictionary. For the other types (practically, only the UX types), // we should ignore nothing. InPredicate mapsToMissing; - if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(type, out mapsToMissing)) + if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(type, out mapsToMissing)) mapsToMissing = (in T val) => false; return new Impl(type, mapsToMissing, sorted); } @@ -207,7 +207,7 @@ protected Builder(PrimitiveDataViewType type) public override void ParseAddTermArg(ref ReadOnlyMemory terms, IChannel ch) { T val; - var tryParse = Data.Conversion.Conversions.Instance.GetTryParseConversion(ItemType); + var tryParse = Data.Conversion.Conversions.DefaultInstance.GetTryParseConversion(ItemType); for (bool more = true; more;) { ReadOnlyMemory term; @@ -233,7 +233,7 @@ public override void ParseAddTermArg(ref ReadOnlyMemory terms, IChannel ch public override void ParseAddTermArg(string[] terms, IChannel ch) { T val; - var tryParse = Data.Conversion.Conversions.Instance.GetTryParseConversion(ItemType); + var tryParse = Data.Conversion.Conversions.DefaultInstance.GetTryParseConversion(ItemType); foreach (var sterm in terms) { ReadOnlyMemory term = sterm.AsMemory(); @@ -748,7 +748,7 @@ internal override void WriteTextTerms(TextWriter writer) { writer.WriteLine("# Number of terms of type '{0}' = {1}", ItemType, Count); StringBuilder sb = null; - var stringMapper = Data.Conversion.Conversions.Instance.GetStringConversion(ItemType); + var stringMapper = Data.Conversion.Conversions.DefaultInstance.GetStringConversion(ItemType); for (int i = 0; i < _values.Count; ++i) { T val = _values.GetItem(i); @@ -1046,7 +1046,7 @@ public override void AddMetadata(DataViewSchema.Annotations.Builder builder) return; if (IsTextMetadata && !(TypedMap.ItemType is TextDataViewType)) { - var conv = Data.Conversion.Conversions.Instance; + var conv = Data.Conversion.Conversions.DefaultInstance; var stringMapper = conv.GetStringConversion(TypedMap.ItemType); ValueGetter>> getter = @@ -1112,7 +1112,7 @@ private bool AddMetadataCore(DataViewType srcMetaType, DataViewSchema.Ann var srcType = TypedMap.ItemType as KeyDataViewType; _host.AssertValue(srcType); var dstType = new KeyDataViewType(typeof(uint), srcType.Count); - var convInst = Data.Conversion.Conversions.Instance; + var convInst = Data.Conversion.Conversions.DefaultInstance; ValueMapper conv; bool identity; // If we can't convert this type to U4, don't try to pass along the metadata. @@ -1192,7 +1192,7 @@ private bool WriteTextTermsCore(PrimitiveDataViewType srcMetaType, TextWr var srcType = TypedMap.ItemType as KeyDataViewType; _host.AssertValue(srcType); var dstType = new KeyDataViewType(typeof(uint), srcType.Count); - var convInst = Data.Conversion.Conversions.Instance; + var convInst = Data.Conversion.Conversions.DefaultInstance; ValueMapper conv; bool identity; // If we can't convert this type to U4, don't try. diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index a782a12b51..766e7eb3d1 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1289,7 +1289,7 @@ private static int AddColumnIfNeeded(DataViewSchema.Column? info, List toTr private ValueMapper, VBuffer> GetCopier(DataViewType itemType1, DataViewType itemType2) { - var conv = Conversions.Instance.GetStandardConversion(itemType1, itemType2, out bool identity); + var conv = Conversions.DefaultInstance.GetStandardConversion(itemType1, itemType2, out bool identity); if (identity) { ValueMapper, VBuffer> identityResult = @@ -1368,7 +1368,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB BinFinder finder = new BinFinder(); FeaturesToContentMap fmap = new FeaturesToContentMap(examples.Schema); - var hasMissingPred = Conversions.Instance.GetHasMissingPredicate(((ITransposeDataView)trans).GetSlotType(featIdx)); + var hasMissingPred = Conversions.DefaultInstance.GetHasMissingPredicate(((ITransposeDataView)trans).GetSlotType(featIdx)); // There is no good mechanism to filter out rows with missing feature values on transposed data. // So, we instead perform one featurization pass which, if successful, will remain one pass but, // if we ever encounter missing values will become a "detect missing features" pass, which will diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 2d2977fbe6..9630bec3f9 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -733,8 +733,8 @@ private static IDataView AppendFloatMapper(IHostEnvironment env, IChanne // key-types we just upfront convert it to the most general type (ulong) and work from there. KeyDataViewType dstType = new KeyDataViewType(typeof(ulong), type.Count); bool identity; - var converter = Conversions.Instance.GetStandardConversion(type, dstType, out identity); - var isNa = Conversions.Instance.GetIsNAPredicate(type); + var converter = Conversions.DefaultInstance.GetStandardConversion(type, dstType, out identity); + var isNa = Conversions.DefaultInstance.GetIsNAPredicate(type); ValueMapper mapper; if (seed == 0) diff --git a/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs index 2ce327b9ba..1476f4d28c 100644 --- a/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs @@ -254,7 +254,7 @@ private TPredictor CreatePredictor(VBuffer weights, float bias) VBuffer maybeSparseWeights = default; VBufferUtils.CreateMaybeSparseCopy(in weights, ref maybeSparseWeights, - Conversions.Instance.GetIsDefaultPredicate(NumberDataViewType.Single)); + Conversions.DefaultInstance.GetIsDefaultPredicate(NumberDataViewType.Single)); var predictor = new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias); return new ParameterMixingCalibratedModelParameters(Host, predictor, new PlattCalibrator(Host, -1, 0)); } diff --git a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs index 106bd27cb0..863662eeed 100644 --- a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs @@ -626,7 +626,7 @@ private ValueGetter GetterDelegateCore(int col, DataViewType typ Ch.Check(col >= 0 && col < _colValues.Length); Ch.AssertValue(type); - var conv = Conversions.Instance.GetStandardConversion(TextDataViewType.Instance, type) as ValueMapper, TValue>; + var conv = Conversions.DefaultInstance.GetStandardConversion(TextDataViewType.Instance, type) as ValueMapper, TValue>; if (conv == null) { throw Ch.Except("Invalid TValue: '{0}' of the conversion.", typeof(TValue)); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs index 1b03a860dc..d3695c870e 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs @@ -1525,7 +1525,7 @@ private protected LinearBinaryModelParameters CreateLinearBinaryModelParameters( VBuffer maybeSparseWeights = default; // below should be `in weights[0]`, but can't because of https://github.com/dotnet/roslyn/issues/29371 VBufferUtils.CreateMaybeSparseCopy(weights[0], ref maybeSparseWeights, - Conversions.Instance.GetIsDefaultPredicate(NumberDataViewType.Single)); + Conversions.DefaultInstance.GetIsDefaultPredicate(NumberDataViewType.Single)); return new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias[0]); } @@ -1813,7 +1813,7 @@ private protected override IPredictorWithFeatureWeights CreatePredictor(V VBuffer maybeSparseWeights = default; // below should be `in weights[0]`, but can't because of https://github.com/dotnet/roslyn/issues/29371 VBufferUtils.CreateMaybeSparseCopy(weights[0], ref maybeSparseWeights, - Conversions.Instance.GetIsDefaultPredicate(NumberDataViewType.Single)); + Conversions.DefaultInstance.GetIsDefaultPredicate(NumberDataViewType.Single)); var predictor = new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias[0]); if (Info.NeedCalibration) @@ -2210,7 +2210,7 @@ private protected LinearBinaryModelParameters CreateLinearBinaryModelParameters( VBuffer maybeSparseWeights = default; VBufferUtils.CreateMaybeSparseCopy(weights, ref maybeSparseWeights, - Conversions.Instance.GetIsDefaultPredicate(NumberDataViewType.Single)); + Conversions.DefaultInstance.GetIsDefaultPredicate(NumberDataViewType.Single)); return new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias); } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/SdcaRegression.cs index eb0818e5d3..6329fabbb8 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/SdcaRegression.cs @@ -150,7 +150,7 @@ private protected override LinearRegressionModelParameters CreatePredictor(VBuff VBuffer maybeSparseWeights = default; // below should be `in weights[0]`, but can't because of https://github.com/dotnet/roslyn/issues/29371 VBufferUtils.CreateMaybeSparseCopy(weights[0], ref maybeSparseWeights, - Conversions.Instance.GetIsDefaultPredicate(NumberDataViewType.Single)); + Conversions.DefaultInstance.GetIsDefaultPredicate(NumberDataViewType.Single)); return new LinearRegressionModelParameters(Host, in maybeSparseWeights, bias[0]); } diff --git a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs index 1851e518c0..78adaa4f9c 100644 --- a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs @@ -100,7 +100,7 @@ private TInput GetNaValue() int index; sch.TryGetColumnIndex(InputColumnName, out index); DataViewType col = sch[index].Type; - TInput nanValue = Data.Conversion.Conversions.Instance.GetNAOrDefault(col); + TInput nanValue = Data.Conversion.Conversions.DefaultInstance.GetNAOrDefault(col); // We store the nan_value here to avoid getting it each time a state is instanciated. return nanValue; diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index 1c09ae54ef..dd210b14cb 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -399,8 +399,8 @@ public CountAggregator(DataViewType type, ValueGetter getter) getter(ref t); VBufferEditor.CreateFromBuffer(ref _buffer).Values[0] = t; }; - _isDefault = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(type); - if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(type, out _isMissing)) + _isDefault = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(type); + if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(type, out _isMissing)) _isMissing = (in T value) => false; } @@ -410,8 +410,8 @@ public CountAggregator(VectorDataViewType type, ValueGetter> getter) var size = type.Size; _count = new long[size]; _fillBuffer = () => getter(ref _buffer); - _isDefault = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(type.ItemType); - if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out _isMissing)) + _isDefault = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(type.ItemType); + if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(type.ItemType, out _isMissing)) _isMissing = (in T value) => false; } diff --git a/src/Microsoft.ML.Transforms/Expression/BuiltinFunctions.cs b/src/Microsoft.ML.Transforms/Expression/BuiltinFunctions.cs index 25644e89f9..2039123587 100644 --- a/src/Microsoft.ML.Transforms/Expression/BuiltinFunctions.cs +++ b/src/Microsoft.ML.Transforms/Expression/BuiltinFunctions.cs @@ -844,21 +844,21 @@ public static BL IsNA(R8 a) public static BL ToBL(TX a) { BL res = default(BL); - Conversions.Instance.Convert(in a, ref res); + Conversions.DefaultInstance.Convert(in a, ref res); return res; } public static I4 ToI4(TX a) { I4 res = default(I4); - Conversions.Instance.Convert(in a, ref res); + Conversions.DefaultInstance.Convert(in a, ref res); return res; } public static I8 ToI8(TX a) { I8 res = default(I8); - Conversions.Instance.Convert(in a, ref res); + Conversions.DefaultInstance.Convert(in a, ref res); return res; } @@ -880,7 +880,7 @@ public static R4 ToR4(R8 a) public static R4 ToR4(TX a) { R4 res = default(R4); - Conversions.Instance.Convert(in a, ref res); + Conversions.DefaultInstance.Convert(in a, ref res); return res; } @@ -902,7 +902,7 @@ public static R8 ToR8(R8 a) public static R8 ToR8(TX a) { R8 res = default(R8); - Conversions.Instance.Convert(in a, ref res); + Conversions.DefaultInstance.Convert(in a, ref res); return res; } diff --git a/src/Microsoft.ML.Transforms/ExpressionTransformer.cs b/src/Microsoft.ML.Transforms/ExpressionTransformer.cs index 476e88025b..8b55a30aed 100644 --- a/src/Microsoft.ML.Transforms/ExpressionTransformer.cs +++ b/src/Microsoft.ML.Transforms/ExpressionTransformer.cs @@ -609,7 +609,7 @@ private ValueGetter> GetGetterVec(IExceptionContext ectx var src0 = default(VBuffer); var dstDef = fn(default(T0)); - var isDef = Conversions.Instance.GetIsDefaultPredicate(outputColumnItemType); + var isDef = Conversions.DefaultInstance.GetIsDefaultPredicate(outputColumnItemType); if (isDef(in dstDef)) { // Sparsity is preserved. @@ -668,7 +668,7 @@ private ValueGetter> GetGetterVec(IExceptionContext ectx.Assert(inputColumns.Length == 2); ectx.Assert(perm.Length == 2); - var isDef = Conversions.Instance.GetIsDefaultPredicate(outputColumnItemType); + var isDef = Conversions.DefaultInstance.GetIsDefaultPredicate(outputColumnItemType); var fn = (Func)del; var getSrc0 = input.GetGetter>(inputColumns[perm[0]]); var getSrc1 = input.GetGetter(inputColumns[perm[1]]); @@ -728,7 +728,7 @@ private ValueGetter> GetGetterVec(IExceptionCont ectx.Assert(inputColumns.Length == 3); ectx.Assert(perm.Length == 3); - var isDef = Conversions.Instance.GetIsDefaultPredicate(outputColumnItemType); + var isDef = Conversions.DefaultInstance.GetIsDefaultPredicate(outputColumnItemType); var fn = (Func)del; var getSrc0 = input.GetGetter>(inputColumns[perm[0]]); var getSrc1 = input.GetGetter(inputColumns[perm[1]]); @@ -791,7 +791,7 @@ private ValueGetter> GetGetterVec(IException ectx.Assert(inputColumns.Length == 4); ectx.Assert(perm.Length == 4); - var isDef = Conversions.Instance.GetIsDefaultPredicate(outputColumnItemType); + var isDef = Conversions.DefaultInstance.GetIsDefaultPredicate(outputColumnItemType); var fn = (Func)del; var getSrc0 = input.GetGetter>(inputColumns[perm[0]]); var getSrc1 = input.GetGetter(inputColumns[perm[1]]); @@ -857,7 +857,7 @@ private ValueGetter> GetGetterVec(IExcep ectx.Assert(inputColumns.Length == 5); ectx.Assert(perm.Length == 5); - var isDef = Conversions.Instance.GetIsDefaultPredicate(outputColumnItemType); + var isDef = Conversions.DefaultInstance.GetIsDefaultPredicate(outputColumnItemType); var fn = (Func)del; var getSrc0 = input.GetGetter>(inputColumns[perm[0]]); var getSrc1 = input.GetGetter(inputColumns[perm[1]]); diff --git a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs index 7242b129b7..c977f423ef 100644 --- a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs @@ -619,7 +619,7 @@ private HashDelegate ComposeHashDelegate() // Default case: convert to text and hash as a string. var sb = default(StringBuilder); - var conv = Data.Conversion.Conversions.Instance.GetStringConversion(); + var conv = Data.Conversion.Conversions.DefaultInstance.GetStringConversion(); return (in TSrc value, uint seed) => { diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index 73080dbc22..7d9c5dc0a1 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -46,7 +46,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName); if (originalColumn.Kind == SchemaShape.Column.VectorKind.Scalar) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", originalColumn.Name, "Vector", "Scalar"); - if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(originalColumn.ItemType, out _)) + if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(originalColumn.ItemType, out _)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", originalColumn.Name, "Single, Double or Key", originalColumn.ItemType.ToString()); var col = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.VariableVector, originalColumn.ItemType, originalColumn.IsKey, originalColumn.Annotations); resultDic[outputColumnName] = col; @@ -195,7 +195,7 @@ public Mapper(MissingValueDroppingTransformer parent, DataViewSchema inputSchema var srcCol = inputSchema[_srcCols[i]]; if (!(srcCol.Type is VectorDataViewType)) throw _parent.Host.Except($"Column '{srcCol.Name}' is not a vector column"); - if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(srcCol.Type.GetItemType(), out _isNAs[i])) + if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(srcCol.Type.GetItemType(), out _isNAs[i])) throw _parent.Host.Except($"Column '{srcCol.Name}' is of type {srcCol.Type.GetItemType()}, which does not support missing values"); _srcTypes[i] = srcCol.Type; _types[i] = new VectorDataViewType((PrimitiveDataViewType)srcCol.Type.GetItemType()); diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs index 96e09c1299..99499410bb 100644 --- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs @@ -163,7 +163,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa throw h.Except("Column '{0}' does not exist", column.Source); var replaceType = input.Schema[inputCol].Type; var replaceItemType = replaceType.GetItemType(); - if (!Data.Conversion.Conversions.Instance.TryGetStandardConversion(BooleanDataViewType.Instance, replaceItemType, out Delegate conv, out bool identity)) + if (!Data.Conversion.Conversions.DefaultInstance.TryGetStandardConversion(BooleanDataViewType.Instance, replaceItemType, out Delegate conv, out bool identity)) { throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'", BooleanDataViewType.Instance, replaceItemType); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index dde8e81c06..54cec41759 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -229,7 +229,7 @@ private static Delegate GetIsNADelegate(DataViewType type) private static Delegate GetIsNADelegate(DataViewType type) { - return Data.Conversion.Conversions.Instance.GetIsNAPredicate(type.GetItemType()); + return Data.Conversion.Conversions.DefaultInstance.GetIsNAPredicate(type.GetItemType()); } protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) @@ -534,7 +534,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col) || !Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) + if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col) || !Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName); var metadata = new List(); if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta)) diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 56c8a857b7..76578f4f2d 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -162,7 +162,7 @@ internal static string TestType(DataViewType type) private static string TestType(DataViewType type) { Contracts.Assert(type.GetItemType().RawType == typeof(T)); - if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(type.GetItemType(), out InPredicate isNA)) + if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(type.GetItemType(), out InPredicate isNA)) { return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", type, LoadName); @@ -254,7 +254,7 @@ private T[] GetValuesArray(VBuffer src, VectorDataViewType srcType, int ii Host.Assert(srcType != null); Host.Assert(srcType.Size == src.Length); VBufferUtils.Densify(ref src); - InPredicate defaultPred = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(srcType.ItemType); + InPredicate defaultPred = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(srcType.ItemType); _repIsDefault[iinfo] = new BitArray(srcType.Size); var srcValues = src.GetValues(); for (int slot = 0; slot < srcValues.Length; slot++) @@ -366,7 +366,7 @@ private BitArray ComputeDefaultSlots(DataViewType type, Array values) { Host.Assert(values.Length == type.GetVectorSize()); BitArray defaultSlots = new BitArray(values.Length); - InPredicate defaultPred = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(type.GetItemType()); + InPredicate defaultPred = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(type.GetItemType()); T[] typedValues = (T[])values; for (int slot = 0; slot < values.Length; slot++) { @@ -394,7 +394,7 @@ private Delegate GetIsNADelegate(DataViewType type) } private Delegate GetIsNADelegate(DataViewType type) - => Data.Conversion.Conversions.Instance.GetIsNAPredicate(type.GetItemType()); + => Data.Conversion.Conversions.DefaultInstance.GetIsNAPredicate(type.GetItemType()); /// /// Converts a string to its respective value in the corresponding type. @@ -413,7 +413,7 @@ private object GetSpecifiedValue(string srcStr, DataViewType dstType, InPredi { // Handles converting input strings to correct types. var srcTxt = srcStr.AsMemory(); - var strToT = Data.Conversion.Conversions.Instance.GetStandardConversion, T>(TextDataViewType.Instance, dstType.GetItemType(), out bool identity); + var strToT = Data.Conversion.Conversions.DefaultInstance.GetStandardConversion, T>(TextDataViewType.Instance, dstType.GetItemType(), out bool identity); strToT(in srcTxt, ref val); // Make sure that the srcTxt can legitimately be converted to dstType, throw error otherwise. if (isNA(in val)) @@ -667,7 +667,7 @@ private Delegate ComposeGetterVec(DataViewRow input, int iinfo) { var getSrc = input.GetGetter>(input.Schema[ColMapNewToOld[iinfo]]); var isNA = (InPredicate)_isNAs[iinfo]; - var isDefault = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(_infos[iinfo].TypeSrc.GetItemType()); + var isDefault = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate(_infos[iinfo].TypeSrc.GetItemType()); var src = default(VBuffer); ValueGetter> getter; diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index efecb8fee8..45b960935f 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -766,7 +766,7 @@ private void FillTable(in VBuffer features, int offset, int numFeatures) /// private static ValueMapper, VBuffer> BinKeys(DataViewType colType) { - var conv = Data.Conversion.Conversions.Instance.GetStandardConversion(colType, NumberDataViewType.UInt32, out bool identity); + var conv = Data.Conversion.Conversions.DefaultInstance.GetStandardConversion(colType, NumberDataViewType.UInt32, out bool identity); ValueMapper mapper; if (identity) { diff --git a/src/Microsoft.ML.Transforms/SvmLight/SvmLightLoader.cs b/src/Microsoft.ML.Transforms/SvmLight/SvmLightLoader.cs index 11d1a0e774..aaa116498b 100644 --- a/src/Microsoft.ML.Transforms/SvmLight/SvmLightLoader.cs +++ b/src/Microsoft.ML.Transforms/SvmLight/SvmLightLoader.cs @@ -142,8 +142,8 @@ private sealed class InputMapper public InputMapper() { _seps = new char[] { ' ', '\t' }; - _tryFloatParse = Conversions.Instance.GetTryParseConversion(NumberDataViewType.Single); - _tryLongParse = Conversions.Instance.GetTryParseConversion(NumberDataViewType.Int64); + _tryFloatParse = Conversions.DefaultInstance.GetTryParseConversion(NumberDataViewType.Single); + _tryLongParse = Conversions.DefaultInstance.GetTryParseConversion(NumberDataViewType.Int64); } public void MapInput(Input input, IntermediateInput intermediate) @@ -306,7 +306,7 @@ public void ParseIndices(IntermediateInput input, Indices output) var inputValues = input.FeatureKeys.GetValues(); for (int i = 0; i < inputValues.Length; i++) { - if (Conversions.Instance.TryParse(in inputValues[i], out uint index)) + if (Conversions.DefaultInstance.TryParse(in inputValues[i], out uint index)) { if (index < _offset) { @@ -671,7 +671,7 @@ private static IDataView GetData(IHostEnvironment env, long? numRows, IMultiStre private static uint InferMax(IHostEnvironment env, IDataView view) { ulong keyMax = 0; - var parser = Conversions.Instance.GetTryParseConversion(NumberDataViewType.UInt64); + var parser = Conversions.DefaultInstance.GetTryParseConversion(NumberDataViewType.UInt64); var col = view.Schema.GetColumnOrNull(nameof(IntermediateInput.FeatureKeys)); env.Assert(col.HasValue); diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index 7b5ad5aa57..03aa8c4c82 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -633,7 +633,7 @@ private ValueGetter MakeGetter(int col, PrimitiveDataViewType itemType) // cachedIndex == row.Count || _pivotColPosition <= row.Indices[cachedIndex]. int cachedIndex = 0; VBuffer row = default(VBuffer); - T naValue = Data.Conversion.Conversions.Instance.GetNAOrDefault(itemType); + T naValue = Data.Conversion.Conversions.DefaultInstance.GetNAOrDefault(itemType); return (ref T value) => { diff --git a/test/Microsoft.ML.AutoML.Tests/ConversionTests.cs b/test/Microsoft.ML.AutoML.Tests/ConversionTests.cs index 5289aded40..3dae4e5c21 100644 --- a/test/Microsoft.ML.AutoML.Tests/ConversionTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/ConversionTests.cs @@ -34,7 +34,7 @@ public void ConvertFloatMissingValues() foreach(var missingValue in missingValues) { float value; - var success = Conversions.Instance.TryParse(missingValue.AsMemory(), out value); + var success = Conversions.DefaultInstance.TryParse(missingValue.AsMemory(), out value); _output.WriteLine($"{missingValue} parsed as {value}"); Assert.True(success); //Assert.Equal(float.NaN, value); @@ -51,7 +51,7 @@ public void ConvertFloatParseFailure() foreach (var value in values) { - var success = Conversions.Instance.TryParse(value.AsMemory(), out float _); + var success = Conversions.DefaultInstance.TryParse(value.AsMemory(), out float _); Assert.False(success); } } @@ -70,7 +70,7 @@ public void ConvertBoolMissingValues() foreach (var missingValue in missingValues) { - var success = Conversions.Instance.TryParse(missingValue.AsMemory(), out bool _); + var success = Conversions.DefaultInstance.TryParse(missingValue.AsMemory(), out bool _); Assert.True(success); } } @@ -88,7 +88,7 @@ public void ConvertBoolParseFailure() foreach (var value in values) { - var success = Conversions.Instance.TryParse(value.AsMemory(), out bool _); + var success = Conversions.DefaultInstance.TryParse(value.AsMemory(), out bool _); Assert.False(success); } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs index 6c93616cbf..21384c7721 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -18,14 +18,14 @@ public DataTypesTest(ITestOutputHelper helper) { } - private readonly static Conversions _conv = Conversions.Instance; + private readonly static Conversions _conv = Conversions.DefaultInstance; [Fact] public void R4ToSBtoR4() { - var r4ToSB = Conversions.Instance.GetStringConversion(NumberDataViewType.Single); + var r4ToSB = Conversions.DefaultInstance.GetStringConversion(NumberDataViewType.Single); - var txToR4 = Conversions.Instance.GetStandardConversion< ReadOnlyMemory, float>( + var txToR4 = Conversions.DefaultInstance.GetStandardConversion< ReadOnlyMemory, float>( TextDataViewType.Instance, NumberDataViewType.Single, out bool identity2); Assert.NotNull(r4ToSB); @@ -47,9 +47,9 @@ public void R4ToSBtoR4() [Fact] public void R8ToSBtoR8() { - var r8ToSB = Conversions.Instance.GetStringConversion(NumberDataViewType.Double); + var r8ToSB = Conversions.DefaultInstance.GetStringConversion(NumberDataViewType.Double); - var txToR8 = Conversions.Instance.GetStandardConversion, double>( + var txToR8 = Conversions.DefaultInstance.GetStandardConversion, double>( TextDataViewType.Instance, NumberDataViewType.Double, out bool identity2); Assert.NotNull(r8ToSB); @@ -232,7 +232,7 @@ public void TXToLong() public void DTToDT() { bool identity; - var dtToDT = Conversions.Instance.GetStandardConversion( + var dtToDT = Conversions.DefaultInstance.GetStandardConversion( DateTimeDataViewType.Instance, DateTimeDataViewType.Instance, out identity); Assert.NotNull(dtToDT); @@ -252,7 +252,7 @@ private static ValueMapper GetMapper(DataViewType dstTyp { Assert.True(typeof(TDst) == dstType.RawType); - return Conversions.Instance.GetStandardConversion( + return Conversions.DefaultInstance.GetStandardConversion( TextDataViewType.Instance, dstType, out bool identity); } } diff --git a/test/Microsoft.ML.Tests/ExpressionLanguageTests/ExpressionLanguageTests.cs b/test/Microsoft.ML.Tests/ExpressionLanguageTests/ExpressionLanguageTests.cs index a05b4fa419..3180440f9a 100644 --- a/test/Microsoft.ML.Tests/ExpressionLanguageTests/ExpressionLanguageTests.cs +++ b/test/Microsoft.ML.Tests/ExpressionLanguageTests/ExpressionLanguageTests.cs @@ -291,7 +291,7 @@ private Func, bool> GetGetter(int i, DataViewType dst, obje src => { bool v; - bool tmp = Conversions.Instance.TryParse(in src, out v); + bool tmp = Conversions.DefaultInstance.TryParse(in src, out v); args[i] = v; return tmp; }; @@ -300,7 +300,7 @@ private Func, bool> GetGetter(int i, DataViewType dst, obje src => { int v; - bool tmp = Conversions.Instance.TryParse(in src, out v); + bool tmp = Conversions.DefaultInstance.TryParse(in src, out v); args[i] = v; return tmp; }; @@ -309,7 +309,7 @@ private Func, bool> GetGetter(int i, DataViewType dst, obje src => { long v; - bool tmp = Conversions.Instance.TryParse(in src, out v); + bool tmp = Conversions.DefaultInstance.TryParse(in src, out v); args[i] = v; return tmp; }; @@ -318,7 +318,7 @@ private Func, bool> GetGetter(int i, DataViewType dst, obje src => { float v; - bool tmp = Conversions.Instance.TryParse(in src, out v); + bool tmp = Conversions.DefaultInstance.TryParse(in src, out v); args[i] = v; return tmp; }; @@ -327,7 +327,7 @@ private Func, bool> GetGetter(int i, DataViewType dst, obje src => { double v; - bool tmp = Conversions.Instance.TryParse(in src, out v); + bool tmp = Conversions.DefaultInstance.TryParse(in src, out v); args[i] = v; return tmp; }; @@ -353,35 +353,35 @@ private Action GetPrinter(DataViewType dst, StringBuilder sb) src => { var v = (bool)src; - Conversions.Instance.Convert(in v, ref sb); + Conversions.DefaultInstance.Convert(in v, ref sb); }; case InternalDataKind.I4: return src => { var v = (int)src; - Conversions.Instance.Convert(in v, ref sb); + Conversions.DefaultInstance.Convert(in v, ref sb); }; case InternalDataKind.I8: return src => { var v = (long)src; - Conversions.Instance.Convert(in v, ref sb); + Conversions.DefaultInstance.Convert(in v, ref sb); }; case InternalDataKind.R4: return src => { var v = (Single)src; - Conversions.Instance.Convert(in v, ref sb); + Conversions.DefaultInstance.Convert(in v, ref sb); }; case InternalDataKind.R8: return src => { var v = (Double)src; - Conversions.Instance.Convert(in v, ref sb); + Conversions.DefaultInstance.Convert(in v, ref sb); }; case InternalDataKind.TX: return diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index a4d44c5cc2..358fad1272 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -6,7 +6,9 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Reflection; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.RunTests; using Microsoft.ML.Runtime; @@ -1002,6 +1004,126 @@ public void TestWrongDecimalMarkerInputs(bool useCommaAsDecimalMarker) } } + [Theory] + [InlineData(true, true)] + [InlineData(false, false)] + [InlineData(true, false)] + [InlineData(false, true)] + public void TestDifferentDecimalMarkersAtTheSameTime(bool useCorrectPeriod, bool useCorrectComma) + { + // Using 2 different textloaders, with different decimal markers + // should yield the expected results even when using their cursors at the same time + // in all of the scenarios tested here + + var mlContext = new MLContext(seed: 1); + + var periodPath = GetDataPath("iris.txt"); + var commaPath = GetDataPath("iris-decimal-marker-as-comma.txt"); + + var optionsPeriod = new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("Label", DataKind.UInt32, 0), + new TextLoader.Column("Features", DataKind.Single, new[] { new TextLoader.Range(1, 4) }) + }, + DecimalMarker = '.' + }; + + var optionsComma = new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("Label", DataKind.UInt32, 0), + new TextLoader.Column("Features", DataKind.Single, new[] { new TextLoader.Range(1, 4) }) + }, + DecimalMarker = ',' + }; + + for (int j = 0; j < 2; j++) + { + // Run various times inside the same test, to also test that TextLoader is only creating 1 + // Custom instance of ValueCreatorCache + + IDataView dataViewPeriod; + IDataView dataViewComma; + + if (useCorrectPeriod) + dataViewPeriod = mlContext.Data.LoadFromTextFile(periodPath, optionsPeriod); + else + dataViewPeriod = mlContext.Data.LoadFromTextFile(commaPath, optionsPeriod); + + if (useCorrectComma) + dataViewComma = mlContext.Data.LoadFromTextFile(commaPath, optionsComma); + else + dataViewComma = mlContext.Data.LoadFromTextFile(periodPath, optionsComma); + + VBuffer featuresPeriod = default; + VBuffer featuresComma = default; + + + using (var cursorPeriod = dataViewPeriod.GetRowCursor(dataViewPeriod.Schema)) + using (var cursorComma = dataViewComma.GetRowCursor(dataViewComma.Schema)) + { + var delegatePeriod = cursorPeriod.GetGetter>(dataViewPeriod.Schema["Features"]); + var delegateComma = cursorComma.GetGetter>(dataViewPeriod.Schema["Features"]); + while (cursorPeriod.MoveNext() && cursorComma.MoveNext()) + { + delegatePeriod(ref featuresPeriod); + delegateComma(ref featuresComma); + + var featuresPeriodArray = featuresPeriod.GetValues().ToArray(); + var featuresCommaArray = featuresComma.GetValues().ToArray(); + Assert.Equal(featuresPeriodArray.Length, featuresCommaArray.Length); + + for (int i = 0; i < featuresPeriodArray.Length; i++) + { + if (useCorrectPeriod && useCorrectComma) + { + // Check that none of the two files loadad NaNs + // As both of them should have been loaded correctly + Assert.Equal(featuresPeriodArray[i], featuresCommaArray[i]); + Assert.NotEqual(Single.NaN, featuresPeriodArray[i]); + } + else if (!useCorrectPeriod && !useCorrectComma) + { + // Check that everything was loaded as NaN + // Because the wrong decimal marker was used for both loaders + Assert.Equal(featuresPeriodArray[i], featuresCommaArray[i]); + Assert.Equal(Single.NaN, featuresPeriodArray[i]); + } + else if (!useCorrectPeriod && useCorrectComma) + { + // Check that only the file with commas was loaded correctly + Assert.Equal(Single.NaN, featuresPeriodArray[i]); + Assert.NotEqual(Single.NaN, featuresCommaArray[i]); + } + else + { + // Check that only the file with periods was loaded correctly + Assert.NotEqual(Single.NaN, featuresPeriodArray[i]); + Assert.Equal(Single.NaN, featuresCommaArray[i]); + } + } + } + } + + // Check how many custom instances there are of TextLoader.ValueCreatorCache + var vccType = typeof(TextLoader).GetNestedType("ValueCreatorCache", BindingFlags.NonPublic | BindingFlags.Static); + var customInstancesInfo = vccType.GetField("_customInstances", BindingFlags.NonPublic | BindingFlags.Static); + var customInstancesObject = customInstancesInfo.GetValue(null); + var customInstancesCount = (int)customInstancesObject.GetType().GetProperty("Count").GetValue(customInstancesObject, null); + var customInstancesContainsMethod = customInstancesObject.GetType().GetMethod("ContainsKey"); + + // Regardless of useCorrectPeriod and useCorrectComma + // Since we always created a TextLoader with Comma as DecimalMarker + // There should always be 1, and only 1, custom instance of ValueCreatorCache, corresponding to the comma option + // Even after running multiple times the loop above. + Assert.Equal(1, customInstancesCount); + Assert.True((bool)customInstancesContainsMethod.Invoke(customInstancesObject, new[] { (object) DoubleParser.OptionFlags.UseCommaAsDecimalMarker })); + } + } + private class IrisNoFields { }