Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 130 additions & 3 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def test_union_array_slice():
assert arr[i:j].to_pylist() == lst[i:j]


def _check_cast_case(case, safe=True):
def _check_cast_case(case, *, safe=True, check_array_construction=True):
in_data, in_type, out_data, out_type = case
if isinstance(out_data, pa.Array):
assert out_data.type == out_type
Expand All @@ -884,8 +884,9 @@ def _check_cast_case(case, safe=True):

# constructing an array with out type which optionally involves casting
# for more see ARROW-1949
in_arr = pa.array(in_data, type=out_type, safe=safe)
assert in_arr.equals(expected)
if check_array_construction:
in_arr = pa.array(in_data, type=out_type, safe=safe)
assert in_arr.equals(expected)


def test_cast_integers_safe():
Expand Down Expand Up @@ -1011,6 +1012,132 @@ def test_floating_point_truncate_unsafe():
_check_cast_case(case, safe=False)


def test_decimal_to_int_safe():
safe_cases = [
(
[decimal.Decimal("123456"), None, decimal.Decimal("-912345")],
pa.decimal128(32, 5),
[123456, None, -912345],
pa.int32()
),
(
[decimal.Decimal("1234"), None, decimal.Decimal("-9123")],
pa.decimal128(19, 10),
[1234, None, -9123],
pa.int16()
),
(
[decimal.Decimal("123"), None, decimal.Decimal("-91")],
pa.decimal128(19, 10),
[123, None, -91],
pa.int8()
),
]
for case in safe_cases:
_check_cast_case(case)
_check_cast_case(case, safe=True)


def test_decimal_to_int_value_out_of_bounds():
out_of_bounds_cases = [
(
np.array([
decimal.Decimal("1234567890123"),
None,
decimal.Decimal("-912345678901234")
]),
pa.decimal128(32, 5),
[1912276171, None, -135950322],
pa.int32()
),
(
[decimal.Decimal("123456"), None, decimal.Decimal("-912345678")],
pa.decimal128(32, 5),
[-7616, None, -19022],
pa.int16()
),
(
[decimal.Decimal("1234"), None, decimal.Decimal("-9123")],
pa.decimal128(32, 5),
[-46, None, 93],
pa.int8()
),
]

for case in out_of_bounds_cases:
# test safe casting raises
with pytest.raises(pa.ArrowInvalid,
match='Integer value out of bounds'):
_check_cast_case(case)

# XXX `safe=False` can be ignored when constructing an array
# from a sequence of Python objects (ARROW-8567)
_check_cast_case(case, safe=False, check_array_construction=False)


def test_decimal_to_int_non_integer():
non_integer_cases = [
(
[
decimal.Decimal("123456.21"),
None,
decimal.Decimal("-912345.13")
],
pa.decimal128(32, 5),
[123456, None, -912345],
pa.int32()
),
(
[decimal.Decimal("1234.134"), None, decimal.Decimal("-9123.1")],
pa.decimal128(19, 10),
[1234, None, -9123],
pa.int16()
),
(
[decimal.Decimal("123.1451"), None, decimal.Decimal("-91.21")],
pa.decimal128(19, 10),
[123, None, -91],
pa.int8()
),
]

for case in non_integer_cases:
# test safe casting raises
msg_regexp = 'Rescaling decimal value would cause data loss'
with pytest.raises(pa.ArrowInvalid, match=msg_regexp):
_check_cast_case(case)

_check_cast_case(case, safe=False)


def test_decimal_to_decimal():
arr = pa.array(
[decimal.Decimal("1234.12"), None],
type=pa.decimal128(19, 10)
)
result = arr.cast(pa.decimal128(15, 6))
expected = pa.array(
[decimal.Decimal("1234.12"), None],
type=pa.decimal128(15, 6)
)
assert result.equals(expected)

with pytest.raises(pa.ArrowInvalid,
match='Rescaling decimal value would cause data loss'):
result = arr.cast(pa.decimal128(9, 1))

result = arr.cast(pa.decimal128(9, 1), safe=False)
expected = pa.array(
[decimal.Decimal("1234.1"), None],
type=pa.decimal128(9, 1)
)
assert result.equals(expected)

# TODO FIXME
# this should fail but decimal overflow is not implemented
result = arr.cast(pa.decimal128(1, 2))


def test_safe_cast_nan_to_int_raises():
arr = pa.array([np.nan, 1.])

Expand Down