diff --git a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_null.slt b/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_null.slt new file mode 100644 index 0000000000000..f933c90acc734 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_null.slt @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +onlyif postgres +statement ok +CREATE TABLE aggregate_test_100_by_sql +( + c1 character varying NOT NULL, + c2 smallint NOT NULL, + c3 smallint NOT NULL, + c4 smallint, + c5 integer, + c6 bigint NOT NULL, + c7 smallint NOT NULL, + c8 integer NOT NULL, + c9 bigint NOT NULL, + c10 character varying NOT NULL, + c11 real NOT NULL, + c12 double precision NOT NULL, + c13 character varying NOT NULL +); + + +# Copy the data +onlyif postgres +statement ok +COPY aggregate_test_100_by_sql + FROM '../../testing/data/csv/aggregate_test_100.csv' + DELIMITER ',' + CSV HEADER; + + +### +## Setup test for datafusion +### +onlyif DataFusion +statement ok +CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + + +statement ok +CREATE TABLE aggregate_test_100_nullable_by_sql AS +SELECT + *, + CASE + WHEN c4 % 3 = 0 THEN NULL + ELSE c5 + END AS n5, + CASE + WHEN c3 % 3 != 0 THEN c9 + ELSE NULL + END AS n9 +FROM aggregate_test_100_by_sql + + +query III +SELECT + COUNT(*), COUNT(n5), COUNT(n9) +FROM aggregate_test_100_nullable_by_sql +---- +100 66 72 + + +######## +# Clean up after the test +######## +statement ok +DROP TABLE aggregate_test_100_by_sql + +statement ok +DROP TABLE aggregate_test_100_nullable_by_sql diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index fafda79a6f61d..bfe464b232cd5 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,6 +21,7 @@ use crate::expr::{ }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::binary_operator_data_type; +use crate::type_coercion::other::get_coerce_type_for_case_when; use crate::{aggregate_function, function, window_function}; use arrow::compute::can_cast_types; use arrow::datatypes::DataType; @@ -68,7 +69,26 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), - Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), + Expr::Case(case) => { + // when #5681 will be fixed, this code can be reverted to: + // case.when_then_expr[0].1.get_type(schema) + let then_types = case + .when_then_expr + .iter() + .map(|when_then| when_then.1.get_type(schema)) + .collect::>>()?; + let else_type = match &case.else_expr { + None => Ok(None), + Some(expr) => expr.get_type(schema).map(Some), + }?; + get_coerce_type_for_case_when(&then_types, else_type.as_ref()).ok_or_else( + || { + DataFusionError::Internal(String::from( + "Cannot infer type for CASE statement", + )) + }, + ) + } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::ScalarUDF { fun, args } => {