Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions cpp/src/arrow/dataset/dataset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ TEST(TestProjector, CheckProjectable) {
auto i8_req = field("i8", int8(), false);
auto u16_req = field("u16", uint16(), false);
auto str_req = field("str", utf8(), false);
auto str_nil = field("str", null());

// trivial
Assert({}).ProjectableTo({});
Expand All @@ -235,6 +236,8 @@ TEST(TestProjector, CheckProjectable) {
Assert({i8}).NotProjectableTo({i8_req},
"not nullable but is not required in origin schema");
Assert({i8_req}).ProjectableTo({i8});
Assert({str_nil}).ProjectableTo({str});
Assert({str_nil}).NotProjectableTo({str_req});

// change field type
Assert({i8}).NotProjectableTo({field("i8", utf8())},
Expand All @@ -257,15 +260,18 @@ TEST(TestProjector, MismatchedType) {
TEST(TestProjector, AugmentWithNull) {
constexpr int64_t kBatchSize = 1024;

auto from_schema = schema({field("f64", float64()), field("b", boolean())});
auto from_schema =
schema({field("f64", float64()), field("b", boolean()), field("str", null())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, from_schema);
auto to_schema = schema({field("i32", int32()), field("f64", float64())});
auto to_schema =
schema({field("i32", int32()), field("f64", float64()), field("str", utf8())});

RecordBatchProjector projector(to_schema);

ASSERT_OK_AND_ASSIGN(auto null_i32, MakeArrayOfNull(int32(), batch->num_rows()));
auto expected_batch =
RecordBatch::Make(to_schema, batch->num_rows(), {null_i32, batch->column(0)});
ASSERT_OK_AND_ASSIGN(auto null_str, MakeArrayOfNull(utf8(), batch->num_rows()));
auto expected_batch = RecordBatch::Make(to_schema, batch->num_rows(),
{null_i32, batch->column(0), null_str});

ASSERT_OK_AND_ASSIGN(auto reconciled_batch, projector.Project(*batch));
AssertBatchesEqual(*expected_batch, *reconciled_batch);
Expand Down
14 changes: 12 additions & 2 deletions cpp/src/arrow/dataset/projector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ Status CheckProjectable(const Schema& from, const Schema& to) {
from);
}

if (from_field->type()->id() == Type::NA) {
// promotion from null to any type is supported
if (to_field->nullable()) continue;

return Status::TypeError("field ", to_field->ToString(),
" is not nullable but has type ", NullType(),
" in origin schema ", from);
}

if (!from_field->type()->Equals(to_field->type())) {
return Status::TypeError("fields had matching names but differing types. From: ",
from_field->ToString(), " To: ", to_field->ToString());
Expand Down Expand Up @@ -98,7 +107,7 @@ Result<std::shared_ptr<RecordBatch>> RecordBatchProjector::Project(
RETURN_NOT_OK(ResizeMissingColumns(batch.num_rows(), pool));
}

std::vector<std::shared_ptr<Array>> columns(to_->num_fields());
ArrayVector columns(to_->num_fields());

for (int i = 0; i < to_->num_fields(); ++i) {
if (column_indices_[i] != kNoMatch) {
Expand All @@ -120,7 +129,8 @@ Status RecordBatchProjector::SetInputSchema(std::shared_ptr<Schema> from,
ARROW_ASSIGN_OR_RAISE(auto match,
FieldRef(to_->field(i)->name()).FindOneOrNone(*from_));

if (match.indices().empty()) {
if (match.indices().empty() ||
from_->field(match.indices()[0])->type()->id() == Type::NA) {
// Mark column i as missing by setting missing_columns_[i]
// to a non-null placeholder.
ARROW_ASSIGN_OR_RAISE(missing_columns_[i],
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/testing/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class ARROW_TESTING_EXPORT ConstantArrayGenerator {
static std::shared_ptr<arrow::Array> Zeroes(int64_t size,
const std::shared_ptr<DataType>& type) {
switch (type->id()) {
case Type::NA:
return std::make_shared<NullArray>(size);
case Type::BOOL:
return Boolean(size);
case Type::UINT8:
Expand Down
15 changes: 15 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,21 @@ def test_dataset_project_only_partition_columns(tempdir):
assert all_cols.column('part').equals(part_only.column('part'))


@pytest.mark.parquet
@pytest.mark.pandas
def test_dataset_project_null_column(tempdir):
import pandas as pd
df = pd.DataFrame({"col": np.array([None, None, None], dtype='object')})

f = tempdir / "test_dataset_project_null_column.parquet"
df.to_parquet(f, engine="pyarrow")

dataset = ds.dataset(f, format="parquet",
schema=pa.schema([("col", pa.int64())]))
expected = pa.table({'col': pa.array([None, None, None], pa.int64())})
assert dataset.to_table().equals(expected)


def _check_dataset_roundtrip(dataset, base_dir, expected_files,
base_dir_path=None, partitioning=None):
base_dir_path = base_dir_path or base_dir
Expand Down