diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 9648b7ebabf..09f6d8e884b 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -864,7 +864,7 @@ def test_union_array_slice(): assert arr[i:j].to_pylist() == lst[i:j] -def _check_cast_case(case, safe=True): +def _check_cast_case(case, *, safe=True, check_array_construction=True): in_data, in_type, out_data, out_type = case if isinstance(out_data, pa.Array): assert out_data.type == out_type @@ -884,8 +884,9 @@ def _check_cast_case(case, safe=True): # constructing an array with out type which optionally involves casting # for more see ARROW-1949 - in_arr = pa.array(in_data, type=out_type, safe=safe) - assert in_arr.equals(expected) + if check_array_construction: + in_arr = pa.array(in_data, type=out_type, safe=safe) + assert in_arr.equals(expected) def test_cast_integers_safe(): @@ -1011,6 +1012,132 @@ def test_floating_point_truncate_unsafe(): _check_cast_case(case, safe=False) +def test_decimal_to_int_safe(): + safe_cases = [ + ( + [decimal.Decimal("123456"), None, decimal.Decimal("-912345")], + pa.decimal128(32, 5), + [123456, None, -912345], + pa.int32() + ), + ( + [decimal.Decimal("1234"), None, decimal.Decimal("-9123")], + pa.decimal128(19, 10), + [1234, None, -9123], + pa.int16() + ), + ( + [decimal.Decimal("123"), None, decimal.Decimal("-91")], + pa.decimal128(19, 10), + [123, None, -91], + pa.int8() + ), + ] + for case in safe_cases: + _check_cast_case(case) + _check_cast_case(case, safe=True) + + +def test_decimal_to_int_value_out_of_bounds(): + out_of_bounds_cases = [ + ( + np.array([ + decimal.Decimal("1234567890123"), + None, + decimal.Decimal("-912345678901234") + ]), + pa.decimal128(32, 5), + [1912276171, None, -135950322], + pa.int32() + ), + ( + [decimal.Decimal("123456"), None, decimal.Decimal("-912345678")], + pa.decimal128(32, 5), + [-7616, None, -19022], + pa.int16() + ), + ( + [decimal.Decimal("1234"), None, decimal.Decimal("-9123")], + pa.decimal128(32, 5), + [-46, None, 93], + pa.int8() + ), + ] + + for case in out_of_bounds_cases: + # test safe casting raises + with pytest.raises(pa.ArrowInvalid, + match='Integer value out of bounds'): + _check_cast_case(case) + + # XXX `safe=False` can be ignored when constructing an array + # from a sequence of Python objects (ARROW-8567) + _check_cast_case(case, safe=False, check_array_construction=False) + + +def test_decimal_to_int_non_integer(): + non_integer_cases = [ + ( + [ + decimal.Decimal("123456.21"), + None, + decimal.Decimal("-912345.13") + ], + pa.decimal128(32, 5), + [123456, None, -912345], + pa.int32() + ), + ( + [decimal.Decimal("1234.134"), None, decimal.Decimal("-9123.1")], + pa.decimal128(19, 10), + [1234, None, -9123], + pa.int16() + ), + ( + [decimal.Decimal("123.1451"), None, decimal.Decimal("-91.21")], + pa.decimal128(19, 10), + [123, None, -91], + pa.int8() + ), + ] + + for case in non_integer_cases: + # test safe casting raises + msg_regexp = 'Rescaling decimal value would cause data loss' + with pytest.raises(pa.ArrowInvalid, match=msg_regexp): + _check_cast_case(case) + + _check_cast_case(case, safe=False) + + +def test_decimal_to_decimal(): + arr = pa.array( + [decimal.Decimal("1234.12"), None], + type=pa.decimal128(19, 10) + ) + result = arr.cast(pa.decimal128(15, 6)) + expected = pa.array( + [decimal.Decimal("1234.12"), None], + type=pa.decimal128(15, 6) + ) + assert result.equals(expected) + + with pytest.raises(pa.ArrowInvalid, + match='Rescaling decimal value would cause data loss'): + result = arr.cast(pa.decimal128(9, 1)) + + result = arr.cast(pa.decimal128(9, 1), safe=False) + expected = pa.array( + [decimal.Decimal("1234.1"), None], + type=pa.decimal128(9, 1) + ) + assert result.equals(expected) + + # TODO FIXME + # this should fail but decimal overflow is not implemented + result = arr.cast(pa.decimal128(1, 2)) + + def test_safe_cast_nan_to_int_raises(): arr = pa.array([np.nan, 1.])