diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index e292cf4a9bc..1e4c9b7f719 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -202,11 +202,7 @@ Result InMemoryDataset::GetFragmentsImpl(compute::Expression) auto create_fragment = [schema](std::shared_ptr batch) -> Result> { - if (!batch->schema()->Equals(schema)) { - return Status::TypeError("yielded batch had schema ", *batch->schema(), - " which did not match InMemorySource's: ", *schema); - } - + RETURN_NOT_OK(CheckProjectable(*schema, *batch->schema())); return std::make_shared(RecordBatchVector{std::move(batch)}); }; diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 66d69c30c82..35b6e8129e2 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -62,9 +62,13 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) { schema_, RecordBatchVector{static_cast(kNumberBatches), batch}); // drop field - ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status()); + auto new_schema = schema({field("i32", int32())}); + ASSERT_OK_AND_ASSIGN(auto new_dataset, dataset->ReplaceSchema(new_schema)); + AssertDatasetHasSchema(new_dataset, new_schema); // add field (will be materialized as null during projection) - ASSERT_OK(dataset->ReplaceSchema(schema({field("str", utf8())})).status()); + new_schema = schema({field("str", utf8())}); + ASSERT_OK_AND_ASSIGN(new_dataset, dataset->ReplaceSchema(new_schema)); + AssertDatasetHasSchema(new_dataset, new_schema); // incompatible type ASSERT_RAISES(TypeError, dataset->ReplaceSchema(schema({field("i32", utf8())})).status()); @@ -107,6 +111,40 @@ TEST_F(TestInMemoryDataset, InMemoryFragment) { AssertSchemaEqual(batch->schema(), schema); } +TEST_F(TestInMemoryDataset, HandlesDifferingSchemas) { + constexpr int64_t kBatchSize = 1024; + + // These schemas can be merged + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch1 = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + SetSchema({field("i32", int32())}); + auto batch2 = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + RecordBatchVector batches{batch1, batch2}; + + auto dataset = std::make_shared(schema_, batches); + + ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan()); + ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); + ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); + ASSERT_EQ(*table->schema(), *schema_); + ASSERT_EQ(table->num_rows(), 2 * kBatchSize); + + // These cannot be merged + SetSchema({field("i32", int32()), field("f64", float64())}); + batch1 = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + SetSchema({field("i32", struct_({field("x", date32())}))}); + batch2 = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + batches = RecordBatchVector({batch1, batch2}); + + dataset = std::make_shared(schema_, batches); + + ASSERT_OK_AND_ASSIGN(scanner_builder, dataset->NewScan()); + ASSERT_OK_AND_ASSIGN(scanner, scanner_builder->Finish()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, testing::HasSubstr("fields had matching names but differing types"), + scanner->ToTable()); +} + class TestUnionDataset : public DatasetFixtureMixin {}; TEST_F(TestUnionDataset, ReplaceSchema) { @@ -131,9 +169,13 @@ TEST_F(TestUnionDataset, ReplaceSchema) { AssertDatasetEquals(reader.get(), dataset.get()); // drop field - ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status()); + auto new_schema = schema({field("i32", int32())}); + ASSERT_OK_AND_ASSIGN(auto new_dataset, dataset->ReplaceSchema(new_schema)); + AssertDatasetHasSchema(new_dataset, new_schema); // add nullable field (will be materialized as null during projection) - ASSERT_OK(dataset->ReplaceSchema(schema({field("str", utf8())})).status()); + new_schema = schema({field("str", utf8())}); + ASSERT_OK_AND_ASSIGN(new_dataset, dataset->ReplaceSchema(new_schema)); + AssertDatasetHasSchema(new_dataset, new_schema); // incompatible type ASSERT_RAISES(TypeError, dataset->ReplaceSchema(schema({field("i32", utf8())})).status()); diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index cc89c163cb7..226c23ef5e4 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -148,9 +148,13 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { FileSystemDataset::Make(schm, literal(true), format, nullptr, {})); // drop field - ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status()); + auto new_schema = schema({field("i32", int32())}); + ASSERT_OK_AND_ASSIGN(auto new_dataset, dataset->ReplaceSchema(new_schema)); + AssertDatasetHasSchema(new_dataset, new_schema); // add nullable field (will be materialized as null during projection) - ASSERT_OK(dataset->ReplaceSchema(schema({field("str", utf8())})).status()); + new_schema = schema({field("str", utf8())}); + ASSERT_OK_AND_ASSIGN(new_dataset, dataset->ReplaceSchema(new_schema)); + AssertDatasetHasSchema(new_dataset, new_schema); // incompatible type ASSERT_RAISES(TypeError, dataset->ReplaceSchema(schema({field("i32", utf8())})).status()); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 9ec0a59860e..b7fc66e2ae2 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -81,6 +81,14 @@ using compute::project; using fs::internal::GetAbstractPathExtension; +/// \brief Assert a dataset produces data with the schema +void AssertDatasetHasSchema(std::shared_ptr ds, std::shared_ptr schema) { + ASSERT_OK_AND_ASSIGN(auto scanner_builder, ds->NewScan()); + ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); + ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); + ASSERT_EQ(*table->schema(), *schema); +} + class FileSourceFixtureMixin : public ::testing::Test { public: std::unique_ptr GetSource(std::shared_ptr buffer) { diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 493222f6a10..8fd5c7d78e2 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -595,6 +595,24 @@ test_that("UnionDataset can merge schemas", { expect_equal(actual, expected) }) +test_that("UnionDataset handles InMemoryDatasets", { + sub_df1 <- Table$create( + x = Array$create(c(1, 2, 3)), + y = Array$create(c("a", "b", "c")) + ) + sub_df2 <- Table$create( + x = Array$create(c(4, 5)), + z = Array$create(c("d", "e")) + ) + + ds1 <- InMemoryDataset$create(sub_df1) + ds2 <- InMemoryDataset$create(sub_df2) + ds <- c(ds1, ds2) + actual <- ds %>% collect(as_data_frame = FALSE) + expected <- concat_tables(sub_df1, sub_df2) + expect_equal(actual, expected) +}) + test_that("map_batches", { ds <- open_dataset(dataset_dir, partitioning = "part")