diff --git a/python/src/types.rs b/python/src/types.rs index bd6ef0d376e63..6201ce374c642 100644 --- a/python/src/types.rs +++ b/python/src/types.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, TimeUnit}; use pyo3::{FromPyObject, PyAny, PyResult}; use crate::errors; @@ -28,37 +28,123 @@ pub struct PyDataType { impl<'source> FromPyObject<'source> for PyDataType { fn extract(ob: &'source PyAny) -> PyResult { + let str_ob = ob.to_string(); let id = ob.getattr("id")?.extract::()?; - let data_type = data_type_id(&id)?; + let data_type = data_type_id(&id, &str_ob)?; Ok(PyDataType { data_type }) } } -fn data_type_id(id: &i32) -> Result { +fn data_type_id(id: &i32, str_ob: &str) -> Result { // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd // this is not ideal as it does not generalize for non-basic types // Find a way to get a unique name from the pyarrow.DataType - Ok(match id { - 1 => DataType::Boolean, - 2 => DataType::UInt8, - 3 => DataType::Int8, - 4 => DataType::UInt16, - 5 => DataType::Int16, - 6 => DataType::UInt32, - 7 => DataType::Int32, - 8 => DataType::UInt64, - 9 => DataType::Int64, - 10 => DataType::Float16, - 11 => DataType::Float32, - 12 => DataType::Float64, - 13 => DataType::Utf8, - 14 => DataType::Binary, - 34 => DataType::LargeUtf8, - 35 => DataType::LargeBinary, - other => { + if str_ob.contains("date") { + Ok(data_type_date(str_ob)?) + } else if str_ob.contains("time") { + Ok(data_type_timestamp(str_ob)?) + } else { + Ok(match id { + 1 => DataType::Boolean, + 2 => DataType::UInt8, + 3 => DataType::Int8, + 4 => DataType::UInt16, + 5 => DataType::Int16, + 6 => DataType::UInt32, + 7 => DataType::Int32, + 8 => DataType::UInt64, + 9 => DataType::Int64, + 10 => DataType::Float16, + 11 => DataType::Float32, + 12 => DataType::Float64, + 13 => DataType::Utf8, + 14 => DataType::Binary, + 34 => DataType::LargeUtf8, + 35 => DataType::LargeBinary, + other => { + return Err(errors::DataFusionError::Common(format!( + "The type {} is not valid", + other + ))) + } + }) + } +} + +fn data_type_timestamp(str_ob: &str) -> Result { + // maps to usage from apache/arrow/pyarrow/types.pxi + Ok(match str_ob.as_ref() { + "time32[s]" => DataType::Time32(TimeUnit::Second), + "time32[ms]" => DataType::Time32(TimeUnit::Millisecond), + "time64[us]" => DataType::Time64(TimeUnit::Microsecond), + "time64[ns]" => DataType::Time64(TimeUnit::Nanosecond), + "timestamp[s]" => DataType::Timestamp(TimeUnit::Second, None), + "timestamp[ms]" => DataType::Timestamp(TimeUnit::Millisecond, None), + "timestamp[us]" => DataType::Timestamp(TimeUnit::Microsecond, None), + "timestamp[ns]" => DataType::Timestamp(TimeUnit::Nanosecond, None), + _ => data_type_timestamp_infer(str_ob)?, + }) +} + +fn data_type_date(str_ob: &str) -> Result { + // maps to usage from apache/arrow/pyarrow/types.pxi + Ok(match str_ob.as_ref() { + "date32" => DataType::Date32, + "date64" => DataType::Date64, + "date32[day]" => DataType::Date32, + "date64[ms]" => DataType::Date64, + _ => { + return Err(errors::DataFusionError::Common(format!( + "invalid date {} provided", + str_ob + ))) + } + }) +} + +fn time_unit_str(unit: &str) -> Result { + Ok(match unit { + "s" => TimeUnit::Second, + "ms" => TimeUnit::Millisecond, + "us" => TimeUnit::Microsecond, + "ns" => TimeUnit::Nanosecond, + _ => { + return Err(errors::DataFusionError::Common(format!( + "invalid timestamp unit {} provided", + unit + ))) + } + }) +} + +fn data_type_timestamp_infer(str_ob: &str) -> Result { + // parse the timestamp string object - this approach is less than idea, as it requires maintaining + // this and more direct access methods are better + let chunks: Vec<_> = str_ob.split("[").collect(); + let timestamp_str: String = chunks[0].to_string(); + let unit_tz: String = chunks[1].to_string().replace(",", "").replace("]", ""); + + let mut tz: Option = None; + let unit: TimeUnit; + + if unit_tz.len() < 3 { + unit = time_unit_str(&unit_tz)?; + } else { + // manage timezones + let chunks: Vec<_> = unit_tz.split(" ").collect(); + let tz_part: Vec<_> = unit_tz.split("=").collect(); + unit = time_unit_str(&chunks[0])?; + tz = Some(tz_part[1].to_string()); + } + + Ok(match timestamp_str.as_ref() { + "time32" => DataType::Time32(unit), + "time64" => DataType::Time64(unit), + "timestamp" => DataType::Timestamp(unit, tz), + _ => { return Err(errors::DataFusionError::Common(format!( - "The type {} is not valid", - other + "invalid timestamp string {} provided", + str_ob ))) } }) diff --git a/python/tests/generic.py b/python/tests/generic.py index 8d5adaaaf9563..c9e96484febc5 100644 --- a/python/tests/generic.py +++ b/python/tests/generic.py @@ -66,6 +66,17 @@ def data_date32(): ) +def data_date64(): + data = [ + datetime.date(2000, 1, 1), + datetime.date(1980, 1, 1), + datetime.date(2030, 1, 1), + ] + return pa.array( + data, type=pa.date64(), mask=np.array([False, True, False]) + ) + + def data_timedelta(f): data = [ datetime.timedelta(days=100), diff --git a/python/tests/test_dates.py b/python/tests/test_dates.py new file mode 100644 index 0000000000000..0f7ce3baa9492 --- /dev/null +++ b/python/tests/test_dates.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +import pyarrow as pa +import pytest +from datafusion import ExecutionContext +from datafusion import functions as f + +from . import generic as helpers + + +@pytest.fixture +def ctx(): + return ExecutionContext() + + +@pytest.fixture +def df(): + ctx = ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [ + helpers.data_datetime("s"), + helpers.data_date32(), + helpers.data_date64(), + ], + names=["ts", "dt1", "dt2"], + ) + + return ctx.create_dataframe([[batch]]) + + +def test_select_ts_date(df): + df = df.select(f.col("ts"), f.col("dt1"), f.col("dt2")) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.column(0) == helpers.data_datetime("s") + assert result.column(1) == helpers.data_date32() + assert result.column(2) == helpers.data_date64() + + +@pytest.mark.parametrize( + ("input_values", "input_type", "output_type"), + [ + ( + [datetime(1970, 1, 1), datetime(1970, 1, 2), datetime(1970, 1, 3)], + pa.date32(), + pa.date32(), + ), + ( + [datetime(1970, 1, 1), datetime(1970, 1, 2), datetime(1970, 1, 3)], + pa.date64(), + pa.date64(), + ), + ( + [datetime(1970, 1, 1), datetime(1970, 1, 2), datetime(1970, 1, 3)], + pa.timestamp("ms"), + pa.timestamp("ms"), + ), + ( + [datetime(1970, 1, 1), datetime(1970, 1, 2), datetime(1970, 1, 3)], + pa.timestamp("s"), + pa.timestamp("s"), + ), + ( + [datetime(1970, 1, 1), datetime(1970, 1, 2), datetime(1970, 1, 3)], + pa.timestamp("us"), + pa.timestamp("us"), + ), + ( + [datetime(1970, 1, 1), datetime(1970, 1, 2), datetime(1970, 1, 3)], + pa.timestamp("ns"), + pa.timestamp("ns"), + ), + ([0, 1, 2], pa.time32("s"), pa.time32("s"),), + ([0, 1, 2], pa.time64("us"), pa.time64("us"),), + ], +) +def test_datetypes(ctx, input_values, input_type, output_type): + batch = pa.RecordBatch.from_arrays( + [pa.array(input_values, type=input_type)], names=["a"] + ) + + df = ctx.create_dataframe([[batch]]) + result = df.collect()[0] + assert result.column(0).type == output_type + assert result.column(0) == batch.column(0)