|
6 | 6 | from numpy import nan |
7 | 7 | import pytest |
8 | 8 |
|
9 | | -from pandas import DataFrame, MultiIndex, Series, concat, merge, to_datetime |
| 9 | +from pandas import ( |
| 10 | + DataFrame, MultiIndex, Series, array, concat, merge, to_datetime) |
10 | 11 | from pandas.core import common as com |
11 | 12 | from pandas.core.sorting import ( |
12 | 13 | decons_group_index, get_group_index, is_int64_overflow_possible, |
@@ -358,34 +359,43 @@ def test_basic_sort(self): |
358 | 359 | expected = np.array([]) |
359 | 360 | tm.assert_numpy_array_equal(result, expected) |
360 | 361 |
|
361 | | - def test_labels(self): |
| 362 | + @pytest.mark.parametrize('verify', [True, False]) |
| 363 | + def test_labels(self, verify): |
362 | 364 | values = [3, 1, 2, 0, 4] |
363 | 365 | expected = np.array([0, 1, 2, 3, 4]) |
364 | 366 |
|
365 | 367 | labels = [0, 1, 1, 2, 3, 0, -1, 4] |
366 | | - result, result_labels = safe_sort(values, labels) |
| 368 | + result, result_labels = safe_sort(values, labels, verify=verify) |
367 | 369 | expected_labels = np.array([3, 1, 1, 2, 0, 3, -1, 4], dtype=np.intp) |
368 | 370 | tm.assert_numpy_array_equal(result, expected) |
369 | 371 | tm.assert_numpy_array_equal(result_labels, expected_labels) |
370 | 372 |
|
371 | 373 | # na_sentinel |
372 | 374 | labels = [0, 1, 1, 2, 3, 0, 99, 4] |
373 | | - result, result_labels = safe_sort(values, labels, |
374 | | - na_sentinel=99) |
| 375 | + result, result_labels = safe_sort(values, labels, na_sentinel=99, |
| 376 | + verify=verify) |
375 | 377 | expected_labels = np.array([3, 1, 1, 2, 0, 3, 99, 4], dtype=np.intp) |
376 | 378 | tm.assert_numpy_array_equal(result, expected) |
377 | 379 | tm.assert_numpy_array_equal(result_labels, expected_labels) |
378 | 380 |
|
379 | | - # out of bound indices |
380 | | - labels = [0, 101, 102, 2, 3, 0, 99, 4] |
381 | | - result, result_labels = safe_sort(values, labels) |
382 | | - expected_labels = np.array([3, -1, -1, 2, 0, 3, -1, 4], dtype=np.intp) |
| 381 | + labels = [] |
| 382 | + result, result_labels = safe_sort(values, labels, verify=verify) |
| 383 | + expected_labels = np.array([], dtype=np.intp) |
383 | 384 | tm.assert_numpy_array_equal(result, expected) |
384 | 385 | tm.assert_numpy_array_equal(result_labels, expected_labels) |
385 | 386 |
|
386 | | - labels = [] |
387 | | - result, result_labels = safe_sort(values, labels) |
388 | | - expected_labels = np.array([], dtype=np.intp) |
| 387 | + @pytest.mark.parametrize('na_sentinel', [-1, 99]) |
| 388 | + def test_labels_out_of_bound(self, na_sentinel): |
| 389 | + values = [3, 1, 2, 0, 4] |
| 390 | + expected = np.array([0, 1, 2, 3, 4]) |
| 391 | + |
| 392 | + # out of bound indices |
| 393 | + labels = [0, 101, 102, 2, 3, 0, 99, 4] |
| 394 | + result, result_labels = safe_sort( |
| 395 | + values, labels, na_sentinel=na_sentinel) |
| 396 | + expected_labels = np.array( |
| 397 | + [3, na_sentinel, na_sentinel, 2, 0, 3, na_sentinel, 4], |
| 398 | + dtype=np.intp) |
389 | 399 | tm.assert_numpy_array_equal(result, expected) |
390 | 400 | tm.assert_numpy_array_equal(result_labels, expected_labels) |
391 | 401 |
|
@@ -430,3 +440,22 @@ def test_exceptions(self): |
430 | 440 | with pytest.raises(ValueError, |
431 | 441 | match="values should be unique"): |
432 | 442 | safe_sort(values=[0, 1, 2, 1], labels=[0, 1]) |
| 443 | + |
| 444 | + def test_extension_array(self): |
| 445 | + # a = array([1, 3, np.nan, 2], dtype='Int64') |
| 446 | + a = array([1, 3, 2], dtype='Int64') |
| 447 | + result = safe_sort(a) |
| 448 | + # expected = array([1, 2, 3, np.nan], dtype='Int64') |
| 449 | + expected = array([1, 2, 3], dtype='Int64') |
| 450 | + tm.assert_extension_array_equal(result, expected) |
| 451 | + |
| 452 | + @pytest.mark.parametrize('verify', [True, False]) |
| 453 | + @pytest.mark.parametrize('na_sentinel', [-1, 99]) |
| 454 | + def test_extension_array_labels(self, verify, na_sentinel): |
| 455 | + a = array([1, 3, 2], dtype='Int64') |
| 456 | + result, labels = safe_sort(a, [0, 1, na_sentinel, 2], |
| 457 | + na_sentinel=na_sentinel, verify=verify) |
| 458 | + expected_values = array([1, 2, 3], dtype='Int64') |
| 459 | + expected_labels = np.array([0, 2, na_sentinel, 1], dtype=np.intp) |
| 460 | + tm.assert_extension_array_equal(result, expected_values) |
| 461 | + tm.assert_numpy_array_equal(labels, expected_labels) |
0 commit comments