From 657b4889d966249dc84a3aa83aed75652fc52fab Mon Sep 17 00:00:00 2001 From: Prashanth Govindarajan Date: Fri, 30 Apr 2021 15:01:44 -0700 Subject: [PATCH 1/4] Handle nulls better in Merge --- .../ArrowStringDataFrameColumn.cs | 26 ++- src/Microsoft.Data.Analysis/DataFrame.Join.cs | 180 +++++++++++------- .../DataFrameColumn.cs | 7 +- .../PrimitiveDataFrameColumn.cs | 25 ++- .../StringDataFrameColumn.cs | 23 ++- .../DataFrameTests.cs | 128 +++++++++++++ 6 files changed, 291 insertions(+), 98 deletions(-) diff --git a/src/Microsoft.Data.Analysis/ArrowStringDataFrameColumn.cs b/src/Microsoft.Data.Analysis/ArrowStringDataFrameColumn.cs index a80e85c173..d05912ca6f 100644 --- a/src/Microsoft.Data.Analysis/ArrowStringDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/ArrowStringDataFrameColumn.cs @@ -460,34 +460,42 @@ private ArrowStringDataFrameColumn Clone(PrimitiveDataFrameColumn mapIndice /// public override DataFrame ValueCounts() { - Dictionary> groupedValues = GroupColumnValues(); + Dictionary> groupedValues = GroupColumnValues(out HashSet _); return StringDataFrameColumn.ValueCountsImplementation(groupedValues); } /// public override GroupBy GroupBy(int columnIndex, DataFrame parent) { - Dictionary> dictionary = GroupColumnValues(); + Dictionary> dictionary = GroupColumnValues(out HashSet _); return new GroupBy(parent, columnIndex, dictionary); } /// - public override Dictionary> GroupColumnValues() + public override Dictionary> GroupColumnValues(out HashSet nullIndices) { if (typeof(TKey) == typeof(string)) { + nullIndices = new HashSet(); Dictionary> multimap = new Dictionary>(EqualityComparer.Default); for (long i = 0; i < Length; i++) { - string str = this[i] ?? "__null__"; - bool containsKey = multimap.TryGetValue(str, out ICollection values); - if (containsKey) + string str = this[i]; + if (str != null) { - values.Add(i); + bool containsKey = multimap.TryGetValue(str, out ICollection values); + if (containsKey) + { + values.Add(i); + } + else + { + multimap.Add(str, new List() { i }); + } } else { - multimap.Add(str, new List() { i }); + nullIndices.Add(i); } } return multimap as Dictionary>; @@ -499,7 +507,7 @@ public override Dictionary> GroupColumnValues() } /// - public ArrowStringDataFrameColumn FillNulls(string value, bool inPlace = false) + public ArrowStringDataFrameColumn FillNulls(string value, bool inPlace = false) { if (value == null) { diff --git a/src/Microsoft.Data.Analysis/DataFrame.Join.cs b/src/Microsoft.Data.Analysis/DataFrame.Join.cs index 381268dee2..a8bc9bfc20 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.Join.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.Join.cs @@ -168,7 +168,7 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ { // First hash other dataframe on the rightJoinColumn DataFrameColumn otherColumn = other.Columns[rightJoinColumn]; - Dictionary> multimap = otherColumn.GroupColumnValues(); + Dictionary> multimap = otherColumn.GroupColumnValues(out HashSet otherColumnNullIndices); // Go over the records in this dataframe and match with the dictionary DataFrameColumn thisColumn = Columns[leftJoinColumn]; @@ -176,74 +176,64 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ for (long i = 0; i < thisColumn.Length; i++) { var thisColumnValue = thisColumn[i]; - TKey thisColumnValueOrDefault = (TKey)(thisColumnValue == null ? default(TKey) : thisColumnValue); - if (multimap.TryGetValue(thisColumnValueOrDefault, out ICollection rowNumbers)) + if (thisColumnValue != null) { - foreach (long row in rowNumbers) + if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection rowNumbers)) { - if (thisColumnValue == null) + foreach (long row in rowNumbers) { - // Match only with nulls in otherColumn - if (otherColumn[row] == null) - { - leftRowIndices.Append(i); - rightRowIndices.Append(row); - } - } - else - { - // Cannot match nulls in otherColumn - if (otherColumn[row] != null) - { - leftRowIndices.Append(i); - rightRowIndices.Append(row); - } + leftRowIndices.Append(i); + rightRowIndices.Append(row); } } + else + { + leftRowIndices.Append(i); + rightRowIndices.Append(null); + } } else { - leftRowIndices.Append(i); - rightRowIndices.Append(null); + foreach (long row in otherColumnNullIndices) + { + leftRowIndices.Append(i); + rightRowIndices.Append(row); + } } } } else if (joinAlgorithm == JoinAlgorithm.Right) { DataFrameColumn thisColumn = Columns[leftJoinColumn]; - Dictionary> multimap = thisColumn.GroupColumnValues(); + Dictionary> multimap = thisColumn.GroupColumnValues(out HashSet thisColumnNullIndices); DataFrameColumn otherColumn = other.Columns[rightJoinColumn]; for (long i = 0; i < otherColumn.Length; i++) { var otherColumnValue = otherColumn[i]; - TKey otherColumnValueOrDefault = (TKey)(otherColumnValue == null ? default(TKey) : otherColumnValue); - if (multimap.TryGetValue(otherColumnValueOrDefault, out ICollection rowNumbers)) + if (otherColumnValue != null) { - foreach (long row in rowNumbers) + if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection rowNumbers)) { - if (otherColumnValue == null) + foreach (long row in rowNumbers) { - if (thisColumn[row] == null) - { - leftRowIndices.Append(row); - rightRowIndices.Append(i); - } - } - else - { - if (thisColumn[row] != null) - { - leftRowIndices.Append(row); - rightRowIndices.Append(i); - } + leftRowIndices.Append(row); + rightRowIndices.Append(i); } } + else + { + leftRowIndices.Append(null); + rightRowIndices.Append(i); + } } else { - leftRowIndices.Append(null); - rightRowIndices.Append(i); + foreach (long thisColumnNullIndex in thisColumnNullIndices) + { + leftRowIndices.Append(thisColumnNullIndex); + rightRowIndices.Append(i); + } } } } @@ -289,63 +279,107 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ else if (joinAlgorithm == JoinAlgorithm.FullOuter) { DataFrameColumn otherColumn = other.Columns[rightJoinColumn]; - Dictionary> multimap = otherColumn.GroupColumnValues(); + Dictionary> multimap = otherColumn.GroupColumnValues(out HashSet otherColumnNullIndices); Dictionary intersection = new Dictionary(EqualityComparer.Default); // Go over the records in this dataframe and match with the dictionary DataFrameColumn thisColumn = Columns[leftJoinColumn]; + Int64DataFrameColumn thisColumnNullIndices = new Int64DataFrameColumn("ThisColumnNullIndices"); for (long i = 0; i < thisColumn.Length; i++) { var thisColumnValue = thisColumn[i]; - TKey thisColumnValueOrDefault = (TKey)(thisColumnValue == null ? default(TKey) : thisColumnValue); - if (multimap.TryGetValue(thisColumnValueOrDefault, out ICollection rowNumbers)) + if (thisColumnValue != null) { - foreach (long row in rowNumbers) + if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection rowNumbers)) { - if (thisColumnValue == null) + foreach (long row in rowNumbers) { - // Has to match only with nulls in otherColumn - if (otherColumn[row] == null) + leftRowIndices.Append(i); + rightRowIndices.Append(row); + if (!intersection.ContainsKey((TKey)thisColumnValue)) { - leftRowIndices.Append(i); - rightRowIndices.Append(row); - if (!intersection.ContainsKey(thisColumnValueOrDefault)) - { - intersection.Add(thisColumnValueOrDefault, rowNumber); - } - } - } - else - { - // Cannot match to nulls in otherColumn - if (otherColumn[row] != null) - { - leftRowIndices.Append(i); - rightRowIndices.Append(row); - if (!intersection.ContainsKey(thisColumnValueOrDefault)) - { - intersection.Add(thisColumnValueOrDefault, rowNumber); - } + intersection.Add((TKey)thisColumnValue, rowNumber); } } } + else + { + leftRowIndices.Append(i); + rightRowIndices.Append(null); + } } else { - leftRowIndices.Append(i); - rightRowIndices.Append(null); + thisColumnNullIndices.Append(i); } } for (long i = 0; i < otherColumn.Length; i++) { - TKey value = (TKey)(otherColumn[i] ?? default(TKey)); - if (!intersection.ContainsKey(value)) + var value = otherColumn[i]; + if (value != null) + { + if (!intersection.ContainsKey((TKey)value)) + { + leftRowIndices.Append(null); + rightRowIndices.Append(i); + } + } + } + + // Now handle the null rows + foreach (long? thisColumnNullIndex in thisColumnNullIndices) + { + foreach (long otherColumnNullIndex in otherColumnNullIndices) + { + leftRowIndices.Append(thisColumnNullIndex.Value); + rightRowIndices.Append(otherColumnNullIndex); + } + if (otherColumnNullIndices.Count == 0) + { + leftRowIndices.Append(thisColumnNullIndex.Value); + rightRowIndices.Append(null); + } + } + if (thisColumnNullIndices.Length == 0) + { + foreach (long otherColumnNullIndex in otherColumnNullIndices) { leftRowIndices.Append(null); - rightRowIndices.Append(i); + rightRowIndices.Append(otherColumnNullIndex); } } + + //// Now handle the null rows + //IEnumerator thisColumnNullIndicesEnumerator = thisColumnNullIndices.GetEnumerator(); + //HashSet.Enumerator otherColumnNullIndicesEnumerator = otherColumnNullIndices.GetEnumerator(); + //while (thisColumnNullIndicesEnumerator.MoveNext() && otherColumnNullIndicesEnumerator.MoveNext()) + //{ + // long? thisColumnNullIndex = thisColumnNullIndicesEnumerator.Current; + // long otherColumnNullIndex = otherColumnNullIndicesEnumerator.Current; + // leftRowIndices.Append(thisColumnNullIndex); + // rightRowIndices.Append(otherColumnNullIndex); + //} + //while (!otherColumnNullIndicesEnumerator.MoveNext()) + //{ + // long? thisColumnNullIndex = thisColumnNullIndicesEnumerator.Current; + // leftRowIndices.Append(thisColumnNullIndex); + // rightRowIndices.Append(null); + // if (!thisColumnNullIndicesEnumerator.MoveNext()) + // { + // break; + // } + //} + //while (!thisColumnNullIndicesEnumerator.MoveNext()) + //{ + // long otherColumnNullIndex = otherColumnNullIndicesEnumerator.Current; + // leftRowIndices.Append(null); + // rightRowIndices.Append(otherColumnNullIndex); + // if (!otherColumnNullIndicesEnumerator.MoveNext()) + // { + // break; + // } + //} } else throw new NotImplementedException(nameof(joinAlgorithm)); diff --git a/src/Microsoft.Data.Analysis/DataFrameColumn.cs b/src/Microsoft.Data.Analysis/DataFrameColumn.cs index bd21d6fe96..67e79cc300 100644 --- a/src/Microsoft.Data.Analysis/DataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/DataFrameColumn.cs @@ -203,7 +203,12 @@ public virtual DataFrameColumn Sort(bool ascending = true) return Clone(sortIndices, !ascending, NullCount); } - public virtual Dictionary> GroupColumnValues() => throw new NotImplementedException(); + /// + /// Groups the rows of this column by their value. + /// + /// The type of data held by this column + /// A mapping of value() to the indices containing this value + public virtual Dictionary> GroupColumnValues(out HashSet nullIndices) => throw new NotImplementedException(); /// /// Returns a DataFrame containing counts of unique values diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index a7e7d20cb9..f063a66695 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -313,7 +313,7 @@ protected override DataFrameColumn FillNullsImplementation(object value, bool in public override DataFrame ValueCounts() { - Dictionary> groupedValues = GroupColumnValues(); + Dictionary> groupedValues = GroupColumnValues(out HashSet _); PrimitiveDataFrameColumn keys = new PrimitiveDataFrameColumn("Values"); PrimitiveDataFrameColumn counts = new PrimitiveDataFrameColumn("Counts"); foreach (KeyValuePair> keyValuePair in groupedValues) @@ -520,32 +520,41 @@ internal SingleDataFrameColumn CloneAsSingleColumn() /// public override GroupBy GroupBy(int columnIndex, DataFrame parent) { - Dictionary> dictionary = GroupColumnValues(); + Dictionary> dictionary = GroupColumnValues(out HashSet _); return new GroupBy(parent, columnIndex, dictionary); } - public override Dictionary> GroupColumnValues() + public override Dictionary> GroupColumnValues(out HashSet nullIndices) { if (typeof(TKey) == typeof(T)) { Dictionary> multimap = new Dictionary>(EqualityComparer.Default); + nullIndices = new HashSet(); for (int b = 0; b < _columnContainer.Buffers.Count; b++) { ReadOnlyDataFrameBuffer buffer = _columnContainer.Buffers[b]; ReadOnlySpan readOnlySpan = buffer.ReadOnlySpan; + ReadOnlySpan nullBitMapSpan = _columnContainer.NullBitMapBuffers[b].ReadOnlySpan; long previousLength = b * ReadOnlyDataFrameBuffer.MaxCapacity; for (int i = 0; i < readOnlySpan.Length; i++) { long currentLength = i + previousLength; - bool containsKey = multimap.TryGetValue(readOnlySpan[i], out ICollection values); - if (containsKey) + if (_columnContainer.IsValid(nullBitMapSpan, i)) { - values.Add(currentLength); + bool containsKey = multimap.TryGetValue(readOnlySpan[i], out ICollection values); + if (containsKey) + { + values.Add(currentLength); + } + else + { + multimap.Add(readOnlySpan[i], new List() { currentLength }); + } } else { - multimap.Add(readOnlySpan[i], new List() { currentLength }); - } + nullIndices.Add(currentLength); + } } } return multimap as Dictionary>; diff --git a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs index 7ada30e10c..7fb7f0c2ec 100644 --- a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs @@ -398,31 +398,40 @@ internal static DataFrame ValueCountsImplementation(Dictionary> groupedValues = GroupColumnValues(); + Dictionary> groupedValues = GroupColumnValues(out HashSet _); return ValueCountsImplementation(groupedValues); } public override GroupBy GroupBy(int columnIndex, DataFrame parent) { - Dictionary> dictionary = GroupColumnValues(); + Dictionary> dictionary = GroupColumnValues(out HashSet _); return new GroupBy(parent, columnIndex, dictionary); } - public override Dictionary> GroupColumnValues() + public override Dictionary> GroupColumnValues(out HashSet nullIndices) { if (typeof(TKey) == typeof(string)) { Dictionary> multimap = new Dictionary>(EqualityComparer.Default); + nullIndices = new HashSet(); for (long i = 0; i < Length; i++) { - bool containsKey = multimap.TryGetValue(this[i] ?? default, out ICollection values); - if (containsKey) + string str = this[i]; + if (str != null) { - values.Add(i); + bool containsKey = multimap.TryGetValue(this[i], out ICollection values); + if (containsKey) + { + values.Add(i); + } + else + { + multimap.Add(this[i] ?? default, new List() { i }); + } } else { - multimap.Add(this[i] ?? default, new List() { i }); + nullIndices.Add(i); } } return multimap as Dictionary>; diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 72072fd533..5188b81625 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -1669,6 +1669,134 @@ public void TestMerge() VerifyMerge(merge, left, right, JoinAlgorithm.Inner); } + private void MatchRowsOnMergedDataFrame(DataFrame merge, DataFrame left, DataFrame right, long mergeRow, long? leftRow, long? rightRow) + { + Assert.Equal(merge.Columns.Count, left.Columns.Count + right.Columns.Count); + DataFrameRow dataFrameMergeRow = merge.Rows[mergeRow]; + int columnIndex = 0; + foreach (object value in dataFrameMergeRow) + { + object compare = null; + if (columnIndex < left.Columns.Count) + { + if (leftRow != null) + { + compare = left.Rows[leftRow.Value][columnIndex]; + } + } + else + { + int rightColumnIndex = columnIndex - left.Columns.Count; + if (rightRow != null) + { + compare = right.Rows[rightRow.Value][rightColumnIndex]; + } + } + Assert.Equal(value, compare); + columnIndex++; + } + } + + [Theory] + [InlineData(10, 5, JoinAlgorithm.Left)] + [InlineData(5, 10, JoinAlgorithm.Right)] + public void TestMergeEdgeCases_Left(int leftLength, int rightLength, JoinAlgorithm joinAlgorithm) + { + DataFrame left = MakeDataFrameWithAllMutableColumnTypes(leftLength); + if (leftLength > 5) + { + left["Int"][8] = null; + } + DataFrame right = MakeDataFrameWithAllMutableColumnTypes(rightLength); + if (rightLength > 5) + { + right["Int"][8] = null; + } + + DataFrame merge = left.Merge(right, "Int", "Int", joinAlgorithm: joinAlgorithm); + Assert.Equal(10, merge.Rows.Count); + Assert.Equal(merge.Columns.Count, left.Columns.Count + right.Columns.Count); + int[] matchedFullRows = new int[] { 0, 1, 3, 4 }; + for (long i = 0; i < matchedFullRows.Length; i++) + { + int rowIndex = matchedFullRows[i]; + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, rowIndex, rowIndex); + } + + int[] matchedLeftOrRightRowsNullOtherRows = new int[] { 2, 5, 6, 7, 8, 9 }; + for (long i = 0; i < matchedLeftOrRightRowsNullOtherRows.Length; i++) + { + int rowIndex = matchedLeftOrRightRowsNullOtherRows[i]; + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, rowIndex, null); + } + } + + [Fact] + public void TestMergeEdgeCases_Outer() + { + DataFrame left = MakeDataFrameWithAllMutableColumnTypes(5); + left["Int"][3] = null; + DataFrame right = MakeDataFrameWithAllMutableColumnTypes(5); + // Creates this case: + /* + * Left: Right: + * 0 0 + * 1 5 + * null(2) null(7) + * null(3) null(8) + * 4 6 + */ + /* + * Merge will result in a DataFrame like: + * Int_Left Int_Right + * 0 0 + * 1 null + * 4 null + * null 5 + * null 6 + * null(2) null(7) + * null(2) null(8) + * null(3) null(7) + * null(3) null(8) + */ + right["Int"][1] = 5; + right["Int"][3] = null; + right["Int"][4] = 6; + + DataFrame merge = left.Merge(right, "Int", "Int", joinAlgorithm: JoinAlgorithm.FullOuter); + Assert.Equal(9, merge.Rows.Count); + Assert.Equal(merge.Columns.Count, left.Columns.Count + right.Columns.Count); + + int[] mergeRows = new int[] { 0, 5, 6, 7, 8 }; + int[] leftRows = new int[] { 0, 2, 2, 3, 3 }; + int[] rightRows = new int[] { 0, 2, 3, 2, 3 }; + for (long i = 0; i < mergeRows.Length; i++) + { + int rowIndex = mergeRows[i]; + int leftRowIndex = leftRows[i]; + int rightRowIndex = rightRows[i]; + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, leftRowIndex, rightRowIndex); + } + + mergeRows = new int[] { 1, 2 }; + leftRows = new int[] { 1, 4 }; + for (long i = 0; i < mergeRows.Length; i++) + { + int rowIndex = mergeRows[i]; + int leftRowIndex = leftRows[i]; + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, leftRowIndex, null); + } + + mergeRows = new int[] { 3, 4 }; + rightRows = new int[] { 1, 4 }; + for (long i = 0; i < mergeRows.Length; i++) + { + int rowIndex = mergeRows[i]; + int rightRowIndex = rightRows[i]; + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, null, rightRowIndex); + } + } + [Fact] public void TestDescription() { From f0c0daa8ed70eb3e3ac3d99f779cc91a40409802 Mon Sep 17 00:00:00 2001 From: Prashanth Govindarajan Date: Fri, 30 Apr 2021 16:41:16 -0700 Subject: [PATCH 2/4] sq --- src/Microsoft.Data.Analysis/DataFrame.Join.cs | 34 +++++----- .../DataFrameTests.cs | 66 +++++++++++++++++-- 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.Data.Analysis/DataFrame.Join.cs b/src/Microsoft.Data.Analysis/DataFrame.Join.cs index a8bc9bfc20..c06759f4a8 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.Join.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.Join.cs @@ -243,37 +243,33 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ long leftRowCount = Rows.Count; long rightRowCount = other.Rows.Count; - var leftColumnIsSmaller = (leftRowCount <= rightRowCount); + bool leftColumnIsSmaller = leftRowCount <= rightRowCount; DataFrameColumn hashColumn = leftColumnIsSmaller ? Columns[leftJoinColumn] : other.Columns[rightJoinColumn]; DataFrameColumn otherColumn = ReferenceEquals(hashColumn, Columns[leftJoinColumn]) ? other.Columns[rightJoinColumn] : Columns[leftJoinColumn]; - Dictionary> multimap = hashColumn.GroupColumnValues(); + Dictionary> multimap = hashColumn.GroupColumnValues(out HashSet smallerDataFrameColumnNullIndices); for (long i = 0; i < otherColumn.Length; i++) { var otherColumnValue = otherColumn[i]; - TKey otherColumnValueOrDefault = (TKey)(otherColumnValue == null ? default(TKey) : otherColumnValue); - if (multimap.TryGetValue(otherColumnValueOrDefault, out ICollection rowNumbers)) + if (otherColumnValue != null) { - foreach (long row in rowNumbers) + if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection rowNumbers)) { - if (otherColumnValue == null) - { - if (hashColumn[row] == null) - { - leftRowIndices.Append(leftColumnIsSmaller ? row : i); - rightRowIndices.Append(leftColumnIsSmaller ? i : row); - } - } - else + foreach (long row in rowNumbers) { - if (hashColumn[row] != null) - { - leftRowIndices.Append(leftColumnIsSmaller ? row : i); - rightRowIndices.Append(leftColumnIsSmaller ? i : row); - } + leftRowIndices.Append(leftColumnIsSmaller ? row : i); + rightRowIndices.Append(leftColumnIsSmaller ? i : row); } } } + else + { + foreach (long nullIndex in smallerDataFrameColumnNullIndices) + { + leftRowIndices.Append(leftColumnIsSmaller ? nullIndex : i); + rightRowIndices.Append(leftColumnIsSmaller ? i : nullIndex); + } + } } } else if (joinAlgorithm == JoinAlgorithm.FullOuter) diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 5188b81625..4b05b0852f 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -1152,7 +1152,7 @@ public void TestGroupBy() if (originalColumn.Name == "Bool") continue; DataFrameColumn headColumn = head.Columns[originalColumn.Name]; - Assert.Equal(originalColumn[5], headColumn[verify[5]]); + Assert.Equal(originalColumn[7], headColumn[verify[5]]); } Assert.Equal(6, head.Rows.Count); @@ -1569,14 +1569,14 @@ public void TestSample() // all sampled rows should be unique. HashSet uniqueRowValues = new HashSet(); - foreach(int? value in sampled.Columns["Int"]) + foreach (int? value in sampled.Columns["Int"]) { uniqueRowValues.Add(value); } Assert.Equal(uniqueRowValues.Count, sampled.Rows.Count); // should throw exception as sample size is greater than dataframe rows - Assert.Throws(()=> df.Sample(13)); + Assert.Throws(() => df.Sample(13)); } [Theory] @@ -1658,7 +1658,7 @@ public void TestMerge() Assert.Equal(16, merge.Rows.Count); Assert.Equal(merge.Columns.Count, left.Columns.Count + right.Columns.Count); Assert.Null(merge.Columns["Int_left"][12]); - Assert.Null(merge.Columns["Int_left"][5]); + Assert.Null(merge.Columns["Int_left"][15]); VerifyMerge(merge, left, right, JoinAlgorithm.FullOuter); // Inner merge @@ -1727,7 +1727,63 @@ public void TestMergeEdgeCases_Left(int leftLength, int rightLength, JoinAlgorit for (long i = 0; i < matchedLeftOrRightRowsNullOtherRows.Length; i++) { int rowIndex = matchedLeftOrRightRowsNullOtherRows[i]; - MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, rowIndex, null); + if (leftLength > 5) + { + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, rowIndex, null); + } + else + { + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, null, rowIndex); + } + } + } + + [Fact] + public void TestMergeEdgeCases_Inner() + { + DataFrame left = MakeDataFrameWithAllMutableColumnTypes(5); + DataFrame right = MakeDataFrameWithAllMutableColumnTypes(10); + left["Int"][3] = null; + right["Int"][6] = null; + // Creates this case: + /* + * Left: Right: + * 0 0 + * 1 1 + * null(2) 2 + * null(3) 3 + * 4 4 + * null(5) + * null(6) + * 7 + * 8 + * 9 + */ + /* + * Merge will result in a DataFrame like: + * Int_Left Int_Right + * 0 0 + * 1 1 + * 4 4 + * null(2) null(5) + * null(3) null(5) + * null(2) null(6) + * null(3) null(6) + */ + + DataFrame merge = left.Merge(right, "Int", "Int", joinAlgorithm: JoinAlgorithm.Inner); + Assert.Equal(7, merge.Rows.Count); + Assert.Equal(merge.Columns.Count, left.Columns.Count + right.Columns.Count); + + int[] mergeRows = new int[] { 0, 1, 2, 3, 4, 5, 6 }; + int[] leftRows = new int[] { 0, 1, 4, 2, 3, 2, 3 }; + int[] rightRows = new int[] { 0, 1, 4, 5, 5, 6, 6 }; + for (long i = 0; i < mergeRows.Length; i++) + { + int rowIndex = mergeRows[i]; + int leftRowIndex = leftRows[i]; + int rightRowIndex = rightRows[i]; + MatchRowsOnMergedDataFrame(merge, left, right, rowIndex, leftRowIndex, rightRowIndex); } } From a3d211f65f8b9f538678219da778b63d0e3c4902 Mon Sep 17 00:00:00 2001 From: Prashanth Govindarajan Date: Fri, 30 Apr 2021 16:51:31 -0700 Subject: [PATCH 3/4] sq --- src/Microsoft.Data.Analysis/DataFrame.Join.cs | 31 ------------------- .../PrimitiveDataFrameColumn.cs | 2 +- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/src/Microsoft.Data.Analysis/DataFrame.Join.cs b/src/Microsoft.Data.Analysis/DataFrame.Join.cs index c06759f4a8..da99e6254f 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.Join.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.Join.cs @@ -345,37 +345,6 @@ public DataFrame Merge(DataFrame other, string leftJoinColumn, string righ rightRowIndices.Append(otherColumnNullIndex); } } - - //// Now handle the null rows - //IEnumerator thisColumnNullIndicesEnumerator = thisColumnNullIndices.GetEnumerator(); - //HashSet.Enumerator otherColumnNullIndicesEnumerator = otherColumnNullIndices.GetEnumerator(); - //while (thisColumnNullIndicesEnumerator.MoveNext() && otherColumnNullIndicesEnumerator.MoveNext()) - //{ - // long? thisColumnNullIndex = thisColumnNullIndicesEnumerator.Current; - // long otherColumnNullIndex = otherColumnNullIndicesEnumerator.Current; - // leftRowIndices.Append(thisColumnNullIndex); - // rightRowIndices.Append(otherColumnNullIndex); - //} - //while (!otherColumnNullIndicesEnumerator.MoveNext()) - //{ - // long? thisColumnNullIndex = thisColumnNullIndicesEnumerator.Current; - // leftRowIndices.Append(thisColumnNullIndex); - // rightRowIndices.Append(null); - // if (!thisColumnNullIndicesEnumerator.MoveNext()) - // { - // break; - // } - //} - //while (!thisColumnNullIndicesEnumerator.MoveNext()) - //{ - // long otherColumnNullIndex = otherColumnNullIndicesEnumerator.Current; - // leftRowIndices.Append(null); - // rightRowIndices.Append(otherColumnNullIndex); - // if (!otherColumnNullIndicesEnumerator.MoveNext()) - // { - // break; - // } - //} } else throw new NotImplementedException(nameof(joinAlgorithm)); diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index f063a66695..10f1627692 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -554,7 +554,7 @@ public override Dictionary> GroupColumnValues(out else { nullIndices.Add(currentLength); - } + } } } return multimap as Dictionary>; From 822385c0dd0000a6a786b492a24b8d0880a4acd7 Mon Sep 17 00:00:00 2001 From: Prashanth Govindarajan Date: Wed, 5 May 2021 10:43:47 -0700 Subject: [PATCH 4/4] Add unit test --- .../StringDataFrameColumn.cs | 4 ++-- .../DataFrameTests.cs | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs index 7fb7f0c2ec..761fbcda6b 100644 --- a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs @@ -419,14 +419,14 @@ public override Dictionary> GroupColumnValues(out string str = this[i]; if (str != null) { - bool containsKey = multimap.TryGetValue(this[i], out ICollection values); + bool containsKey = multimap.TryGetValue(str, out ICollection values); if (containsKey) { values.Add(i); } else { - multimap.Add(this[i] ?? default, new List() { i }); + multimap.Add(str, new List() { i }); } } else diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 4b05b0852f..348aeeea97 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -1700,7 +1700,7 @@ private void MatchRowsOnMergedDataFrame(DataFrame merge, DataFrame left, DataFra [Theory] [InlineData(10, 5, JoinAlgorithm.Left)] [InlineData(5, 10, JoinAlgorithm.Right)] - public void TestMergeEdgeCases_Left(int leftLength, int rightLength, JoinAlgorithm joinAlgorithm) + public void TestMergeEdgeCases_LeftOrRight(int leftLength, int rightLength, JoinAlgorithm joinAlgorithm) { DataFrame left = MakeDataFrameWithAllMutableColumnTypes(leftLength); if (leftLength > 5) @@ -1853,6 +1853,21 @@ public void TestMergeEdgeCases_Outer() } } + [Fact] + public void TestMerge_Issue5778() + { + DataFrame left = MakeDataFrameWithAllMutableColumnTypes(2, false); + DataFrame right = MakeDataFrameWithAllMutableColumnTypes(1); + + DataFrame merge = left.Merge(right, "Int", "Int"); + + Assert.Equal(2, merge.Rows.Count); + Assert.Equal(0, (int)merge.Columns["Int_left"][0]); + Assert.Equal(1, (int)merge.Columns["Int_left"][1]); + MatchRowsOnMergedDataFrame(merge, left, right, 0, 0, 0); + MatchRowsOnMergedDataFrame(merge, left, right, 1, 1, 0); + } + [Fact] public void TestDescription() {