diff --git a/.github/workflows/comment_bot.yml b/.github/workflows/comment_bot.yml index 35d889152fb..b78ae80fb97 100644 --- a/.github/workflows/comment_bot.yml +++ b/.github/workflows/comment_bot.yml @@ -95,8 +95,8 @@ jobs: set -ex export PATH=/home/runner/.local/bin:$PATH python3 -m pip install --upgrade pip setuptools wheel - python3 -m pip install -r dev/archery/requirements-lint.txt - python3 run-cmake-format.py + python3 -m pip install -e dev/archery[lint] + archery lint --cmake-format --fix - name: Run clang-format on cpp if: env.CLANG_FORMAT_CPP == 'true' || endsWith(github.event.comment.body, 'everything') run: | diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index 5acb47a0ae0..e160ba8128a 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -247,6 +247,7 @@ jobs: Sys.setenv( RWINLIB_LOCAL = file.path(Sys.getenv("GITHUB_WORKSPACE"), "libarrow.zip"), MAKEFLAGS = paste0("-j", parallel::detectCores()), + ARROW_R_DEV = TRUE, "_R_CHECK_FORCE_SUGGESTS_" = FALSE ) rcmdcheck::rcmdcheck("r", diff --git a/c_glib/arrow-dataset-glib/arrow-dataset-glib.h b/c_glib/arrow-dataset-glib/arrow-dataset-glib.h index 03e56516112..58f4e216cc7 100644 --- a/c_glib/arrow-dataset-glib/arrow-dataset-glib.h +++ b/c_glib/arrow-dataset-glib/arrow-dataset-glib.h @@ -23,6 +23,8 @@ #include #include +#include #include #include +#include #include diff --git a/c_glib/arrow-dataset-glib/arrow-dataset-glib.hpp b/c_glib/arrow-dataset-glib/arrow-dataset-glib.hpp index 65341b9b77e..8e996506884 100644 --- a/c_glib/arrow-dataset-glib/arrow-dataset-glib.hpp +++ b/c_glib/arrow-dataset-glib/arrow-dataset-glib.hpp @@ -25,4 +25,5 @@ #include #include #include +#include #include diff --git a/c_glib/arrow-dataset-glib/dataset-factory.cpp b/c_glib/arrow-dataset-glib/dataset-factory.cpp index 146db69adfc..433e58b2031 100644 --- a/c_glib/arrow-dataset-glib/dataset-factory.cpp +++ b/c_glib/arrow-dataset-glib/dataset-factory.cpp @@ -23,6 +23,7 @@ #include #include #include +#include G_BEGIN_DECLS @@ -142,6 +143,7 @@ gadataset_dataset_factory_finish(GADatasetDatasetFactory *factory, typedef struct GADatasetFileSystemDatasetFactoryPrivate_ { GADatasetFileFormat *format; GArrowFileSystem *file_system; + GADatasetPartitioning *partitioning; GList *files; arrow::dataset::FileSystemFactoryOptions options; } GADatasetFileSystemDatasetFactoryPrivate; @@ -149,6 +151,8 @@ typedef struct GADatasetFileSystemDatasetFactoryPrivate_ { enum { PROP_FORMAT = 1, PROP_FILE_SYSTEM, + PROP_PARTITIONING, + PROP_PARTITION_BASE_DIR, }; G_DEFINE_TYPE_WITH_PRIVATE(GADatasetFileSystemDatasetFactory, @@ -175,6 +179,11 @@ gadataset_file_system_dataset_factory_dispose(GObject *object) priv->file_system = NULL; } + if (priv->partitioning) { + g_object_unref(priv->partitioning); + priv->partitioning = NULL; + } + if (priv->files) { g_list_free_full(priv->files, g_object_unref); priv->files = NULL; @@ -205,6 +214,29 @@ gadataset_file_system_dataset_factory_set_property(GObject *object, case PROP_FORMAT: priv->format = GADATASET_FILE_FORMAT(g_value_dup_object(value)); break; + case PROP_PARTITIONING: + { + auto partitioning = g_value_get_object(value); + if (partitioning == priv->partitioning) { + break; + } + auto old_partitioning = priv->partitioning; + if (partitioning) { + g_object_ref(partitioning); + priv->partitioning = GADATASET_PARTITIONING(partitioning); + priv->options.partitioning = + gadataset_partitioning_get_raw(priv->partitioning); + } else { + priv->options.partitioning = arrow::dataset::Partitioning::Default(); + } + if (old_partitioning) { + g_object_unref(old_partitioning); + } + } + break; + case PROP_PARTITION_BASE_DIR: + priv->options.partition_base_dir = g_value_get_string(value); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -226,6 +258,12 @@ gadataset_file_system_dataset_factory_get_property(GObject *object, case PROP_FILE_SYSTEM: g_value_set_object(value, priv->file_system); break; + case PROP_PARTITIONING: + g_value_set_object(value, priv->partitioning); + break; + case PROP_PARTITION_BASE_DIR: + g_value_set_string(value, priv->options.partition_base_dir.c_str()); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -279,6 +317,35 @@ gadataset_file_system_dataset_factory_class_init( GARROW_TYPE_FILE_SYSTEM, static_cast(G_PARAM_READABLE)); g_object_class_install_property(gobject_class, PROP_FILE_SYSTEM, spec); + + /** + * GADatasetFileSystemDatasetFactory:partitioning: + * + * Partitioning used by #GADatasetFileSystemDataset. + * + * Since: 6.0.0 + */ + spec = g_param_spec_object("partitioning", + "Partitioning", + "Partitioning used by GADatasetFileSystemDataset", + GADATASET_TYPE_PARTITIONING, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_PARTITIONING, spec); + + /** + * GADatasetFileSystemDatasetFactory:partition-base-dir: + * + * Partition base directory used by #GADatasetFileSystemDataset. + * + * Since: 6.0.0 + */ + spec = g_param_spec_string("partition-base-dir", + "Partition base directory", + "Partition base directory " + "used by GADatasetFileSystemDataset", + NULL, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_PARTITION_BASE_DIR, spec); } /** @@ -454,6 +521,7 @@ gadataset_file_system_dataset_factory_finish( "dataset", &arrow_dataset, "file-system", priv->file_system, "format", priv->format, + "partitioning", priv->partitioning, NULL)); } diff --git a/c_glib/arrow-dataset-glib/dataset.cpp b/c_glib/arrow-dataset-glib/dataset.cpp index 3bd62f99ef3..8613bedad42 100644 --- a/c_glib/arrow-dataset-glib/dataset.cpp +++ b/c_glib/arrow-dataset-glib/dataset.cpp @@ -18,11 +18,14 @@ */ #include +#include #include #include #include -#include +#include +#include +#include G_BEGIN_DECLS @@ -36,13 +39,8 @@ G_BEGIN_DECLS * * #GADatasetFileSystemDataset is a class for file system dataset. * - * #GADatasetFileFormat is a base class for file formats. - * - * #GADatasetCSVFileFormat is a class for CSV file format. - * - * #GADatasetIPCFileFormat is a class for IPC file format. - * - * #GADatasetParquetFileFormat is a class for Apache Parquet file format. + * #GADatasetFileSystemDatasetWriteOptions is a class for options to + * write a dataset to file system dataset. * * Since: 5.0.0 */ @@ -190,14 +188,326 @@ gadataset_dataset_get_type_name(GADatasetDataset *dataset) } +typedef struct GADatasetFileSystemDatasetWriteOptionsPrivate_ { + arrow::dataset::FileSystemDatasetWriteOptions options; + GADatasetFileWriteOptions *file_write_options; + GArrowFileSystem *file_system; + GADatasetPartitioning *partitioning; +} GADatasetFileSystemDatasetWriteOptionsPrivate; + +enum { + PROP_FILE_WRITE_OPTIONS = 1, + PROP_FILE_SYSTEM, + PROP_BASE_DIR, + PROP_PARTITIONING, + PROP_MAX_PARTITIONS, + PROP_BASE_NAME_TEMPLATE, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GADatasetFileSystemDatasetWriteOptions, + gadataset_file_system_dataset_write_options, + G_TYPE_OBJECT) + +#define GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS_GET_PRIVATE(obj) \ + static_cast( \ + gadataset_file_system_dataset_write_options_get_instance_private( \ + GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS(obj))) + +static void +gadataset_file_system_dataset_write_options_finalize(GObject *object) +{ + auto priv = GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS_GET_PRIVATE(object); + priv->options.~FileSystemDatasetWriteOptions(); + G_OBJECT_CLASS(gadataset_file_system_dataset_write_options_parent_class)-> + finalize(object); +} + +static void +gadataset_file_system_dataset_write_options_dispose(GObject *object) +{ + auto priv = GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS_GET_PRIVATE(object); + + if (priv->file_write_options) { + g_object_unref(priv->file_write_options); + priv->file_write_options = NULL; + } + + if (priv->file_system) { + g_object_unref(priv->file_system); + priv->file_system = NULL; + } + + if (priv->partitioning) { + g_object_unref(priv->partitioning); + priv->partitioning = NULL; + } + + G_OBJECT_CLASS(gadataset_file_system_dataset_write_options_parent_class)-> + dispose(object); +} + +static void +gadataset_file_system_dataset_write_options_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_FILE_WRITE_OPTIONS: + { + auto file_write_options = g_value_get_object(value); + if (file_write_options == priv->file_write_options) { + break; + } + auto old_file_write_options = priv->file_write_options; + if (file_write_options) { + g_object_ref(file_write_options); + priv->file_write_options = + GADATASET_FILE_WRITE_OPTIONS(file_write_options); + priv->options.file_write_options = + gadataset_file_write_options_get_raw(priv->file_write_options); + } else { + priv->options.file_write_options = nullptr; + } + if (old_file_write_options) { + g_object_unref(old_file_write_options); + } + } + break; + case PROP_FILE_SYSTEM: + { + auto file_system = g_value_get_object(value); + if (file_system == priv->file_system) { + break; + } + auto old_file_system = priv->file_system; + if (file_system) { + g_object_ref(file_system); + priv->file_system = GARROW_FILE_SYSTEM(file_system); + priv->options.filesystem = garrow_file_system_get_raw(priv->file_system); + } else { + priv->options.filesystem = nullptr; + } + if (old_file_system) { + g_object_unref(old_file_system); + } + } + break; + case PROP_BASE_DIR: + priv->options.base_dir = g_value_get_string(value); + break; + case PROP_PARTITIONING: + { + auto partitioning = g_value_get_object(value); + if (partitioning == priv->partitioning) { + break; + } + auto old_partitioning = priv->partitioning; + if (partitioning) { + g_object_ref(partitioning); + priv->partitioning = GADATASET_PARTITIONING(partitioning); + priv->options.partitioning = + gadataset_partitioning_get_raw(priv->partitioning); + } else { + priv->options.partitioning = arrow::dataset::Partitioning::Default(); + } + if (old_partitioning) { + g_object_unref(old_partitioning); + } + } + break; + case PROP_MAX_PARTITIONS: + priv->options.max_partitions = g_value_get_uint(value); + break; + case PROP_BASE_NAME_TEMPLATE: + priv->options.basename_template = g_value_get_string(value); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gadataset_file_system_dataset_write_options_get_property(GObject *object, + guint prop_id, + GValue *value, + GParamSpec *pspec) +{ + auto priv = GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_FILE_WRITE_OPTIONS: + g_value_set_object(value, priv->file_write_options); + break; + case PROP_FILE_SYSTEM: + g_value_set_object(value, priv->file_system); + break; + case PROP_BASE_DIR: + g_value_set_string(value, priv->options.base_dir.c_str()); + break; + case PROP_PARTITIONING: + g_value_set_object(value, priv->partitioning); + break; + case PROP_MAX_PARTITIONS: + g_value_set_uint(value, priv->options.max_partitions); + break; + case PROP_BASE_NAME_TEMPLATE: + g_value_set_string(value, priv->options.basename_template.c_str()); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gadataset_file_system_dataset_write_options_init( + GADatasetFileSystemDatasetWriteOptions *object) +{ + auto priv = GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS_GET_PRIVATE(object); + new(&(priv->options)) arrow::dataset::FileSystemDatasetWriteOptions; + priv->options.partitioning = arrow::dataset::Partitioning::Default(); +} + +static void +gadataset_file_system_dataset_write_options_class_init( + GADatasetFileSystemDatasetWriteOptionsClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + gobject_class->finalize = + gadataset_file_system_dataset_write_options_finalize; + gobject_class->dispose = + gadataset_file_system_dataset_write_options_dispose; + gobject_class->set_property = + gadataset_file_system_dataset_write_options_set_property; + gobject_class->get_property = + gadataset_file_system_dataset_write_options_get_property; + + arrow::dataset::FileSystemDatasetWriteOptions default_options; + GParamSpec *spec; + /** + * GADatasetFileSystemDatasetWriteOptions:file_write_options: + * + * Options for individual fragment writing. + * + * Since: 6.0.0 + */ + spec = g_param_spec_object("file-write-options", + "File write options", + "Options for individual fragment writing", + GADATASET_TYPE_FILE_WRITE_OPTIONS, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_FILE_WRITE_OPTIONS, spec); + + /** + * GADatasetFileSystemDatasetWriteOptions:file_system: + * + * #GArrowFileSystem into which a dataset will be written. + * + * Since: 6.0.0 + */ + spec = g_param_spec_object("file-system", + "File system", + "GArrowFileSystem into which " + "a dataset will be written", + GARROW_TYPE_FILE_SYSTEM, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_FILE_SYSTEM, spec); + + /** + * GADatasetFileSystemDatasetWriteOptions:base_dir: + * + * Root directory into which the dataset will be written. + * + * Since: 6.0.0 + */ + spec = g_param_spec_string("base-dir", + "Base directory", + "Root directory into which " + "the dataset will be written", + NULL, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_BASE_DIR, spec); + + /** + * GADatasetFileSystemDatasetWriteOptions:partitioning: + * + * #GADatasetPartitioning used to generate fragment paths. + * + * Since: 6.0.0 + */ + spec = g_param_spec_object("partitioning", + "Partitioning", + "GADatasetPartitioning used to " + "generate fragment paths", + GADATASET_TYPE_PARTITIONING, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_PARTITIONING, spec); + + /** + * GADatasetFileSystemDatasetWriteOptions:max-partitions: + * + * Maximum number of partitions any batch may be written into. + * + * Since: 6.0.0 + */ + spec = g_param_spec_uint("max-partitions", + "Max partitions", + "Maximum number of partitions " + "any batch may be written into", + 0, + G_MAXINT, + default_options.max_partitions, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_MAX_PARTITIONS, spec); + + /** + * GADatasetFileSystemDatasetWriteOptions:base-name-template: + * + * Template string used to generate fragment base names. {i} will be + * replaced by an auto incremented integer. + * + * Since: 6.0.0 + */ + spec = g_param_spec_string("base-name-template", + "Base name template", + "Template string used to generate fragment " + "base names. {i} will be replaced by " + "an auto incremented integer", + NULL, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_BASE_NAME_TEMPLATE, spec); +} + +/** + * gadataset_file_system_dataset_write_options_new: + * + * Returns: The newly created #GADatasetFileSystemDatasetWriteOptions. + * + * Since: 6.0.0 + */ +GADatasetFileSystemDatasetWriteOptions * +gadataset_file_system_dataset_write_options_new(void) +{ + return GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS( + g_object_new(GADATASET_TYPE_FILE_SYSTEM_DATASET_WRITE_OPTIONS, + NULL)); +} + + typedef struct GADatasetFileSystemDatasetPrivate_ { GADatasetFileFormat *format; GArrowFileSystem *file_system; + GADatasetPartitioning *partitioning; } GADatasetFileSystemDatasetPrivate; enum { - PROP_FORMAT = 1, - PROP_FILE_SYSTEM, + PROP_FILE_SYSTEM_DATASET_FORMAT = 1, + PROP_FILE_SYSTEM_DATASET_FILE_SYSTEM, + PROP_FILE_SYSTEM_DATASET_PARTITIONING, }; G_DEFINE_TYPE_WITH_PRIVATE(GADatasetFileSystemDataset, @@ -236,12 +546,15 @@ gadataset_file_system_dataset_set_property(GObject *object, auto priv = GADATASET_FILE_SYSTEM_DATASET_GET_PRIVATE(object); switch (prop_id) { - case PROP_FORMAT: + case PROP_FILE_SYSTEM_DATASET_FORMAT: priv->format = GADATASET_FILE_FORMAT(g_value_dup_object(value)); break; - case PROP_FILE_SYSTEM: + case PROP_FILE_SYSTEM_DATASET_FILE_SYSTEM: priv->file_system = GARROW_FILE_SYSTEM(g_value_dup_object(value)); break; + case PROP_FILE_SYSTEM_DATASET_PARTITIONING: + priv->partitioning = GADATASET_PARTITIONING(g_value_dup_object(value)); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -257,12 +570,15 @@ gadataset_file_system_dataset_get_property(GObject *object, auto priv = GADATASET_FILE_SYSTEM_DATASET_GET_PRIVATE(object); switch (prop_id) { - case PROP_FORMAT: + case PROP_FILE_SYSTEM_DATASET_FORMAT: g_value_set_object(value, priv->format); break; - case PROP_FILE_SYSTEM: + case PROP_FILE_SYSTEM_DATASET_FILE_SYSTEM: g_value_set_object(value, priv->file_system); break; + case PROP_FILE_SYSTEM_DATASET_PARTITIONING: + g_value_set_object(value, priv->partitioning); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -296,7 +612,9 @@ gadataset_file_system_dataset_class_init(GADatasetFileSystemDatasetClass *klass) GADATASET_TYPE_FILE_FORMAT, static_cast(G_PARAM_READWRITE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_FORMAT, spec); + g_object_class_install_property(gobject_class, + PROP_FILE_SYSTEM_DATASET_FORMAT, + spec); /** * GADatasetFileSystemDataset:file-system: @@ -311,7 +629,52 @@ gadataset_file_system_dataset_class_init(GADatasetFileSystemDatasetClass *klass) GARROW_TYPE_FILE_SYSTEM, static_cast(G_PARAM_READWRITE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_FILE_SYSTEM, spec); + g_object_class_install_property(gobject_class, + PROP_FILE_SYSTEM_DATASET_FILE_SYSTEM, + spec); + + /** + * GADatasetFileSystemDataset:partitioning: + * + * Partitioning of the dataset. + * + * Since: 6.0.0 + */ + spec = g_param_spec_object("partitioning", + "Partitioning", + "Partitioning of the dataset", + GADATASET_TYPE_PARTITIONING, + static_cast(G_PARAM_READWRITE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, + PROP_FILE_SYSTEM_DATASET_PARTITIONING, + spec); +} + +/** + * gadataset_file_system_dataset_write_scanner: + * @scanner: A #GADatasetScanner that produces data to be written. + * @options: A #GADatasetFileSystemDatasetWriteOptions. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 6.0.0 + */ +gboolean +gadataset_file_system_dataset_write_scanner( + GADatasetScanner *scanner, + GADatasetFileSystemDatasetWriteOptions *options, + GError **error) +{ + auto arrow_scanner = gadataset_scanner_get_raw(scanner); + auto arrow_options = + gadataset_file_system_dataset_write_options_get_raw(options); + auto status = + arrow::dataset::FileSystemDataset::Write(*arrow_options, arrow_scanner); + return garrow::check(error, + status, + "[file-system-dataset][write-scanner]"); } @@ -363,3 +726,11 @@ gadataset_dataset_get_raw(GADatasetDataset *dataset) auto priv = GADATASET_DATASET_GET_PRIVATE(dataset); return priv->dataset; } + +arrow::dataset::FileSystemDatasetWriteOptions * +gadataset_file_system_dataset_write_options_get_raw( + GADatasetFileSystemDatasetWriteOptions *options) +{ + auto priv = GADATASET_FILE_SYSTEM_DATASET_WRITE_OPTIONS_GET_PRIVATE(options); + return &(priv->options); +} diff --git a/c_glib/arrow-dataset-glib/dataset.h b/c_glib/arrow-dataset-glib/dataset.h index 97cf35d74d7..86d077caa98 100644 --- a/c_glib/arrow-dataset-glib/dataset.h +++ b/c_glib/arrow-dataset-glib/dataset.h @@ -24,6 +24,7 @@ G_BEGIN_DECLS typedef struct _GADatasetScannerBuilder GADatasetScannerBuilder; +typedef struct _GADatasetScanner GADatasetScanner; #define GADATASET_TYPE_DATASET (gadataset_dataset_get_type()) G_DECLARE_DERIVABLE_TYPE(GADatasetDataset, @@ -49,6 +50,23 @@ gchar * gadataset_dataset_get_type_name(GADatasetDataset *dataset); +#define GADATASET_TYPE_FILE_SYSTEM_DATASET_WRITE_OPTIONS \ + (gadataset_file_system_dataset_write_options_get_type()) +G_DECLARE_DERIVABLE_TYPE(GADatasetFileSystemDatasetWriteOptions, + gadataset_file_system_dataset_write_options, + GADATASET, + FILE_SYSTEM_DATASET_WRITE_OPTIONS, + GObject) +struct _GADatasetFileSystemDatasetWriteOptionsClass +{ + GObjectClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +GADatasetFileSystemDatasetWriteOptions * +gadataset_file_system_dataset_write_options_new(void); + + #define GADATASET_TYPE_FILE_SYSTEM_DATASET \ (gadataset_file_system_dataset_get_type()) G_DECLARE_DERIVABLE_TYPE(GADatasetFileSystemDataset, @@ -61,5 +79,12 @@ struct _GADatasetFileSystemDatasetClass GADatasetDatasetClass parent_class; }; +GARROW_AVAILABLE_IN_6_0 +gboolean +gadataset_file_system_dataset_write_scanner( + GADatasetScanner *scanner, + GADatasetFileSystemDatasetWriteOptions *options, + GError **error); + G_END_DECLS diff --git a/c_glib/arrow-dataset-glib/dataset.hpp b/c_glib/arrow-dataset-glib/dataset.hpp index 94dddd2eb7a..1dab391e8af 100644 --- a/c_glib/arrow-dataset-glib/dataset.hpp +++ b/c_glib/arrow-dataset-glib/dataset.hpp @@ -23,6 +23,7 @@ #include + GADatasetDataset * gadataset_dataset_new_raw( std::shared_ptr *arrow_dataset); @@ -39,10 +40,7 @@ gadataset_dataset_new_raw_valist( std::shared_ptr gadataset_dataset_get_raw(GADatasetDataset *dataset); -GADatasetFileFormat * -gadataset_file_format_new_raw( - std::shared_ptr *arrow_format); -std::shared_ptr -gadataset_dataset_get_raw(GADatasetDataset *dataset); - +arrow::dataset::FileSystemDatasetWriteOptions * +gadataset_file_system_dataset_write_options_get_raw( + GADatasetFileSystemDatasetWriteOptions *options); diff --git a/c_glib/arrow-dataset-glib/enums.c.template b/c_glib/arrow-dataset-glib/enums.c.template new file mode 100644 index 00000000000..8921ab06252 --- /dev/null +++ b/c_glib/arrow-dataset-glib/enums.c.template @@ -0,0 +1,52 @@ +/*** BEGIN file-header ***/ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +/*** END file-header ***/ + +/*** BEGIN file-production ***/ + +/* enumerations from "@filename@" */ +/*** END file-production ***/ + +/*** BEGIN value-header ***/ +GType +@enum_name@_get_type(void) +{ + static GType etype = 0; + if (G_UNLIKELY(etype == 0)) { + static const G@Type@Value values[] = { +/*** END value-header ***/ + +/*** BEGIN value-production ***/ + {@VALUENAME@, "@VALUENAME@", "@valuenick@"}, +/*** END value-production ***/ + +/*** BEGIN value-tail ***/ + {0, NULL, NULL} + }; + etype = g_@type@_register_static(g_intern_static_string("@EnumName@"), values); + } + return etype; +} +/*** END value-tail ***/ + +/*** BEGIN file-tail ***/ +/*** END file-tail ***/ diff --git a/c_glib/arrow-dataset-glib/enums.h.template b/c_glib/arrow-dataset-glib/enums.h.template new file mode 100644 index 00000000000..d6a0a455f5a --- /dev/null +++ b/c_glib/arrow-dataset-glib/enums.h.template @@ -0,0 +1,41 @@ +/*** BEGIN file-header ***/ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +G_BEGIN_DECLS +/*** END file-header ***/ + +/*** BEGIN file-production ***/ + +/* enumerations from "@filename@" */ +/*** END file-production ***/ + +/*** BEGIN value-header ***/ +GType @enum_name@_get_type(void) G_GNUC_CONST; +#define @ENUMPREFIX@_TYPE_@ENUMSHORT@ (@enum_name@_get_type()) +/*** END value-header ***/ + +/*** BEGIN file-tail ***/ + +G_END_DECLS +/*** END file-tail ***/ diff --git a/c_glib/arrow-dataset-glib/file-format.cpp b/c_glib/arrow-dataset-glib/file-format.cpp index 43f6a198f23..c0c92d966f8 100644 --- a/c_glib/arrow-dataset-glib/file-format.cpp +++ b/c_glib/arrow-dataset-glib/file-format.cpp @@ -18,6 +18,11 @@ */ #include +#include +#include +#include +#include +#include #include @@ -29,6 +34,11 @@ G_BEGIN_DECLS * @title: File format classes * @include: arrow-dataset-glib/arrow-dataset-glib.h * + * #GADatasetFileWriteOptions is a class for options to write a file + * of this format. + * + * #GADatasetFileWriter is a class for writing a file of this format. + * * #GADatasetFileFormat is a base class for file format classes. * * #GADatasetCSVFileFormat is a class for CSV file format. @@ -40,12 +50,218 @@ G_BEGIN_DECLS * Since: 3.0.0 */ +typedef struct GADatasetFileWriteOptionsPrivate_ { + std::shared_ptr options; +} GADatasetFileWriteOptionsPrivate; + +enum { + PROP_OPTIONS = 1, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GADatasetFileWriteOptions, + gadataset_file_write_options, + G_TYPE_OBJECT) + +#define GADATASET_FILE_WRITE_OPTIONS_GET_PRIVATE(obj) \ + static_cast( \ + gadataset_file_write_options_get_instance_private( \ + GADATASET_FILE_WRITE_OPTIONS(obj))) + +static void +gadataset_file_write_options_finalize(GObject *object) +{ + auto priv = GADATASET_FILE_WRITE_OPTIONS_GET_PRIVATE(object); + priv->options.~shared_ptr(); + G_OBJECT_CLASS(gadataset_file_write_options_parent_class)->finalize(object); +} + +static void +gadataset_file_write_options_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GADATASET_FILE_WRITE_OPTIONS_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_OPTIONS: + priv->options = + *static_cast *>( + g_value_get_pointer(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gadataset_file_write_options_init(GADatasetFileWriteOptions *object) +{ + auto priv = GADATASET_FILE_WRITE_OPTIONS_GET_PRIVATE(object); + new(&priv->options) std::shared_ptr; +} + +static void +gadataset_file_write_options_class_init(GADatasetFileWriteOptionsClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->finalize = gadataset_file_write_options_finalize; + gobject_class->set_property = gadataset_file_write_options_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer("options", + "Options", + "The raw " + "std::shared *", + static_cast(G_PARAM_WRITABLE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_OPTIONS, spec); +} + + +typedef struct GADatasetFileWriterPrivate_ { + std::shared_ptr writer; +} GADatasetFileWriterPrivate; + +enum { + PROP_WRITER = 1, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GADatasetFileWriter, + gadataset_file_writer, + G_TYPE_OBJECT) + +#define GADATASET_FILE_WRITER_GET_PRIVATE(obj) \ + static_cast( \ + gadataset_file_writer_get_instance_private( \ + GADATASET_FILE_WRITER(obj))) + +static void +gadataset_file_writer_finalize(GObject *object) +{ + auto priv = GADATASET_FILE_WRITER_GET_PRIVATE(object); + priv->writer.~shared_ptr(); + G_OBJECT_CLASS(gadataset_file_writer_parent_class)->finalize(object); +} + +static void +gadataset_file_writer_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GADATASET_FILE_WRITER_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_WRITER: + priv->writer = + *static_cast *>( + g_value_get_pointer(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gadataset_file_writer_init(GADatasetFileWriter *object) +{ + auto priv = GADATASET_FILE_WRITER_GET_PRIVATE(object); + new(&(priv->writer)) std::shared_ptr; +} + +static void +gadataset_file_writer_class_init(GADatasetFileWriterClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->finalize = gadataset_file_writer_finalize; + gobject_class->set_property = gadataset_file_writer_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer("writer", + "Writer", + "The raw " + "std::shared *", + static_cast(G_PARAM_WRITABLE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_WRITER, spec); +} + +/** + * gadataset_file_writer_write_record_batch: + * @writer: A #GADatasetFileWriter. + * @record_batch: A #GArrowRecordBatch to be written. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 6.0.0 + */ +gboolean +gadataset_file_writer_write_record_batch(GADatasetFileWriter *writer, + GArrowRecordBatch *record_batch, + GError **error) +{ + const auto arrow_writer = gadataset_file_writer_get_raw(writer); + const auto arrow_record_batch = garrow_record_batch_get_raw(record_batch); + auto status = arrow_writer->Write(arrow_record_batch); + return garrow::check(error, status, "[file-writer][write-record-batch]"); +} + +/** + * gadataset_file_writer_write_record_batch_reader: + * @writer: A #GADatasetFileWriter. + * @reader: A #GArrowRecordBatchReader to be written. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 6.0.0 + */ +gboolean +gadataset_file_writer_write_record_batch_reader(GADatasetFileWriter *writer, + GArrowRecordBatchReader *reader, + GError **error) +{ + const auto arrow_writer = gadataset_file_writer_get_raw(writer); + auto arrow_reader = garrow_record_batch_reader_get_raw(reader); + auto status = arrow_writer->Write(arrow_reader.get()); + return garrow::check(error, + status, + "[file-writer][write-record-batch-reader]"); +} + +/** + * gadataset_file_writer_finish: + * @writer: A #GADatasetFileWriter. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 6.0.0 + */ +gboolean +gadataset_file_writer_finish(GADatasetFileWriter *writer, + GError **error) +{ + const auto arrow_writer = gadataset_file_writer_get_raw(writer); + auto status = arrow_writer->Finish(); + return garrow::check(error, + status, + "[file-writer][finish]"); +} + + typedef struct GADatasetFileFormatPrivate_ { - std::shared_ptr file_format; + std::shared_ptr format; } GADatasetFileFormatPrivate; enum { - PROP_FILE_FORMAT = 1, + PROP_FORMAT = 1, }; G_DEFINE_TYPE_WITH_PRIVATE(GADatasetFileFormat, @@ -61,9 +277,7 @@ static void gadataset_file_format_finalize(GObject *object) { auto priv = GADATASET_FILE_FORMAT_GET_PRIVATE(object); - - priv->file_format.~shared_ptr(); - + priv->format.~shared_ptr(); G_OBJECT_CLASS(gadataset_file_format_parent_class)->finalize(object); } @@ -76,8 +290,8 @@ gadataset_file_format_set_property(GObject *object, auto priv = GADATASET_FILE_FORMAT_GET_PRIVATE(object); switch (prop_id) { - case PROP_FILE_FORMAT: - priv->file_format = + case PROP_FORMAT: + priv->format = *static_cast *>( g_value_get_pointer(value)); break; @@ -91,7 +305,7 @@ static void gadataset_file_format_init(GADatasetFileFormat *object) { auto priv = GADATASET_FILE_FORMAT_GET_PRIVATE(object); - new(&priv->file_format) std::shared_ptr; + new(&priv->format) std::shared_ptr; } static void @@ -103,49 +317,106 @@ gadataset_file_format_class_init(GADatasetFileFormatClass *klass) gobject_class->set_property = gadataset_file_format_set_property; GParamSpec *spec; - spec = g_param_spec_pointer("file-format", - "FileFormat", + spec = g_param_spec_pointer("format", + "Format", "The raw std::shared *", static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_FILE_FORMAT, spec); + g_object_class_install_property(gobject_class, PROP_FORMAT, spec); } /** * gadataset_file_format_get_type_name: - * @file_format: A #GADatasetFileFormat. + * @format: A #GADatasetFileFormat. * - * Returns: The type name of @file_format. + * Returns: The type name of @format. * * It should be freed with g_free() when no longer needed. * * Since: 3.0.0 */ gchar * -gadataset_file_format_get_type_name(GADatasetFileFormat *file_format) +gadataset_file_format_get_type_name(GADatasetFileFormat *format) { - const auto arrow_file_format = gadataset_file_format_get_raw(file_format); - const auto &type_name = arrow_file_format->type_name(); + const auto arrow_format = gadataset_file_format_get_raw(format); + const auto &type_name = arrow_format->type_name(); return g_strndup(type_name.data(), type_name.size()); } +/** + * gadataset_file_format_get_default_write_options: + * @format: A #GADatasetFileFormat. + * + * Returns: (transfer full): The default #GADatasetFileWriteOptions of @format. + * + * Since: 6.0.0 + */ +GADatasetFileWriteOptions * +gadataset_file_format_get_default_write_options(GADatasetFileFormat *format) +{ + const auto arrow_format = gadataset_file_format_get_raw(format); + auto arrow_options = arrow_format->DefaultWriteOptions(); + return gadataset_file_write_options_new_raw(&arrow_options); +} + +/** + * gadataset_file_format_open_writer: + * @format: A #GADatasetFileFormat. + * @destination: A #GArrowOutputStream. + * @file_system: The #GArrowFileSystem of @destination. + * @path: The path of @destination. + * @schema: A #GArrowSchema that is used by written record batches. + * @options: A #GADatasetFileWriteOptions. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: (transfer full): The newly created #GADatasetFileWriter of @format + * on success, %NULL on error. + * + * Since: 6.0.0 + */ +GADatasetFileWriter * +gadataset_file_format_open_writer(GADatasetFileFormat *format, + GArrowOutputStream *destination, + GArrowFileSystem *file_system, + const gchar *path, + GArrowSchema *schema, + GADatasetFileWriteOptions *options, + GError **error) +{ + const auto arrow_format = gadataset_file_format_get_raw(format); + auto arrow_destination = garrow_output_stream_get_raw(destination); + auto arrow_file_system = garrow_file_system_get_raw(file_system); + auto arrow_schema = garrow_schema_get_raw(schema); + auto arrow_options = gadataset_file_write_options_get_raw(options); + auto arrow_writer_result = + arrow_format->MakeWriter(arrow_destination, + arrow_schema, + arrow_options, + {arrow_file_system, path}); + if (garrow::check(error, arrow_writer_result, "[file-format][open-writer]")) { + auto arrow_writer = *arrow_writer_result; + return gadataset_file_writer_new_raw(&arrow_writer); + } else { + return NULL; + } +} + /** * gadataset_file_format_equal: - * @file_format: A #GADatasetFileFormat. - * @other_file_format: A #GADatasetFileFormat to be compared. + * @format: A #GADatasetFileFormat. + * @other_format: A #GADatasetFileFormat to be compared. * * Returns: %TRUE if they are the same content file format, %FALSE otherwise. * * Since: 3.0.0 */ gboolean -gadataset_file_format_equal(GADatasetFileFormat *file_format, - GADatasetFileFormat *other_file_format) +gadataset_file_format_equal(GADatasetFileFormat *format, + GADatasetFileFormat *other_format) { - const auto arrow_file_format = gadataset_file_format_get_raw(file_format); - const auto arrow_other_file_format = - gadataset_file_format_get_raw(other_file_format); - return arrow_file_format->Equals(*arrow_other_file_format); + const auto arrow_format = gadataset_file_format_get_raw(format); + const auto arrow_other_format = gadataset_file_format_get_raw(other_format); + return arrow_format->Equals(*arrow_other_format); } @@ -173,10 +444,9 @@ gadataset_csv_file_format_class_init(GADatasetCSVFileFormatClass *klass) GADatasetCSVFileFormat * gadataset_csv_file_format_new(void) { - std::shared_ptr arrow_file_format = + std::shared_ptr arrow_format = std::make_shared(); - return GADATASET_CSV_FILE_FORMAT( - gadataset_file_format_new_raw(&arrow_file_format)); + return GADATASET_CSV_FILE_FORMAT(gadataset_file_format_new_raw(&arrow_format)); } @@ -204,10 +474,9 @@ gadataset_ipc_file_format_class_init(GADatasetIPCFileFormatClass *klass) GADatasetIPCFileFormat * gadataset_ipc_file_format_new(void) { - std::shared_ptr arrow_file_format = + std::shared_ptr arrow_format = std::make_shared(); - return GADATASET_IPC_FILE_FORMAT( - gadataset_file_format_new_raw(&arrow_file_format)); + return GADATASET_IPC_FILE_FORMAT(gadataset_file_format_new_raw(&arrow_format)); } @@ -235,21 +504,56 @@ gadataset_parquet_file_format_class_init(GADatasetParquetFileFormatClass *klass) GADatasetParquetFileFormat * gadataset_parquet_file_format_new(void) { - std::shared_ptr arrow_file_format = + std::shared_ptr arrow_format = std::make_shared(); return GADATASET_PARQUET_FILE_FORMAT( - gadataset_file_format_new_raw(&arrow_file_format)); + gadataset_file_format_new_raw(&arrow_format)); } G_END_DECLS +GADatasetFileWriteOptions * +gadataset_file_write_options_new_raw( + std::shared_ptr *arrow_options) +{ + return GADATASET_FILE_WRITE_OPTIONS( + g_object_new(GADATASET_TYPE_FILE_WRITE_OPTIONS, + "options", arrow_options, + NULL)); +} + +std::shared_ptr +gadataset_file_write_options_get_raw(GADatasetFileWriteOptions *options) +{ + auto priv = GADATASET_FILE_WRITE_OPTIONS_GET_PRIVATE(options); + return priv->options; +} + + +GADatasetFileWriter * +gadataset_file_writer_new_raw( + std::shared_ptr *arrow_writer) +{ + return GADATASET_FILE_WRITER(g_object_new(GADATASET_TYPE_FILE_WRITER, + "writer", arrow_writer, + NULL)); +} + +std::shared_ptr +gadataset_file_writer_get_raw(GADatasetFileWriter *writer) +{ + auto priv = GADATASET_FILE_WRITER_GET_PRIVATE(writer); + return priv->writer; +} + + GADatasetFileFormat * gadataset_file_format_new_raw( - std::shared_ptr *arrow_file_format) + std::shared_ptr *arrow_format) { GType type = GADATASET_TYPE_FILE_FORMAT; - const auto &type_name = (*arrow_file_format)->type_name(); + const auto &type_name = (*arrow_format)->type_name(); if (type_name == "csv") { type = GADATASET_TYPE_CSV_FILE_FORMAT; } else if (type_name == "ipc") { @@ -258,13 +562,13 @@ gadataset_file_format_new_raw( type = GADATASET_TYPE_PARQUET_FILE_FORMAT; } return GADATASET_FILE_FORMAT(g_object_new(type, - "file-format", arrow_file_format, + "format", arrow_format, NULL)); } std::shared_ptr -gadataset_file_format_get_raw(GADatasetFileFormat *file_format) +gadataset_file_format_get_raw(GADatasetFileFormat *format) { - auto priv = GADATASET_FILE_FORMAT_GET_PRIVATE(file_format); - return priv->file_format; + auto priv = GADATASET_FILE_FORMAT_GET_PRIVATE(format); + return priv->format; } diff --git a/c_glib/arrow-dataset-glib/file-format.h b/c_glib/arrow-dataset-glib/file-format.h index 7a6f46f56e9..16a8340747c 100644 --- a/c_glib/arrow-dataset-glib/file-format.h +++ b/c_glib/arrow-dataset-glib/file-format.h @@ -23,6 +23,47 @@ G_BEGIN_DECLS +#define GADATASET_TYPE_FILE_WRITE_OPTIONS \ + (gadataset_file_write_options_get_type()) +G_DECLARE_DERIVABLE_TYPE(GADatasetFileWriteOptions, + gadataset_file_write_options, + GADATASET, + FILE_WRITE_OPTIONS, + GObject) +struct _GADatasetFileWriteOptionsClass +{ + GObjectClass parent_class; +}; + + +#define GADATASET_TYPE_FILE_WRITER \ + (gadataset_file_writer_get_type()) +G_DECLARE_DERIVABLE_TYPE(GADatasetFileWriter, + gadataset_file_writer, + GADATASET, + FILE_WRITER, + GObject) +struct _GADatasetFileWriterClass +{ + GObjectClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +gboolean +gadataset_file_writer_write_record_batch(GADatasetFileWriter *writer, + GArrowRecordBatch *record_batch, + GError **error); +GARROW_AVAILABLE_IN_6_0 +gboolean +gadataset_file_writer_write_record_batch_reader(GADatasetFileWriter *writer, + GArrowRecordBatchReader *reader, + GError **error); +GARROW_AVAILABLE_IN_6_0 +gboolean +gadataset_file_writer_finish(GADatasetFileWriter *writer, + GError **error); + + #define GADATASET_TYPE_FILE_FORMAT (gadataset_file_format_get_type()) G_DECLARE_DERIVABLE_TYPE(GADatasetFileFormat, gadataset_file_format, @@ -36,12 +77,24 @@ struct _GADatasetFileFormatClass GARROW_AVAILABLE_IN_3_0 gchar * -gadataset_file_format_get_type_name(GADatasetFileFormat *file_format); +gadataset_file_format_get_type_name(GADatasetFileFormat *format); +GARROW_AVAILABLE_IN_6_0 +GADatasetFileWriteOptions * +gadataset_file_format_get_default_write_options(GADatasetFileFormat *format); +GARROW_AVAILABLE_IN_6_0 +GADatasetFileWriter * +gadataset_file_format_open_writer(GADatasetFileFormat *format, + GArrowOutputStream *destination, + GArrowFileSystem *file_system, + const gchar *path, + GArrowSchema *schema, + GADatasetFileWriteOptions *options, + GError **error); GARROW_AVAILABLE_IN_3_0 gboolean -gadataset_file_format_equal(GADatasetFileFormat *file_format, - GADatasetFileFormat *other_file_format); +gadataset_file_format_equal(GADatasetFileFormat *format, + GADatasetFileFormat *other_format); #define GADATASET_TYPE_CSV_FILE_FORMAT (gadataset_csv_file_format_get_type()) diff --git a/c_glib/arrow-dataset-glib/file-format.hpp b/c_glib/arrow-dataset-glib/file-format.hpp index 5dfb20b3caa..636dc5c015b 100644 --- a/c_glib/arrow-dataset-glib/file-format.hpp +++ b/c_glib/arrow-dataset-glib/file-format.hpp @@ -23,8 +23,22 @@ #include +GADatasetFileWriteOptions * +gadataset_file_write_options_new_raw( + std::shared_ptr *arrow_options); +std::shared_ptr +gadataset_file_write_options_get_raw(GADatasetFileWriteOptions *options); + + +GADatasetFileWriter * +gadataset_file_writer_new_raw( + std::shared_ptr *arrow_writer); +std::shared_ptr +gadataset_file_writer_get_raw(GADatasetFileWriter *writer); + + GADatasetFileFormat * gadataset_file_format_new_raw( - std::shared_ptr *arrow_file_format); + std::shared_ptr *arrow_format); std::shared_ptr -gadataset_file_format_get_raw(GADatasetFileFormat *file_format); +gadataset_file_format_get_raw(GADatasetFileFormat *format); diff --git a/c_glib/arrow-dataset-glib/meson.build b/c_glib/arrow-dataset-glib/meson.build index b3f617330cf..0d9b8564ecb 100644 --- a/c_glib/arrow-dataset-glib/meson.build +++ b/c_glib/arrow-dataset-glib/meson.build @@ -22,6 +22,7 @@ sources = files( 'dataset.cpp', 'file-format.cpp', 'fragment.cpp', + 'partitioning.cpp', 'scanner.cpp', ) @@ -31,6 +32,7 @@ c_headers = files( 'dataset.h', 'file-format.h', 'fragment.h', + 'partitioning.h', 'scanner.h', ) @@ -40,9 +42,22 @@ cpp_headers = files( 'dataset.hpp', 'file-format.hpp', 'fragment.hpp', + 'partitioning.hpp', 'scanner.hpp', ) +enums = gnome.mkenums('enums', + sources: c_headers, + identifier_prefix: 'GADataset', + symbol_prefix: 'gadataset', + c_template: 'enums.c.template', + h_template: 'enums.h.template', + install_dir: join_paths(include_dir, meson.project_name()), + install_header: true) +enums_source = enums[0] +enums_header = enums[1] + + headers = c_headers + cpp_headers install_headers(headers, subdir: 'arrow-dataset-glib') @@ -51,7 +66,7 @@ dependencies = [ arrow_glib, ] libarrow_dataset_glib = library('arrow-dataset-glib', - sources: sources, + sources: sources + enums, install: true, dependencies: dependencies, include_directories: base_include_directories, @@ -59,7 +74,8 @@ libarrow_dataset_glib = library('arrow-dataset-glib', version: library_version) arrow_dataset_glib = declare_dependency(link_with: libarrow_dataset_glib, include_directories: base_include_directories, - dependencies: dependencies) + dependencies: dependencies, + sources: enums_header) pkgconfig.generate(libarrow_dataset_glib, filebase: 'arrow-dataset-glib', @@ -71,7 +87,7 @@ pkgconfig.generate(libarrow_dataset_glib, if have_gi gnome.generate_gir(libarrow_dataset_glib, dependencies: declare_dependency(sources: arrow_glib_gir), - sources: sources + c_headers, + sources: sources + c_headers + enums, namespace: 'ArrowDataset', nsversion: api_version, identifier_prefix: 'GADataset', diff --git a/c_glib/arrow-dataset-glib/partitioning.cpp b/c_glib/arrow-dataset-glib/partitioning.cpp new file mode 100644 index 00000000000..bce33671a35 --- /dev/null +++ b/c_glib/arrow-dataset-glib/partitioning.cpp @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include + +G_BEGIN_DECLS + +/** + * SECTION: partitioning + * @section_id: partitioning + * @title: Partitioning classes + * @include: arrow-dataset-glib/arrow-dataset-glib.h + * + * #GADatasetPartitioningOptions is a class for partitioning options. + * + * #GADatasetPartitioning is a base class for partitioning classes + * such as #GADatasetDirectoryPartitioning. + * + * #GADatasetKeyValuePartitioning is a base class for key-value style + * partitioning classes such as #GADatasetDirectoryPartitioning. + * + * #GADatasetDirectoryPartitioning is a class for partitioning that + * uses directory structure. + * + * Since: 6.0.0 + */ + +typedef struct GADatasetPartitioningOptionsPrivate_ { + gboolean infer_dictionary; + GArrowSchema *schema; + GADatasetSegmentEncoding segment_encoding; +} GADatasetPartitioningOptionsPrivate; + +enum { + PROP_INFER_DICTIONARY = 1, + PROP_SCHEMA, + PROP_SEGMENT_ENCODING, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GADatasetPartitioningOptions, + gadataset_partitioning_options, + G_TYPE_OBJECT) + +#define GADATASET_PARTITIONING_OPTIONS_GET_PRIVATE(obj) \ + static_cast( \ + gadataset_partitioning_options_get_instance_private( \ + GADATASET_PARTITIONING_OPTIONS(obj))) + +static void +gadataset_partitioning_options_dispose(GObject *object) +{ + auto priv = GADATASET_PARTITIONING_OPTIONS_GET_PRIVATE(object); + + if (priv->schema) { + g_object_unref(priv->schema); + priv->schema = nullptr; + } + + G_OBJECT_CLASS(gadataset_partitioning_options_parent_class)->dispose(object); +} + +static void +gadataset_partitioning_options_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GADATASET_PARTITIONING_OPTIONS_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_INFER_DICTIONARY: + priv->infer_dictionary = g_value_get_boolean(value); + break; + case PROP_SCHEMA: + { + auto schema = g_value_get_object(value); + if (priv->schema == schema) { + break; + } + auto old_schema = priv->schema; + if (schema) { + g_object_ref(schema); + priv->schema = GARROW_SCHEMA(schema); + } else { + priv->schema = NULL; + } + if (old_schema) { + g_object_unref(old_schema); + } + } + break; + case PROP_SEGMENT_ENCODING: + priv->segment_encoding = + static_cast(g_value_get_enum(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gadataset_partitioning_options_get_property(GObject *object, + guint prop_id, + GValue *value, + GParamSpec *pspec) +{ + auto priv = GADATASET_PARTITIONING_OPTIONS_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_INFER_DICTIONARY: + g_value_set_boolean(value, priv->infer_dictionary); + break; + case PROP_SCHEMA: + g_value_set_object(value, priv->schema); + break; + case PROP_SEGMENT_ENCODING: + g_value_set_enum(value, priv->segment_encoding); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gadataset_partitioning_options_init(GADatasetPartitioningOptions *object) +{ +} + +static void +gadataset_partitioning_options_class_init( + GADatasetPartitioningOptionsClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->dispose = gadataset_partitioning_options_dispose; + gobject_class->set_property = gadataset_partitioning_options_set_property; + gobject_class->get_property = gadataset_partitioning_options_get_property; + + arrow::dataset::PartitioningFactoryOptions default_options; + GParamSpec *spec; + /** + * GADatasetPartitioningOptions:infer-dictionary: + * + * When inferring a schema for partition fields, yield dictionary + * encoded types instead of plain. This can be more efficient when + * materializing virtual columns, and Expressions parsed by the + * finished Partitioning will include dictionaries of all unique + * inspected values for each field. + * + * Since: 6.0.0 + */ + spec = g_param_spec_boolean("infer-dictionary", + "Infer dictionary", + "Whether encode partitioned field values as " + "dictionary", + default_options.infer_dictionary, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_INFER_DICTIONARY, spec); + + /** + * GADatasetPartitioningOptions:schema: + * + * Optionally, an expected schema can be provided, in which case + * inference will only check discovered fields against the schema + * and update internal state (such as dictionaries). + * + * Since: 6.0.0 + */ + spec = g_param_spec_object("schema", + "Schema", + "Inference will only check discovered fields " + "against the schema and update internal state", + GARROW_TYPE_SCHEMA, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_SCHEMA, spec); + + /** + * GADatasetPartitioningOptions:segment-encoding: + * + * After splitting a path into components, decode the path + * components before parsing according to this scheme. + * + * Since: 6.0.0 + */ + spec = g_param_spec_enum("segment-encoding", + "Segment encoding", + "After splitting a path into components, " + "decode the path components before " + "parsing according to this scheme", + GADATASET_TYPE_SEGMENT_ENCODING, + static_cast( + default_options.segment_encoding), + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_SEGMENT_ENCODING, spec); +} + +/** + * gadataset_partitioning_options_new: + * + * Returns: The newly created #GADatasetPartitioningOptions. + * + * Since: 6.0.0 + */ +GADatasetPartitioningOptions * +gadataset_partitioning_options_new(void) +{ + return GADATASET_PARTITIONING_OPTIONS( + g_object_new(GADATASET_TYPE_PARTITIONING_OPTIONS, + NULL)); +} + + +typedef struct GADatasetPartitioningPrivate_ { + std::shared_ptr partitioning; +} GADatasetPartitioningPrivate; + +enum { + PROP_PARTITIONING = 1, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GADatasetPartitioning, + gadataset_partitioning, + G_TYPE_OBJECT) + +#define GADATASET_PARTITIONING_GET_PRIVATE(obj) \ + static_cast( \ + gadataset_partitioning_get_instance_private( \ + GADATASET_PARTITIONING(obj))) + +static void +gadataset_partitioning_finalize(GObject *object) +{ + auto priv = GADATASET_PARTITIONING_GET_PRIVATE(object); + priv->partitioning.~shared_ptr(); + G_OBJECT_CLASS(gadataset_partitioning_parent_class)->finalize(object); +} + +static void +gadataset_partitioning_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GADATASET_PARTITIONING_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_PARTITIONING: + priv->partitioning = + *static_cast *>( + g_value_get_pointer(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gadataset_partitioning_init(GADatasetPartitioning *object) +{ + auto priv = GADATASET_PARTITIONING_GET_PRIVATE(object); + new(&priv->partitioning) std::shared_ptr; +} + +static void +gadataset_partitioning_class_init(GADatasetPartitioningClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->finalize = gadataset_partitioning_finalize; + gobject_class->set_property = gadataset_partitioning_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer("partitioning", + "Partitioning", + "The raw " + "std::shared *", + static_cast(G_PARAM_WRITABLE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_PARTITIONING, spec); +} + +/** + * gadataset_partitioning_new: + * + * Returns: The newly created #GADatasetPartitioning that doesn't + * partition. + * + * Since: 6.0.0 + */ +GADatasetPartitioning * +gadataset_partitioning_new(void) +{ + auto arrow_partitioning = arrow::dataset::Partitioning::Default(); + return GADATASET_PARTITIONING( + g_object_new(GADATASET_TYPE_PARTITIONING, + "partitioning", &arrow_partitioning, + NULL)); +} + +/** + * gadataset_partitioning_get_type_name: + * @partitioning: A #GADatasetPartitioning. + * + * Returns: The type name of @partitioning. + * + * It should be freed with g_free() when no longer needed. + * + * Since: 6.0.0 + */ +gchar * +gadataset_partitioning_get_type_name(GADatasetPartitioning *partitioning) +{ + auto arrow_partitioning = gadataset_partitioning_get_raw(partitioning); + auto arrow_type_name = arrow_partitioning->type_name(); + return g_strndup(arrow_type_name.c_str(), + arrow_type_name.size()); +} + + +G_DEFINE_TYPE(GADatasetKeyValuePartitioning, + gadataset_key_value_partitioning, + GADATASET_TYPE_PARTITIONING) + +static void +gadataset_key_value_partitioning_init(GADatasetKeyValuePartitioning *object) +{ +} + +static void +gadataset_key_value_partitioning_class_init( + GADatasetKeyValuePartitioningClass *klass) +{ +} + + +G_DEFINE_TYPE(GADatasetDirectoryPartitioning, + gadataset_directory_partitioning, + GADATASET_TYPE_KEY_VALUE_PARTITIONING) + +static void +gadataset_directory_partitioning_init(GADatasetDirectoryPartitioning *object) +{ +} + +static void +gadataset_directory_partitioning_class_init( + GADatasetDirectoryPartitioningClass *klass) +{ +} + +/** + * gadataset_directory_partitioning_new: + * @schema: A #GArrowSchema that describes all partitioned segments. + * @dictionaries: (nullable) (element-type GArrowArray): A list of #GArrowArray + * for dictionary data types in @schema. + * @options: (nullable): A #GADatasetPartitioningOptions. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: The newly created #GADatasetDirectoryPartitioning on success, + * %NULL on error. + * + * Since: 6.0.0 + */ +GADatasetDirectoryPartitioning * +gadataset_directory_partitioning_new(GArrowSchema *schema, + GList *dictionaries, + GADatasetPartitioningOptions *options, + GError **error) +{ + auto arrow_schema = garrow_schema_get_raw(schema); + std::vector> arrow_dictionaries; + for (auto node = dictionaries; node; node = node->next) { + auto dictionary = GARROW_ARRAY(node->data); + if (dictionary) { + arrow_dictionaries.push_back(garrow_array_get_raw(dictionary)); + } else { + arrow_dictionaries.push_back(nullptr); + } + } + arrow::dataset::KeyValuePartitioningOptions arrow_options; + if (options) { + arrow_options = + gadataset_partitioning_options_get_raw_key_value_partitioning_options( + options); + } + auto arrow_partitioning = + std::make_shared( + arrow_schema, + arrow_dictionaries, + arrow_options); + return GADATASET_DIRECTORY_PARTITIONING( + g_object_new(GADATASET_TYPE_DIRECTORY_PARTITIONING, + "partitioning", &arrow_partitioning, + NULL)); +} + + +G_END_DECLS + +arrow::dataset::KeyValuePartitioningOptions +gadataset_partitioning_options_get_raw_key_value_partitioning_options( + GADatasetPartitioningOptions *options) +{ + auto priv = GADATASET_PARTITIONING_OPTIONS_GET_PRIVATE(options); + arrow::dataset::KeyValuePartitioningOptions arrow_options; + arrow_options.segment_encoding = + static_cast(priv->segment_encoding); + return arrow_options; +} + +std::shared_ptr +gadataset_partitioning_get_raw(GADatasetPartitioning *partitioning) +{ + auto priv = GADATASET_PARTITIONING_GET_PRIVATE(partitioning); + return priv->partitioning; +} diff --git a/c_glib/arrow-dataset-glib/partitioning.h b/c_glib/arrow-dataset-glib/partitioning.h new file mode 100644 index 00000000000..d408d9bd502 --- /dev/null +++ b/c_glib/arrow-dataset-glib/partitioning.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +G_BEGIN_DECLS + +/** + * GADatasetSegmentEncoding + * @GADATASET_SEGMENT_ENCODING_NONE: No encoding. + * @GADATASET_SEGMENT_ENCODING_URI: Segment values are URL-encoded. + * + * They are corresponding to `arrow::dataset::SegmentEncoding` values. + * + * Since: 6.0.0 + */ +typedef enum { + GADATASET_SEGMENT_ENCODING_NONE, + GADATASET_SEGMENT_ENCODING_URI, +} GADatasetSegmentEncoding; + + +#define GADATASET_TYPE_PARTITIONING_OPTIONS \ + (gadataset_partitioning_options_get_type()) +G_DECLARE_DERIVABLE_TYPE(GADatasetPartitioningOptions, + gadataset_partitioning_options, + GADATASET, + PARTITIONING_OPTIONS, + GObject) +struct _GADatasetPartitioningOptionsClass +{ + GObjectClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +GADatasetPartitioningOptions * +gadataset_partitioning_options_new(void); + + +#define GADATASET_TYPE_PARTITIONING (gadataset_partitioning_get_type()) +G_DECLARE_DERIVABLE_TYPE(GADatasetPartitioning, + gadataset_partitioning, + GADATASET, + PARTITIONING, + GObject) +struct _GADatasetPartitioningClass +{ + GObjectClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +GADatasetPartitioning * +gadataset_partitioning_new(void); +GARROW_AVAILABLE_IN_6_0 +gchar * +gadataset_partitioning_get_type_name(GADatasetPartitioning *partitioning); + + +#define GADATASET_TYPE_KEY_VALUE_PARTITIONING \ + (gadataset_key_value_partitioning_get_type()) +G_DECLARE_DERIVABLE_TYPE(GADatasetKeyValuePartitioning, + gadataset_key_value_partitioning, + GADATASET, + KEY_VALUE_PARTITIONING, + GADatasetPartitioning) +struct _GADatasetKeyValuePartitioningClass +{ + GADatasetPartitioningClass parent_class; +}; + + +#define GADATASET_TYPE_DIRECTORY_PARTITIONING \ + (gadataset_directory_partitioning_get_type()) +G_DECLARE_DERIVABLE_TYPE(GADatasetDirectoryPartitioning, + gadataset_directory_partitioning, + GADATASET, + DIRECTORY_PARTITIONING, + GADatasetKeyValuePartitioning) +struct _GADatasetDirectoryPartitioningClass +{ + GADatasetKeyValuePartitioningClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +GADatasetDirectoryPartitioning * +gadataset_directory_partitioning_new(GArrowSchema *schema, + GList *dictionaries, + GADatasetPartitioningOptions *options, + GError **error); + + +G_END_DECLS diff --git a/c_glib/arrow-dataset-glib/partitioning.hpp b/c_glib/arrow-dataset-glib/partitioning.hpp new file mode 100644 index 00000000000..2481ecb3340 --- /dev/null +++ b/c_glib/arrow-dataset-glib/partitioning.hpp @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include + +arrow::dataset::KeyValuePartitioningOptions +gadataset_partitioning_options_get_raw_key_value_partitioning_options( + GADatasetPartitioningOptions *options); + +std::shared_ptr +gadataset_partitioning_get_raw(GADatasetPartitioning *partitioning); diff --git a/c_glib/arrow-dataset-glib/scanner.cpp b/c_glib/arrow-dataset-glib/scanner.cpp index efa2a5c3287..ddd3fd88af7 100644 --- a/c_glib/arrow-dataset-glib/scanner.cpp +++ b/c_glib/arrow-dataset-glib/scanner.cpp @@ -18,6 +18,7 @@ */ #include +#include #include #include @@ -225,6 +226,24 @@ gadataset_scanner_builder_new(GADatasetDataset *dataset, GError **error) } } +/** + * gadataset_scanner_builder_new_record_batch_reader: + * @reader: A #GArrowRecordBatchReader that produces record batches. + * + * Returns: (nullable): A newly created #GADatasetScannerBuilder. + * + * Since: 6.0.0 + */ +GADatasetScannerBuilder * +gadataset_scanner_builder_new_record_batch_reader( + GArrowRecordBatchReader *reader) +{ + auto arrow_reader = garrow_record_batch_reader_get_raw(reader); + auto arrow_scanner_builder = + arrow::dataset::ScannerBuilder::FromRecordBatchReader(arrow_reader); + return gadataset_scanner_builder_new_raw(&arrow_scanner_builder); +} + /** * gadataset_scanner_builder_finish: * @builder: A #GADatasetScannerBuilder. diff --git a/c_glib/arrow-dataset-glib/scanner.h b/c_glib/arrow-dataset-glib/scanner.h index 446815d6db1..ba7f9c6b7c3 100644 --- a/c_glib/arrow-dataset-glib/scanner.h +++ b/c_glib/arrow-dataset-glib/scanner.h @@ -55,6 +55,10 @@ GARROW_AVAILABLE_IN_5_0 GADatasetScannerBuilder * gadataset_scanner_builder_new(GADatasetDataset *dataset, GError **error); +GARROW_AVAILABLE_IN_6_0 +GADatasetScannerBuilder * +gadataset_scanner_builder_new_record_batch_reader( + GArrowRecordBatchReader *reader); GARROW_AVAILABLE_IN_5_0 GADatasetScanner * gadataset_scanner_builder_finish(GADatasetScannerBuilder *builder, diff --git a/c_glib/arrow-glib/composite-data-type.cpp b/c_glib/arrow-glib/composite-data-type.cpp index 95cd283c1b3..fadcafe6b40 100644 --- a/c_glib/arrow-glib/composite-data-type.cpp +++ b/c_glib/arrow-glib/composite-data-type.cpp @@ -376,7 +376,7 @@ garrow_map_data_type_new(GArrowDataType *key_type, * garrow_map_data_type_get_key_type: * @map_data_type: A #GArrowMapDataType. * - * Return: (transfer full): The key type of the map. + * Returns: (transfer full): The key type of the map. * * Since: 0.17.0 */ @@ -395,7 +395,7 @@ garrow_map_data_type_get_key_type(GArrowMapDataType *map_data_type) * garrow_map_data_type_get_item_type: * @map_data_type: A #GArrowMapDataType. * - * Return: (transfer full): The item type of the map. + * Returns: (transfer full): The item type of the map. * * Since: 0.17.0 */ diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 8783510728a..2f4a0de215c 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -126,6 +126,8 @@ G_BEGIN_DECLS * #GArrowFunctionOptions is a base class for all function options * classes such as #GArrowCastOptions. * + * #GArrowFunctionDoc is a class for function document. + * * #GArrowFunction is a class to process data. * * #GArrowExecuteNodeOptions is a base class for all execute node @@ -165,6 +167,12 @@ G_BEGIN_DECLS * #GArrowSortOptions is a class to customize the `sort_indices` * function. * + * #GArrowSetLookupOptions is a class to customize the `is_in` function + * and `index_in` function. + * + * #GArrowVarianceOptions is a class to customize the `stddev` function + * and `variance` function. + * * There are many functions to compute data on an array. */ @@ -254,6 +262,145 @@ garrow_function_options_class_init(GArrowFunctionOptionsClass *klass) } +typedef struct GArrowFunctionDocPrivate_ { + arrow::compute::FunctionDoc *doc; +} GArrowFunctionDocPrivate; + +enum { + PROP_DOC = 1, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GArrowFunctionDoc, + garrow_function_doc, + G_TYPE_OBJECT) + +#define GARROW_FUNCTION_DOC_GET_PRIVATE(object) \ + static_cast( \ + garrow_function_doc_get_instance_private( \ + GARROW_FUNCTION_DOC(object))) + +static void +garrow_function_doc_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GARROW_FUNCTION_DOC_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_DOC: + priv->doc = + static_cast(g_value_get_pointer(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_function_doc_init(GArrowFunctionDoc *object) +{ +} + +static void +garrow_function_doc_class_init(GArrowFunctionDocClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + gobject_class->set_property = garrow_function_doc_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer("doc", + "Doc", + "The raw arrow::compute::FunctionDoc *", + static_cast(G_PARAM_WRITABLE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_DOC, spec); +} + +/** + * garrow_function_doc_get_summary: + * @doc: A #GArrowFunctionDoc. + * + * Returns: A one-line summary of the function, using a verb. + * + * It should be freed with g_free() when no longer needed. + * + * Since: 6.0.0 + */ +gchar * +garrow_function_doc_get_summary(GArrowFunctionDoc *doc) +{ + auto arrow_doc = garrow_function_doc_get_raw(doc); + return g_strndup(arrow_doc->summary.data(), + arrow_doc->summary.size()); +} + +/** + * garrow_function_doc_get_description: + * @doc: A #GArrowFunctionDoc. + * + * Returns: A detailed description of the function, meant to follow + * the summary. + * + * It should be freed with g_free() when no longer needed. + * + * Since: 6.0.0 + */ +gchar * +garrow_function_doc_get_description(GArrowFunctionDoc *doc) +{ + auto arrow_doc = garrow_function_doc_get_raw(doc); + return g_strndup(arrow_doc->description.data(), + arrow_doc->description.size()); +} + +/** + * garrow_function_doc_get_arg_names: + * @doc: A #GArrowFunctionDoc. + * + * Returns: (array zero-terminated=1) (element-type utf8) (transfer full): + * Symbolic names (identifiers) for the function arguments. + * + * It's a %NULL-terminated string array. It must be freed with + * g_strfreev() when no longer needed. + * + * Since: 6.0.0 + */ +gchar ** +garrow_function_doc_get_arg_names(GArrowFunctionDoc *doc) +{ + auto arrow_doc = garrow_function_doc_get_raw(doc); + const auto &arrow_arg_names = arrow_doc->arg_names; + auto n = arrow_arg_names.size(); + auto arg_names = g_new(gchar *, n + 1); + for (size_t i = 0; i < n; ++i) { + arg_names[i] = g_strndup(arrow_arg_names[i].data(), + arrow_arg_names[i].size()); + } + arg_names[n] = NULL; + return arg_names; +} + +/** + * garrow_function_doc_get_options_class_name: + * @doc: A #GArrowFunctionDoc. + * + * Returns: Name of the options class, if any. + * + * It should be freed with g_free() when no longer needed. + * + * Since: 6.0.0 + */ +gchar * +garrow_function_doc_get_options_class_name(GArrowFunctionDoc *doc) +{ + auto arrow_doc = garrow_function_doc_get_raw(doc); + return g_strndup(arrow_doc->options_class.data(), + arrow_doc->options_class.size()); +} + + typedef struct GArrowFunctionPrivate_ { std::shared_ptr function; } GArrowFunctionPrivate; @@ -397,6 +544,22 @@ garrow_function_execute(GArrowFunction *function, } } +/** + * garrow_function_get_doc: + * @function: A #GArrowFunction. + * + * Returns: (transfer full): The function documentation. + * + * Since: 6.0.0 + */ +GArrowFunctionDoc * +garrow_function_get_doc(GArrowFunction *function) +{ + auto arrow_function = garrow_function_get_raw(function); + const auto &arrow_doc = arrow_function->doc(); + return garrow_function_doc_new_raw(&arrow_doc); +} + typedef struct GArrowExecuteNodeOptionsPrivate_ { arrow::compute::ExecNodeOptions *options; @@ -1271,7 +1434,7 @@ garrow_execute_plan_build_source_node(GArrowExecutePlan *plan, /** * garrow_execute_plan_build_aggregate_node: * @plan: A #GArrowExecutePlan. - * @input: A #GarrowExecuteNode. + * @input: A #GArrowExecuteNode. * @options: A #GArrowAggregateNodeOptions. * @error: (nullable): Return location for a #GError or %NULL. * @@ -1304,7 +1467,7 @@ garrow_execute_plan_build_aggregate_node(GArrowExecutePlan *plan, /** * garrow_execute_plan_build_sink_node: * @plan: A #GArrowExecutePlan. - * @input: A #GarrowExecuteNode. + * @input: A #GArrowExecuteNode. * @options: A #GArrowSinkNodeOptions. * @error: (nullable): Return location for a #GError or %NULL. * @@ -2417,6 +2580,312 @@ garrow_sort_options_set_sort_keys(GArrowSortOptions *options, } +typedef struct GArrowSetLookupOptionsPrivate_ { + GArrowDatum *value_set; +} GArrowSetLookupOptionsPrivate; + +enum { + PROP_SET_LOOKUP_OPTIONS_VALUE_SET = 1, + PROP_SET_LOOKUP_OPTIONS_SKIP_NULLS, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GArrowSetLookupOptions, + garrow_set_lookup_options, + GARROW_TYPE_FUNCTION_OPTIONS) + +#define GARROW_SET_LOOKUP_OPTIONS_GET_PRIVATE(object) \ + static_cast( \ + garrow_set_lookup_options_get_instance_private( \ + GARROW_SET_LOOKUP_OPTIONS(object))) + +static void +garrow_set_lookup_options_dispose(GObject *object) +{ + auto priv = GARROW_SET_LOOKUP_OPTIONS_GET_PRIVATE(object); + + if (priv->value_set) { + g_object_unref(priv->value_set); + priv->value_set = NULL; + } + + G_OBJECT_CLASS(garrow_set_lookup_options_parent_class)->dispose(object); +} + +static void +garrow_set_lookup_options_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GARROW_SET_LOOKUP_OPTIONS_GET_PRIVATE(object); + auto options = + garrow_set_lookup_options_get_raw(GARROW_SET_LOOKUP_OPTIONS(object)); + + switch (prop_id) { + case PROP_SET_LOOKUP_OPTIONS_VALUE_SET: + priv->value_set = GARROW_DATUM(g_value_dup_object(value)); + options->value_set = garrow_datum_get_raw(priv->value_set); + break; + case PROP_SET_LOOKUP_OPTIONS_SKIP_NULLS: + options->skip_nulls = g_value_get_boolean(value); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_set_lookup_options_get_property(GObject *object, + guint prop_id, + GValue *value, + GParamSpec *pspec) +{ + auto priv = GARROW_SET_LOOKUP_OPTIONS_GET_PRIVATE(object); + auto options = + garrow_set_lookup_options_get_raw(GARROW_SET_LOOKUP_OPTIONS(object)); + + switch (prop_id) { + case PROP_SET_LOOKUP_OPTIONS_VALUE_SET: + g_value_set_object(value, priv->value_set); + break; + case PROP_SET_LOOKUP_OPTIONS_SKIP_NULLS: + g_value_set_boolean(value, options->skip_nulls); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_set_lookup_options_init(GArrowSetLookupOptions *object) +{ + auto priv = GARROW_FUNCTION_OPTIONS_GET_PRIVATE(object); + priv->options = static_cast( + new arrow::compute::SetLookupOptions()); +} + +static void +garrow_set_lookup_options_class_init(GArrowSetLookupOptionsClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->dispose = garrow_set_lookup_options_dispose; + gobject_class->set_property = garrow_set_lookup_options_set_property; + gobject_class->get_property = garrow_set_lookup_options_get_property; + + + arrow::compute::SetLookupOptions options; + + GParamSpec *spec; + /** + * GArrowSetLookupOptions:value-set: + * + * The set of values to look up input values into. + * + * Since: 6.0.0 + */ + spec = g_param_spec_object("value-set", + "Value set", + "The set of values to look up input values into", + GARROW_TYPE_DATUM, + static_cast(G_PARAM_READWRITE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, + PROP_SET_LOOKUP_OPTIONS_VALUE_SET, + spec); + + /** + * GArrowSetLookupOptions:skip-nulls: + * + * Whether NULLs are skipped or not. + * + * Since: 6.0.0 + */ + spec = g_param_spec_boolean("skip-nulls", + "Skip NULLs", + "Whether NULLs are skipped or not", + options.skip_nulls, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, + PROP_SET_LOOKUP_OPTIONS_SKIP_NULLS, + spec); +} + +/** + * garrow_set_lookup_options_new: + * @value_set: A #GArrowArrayDatum or #GArrowChunkedArrayDatum to be looked up. + * + * Returns: A newly created #GArrowSetLookupOptions. + * + * Since: 6.0.0 + */ +GArrowSetLookupOptions * +garrow_set_lookup_options_new(GArrowDatum *value_set) +{ + return GARROW_SET_LOOKUP_OPTIONS( + g_object_new(GARROW_TYPE_SET_LOOKUP_OPTIONS, + "value-set", value_set, + NULL)); +} + + +enum { + PROP_VARIANCE_OPTIONS_DDOF = 1, + PROP_VARIANCE_OPTIONS_SKIP_NULLS, + PROP_VARIANCE_OPTIONS_MIN_COUNT, +}; + +G_DEFINE_TYPE(GArrowVarianceOptions, + garrow_variance_options, + GARROW_TYPE_FUNCTION_OPTIONS) + +#define GARROW_VARIANCE_OPTIONS_GET_PRIVATE(object) \ + static_cast( \ + garrow_variance_options_get_instance_private( \ + GARROW_VARIANCE_OPTIONS(object))) + +static void +garrow_variance_options_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto options = + garrow_variance_options_get_raw(GARROW_VARIANCE_OPTIONS(object)); + + switch (prop_id) { + case PROP_VARIANCE_OPTIONS_DDOF: + options->ddof = g_value_get_int(value); + break; + case PROP_VARIANCE_OPTIONS_SKIP_NULLS: + options->skip_nulls = g_value_get_boolean(value); + break; + case PROP_VARIANCE_OPTIONS_MIN_COUNT: + options->min_count = g_value_get_uint(value); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_variance_options_get_property(GObject *object, + guint prop_id, + GValue *value, + GParamSpec *pspec) +{ + auto options = + garrow_variance_options_get_raw(GARROW_VARIANCE_OPTIONS(object)); + + switch (prop_id) { + case PROP_VARIANCE_OPTIONS_DDOF: + g_value_set_int(value, options->ddof); + break; + case PROP_VARIANCE_OPTIONS_SKIP_NULLS: + g_value_set_boolean(value, options->skip_nulls); + break; + case PROP_VARIANCE_OPTIONS_MIN_COUNT: + g_value_set_uint(value, options->min_count); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_variance_options_init(GArrowVarianceOptions *object) +{ + auto priv = GARROW_FUNCTION_OPTIONS_GET_PRIVATE(object); + priv->options = static_cast( + new arrow::compute::VarianceOptions()); +} + +static void +garrow_variance_options_class_init(GArrowVarianceOptionsClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->set_property = garrow_variance_options_set_property; + gobject_class->get_property = garrow_variance_options_get_property; + + + arrow::compute::VarianceOptions options; + + GParamSpec *spec; + /** + * GArrowVarianceOptions:ddof: + * + * The Delta Degrees of Freedom (ddof) to be used. + * + * Since: 6.0.0 + */ + spec = g_param_spec_int("ddof", + "Delta Degrees of Freedom", + "The Delta Degrees of Freedom (ddof) to be used", + G_MININT, + G_MAXINT, + options.ddof, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, + PROP_VARIANCE_OPTIONS_DDOF, + spec); + + /** + * GArrowVarianceOptions:skip-nulls: + * + * Whether NULLs are skipped or not. + * + * Since: 6.0.0 + */ + spec = g_param_spec_boolean("skip-nulls", + "Skip NULLs", + "Whether NULLs are skipped or not", + options.skip_nulls, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, + PROP_VARIANCE_OPTIONS_SKIP_NULLS, + spec); + + /** + * GArrowVarianceOptions:min-count: + * + * If less than this many non-null values are observed, emit null. + * + * Since: 6.0.0 + */ + spec = g_param_spec_uint("min-count", + "Min count", + "If less than this many non-null values " + "are observed, emit null", + 0, + G_MAXUINT, + options.min_count, + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, + PROP_VARIANCE_OPTIONS_MIN_COUNT, + spec); + +} + +/** + * garrow_variance_options_new: + * + * Returns: A newly created #GArrowVarianceOptions. + * + * Since: 6.0.0 + */ +GArrowVarianceOptions * +garrow_variance_options_new(void) +{ + return GARROW_VARIANCE_OPTIONS( + g_object_new(GARROW_TYPE_VARIANCE_OPTIONS, NULL)); +} + + /** * garrow_array_cast: * @array: A #GArrowArray. @@ -3610,6 +4079,23 @@ garrow_function_options_get_raw(GArrowFunctionOptions *options) return priv->options; } + +GArrowFunctionDoc * +garrow_function_doc_new_raw(const arrow::compute::FunctionDoc *arrow_doc) +{ + return GARROW_FUNCTION_DOC(g_object_new(GARROW_TYPE_FUNCTION_DOC, + "doc", arrow_doc, + NULL)); +} + +arrow::compute::FunctionDoc * +garrow_function_doc_get_raw(GArrowFunctionDoc *doc) +{ + auto priv = GARROW_FUNCTION_DOC_GET_PRIVATE(doc); + return priv->doc; +} + + GArrowFunction * garrow_function_new_raw(std::shared_ptr *arrow_function) { @@ -3755,3 +4241,18 @@ garrow_sort_options_get_raw(GArrowSortOptions *options) return static_cast( garrow_function_options_get_raw(GARROW_FUNCTION_OPTIONS(options))); } + +arrow::compute::SetLookupOptions * +garrow_set_lookup_options_get_raw(GArrowSetLookupOptions *options) +{ + return static_cast( + garrow_function_options_get_raw(GARROW_FUNCTION_OPTIONS(options))); +} + + +arrow::compute::VarianceOptions * +garrow_variance_options_get_raw(GArrowVarianceOptions *options) +{ + return static_cast( + garrow_function_options_get_raw(GARROW_FUNCTION_OPTIONS(options))); +} diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index 108b27ff7ba..2171d6abd9a 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -51,6 +51,31 @@ struct _GArrowFunctionOptionsClass }; +#define GARROW_TYPE_FUNCTION_DOC (garrow_function_doc_get_type()) +G_DECLARE_DERIVABLE_TYPE(GArrowFunctionDoc, + garrow_function_doc, + GARROW, + FUNCTION_DOC, + GObject) +struct _GArrowFunctionDocClass +{ + GObjectClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +gchar * +garrow_function_doc_get_summary(GArrowFunctionDoc *doc); +GARROW_AVAILABLE_IN_6_0 +gchar * +garrow_function_doc_get_description(GArrowFunctionDoc *doc); +GARROW_AVAILABLE_IN_6_0 +gchar ** +garrow_function_doc_get_arg_names(GArrowFunctionDoc *doc); +GARROW_AVAILABLE_IN_6_0 +gchar * +garrow_function_doc_get_options_class_name(GArrowFunctionDoc *doc); + + #define GARROW_TYPE_FUNCTION (garrow_function_get_type()) G_DECLARE_DERIVABLE_TYPE(GArrowFunction, garrow_function, @@ -73,6 +98,10 @@ GArrowDatum *garrow_function_execute(GArrowFunction *function, GArrowExecuteContext *context, GError **error); +GARROW_AVAILABLE_IN_6_0 +GArrowFunctionDoc * +garrow_function_get_doc(GArrowFunction *function); + #define GARROW_TYPE_EXECUTE_NODE_OPTIONS (garrow_execute_node_options_get_type()) G_DECLARE_DERIVABLE_TYPE(GArrowExecuteNodeOptions, @@ -436,6 +465,38 @@ garrow_sort_options_add_sort_key(GArrowSortOptions *options, GArrowSortKey *sort_key); +#define GARROW_TYPE_SET_LOOKUP_OPTIONS (garrow_set_lookup_options_get_type()) +G_DECLARE_DERIVABLE_TYPE(GArrowSetLookupOptions, + garrow_set_lookup_options, + GARROW, + SET_LOOKUP_OPTIONS, + GArrowFunctionOptions) +struct _GArrowSetLookupOptionsClass +{ + GArrowFunctionOptionsClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +GArrowSetLookupOptions * +garrow_set_lookup_options_new(GArrowDatum *value_set); + + +#define GARROW_TYPE_VARIANCE_OPTIONS (garrow_variance_options_get_type()) +G_DECLARE_DERIVABLE_TYPE(GArrowVarianceOptions, + garrow_variance_options, + GARROW, + VARIANCE_OPTIONS, + GArrowFunctionOptions) +struct _GArrowVarianceOptionsClass +{ + GArrowFunctionOptionsClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +GArrowVarianceOptions * +garrow_variance_options_new(void); + + GArrowArray *garrow_array_cast(GArrowArray *array, GArrowDataType *target_data_type, GArrowCastOptions *options, diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp index 01265eee2a8..88f55d5329c 100644 --- a/c_glib/arrow-glib/compute.hpp +++ b/c_glib/arrow-glib/compute.hpp @@ -30,6 +30,11 @@ garrow_execute_context_get_raw(GArrowExecuteContext *context); arrow::compute::FunctionOptions * garrow_function_options_get_raw(GArrowFunctionOptions *options); +GArrowFunctionDoc * +garrow_function_doc_new_raw(const arrow::compute::FunctionDoc *arrow_doc); +arrow::compute::FunctionDoc * +garrow_function_doc_get_raw(GArrowFunctionDoc *doc); + GArrowFunction * garrow_function_new_raw(std::shared_ptr *arrow_function); std::shared_ptr @@ -89,3 +94,11 @@ garrow_sort_key_get_raw(GArrowSortKey *sort_key); arrow::compute::SortOptions * garrow_sort_options_get_raw(GArrowSortOptions *options); + + +arrow::compute::SetLookupOptions * +garrow_set_lookup_options_get_raw(GArrowSetLookupOptions *options); + + +arrow::compute::VarianceOptions * +garrow_variance_options_get_raw(GArrowVarianceOptions *options); diff --git a/c_glib/arrow-glib/input-stream.cpp b/c_glib/arrow-glib/input-stream.cpp index 64f366a6282..37e4702ff16 100644 --- a/c_glib/arrow-glib/input-stream.cpp +++ b/c_glib/arrow-glib/input-stream.cpp @@ -50,6 +50,8 @@ G_BEGIN_DECLS * * #GArrowBufferInputStream is a class to read data on buffer. * + * #GArrowFileInputStream is a class to read data in file. + * * #GArrowMemoryMappedInputStream is a class to read data in file by * mapping the file on memory. It supports zero copy. * @@ -631,6 +633,86 @@ garrow_buffer_input_stream_get_buffer(GArrowBufferInputStream *input_stream) } +G_DEFINE_TYPE(GArrowFileInputStream, + garrow_file_input_stream, + GARROW_TYPE_SEEKABLE_INPUT_STREAM); + +static void +garrow_file_input_stream_init(GArrowFileInputStream *object) +{ +} + +static void +garrow_file_input_stream_class_init(GArrowFileInputStreamClass *klass) +{ +} + +/** + * garrow_file_input_stream_new: + * @path: The path of the file to be opened. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: (nullable): A newly created #GArrowFileInputStream + * or %NULL on error. + * + * Since: 6.0.0 + */ +GArrowFileInputStream * +garrow_file_input_stream_new(const gchar *path, + GError **error) +{ + auto arrow_stream_result = arrow::io::ReadableFile::Open(path); + if (garrow::check(error, arrow_stream_result, "[file-input-stream][new]")) { + auto arrow_stream = *arrow_stream_result; + return garrow_file_input_stream_new_raw(&arrow_stream); + } else { + return NULL; + } +} + +/** + * garrow_file_input_stream_new_file_descriptor: + * @file_descriptor: The file descriptor of this input stream. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: (nullable): A newly created #GArrowFileInputStream + * or %NULL on error. + * + * Since: 6.0.0 + */ +GArrowFileInputStream * +garrow_file_input_stream_new_file_descriptor(gint file_descriptor, + GError **error) +{ + auto arrow_stream_result = arrow::io::ReadableFile::Open(file_descriptor); + if (garrow::check(error, + arrow_stream_result, + "[file-input-stream][new-file-descriptor]")) { + auto arrow_stream = *arrow_stream_result; + return garrow_file_input_stream_new_raw(&arrow_stream); + } else { + return NULL; + } +} + +/** + * garrow_file_input_stream_get_file_descriptor: + * @stream: A #GArrowFileInuptStream. + * + * Returns: The file descriptor of @stream. + * + * Since: 6.0.0 + */ +gint +garrow_file_input_stream_get_file_descriptor(GArrowFileInputStream *stream) +{ + auto arrow_stream = + std::static_pointer_cast( + garrow_input_stream_get_raw(GARROW_INPUT_STREAM(stream))); + return arrow_stream->file_descriptor(); +} + + G_DEFINE_TYPE(GArrowMemoryMappedInputStream, garrow_memory_mapped_input_stream, GARROW_TYPE_SEEKABLE_INPUT_STREAM); @@ -657,18 +739,14 @@ GArrowMemoryMappedInputStream * garrow_memory_mapped_input_stream_new(const gchar *path, GError **error) { - auto arrow_memory_mapped_file_result = - arrow::io::MemoryMappedFile::Open(std::string(path), - arrow::io::FileMode::READ); - if (arrow_memory_mapped_file_result.ok()) { - auto arrow_memory_mapped_file = - arrow_memory_mapped_file_result.ValueOrDie(); - return garrow_memory_mapped_input_stream_new_raw(&(arrow_memory_mapped_file)); + auto arrow_stream_result = + arrow::io::MemoryMappedFile::Open(path, arrow::io::FileMode::READ); + if (garrow::check(error, + arrow_stream_result, + "[memory-mapped-input-stream][new]")) { + auto arrow_stream = *arrow_stream_result; + return garrow_memory_mapped_input_stream_new_raw(&arrow_stream); } else { - std::string context("[memory-mapped-input-stream][open]: <"); - context += path; - context += ">"; - garrow::check(error, arrow_memory_mapped_file_result, context.c_str()); return NULL; } } @@ -1203,16 +1281,28 @@ garrow_buffer_input_stream_get_raw(GArrowBufferInputStream *buffer_input_stream) return arrow_buffer_reader; } + +GArrowFileInputStream * +garrow_file_input_stream_new_raw( + std::shared_ptr *arrow_stream) +{ + return GARROW_FILE_INPUT_STREAM(g_object_new(GARROW_TYPE_FILE_INPUT_STREAM, + "input-stream", arrow_stream, + NULL)); +} + + GArrowMemoryMappedInputStream * -garrow_memory_mapped_input_stream_new_raw(std::shared_ptr *arrow_memory_mapped_file) +garrow_memory_mapped_input_stream_new_raw( + std::shared_ptr *arrow_stream) { - auto object = g_object_new(GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM, - "input-stream", arrow_memory_mapped_file, - NULL); - auto memory_mapped_input_stream = GARROW_MEMORY_MAPPED_INPUT_STREAM(object); - return memory_mapped_input_stream; + return GARROW_MEMORY_MAPPED_INPUT_STREAM( + g_object_new(GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM, + "input-stream", arrow_stream, + NULL)); } + GArrowCompressedInputStream * garrow_compressed_input_stream_new_raw(std::shared_ptr *arrow_raw, GArrowCodec *codec, diff --git a/c_glib/arrow-glib/input-stream.h b/c_glib/arrow-glib/input-stream.h index 4b4c51eb3e7..5f583c80486 100644 --- a/c_glib/arrow-glib/input-stream.h +++ b/c_glib/arrow-glib/input-stream.h @@ -104,54 +104,42 @@ GArrowBufferInputStream *garrow_buffer_input_stream_new(GArrowBuffer *buffer); GArrowBuffer *garrow_buffer_input_stream_get_buffer(GArrowBufferInputStream *input_stream); -#define GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM \ - (garrow_memory_mapped_input_stream_get_type()) -#define GARROW_MEMORY_MAPPED_INPUT_STREAM(obj) \ - (G_TYPE_CHECK_INSTANCE_CAST((obj), \ - GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM, \ - GArrowMemoryMappedInputStream)) -#define GARROW_MEMORY_MAPPED_INPUT_STREAM_CLASS(klass) \ - (G_TYPE_CHECK_CLASS_CAST((klass), \ - GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM, \ - GArrowMemoryMappedInputStreamClass)) -#define GARROW_IS_MEMORY_MAPPED_INPUT_STREAM(obj) \ - (G_TYPE_CHECK_INSTANCE_TYPE((obj), \ - GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM)) -#define GARROW_IS_MEMORY_MAPPED_INPUT_STREAM_CLASS(klass) \ - (G_TYPE_CHECK_CLASS_TYPE((klass), \ - GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM)) -#define GARROW_MEMORY_MAPPED_INPUT_STREAM_GET_CLASS(obj) \ - (G_TYPE_INSTANCE_GET_CLASS((obj), \ - GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM, \ - GArrowMemoryMappedInputStreamClass)) - -typedef struct _GArrowMemoryMappedInputStream GArrowMemoryMappedInputStream; -#ifndef __GTK_DOC_IGNORE__ -typedef struct _GArrowMemoryMappedInputStreamClass GArrowMemoryMappedInputStreamClass; -#endif - -/** - * GArrowMemoryMappedInputStream: - * - * It wraps `arrow::io::MemoryMappedFile`. - */ -struct _GArrowMemoryMappedInputStream +#define GARROW_TYPE_FILE_INPUT_STREAM (garrow_file_input_stream_get_type()) +G_DECLARE_DERIVABLE_TYPE(GArrowFileInputStream, + garrow_file_input_stream, + GARROW, + FILE_INPUT_STREAM, + GArrowSeekableInputStream) +struct _GArrowFileInputStreamClass { - /*< private >*/ - GArrowSeekableInputStream parent_instance; + GArrowSeekableInputStreamClass parent_class; }; -#ifndef __GTK_DOC_IGNORE__ +GArrowFileInputStream * +garrow_file_input_stream_new(const gchar *path, + GError **error); +GArrowFileInputStream * +garrow_file_input_stream_new_file_descriptor(gint file_descriptor, + GError **error); +gint +garrow_file_input_stream_get_file_descriptor(GArrowFileInputStream *stream); + + +#define GARROW_TYPE_MEMORY_MAPPED_INPUT_STREAM \ + (garrow_memory_mapped_input_stream_get_type()) +G_DECLARE_DERIVABLE_TYPE(GArrowMemoryMappedInputStream, + garrow_memory_mapped_input_stream, + GARROW, + MEMORY_MAPPED_INPUT_STREAM, + GArrowSeekableInputStream) struct _GArrowMemoryMappedInputStreamClass { GArrowSeekableInputStreamClass parent_class; }; -#endif -GType garrow_memory_mapped_input_stream_get_type(void) G_GNUC_CONST; - -GArrowMemoryMappedInputStream *garrow_memory_mapped_input_stream_new(const gchar *path, - GError **error); +GArrowMemoryMappedInputStream * +garrow_memory_mapped_input_stream_new(const gchar *path, + GError **error); #define GARROW_TYPE_GIO_INPUT_STREAM \ diff --git a/c_glib/arrow-glib/input-stream.hpp b/c_glib/arrow-glib/input-stream.hpp index 88fbb8f64c1..2a0a3d3ddcc 100644 --- a/c_glib/arrow-glib/input-stream.hpp +++ b/c_glib/arrow-glib/input-stream.hpp @@ -40,7 +40,16 @@ garrow_buffer_input_stream_new_raw(std::shared_ptr *arr GArrowBuffer *buffer); std::shared_ptr garrow_buffer_input_stream_get_raw(GArrowBufferInputStream *input_stream); -GArrowMemoryMappedInputStream *garrow_memory_mapped_input_stream_new_raw(std::shared_ptr *arrow_memory_mapped_file); + +GArrowFileInputStream * +garrow_file_input_stream_new_raw( + std::shared_ptr *arrow_stream); + + +GArrowMemoryMappedInputStream * +garrow_memory_mapped_input_stream_new_raw( + std::shared_ptr *arrow_stream); + GArrowCompressedInputStream * garrow_compressed_input_stream_new_raw(std::shared_ptr *arrow_raw, diff --git a/c_glib/arrow-glib/reader.cpp b/c_glib/arrow-glib/reader.cpp index 3e6539feb5a..98038248050 100644 --- a/c_glib/arrow-glib/reader.cpp +++ b/c_glib/arrow-glib/reader.cpp @@ -1360,10 +1360,13 @@ garrow_csv_read_options_set_null_values(GArrowCSVReadOptions *options, * garrow_csv_read_options_get_null_values: * @options: A #GArrowCSVReadOptions. * - * Return: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): - * The values to be processed as null. It's a %NULL-terminated string array. + * Returns: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): + * The values to be processed as null. + * * If the number of values is zero, this returns %NULL. - * It must be freed with g_strfreev() when no longer needed. + * + * It's a %NULL-terminated string array. It must be freed with + * g_strfreev() when no longer needed. * * Since: 0.14.0 */ @@ -1425,10 +1428,13 @@ garrow_csv_read_options_set_true_values(GArrowCSVReadOptions *options, * garrow_csv_read_options_get_true_values: * @options: A #GArrowCSVReadOptions. * - * Return: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): - * The values to be processed as true. It's a %NULL-terminated string array. + * Returns: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): + * The values to be processed as true. + * * If the number of values is zero, this returns %NULL. - * It must be freed with g_strfreev() when no longer needed. + * + * It's a %NULL-terminated string array. It must be freed with + * g_strfreev() when no longer needed. * * Since: 0.14.0 */ @@ -1490,10 +1496,13 @@ garrow_csv_read_options_set_false_values(GArrowCSVReadOptions *options, * garrow_csv_read_options_get_false_values: * @options: A #GArrowCSVReadOptions. * - * Return: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): - * The values to be processed as false. It's a %NULL-terminated string array. + * Returns: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): + * The values to be processed as false. + * * If the number of values is zero, this returns %NULL. - * It must be freed with g_strfreev() when no longer needed. + * + * It's a %NULL-terminated string array. It must be freed with + * g_strfreev() when no longer needed. * * Since: 0.14.0 */ @@ -1556,10 +1565,13 @@ garrow_csv_read_options_set_column_names(GArrowCSVReadOptions *options, * garrow_csv_read_options_get_column_names: * @options: A #GArrowCSVReadOptions. * - * Return: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): - * The column names. It's a %NULL-terminated string array. + * Returns: (nullable) (array zero-terminated=1) (element-type utf8) (transfer full): + * The column names. + * * If the number of values is zero, this returns %NULL. - * It must be freed with g_strfreev() when no longer needed. + * + * It's a %NULL-terminated string array. It must be freed with + * g_strfreev() when no longer needed. * * Since: 0.15.0 */ diff --git a/c_glib/doc/arrow-dataset-glib/arrow-dataset-glib-docs.xml b/c_glib/doc/arrow-dataset-glib/arrow-dataset-glib-docs.xml index 3e8da5bd9d1..b13195b0703 100644 --- a/c_glib/doc/arrow-dataset-glib/arrow-dataset-glib-docs.xml +++ b/c_glib/doc/arrow-dataset-glib/arrow-dataset-glib-docs.xml @@ -39,6 +39,8 @@ Data + Partitioning + Dataset Dataset factory @@ -66,8 +68,12 @@ Index of deprecated API + + Index of new symbols in 6.0.0 + + - Index of new symbols in 4.0.0 + Index of new symbols in 5.0.0 diff --git a/c_glib/test/dataset/test-file-system-dataset.rb b/c_glib/test/dataset/test-file-system-dataset.rb index 6d6ec3b18c6..0e856b678f8 100644 --- a/c_glib/test/dataset/test-file-system-dataset.rb +++ b/c_glib/test/dataset/test-file-system-dataset.rb @@ -16,19 +16,73 @@ # under the License. class TestDatasetFileSystemDataset < Test::Unit::TestCase + include Helper::Buildable + include Helper::Readable + def setup omit("Arrow Dataset is required") unless defined?(ArrowDataset) Dir.mktmpdir do |tmpdir| @dir = tmpdir - format = ArrowDataset::IPCFileFormat.new - factory = ArrowDataset::FileSystemDatasetFactory.new(format) - factory.file_system = Arrow::LocalFileSystem.new - @dataset = factory.finish + @format = ArrowDataset::IPCFileFormat.new + @factory = ArrowDataset::FileSystemDatasetFactory.new(@format) + @file_system = Arrow::LocalFileSystem.new + @factory.file_system = @file_system + partitioning_schema = build_schema(label: Arrow::StringDataType.new) + @partitioning = + ArrowDataset::DirectoryPartitioning.new(partitioning_schema) + @factory.partitioning = @partitioning yield end end def test_type_name - assert_equal("filesystem", @dataset.type_name) + dataset = @factory.finish + assert_equal("filesystem", dataset.type_name) + end + + def test_format + dataset = @factory.finish + assert_equal(@format, dataset.format) + end + + def test_file_system + dataset = @factory.finish + assert_equal(@file_system, dataset.file_system) + end + + def test_partitioning + dataset = @factory.finish + assert_equal(@partitioning, dataset.partitioning) + end + + def test_read_write + table = build_table(label: build_string_array(["a", "a", "b", "c"]), + count: build_int32_array([1, 10, 2, 3])) + table_reader = Arrow::TableBatchReader.new(table) + scanner_builder = ArrowDataset::ScannerBuilder.new(table_reader) + scanner = scanner_builder.finish + options = ArrowDataset::FileSystemDatasetWriteOptions.new + options.file_write_options = @format.default_write_options + options.file_system = Arrow::LocalFileSystem.new + options.base_dir = @dir + options.base_name_template = "{i}.arrow" + options.partitioning = @partitioning + ArrowDataset::FileSystemDataset.write_scanner(scanner, options) + Find.find(@dir) do |path| + @factory.add_path(path) if File.file?(path) + end + @factory.partition_base_dir = @dir + dataset = @factory.finish + assert_equal(build_table(count: [ + build_int32_array([1, 10]), + build_int32_array([2]), + build_int32_array([3]), + ], + label: [ + build_string_array(["a", "a"]), + build_string_array(["b"]), + build_string_array(["c"]), + ]), + dataset.to_table) end end diff --git a/c_glib/test/dataset/test-file-writer.rb b/c_glib/test/dataset/test-file-writer.rb new file mode 100644 index 00000000000..5b25d6044d6 --- /dev/null +++ b/c_glib/test/dataset/test-file-writer.rb @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class TestDatasetFileWriter < Test::Unit::TestCase + include Helper::Buildable + include Helper::Readable + + def setup + omit("Arrow Dataset is required") unless defined?(ArrowDataset) + Dir.mktmpdir do |tmpdir| + @dir = tmpdir + @format = ArrowDataset::IPCFileFormat.new + @file_system = Arrow::LocalFileSystem.new + @path = File.join(@dir, "data.arrow") + @output = @file_system.open_output_stream(@path) + @schema = build_schema(visible: Arrow::BooleanDataType.new, + point: Arrow::UInt8DataType.new) + @writer = @format.open_writer(@output, + @file_system, + @path, + @schema, + @format.default_write_options) + yield + end + end + + def test_write_record_batch + record_batch = build_record_batch( + visible: build_boolean_array([true, false, true]), + point: build_uint8_array([1, 2, 3])) + @writer.write_record_batch(record_batch) + @writer.finish + @output.close + read_table(@path) do |written_table| + assert_equal(Arrow::Table.new(record_batch.schema, + [record_batch]), + written_table) + end + end + + def test_write_record_batch_reader + table = build_table(visible: build_boolean_array([true, false, true]), + point: build_uint8_array([1, 2, 3])) + @writer.write_record_batch_reader(Arrow::TableBatchReader.new(table)) + @writer.finish + @output.close + read_table(@path) do |written_table| + assert_equal(table, written_table) + end + end +end diff --git a/c_glib/test/dataset/test-partitioning-options.rb b/c_glib/test/dataset/test-partitioning-options.rb new file mode 100644 index 00000000000..9ff585aa7cf --- /dev/null +++ b/c_glib/test/dataset/test-partitioning-options.rb @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class TestDatasetPartitioningOptions < Test::Unit::TestCase + include Helper::Buildable + + def setup + omit("Arrow Dataset is required") unless defined?(ArrowDataset) + @options = ArrowDataset::PartitioningOptions.new + end + + def test_infer_dictionary + assert_false(@options.infer_dictionary?) + @options.infer_dictionary = true + assert_true(@options.infer_dictionary?) + end + + def test_schema + assert_nil(@options.schema) + schema = build_schema(year: Arrow::UInt16DataType.new) + @options.schema = schema + assert_equal(schema, @options.schema) + end + + def test_segment_encoding + assert_equal(ArrowDataset::SegmentEncoding::NONE, + @options.segment_encoding) + @options.segment_encoding = :uri + assert_equal(ArrowDataset::SegmentEncoding::URI, + @options.segment_encoding) + end +end diff --git a/c_glib/test/dataset/test-partitioning.rb b/c_glib/test/dataset/test-partitioning.rb new file mode 100644 index 00000000000..d98e51f3c59 --- /dev/null +++ b/c_glib/test/dataset/test-partitioning.rb @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class TestDatasetPartitioning < Test::Unit::TestCase + include Helper::Buildable + + def setup + omit("Arrow Dataset is required") unless defined?(ArrowDataset) + end + + def test_default + assert_equal("default", ArrowDataset::Partitioning.new.type_name) + end + + def test_directory + schema = build_schema(year: Arrow::UInt16DataType.new) + partitioning = ArrowDataset::DirectoryPartitioning.new(schema) + assert_equal("schema", partitioning.type_name) + end +end diff --git a/c_glib/test/dataset/test-scanner.rb b/c_glib/test/dataset/test-scanner.rb index f7702d4905f..ed6a706c6f2 100644 --- a/c_glib/test/dataset/test-scanner.rb +++ b/c_glib/test/dataset/test-scanner.rb @@ -45,4 +45,11 @@ def setup def test_to_table assert_equal(@table, @scanner.to_table) end + + def test_new_record_batch_reader + reader = Arrow::TableBatchReader.new(@table) + builder = ArrowDataset::ScannerBuilder.new(reader) + scanner = builder.finish + assert_equal(@table, scanner.to_table) + end end diff --git a/c_glib/test/helper/buildable.rb b/c_glib/test/helper/buildable.rb index 356fa651c6a..3a1240cfa1f 100644 --- a/c_glib/test/helper/buildable.rb +++ b/c_glib/test/helper/buildable.rb @@ -17,6 +17,13 @@ module Helper module Buildable + def build_schema(fields) + fields = fields.collect do |name, data_type| + Arrow::Field.new(name, data_type) + end + Arrow::Schema.new(fields) + end + def build_null_array(values) build_array(Arrow::NullArrayBuilder.new, values) end diff --git a/c_glib/test/helper/readable.rb b/c_glib/test/helper/readable.rb new file mode 100644 index 00000000000..81bf0795c6b --- /dev/null +++ b/c_glib/test/helper/readable.rb @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module Helper + module Readable + def read_table(input, type: :file) + if input.is_a?(Arrow::Buffer) + input_stream = Arrow::BufferIntputStream.new(input) + else + input_stream = Arrow::FileInputStream.new(input) + end + begin + if type == :file + reader = Arrow::RecordBatchFileReader.new(input_stream) + record_batches = [] + reader.n_record_batches.times do |i| + record_batches << reader.read_record_batch(i) + end + yield(Arrow::Table.new(record_batches[0].schema, record_batches)) + else + reader = Arrow::RecordBatchStreamReader.new(input_stream) + begin + yield(reader.read_all) + ensure + reader.close + end + end + ensure + input_stream.close + end + end + end +end diff --git a/c_glib/test/run-test.rb b/c_glib/test/run-test.rb index abae4e722c5..621c78c3986 100755 --- a/c_glib/test/run-test.rb +++ b/c_glib/test/run-test.rb @@ -84,6 +84,7 @@ def should_unlock_gvl?(info, klass) end require "fileutils" +require "find" require "rbconfig" require "stringio" require "tempfile" @@ -97,6 +98,7 @@ def should_unlock_gvl?(info, klass) end require_relative "helper/omittable" require_relative "helper/plasma-store" +require_relative "helper/readable" require_relative "helper/writable" exit(Test::Unit::AutoRunner.run(true, test_dir.to_s)) diff --git a/c_glib/test/test-file-input-stream.rb b/c_glib/test/test-file-input-stream.rb new file mode 100644 index 00000000000..2b43f97f5dd --- /dev/null +++ b/c_glib/test/test-file-input-stream.rb @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class TestFileInputStream < Test::Unit::TestCase + def setup + @data = "Hello World" + @tempfile = Tempfile.open("arrow-file-input-stream") + @tempfile.write(@data) + @tempfile.close + end + + def test_new + input = Arrow::FileInputStream.new(@tempfile.path) + begin + buffer = input.read(5) + assert_equal("Hello", buffer.data.to_s) + ensure + input.close + end + end + + def test_close + input = Arrow::FileInputStream.new(@tempfile.path) + assert do + not input.closed? + end + input.close + assert do + input.closed? + end + end + + def test_size + input = Arrow::FileInputStream.new(@tempfile.path) + begin + assert_equal(@data.bytesize, input.size) + ensure + input.close + end + end + + def test_read + input = Arrow::FileInputStream.new(@tempfile.path) + begin + buffer = input.read(5) + assert_equal("Hello", buffer.data.to_s) + ensure + input.close + end + end + + def test_read_at + input = Arrow::FileInputStream.new(@tempfile.path) + begin + buffer = input.read_at(6, 5) + assert_equal("World", buffer.data.to_s) + ensure + input.close + end + end + + def test_mode + input = Arrow::FileInputStream.new(@tempfile.path) + begin + assert_equal(Arrow::FileMode::READ, input.mode) + ensure + input.close + end + end + + def test_file_descriptor + @tempfile.open + begin + fd = @tempfile.fileno + input = Arrow::FileInputStream.new(fd) + begin + assert_equal(fd, input.file_descriptor) + ensure + input.close + end + ensure + begin + @tempfile.close + rescue + end + end + end +end diff --git a/c_glib/test/test-function-doc.rb b/c_glib/test/test-function-doc.rb new file mode 100644 index 00000000000..7e624a5ab7c --- /dev/null +++ b/c_glib/test/test-function-doc.rb @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class TestFunctionDoc < Test::Unit::TestCase + def setup + @doc = Arrow::Function.find("or").doc + end + + def test_summary + assert_equal("Logical 'or' boolean values", + @doc.summary) + end + + def test_description + assert_equal(<<-DESCRIPTION.chomp, @doc.description) +When a null is encountered in either input, a null is output. +For a different null behavior, see function "or_kleene". + DESCRIPTION + end + + def test_arg_names + assert_equal(["x", "y"], @doc.arg_names) + end + + def test_options_class_name + doc = Arrow::Function.find("cast").doc + assert_equal("CastOptions", doc.options_class_name) + end +end diff --git a/c_glib/test/test-is-in.rb b/c_glib/test/test-is-in.rb index ba44075d6b3..590b5e3798a 100644 --- a/c_glib/test/test-is-in.rb +++ b/c_glib/test/test-is-in.rb @@ -46,6 +46,16 @@ def test_null_in_both assert_equal(build_boolean_array([false, true, true, true]), left.is_in(right)) end + + def test_options + left = build_int16_array([1, 0, nil, 2]) + right = build_int16_array([2, 0, nil]) + is_in = Arrow::Function.find("is_in") + options = Arrow::SetLookupOptions.new(Arrow::ArrayDatum.new(right)) + assert_equal(build_boolean_array([false, true, true, true]), + is_in.execute([Arrow::ArrayDatum.new(left)], + options).value) + end end sub_test_case("ChunkedArray") do @@ -92,5 +102,19 @@ def test_null_in_both assert_equal(build_boolean_array([false, true, true, true]), left.is_in_chunked_array(right)) end + + def test_options + left = build_int16_array([1, 0, nil, 2]) + chunks = [ + build_int16_array([2, 0]), + build_int16_array([3, nil]) + ] + right = Arrow::ChunkedArray.new(chunks) + is_in = Arrow::Function.find("is_in") + options = Arrow::SetLookupOptions.new(Arrow::ChunkedArrayDatum.new(right)) + assert_equal(build_boolean_array([false, true, true, true]), + is_in.execute([Arrow::ArrayDatum.new(left)], + options).value) + end end end diff --git a/c_glib/test/test-set-lookup-options.rb b/c_glib/test/test-set-lookup-options.rb new file mode 100644 index 00000000000..779bacef683 --- /dev/null +++ b/c_glib/test/test-set-lookup-options.rb @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class TestSetLookupOptions < Test::Unit::TestCase + include Helper::Buildable + + def test_new + value_set = Arrow::ArrayDatum.new(build_int8_array([1, 2, 3])) + options = Arrow::SetLookupOptions.new(value_set) + assert_equal(value_set, options.value_set) + end + + sub_test_case("instance methods") do + def setup + value_set = Arrow::ArrayDatum.new(build_int8_array([1, 2, 3])) + @options = Arrow::SetLookupOptions.new(value_set) + end + + def test_skip_nulls + assert do + not @options.skip_nulls? + end + @options.skip_nulls = true + assert do + @options.skip_nulls? + end + end + end +end diff --git a/c_glib/test/test-variance-options.rb b/c_glib/test/test-variance-options.rb new file mode 100644 index 00000000000..64bdf670bf0 --- /dev/null +++ b/c_glib/test/test-variance-options.rb @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class TestVarianceOptions < Test::Unit::TestCase + include Helper::Buildable + + def setup + @options = Arrow::VarianceOptions.new + end + + def test_ddof + assert_equal(0, @options.ddof) + @options.ddof = 1 + assert_equal(1, @options.ddof) + end + + def test_skip_nulls + assert do + @options.skip_nulls? + end + @options.skip_nulls = false + assert do + not @options.skip_nulls? + end + end + + def test_min_count + assert_equal(0, @options.min_count) + @options.min_count = 1 + assert_equal(1, @options.min_count) + end +end diff --git a/ci/appveyor-cpp-setup.bat b/ci/appveyor-cpp-setup.bat index 47448ce0887..3bf01ec12dd 100644 --- a/ci/appveyor-cpp-setup.bat +++ b/ci/appveyor-cpp-setup.bat @@ -70,6 +70,12 @@ if "%JOB%" NEQ "Build_Debug" ( "fsspec" ^ "python=%PYTHON%" ^ || exit /B + + @rem On Windows, GTest is always bundled from source instead of using + @rem conda binaries, avoid any interference between the two versions. + if "%JOB%" == "Toolchain" ( + conda uninstall -n arrow -q -y -c conda-forge gtest + ) ) @rem diff --git a/ci/scripts/go_test.sh b/ci/scripts/go_test.sh index 7dd873df3e1..9b2572e1b43 100755 --- a/ci/scripts/go_test.sh +++ b/ci/scripts/go_test.sh @@ -21,10 +21,18 @@ set -ex source_dir=${1}/go +testargs="-race" +case "$(uname)" in + MINGW*) + # -race doesn't work on windows currently + testargs="" + ;; +esac + pushd ${source_dir}/arrow for d in $(go list ./... | grep -v vendor); do - go test $d + go test $testargs -tags "test" $d done popd @@ -32,7 +40,7 @@ popd pushd ${source_dir}/parquet for d in $(go list ./... | grep -v vendor); do - go test $d + go test $testargs $d done popd diff --git a/ci/scripts/python_wheel_macos_build.sh b/ci/scripts/python_wheel_macos_build.sh index 82e0339c9d0..1a52a2ad52b 100755 --- a/ci/scripts/python_wheel_macos_build.sh +++ b/ci/scripts/python_wheel_macos_build.sh @@ -53,6 +53,7 @@ export PIP_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[ export PIP_TARGET_PLATFORM="macosx_${MACOSX_DEPLOYMENT_TARGET//./_}_${arch}" pip install \ + --upgrade \ --only-binary=:all: \ --target $PIP_SITE_PACKAGES \ --platform $PIP_TARGET_PLATFORM \ diff --git a/ci/scripts/r_docker_configure.sh b/ci/scripts/r_docker_configure.sh index 2b9bc03bea0..d138d030eca 100755 --- a/ci/scripts/r_docker_configure.sh +++ b/ci/scripts/r_docker_configure.sh @@ -37,7 +37,7 @@ if [ "$RHUB_PLATFORM" = "linux-x86_64-fedora-clang" ]; then dnf install -y libcxx-devel sed -i.bak -E -e 's/(CXX1?1? =.*)/\1 -stdlib=libc++/g' $(${R_BIN} RHOME)/etc/Makeconf rm -rf $(${R_BIN} RHOME)/etc/Makeconf.bak - + sed -i.bak -E -e 's/(CXXFLAGS = )(.*)/\1 -g -O3 -Wall -pedantic -frtti -fPIC/' $(${R_BIN} RHOME)/etc/Makeconf rm -rf $(${R_BIN} RHOME)/etc/Makeconf.bak fi @@ -75,9 +75,3 @@ fi # Workaround for html help install failure; see https://github.com/r-lib/devtools/issues/2084#issuecomment-530912786 Rscript -e 'x <- file.path(R.home("doc"), "html"); if (!file.exists(x)) {dir.create(x, recursive=TRUE); file.copy(system.file("html/R.css", package="stats"), x)}' - -if [ "`which curl`" ]; then - # We need this on R >= 4.0 - curl -L https://sourceforge.net/projects/checkbaskisms/files/2.0.0.2/checkbashisms/download > /usr/local/bin/checkbashisms - chmod 755 /usr/local/bin/checkbashisms -fi diff --git a/ci/scripts/r_windows_build.sh b/ci/scripts/r_windows_build.sh index 47120eef433..8a96b3f5e79 100755 --- a/ci/scripts/r_windows_build.sh +++ b/ci/scripts/r_windows_build.sh @@ -92,10 +92,10 @@ cp $MSYS_LIB_DIR/mingw32/lib/lib{zstd,lz4,crypto,utf8proc,re2,aws*}.a $DST_DIR/l # Do the same also for ucrt64 if [ "$RTOOLS_VERSION" != "35" ]; then -ls $MSYS_LIB_DIR/ucrt64/lib/ -mkdir -p $DST_DIR/lib/x64-ucrt -mv ucrt64/lib/*.a $DST_DIR/${RWINLIB_LIB_DIR}/x64-ucrt -cp $MSYS_LIB_DIR/ucrt64/lib/lib{zstd,lz4,crypto,utf8proc,re2,aws*}.a $DST_DIR/lib/x64-ucrt + ls $MSYS_LIB_DIR/ucrt64/lib/ + mkdir -p $DST_DIR/lib/x64-ucrt + mv ucrt64/lib/*.a $DST_DIR/lib/x64-ucrt + cp $MSYS_LIB_DIR/ucrt64/lib/lib{thrift,snappy,zstd,lz4,crypto,utf8proc,re2,aws*}.a $DST_DIR/lib/x64-ucrt fi # Create build artifact diff --git a/cpp/build-support/fuzzing/generate_corpuses.sh b/cpp/build-support/fuzzing/generate_corpuses.sh index f0d8e162375..e3f00e64782 100755 --- a/cpp/build-support/fuzzing/generate_corpuses.sh +++ b/cpp/build-support/fuzzing/generate_corpuses.sh @@ -27,15 +27,21 @@ fi set -ex CORPUS_DIR=/tmp/corpus -ARROW_CPP=$(cd $(dirname $BASH_SOURCE)/../..; pwd) +ARROW_ROOT=$(cd $(dirname $BASH_SOURCE)/../../..; pwd) +ARROW_CPP=$ARROW_ROOT/cpp OUT=$1 # NOTE: name of seed corpus output file should be "-seed_corpus.zip" # where "" is the exact name of the fuzz target executable the # seed corpus is generated for. +IPC_INTEGRATION_FILES=$(find ${ARROW_ROOT}/testing/data/arrow-ipc-stream/integration -name "*.stream") + rm -rf ${CORPUS_DIR} ${OUT}/arrow-ipc-generate-fuzz-corpus -stream ${CORPUS_DIR} +# Several IPC integration files can have the same name, make sure +# they all appear in the corpus by numbering the duplicates. +cp --backup=numbered ${IPC_INTEGRATION_FILES} ${CORPUS_DIR} ${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/arrow-ipc-stream-fuzz_seed_corpus.zip rm -rf ${CORPUS_DIR} @@ -48,5 +54,6 @@ ${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/arrow-ipc rm -rf ${CORPUS_DIR} ${OUT}/parquet-arrow-generate-fuzz-corpus ${CORPUS_DIR} +# Add Parquet testing examples cp ${ARROW_CPP}/submodules/parquet-testing/data/*.parquet ${CORPUS_DIR} ${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/parquet-arrow-fuzz_seed_corpus.zip diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 637f3d1a54f..e06fad9a1de 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -190,6 +190,7 @@ set(ARROW_SRCS io/slow.cc io/stdio.cc io/transform.cc + util/async_util.cc util/basic_decimal.cc util/bit_block_counter.cc util/bit_run_reader.cc diff --git a/cpp/src/arrow/adapters/orc/adapter.cc b/cpp/src/arrow/adapters/orc/adapter.cc index 2f74b40e40d..94a3b6e882a 100644 --- a/cpp/src/arrow/adapters/orc/adapter.cc +++ b/cpp/src/arrow/adapters/orc/adapter.cc @@ -430,10 +430,14 @@ ORCFileReader::~ORCFileReader() {} Status ORCFileReader::Open(const std::shared_ptr& file, MemoryPool* pool, std::unique_ptr* reader) { + return Open(file, pool).Value(reader); +} + +Result> ORCFileReader::Open( + const std::shared_ptr& file, MemoryPool* pool) { auto result = std::unique_ptr(new ORCFileReader()); RETURN_NOT_OK(result->impl_->Open(file, pool)); - *reader = std::move(result); - return Status::OK(); + return std::move(result); } Result> ORCFileReader::ReadMetadata() { @@ -444,33 +448,79 @@ Status ORCFileReader::ReadSchema(std::shared_ptr* out) { return impl_->ReadSchema(out); } +Result> ORCFileReader::ReadSchema() { + std::shared_ptr schema; + RETURN_NOT_OK(impl_->ReadSchema(&schema)); + return schema; +} + Status ORCFileReader::Read(std::shared_ptr* out) { return impl_->Read(out); } +Result> ORCFileReader::Read() { + std::shared_ptr
table; + RETURN_NOT_OK(impl_->Read(&table)); + return table; +} + Status ORCFileReader::Read(const std::shared_ptr& schema, std::shared_ptr
* out) { return impl_->Read(schema, out); } +Result> ORCFileReader::Read( + const std::shared_ptr& schema) { + std::shared_ptr
table; + RETURN_NOT_OK(impl_->Read(schema, &table)); + return table; +} + Status ORCFileReader::Read(const std::vector& include_indices, std::shared_ptr
* out) { return impl_->Read(include_indices, out); } +Result> ORCFileReader::Read( + const std::vector& include_indices) { + std::shared_ptr
table; + RETURN_NOT_OK(impl_->Read(include_indices, &table)); + return table; +} + Status ORCFileReader::Read(const std::shared_ptr& schema, const std::vector& include_indices, std::shared_ptr
* out) { return impl_->Read(schema, include_indices, out); } +Result> ORCFileReader::Read( + const std::shared_ptr& schema, const std::vector& include_indices) { + std::shared_ptr
table; + RETURN_NOT_OK(impl_->Read(schema, include_indices, &table)); + return table; +} + Status ORCFileReader::ReadStripe(int64_t stripe, std::shared_ptr* out) { return impl_->ReadStripe(stripe, out); } +Result> ORCFileReader::ReadStripe(int64_t stripe) { + std::shared_ptr recordBatch; + RETURN_NOT_OK(impl_->ReadStripe(stripe, &recordBatch)); + return recordBatch; +} + Status ORCFileReader::ReadStripe(int64_t stripe, const std::vector& include_indices, std::shared_ptr* out) { return impl_->ReadStripe(stripe, include_indices, out); } +Result> ORCFileReader::ReadStripe( + int64_t stripe, const std::vector& include_indices) { + std::shared_ptr recordBatch; + RETURN_NOT_OK(impl_->ReadStripe(stripe, include_indices, &recordBatch)); + return recordBatch; +} + Status ORCFileReader::Seek(int64_t row_number) { return impl_->Seek(row_number); } Status ORCFileReader::NextStripeReader(int64_t batch_sizes, @@ -478,12 +528,26 @@ Status ORCFileReader::NextStripeReader(int64_t batch_sizes, return impl_->NextStripeReader(batch_sizes, out); } +Result> ORCFileReader::NextStripeReader( + int64_t batch_size) { + std::shared_ptr reader; + RETURN_NOT_OK(impl_->NextStripeReader(batch_size, &reader)); + return reader; +} + Status ORCFileReader::NextStripeReader(int64_t batch_size, const std::vector& include_indices, std::shared_ptr* out) { return impl_->NextStripeReader(batch_size, include_indices, out); } +Result> ORCFileReader::NextStripeReader( + int64_t batch_size, const std::vector& include_indices) { + std::shared_ptr reader; + RETURN_NOT_OK(impl_->NextStripeReader(batch_size, include_indices, &reader)); + return reader; +} + int64_t ORCFileReader::NumberOfStripes() { return impl_->NumberOfStripes(); } int64_t ORCFileReader::NumberOfRows() { return impl_->NumberOfRows(); } diff --git a/cpp/src/arrow/adapters/orc/adapter.h b/cpp/src/arrow/adapters/orc/adapter.h index 012c1701980..036795188f6 100644 --- a/cpp/src/arrow/adapters/orc/adapter.h +++ b/cpp/src/arrow/adapters/orc/adapter.h @@ -27,6 +27,7 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/type_fwd.h" +#include "arrow/util/macros.h" #include "arrow/util/visibility.h" namespace arrow { @@ -45,9 +46,18 @@ class ARROW_EXPORT ORCFileReader { /// \param[in] pool a MemoryPool to use for buffer allocations /// \param[out] reader the returned reader object /// \return Status + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") static Status Open(const std::shared_ptr& file, MemoryPool* pool, std::unique_ptr* reader); + /// \brief Creates a new ORC reader + /// + /// \param[in] file the data source + /// \param[in] pool a MemoryPool to use for buffer allocations + /// \return the returned reader object + static Result> Open( + const std::shared_ptr& file, MemoryPool* pool); + /// \brief Return the metadata read from the ORC file /// /// \return A KeyValueMetadata object containing the ORC metadata @@ -56,31 +66,63 @@ class ARROW_EXPORT ORCFileReader { /// \brief Return the schema read from the ORC file /// /// \param[out] out the returned Schema object + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status ReadSchema(std::shared_ptr* out); + /// \brief Return the schema read from the ORC file + /// + /// \return the returned Schema object + Result> ReadSchema(); + /// \brief Read the file as a Table /// /// The table will be composed of one record batch per stripe. /// /// \param[out] out the returned Table + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status Read(std::shared_ptr
* out); + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \return the returned Table + Result> Read(); + /// \brief Read the file as a Table /// /// The table will be composed of one record batch per stripe. /// /// \param[in] schema the Table schema /// \param[out] out the returned Table + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status Read(const std::shared_ptr& schema, std::shared_ptr
* out); + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \param[in] schema the Table schema + /// \return the returned Table + Result> Read(const std::shared_ptr& schema); + /// \brief Read the file as a Table /// /// The table will be composed of one record batch per stripe. /// /// \param[in] include_indices the selected field indices to read /// \param[out] out the returned Table + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status Read(const std::vector& include_indices, std::shared_ptr
* out); + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \param[in] include_indices the selected field indices to read + /// \return the returned Table + Result> Read(const std::vector& include_indices); + /// \brief Read the file as a Table /// /// The table will be composed of one record batch per stripe. @@ -88,23 +130,50 @@ class ARROW_EXPORT ORCFileReader { /// \param[in] schema the Table schema /// \param[in] include_indices the selected field indices to read /// \param[out] out the returned Table + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status Read(const std::shared_ptr& schema, const std::vector& include_indices, std::shared_ptr
* out); + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \param[in] schema the Table schema + /// \param[in] include_indices the selected field indices to read + /// \return the returned Table + Result> Read(const std::shared_ptr& schema, + const std::vector& include_indices); + /// \brief Read a single stripe as a RecordBatch /// /// \param[in] stripe the stripe index /// \param[out] out the returned RecordBatch + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status ReadStripe(int64_t stripe, std::shared_ptr* out); + /// \brief Read a single stripe as a RecordBatch + /// + /// \param[in] stripe the stripe index + /// \return the returned RecordBatch + Result> ReadStripe(int64_t stripe); + /// \brief Read a single stripe as a RecordBatch /// /// \param[in] stripe the stripe index /// \param[in] include_indices the selected field indices to read /// \param[out] out the returned RecordBatch + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status ReadStripe(int64_t stripe, const std::vector& include_indices, std::shared_ptr* out); + /// \brief Read a single stripe as a RecordBatch + /// + /// \param[in] stripe the stripe index + /// \param[in] include_indices the selected field indices to read + /// \return the returned RecordBatch + Result> ReadStripe( + int64_t stripe, const std::vector& include_indices); + /// \brief Seek to designated row. Invoke NextStripeReader() after seek /// will return stripe reader starting from designated row. /// @@ -119,8 +188,19 @@ class ARROW_EXPORT ORCFileReader { /// \param[in] batch_size the number of rows each record batch contains in /// record batch iteration. /// \param[out] out the returned stripe reader + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status NextStripeReader(int64_t batch_size, std::shared_ptr* out); + /// \brief Get a stripe level record batch iterator with specified row count + /// in each record batch. NextStripeReader serves as a fine grain + /// alternative to ReadStripe which may cause OOM issue by loading + /// the whole stripes into memory. + /// + /// \param[in] batch_size the number of rows each record batch contains in + /// record batch iteration. + /// \return the returned stripe reader + Result> NextStripeReader(int64_t batch_size); + /// \brief Get a stripe level record batch iterator with specified row count /// in each record batch. NextStripeReader serves as a fine grain /// alternative to ReadStripe which may cause OOM issue by loading @@ -131,9 +211,23 @@ class ARROW_EXPORT ORCFileReader { /// /// \param[in] include_indices the selected field indices to read /// \param[out] out the returned stripe reader + ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.") Status NextStripeReader(int64_t batch_size, const std::vector& include_indices, std::shared_ptr* out); + /// \brief Get a stripe level record batch iterator with specified row count + /// in each record batch. NextStripeReader serves as a fine grain + /// alternative to ReadStripe which may cause OOM issue by loading + /// the whole stripes into memory. + /// + /// \param[in] batch_size Get a stripe level record batch iterator with specified row + /// count in each record batch. + /// + /// \param[in] include_indices the selected field indices to read + /// \return the returned stripe reader + Result> NextStripeReader( + int64_t batch_size, const std::vector& include_indices); + /// \brief The number of stripes in the file int64_t NumberOfStripes(); diff --git a/cpp/src/arrow/adapters/orc/adapter_test.cc b/cpp/src/arrow/adapters/orc/adapter_test.cc index 9f7fb561362..39c66b90f6d 100644 --- a/cpp/src/arrow/adapters/orc/adapter_test.cc +++ b/cpp/src/arrow/adapters/orc/adapter_test.cc @@ -237,13 +237,12 @@ void AssertTableWriteReadEqual(const std::shared_ptr
& input_table, ARROW_EXPECT_OK(writer->Close()); EXPECT_OK_AND_ASSIGN(auto buffer, buffer_output_stream->Finish()); std::shared_ptr in_stream(new io::BufferReader(buffer)); - std::unique_ptr reader; - ARROW_EXPECT_OK( - adapters::orc::ORCFileReader::Open(in_stream, default_memory_pool(), &reader)); - std::shared_ptr
actual_output_table; - ARROW_EXPECT_OK(reader->Read(&actual_output_table)); + EXPECT_OK_AND_ASSIGN( + auto reader, adapters::orc::ORCFileReader::Open(in_stream, default_memory_pool())); + EXPECT_OK_AND_ASSIGN(auto actual_output_table, reader->Read()); AssertTablesEqual(*expected_output_table, *actual_output_table, false, false); } + void AssertArrayWriteReadEqual(const std::shared_ptr& input_array, const std::shared_ptr& expected_output_array, const int64_t max_size = kDefaultSmallMemStreamSize) { @@ -323,9 +322,8 @@ TEST(TestAdapterRead, ReadIntAndStringFileMultipleStripes) { std::make_shared(reinterpret_cast(mem_stream.getData()), static_cast(mem_stream.getLength())))); - std::unique_ptr reader; - ASSERT_TRUE( - adapters::orc::ORCFileReader::Open(in_stream, default_memory_pool(), &reader).ok()); + ASSERT_OK_AND_ASSIGN( + auto reader, adapters::orc::ORCFileReader::Open(in_stream, default_memory_pool())); EXPECT_OK_AND_ASSIGN(auto metadata, reader->ReadMetadata()); auto expected_metadata = std::const_pointer_cast( @@ -334,8 +332,7 @@ TEST(TestAdapterRead, ReadIntAndStringFileMultipleStripes) { ASSERT_EQ(stripe_row_count * stripe_count, reader->NumberOfRows()); ASSERT_EQ(stripe_count, reader->NumberOfStripes()); accumulated = 0; - std::shared_ptr stripe_reader; - EXPECT_TRUE(reader->NextStripeReader(reader_batch_size, &stripe_reader).ok()); + EXPECT_OK_AND_ASSIGN(auto stripe_reader, reader->NextStripeReader(reader_batch_size)); while (stripe_reader) { std::shared_ptr record_batch; EXPECT_TRUE(stripe_reader->ReadNext(&record_batch).ok()); @@ -350,14 +347,14 @@ TEST(TestAdapterRead, ReadIntAndStringFileMultipleStripes) { } EXPECT_TRUE(stripe_reader->ReadNext(&record_batch).ok()); } - EXPECT_TRUE(reader->NextStripeReader(reader_batch_size, &stripe_reader).ok()); + EXPECT_OK_AND_ASSIGN(stripe_reader, reader->NextStripeReader(reader_batch_size)); } // test seek operation int64_t start_offset = 830; EXPECT_TRUE(reader->Seek(stripe_row_count + start_offset).ok()); - EXPECT_TRUE(reader->NextStripeReader(reader_batch_size, &stripe_reader).ok()); + EXPECT_OK_AND_ASSIGN(stripe_reader, reader->NextStripeReader(reader_batch_size)); std::shared_ptr record_batch; EXPECT_TRUE(stripe_reader->ReadNext(&record_batch).ok()); while (record_batch) { diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index cc45a369400..d9617c4e603 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -3214,4 +3214,73 @@ TEST(TestSwapEndianArrayData, MonthDayNanoInterval) { ASSERT_OK(swap_array->ValidateFull()); } +DataTypeVector SwappableTypes() { + return DataTypeVector{int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + decimal128(19, 4), + decimal256(37, 8), + timestamp(TimeUnit::MICRO, ""), + time32(TimeUnit::SECOND), + time64(TimeUnit::NANO), + date32(), + date64(), + day_time_interval(), + month_interval(), + month_day_nano_interval(), + binary(), + utf8(), + large_binary(), + large_utf8(), + list(int16()), + large_list(int16()), + dictionary(int16(), utf8())}; +} + +TEST(TestSwapEndianArrayData, RandomData) { + random::RandomArrayGenerator rng(42); + + for (const auto& type : SwappableTypes()) { + ARROW_SCOPED_TRACE("type = ", type->ToString()); + auto arr = rng.ArrayOf(*field("", type), /*size=*/31); + ASSERT_OK_AND_ASSIGN(auto swapped_data, + ::arrow::internal::SwapEndianArrayData(arr->data())); + auto swapped = MakeArray(swapped_data); + ASSERT_OK_AND_ASSIGN(auto roundtripped_data, + ::arrow::internal::SwapEndianArrayData(swapped_data)); + auto roundtripped = MakeArray(roundtripped_data); + ASSERT_OK(roundtripped->ValidateFull()); + + AssertArraysEqual(*arr, *roundtripped, /*verbose=*/true); + if (type->id() == Type::INT8 || type->id() == Type::UINT8) { + AssertArraysEqual(*arr, *swapped, /*verbose=*/true); + } else { + // Random generated data is unlikely to be made of byte-palindromes + ASSERT_FALSE(arr->Equals(*swapped)); + } + } +} + +TEST(TestSwapEndianArrayData, InvalidLength) { + // IPC-incoming data may be invalid, SwapEndianArrayData shouldn't crash + // by accessing memory out of bounds. + random::RandomArrayGenerator rng(42); + + for (const auto& type : SwappableTypes()) { + ARROW_SCOPED_TRACE("type = ", type->ToString()); + ASSERT_OK_AND_ASSIGN(auto arr, MakeArrayOfNull(type, 0)); + auto data = arr->data(); + // Fake length + data->length = 123456789; + ASSERT_OK_AND_ASSIGN(auto swapped_data, ::arrow::internal::SwapEndianArrayData(data)); + auto swapped = MakeArray(swapped_data); + ASSERT_RAISES(Invalid, swapped->Validate()); + } +} + } // namespace arrow diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index f12281155b8..232947d2c88 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -78,11 +78,16 @@ class ArrayDataWrapper { class ArrayDataEndianSwapper { public: - ArrayDataEndianSwapper(const std::shared_ptr& data, int64_t length) - : data_(data), length_(length) { + explicit ArrayDataEndianSwapper(const std::shared_ptr& data) : data_(data) { out_ = data->Copy(); } + // WARNING: this facility can be called on invalid Array data by the IPC reader. + // Do not rely on the advertised ArrayData length, instead use the physical + // buffer sizes to avoid accessing memory out of bounds. + // + // (If this guarantee turns out to be difficult to maintain, we should call + // Validate() instead) Status SwapType(const DataType& type) { RETURN_NOT_OK(VisitTypeInline(type, this)); RETURN_NOT_OK(SwapChildren(type.fields())); @@ -111,6 +116,7 @@ class ArrayDataEndianSwapper { auto in_data = reinterpret_cast(in_buffer->data()); ARROW_ASSIGN_OR_RAISE(auto out_buffer, AllocateBuffer(in_buffer->size())); auto out_data = reinterpret_cast(out_buffer->mutable_data()); + // NOTE: data_->length not trusted (see warning above) int64_t length = in_buffer->size() / sizeof(T); for (int64_t i = 0; i < length; i++) { out_data[i] = BitUtil::ByteSwap(in_data[i]); @@ -146,8 +152,8 @@ class ArrayDataEndianSwapper { auto data = reinterpret_cast(data_->buffers[1]->data()); ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size())); auto new_data = reinterpret_cast(new_buffer->mutable_data()); - int64_t length = length_; - length = data_->buffers[1]->size() / (sizeof(uint64_t) * 2); + // NOTE: data_->length not trusted (see warning above) + const int64_t length = data_->buffers[1]->size() / Decimal128Type::kByteWidth; for (int64_t i = 0; i < length; i++) { uint64_t tmp; auto idx = i * 2; @@ -169,8 +175,8 @@ class ArrayDataEndianSwapper { auto data = reinterpret_cast(data_->buffers[1]->data()); ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size())); auto new_data = reinterpret_cast(new_buffer->mutable_data()); - int64_t length = length_; - length = data_->buffers[1]->size() / (sizeof(uint64_t) * 4); + // NOTE: data_->length not trusted (see warning above) + const int64_t length = data_->buffers[1]->size() / Decimal256Type::kByteWidth; for (int64_t i = 0; i < length; i++) { uint64_t tmp0, tmp1, tmp2; auto idx = i * 4; @@ -206,9 +212,10 @@ class ArrayDataEndianSwapper { auto data = reinterpret_cast(data_->buffers[1]->data()); ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size())); auto new_data = reinterpret_cast(new_buffer->mutable_data()); - int64_t length = data_->length; + // NOTE: data_->length not trusted (see warning above) + const int64_t length = data_->buffers[1]->size() / sizeof(MonthDayNanos); for (int64_t i = 0; i < length; i++) { - MonthDayNanoIntervalType::MonthDayNanos tmp = data[i]; + MonthDayNanos tmp = data[i]; #if ARROW_LITTLE_ENDIAN tmp.months = BitUtil::FromBigEndian(tmp.months); tmp.days = BitUtil::FromBigEndian(tmp.days); @@ -279,7 +286,6 @@ class ArrayDataEndianSwapper { } const std::shared_ptr& data_; - int64_t length_; std::shared_ptr out_; }; @@ -292,7 +298,7 @@ Result> SwapEndianArrayData( if (data->offset != 0) { return Status::Invalid("Unsupported data format: data.offset != 0"); } - ArrayDataEndianSwapper swapper(data, data->length); + ArrayDataEndianSwapper swapper(data); RETURN_NOT_OK(swapper.SwapType(*data->type)); return std::move(swapper.out_); } diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 8b8153465ee..9484b44590a 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -28,6 +28,7 @@ #include "arrow/buffer.h" #include "arrow/c/helpers.h" #include "arrow/c/util_internal.h" +#include "arrow/extension_type.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" #include "arrow/result.h" @@ -56,8 +57,6 @@ using internal::ArrayExportTraits; using internal::SchemaExportGuard; using internal::SchemaExportTraits; -// TODO export / import Extension types and arrays - namespace { Status ExportingNotImplemented(const DataType& type) { @@ -171,23 +170,26 @@ struct SchemaExporter { export_.name_ = field.name(); flags_ = field.nullable() ? ARROW_FLAG_NULLABLE : 0; - const DataType& type = *field.type(); - RETURN_NOT_OK(ExportFormat(type)); - RETURN_NOT_OK(ExportChildren(type.fields())); + const DataType* type = UnwrapExtension(field.type().get()); + RETURN_NOT_OK(ExportFormat(*type)); + RETURN_NOT_OK(ExportChildren(type->fields())); RETURN_NOT_OK(ExportMetadata(field.metadata().get())); return Status::OK(); } - Status ExportType(const DataType& type) { + Status ExportType(const DataType& orig_type) { flags_ = ARROW_FLAG_NULLABLE; - RETURN_NOT_OK(ExportFormat(type)); - RETURN_NOT_OK(ExportChildren(type.fields())); + const DataType* type = UnwrapExtension(&orig_type); + RETURN_NOT_OK(ExportFormat(*type)); + RETURN_NOT_OK(ExportChildren(type->fields())); + // There may be additional metadata to export + RETURN_NOT_OK(ExportMetadata(nullptr)); return Status::OK(); } Status ExportSchema(const Schema& schema) { - static StructType dummy_struct_type({}); + static const StructType dummy_struct_type({}); flags_ = 0; RETURN_NOT_OK(ExportFormat(dummy_struct_type)); @@ -232,6 +234,17 @@ struct SchemaExporter { c_struct->release = ReleaseExportedSchema; } + const DataType* UnwrapExtension(const DataType* type) { + if (type->id() == Type::EXTENSION) { + const auto& ext_type = checked_cast(*type); + additional_metadata_.reserve(2); + additional_metadata_.emplace_back(kExtensionTypeKeyName, ext_type.extension_name()); + additional_metadata_.emplace_back(kExtensionMetadataKeyName, ext_type.Serialize()); + return ext_type.storage_type().get(); + } + return type; + } + Status ExportFormat(const DataType& type) { if (type.id() == Type::DICTIONARY) { const auto& dict_type = checked_cast(type); @@ -259,10 +272,29 @@ struct SchemaExporter { return Status::OK(); } - Status ExportMetadata(const KeyValueMetadata* metadata) { - if (metadata != nullptr && metadata->size() >= 0) { - ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*metadata)); + Status ExportMetadata(const KeyValueMetadata* orig_metadata) { + static const KeyValueMetadata empty_metadata; + + if (orig_metadata == nullptr) { + orig_metadata = &empty_metadata; } + if (additional_metadata_.empty()) { + if (orig_metadata->size() > 0) { + ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*orig_metadata)); + } + return Status::OK(); + } + // Additional metadata needs to be appended to the existing + // (for extension types) + KeyValueMetadata metadata(orig_metadata->keys(), orig_metadata->values()); + for (const auto& kv : additional_metadata_) { + // The metadata may already be there => ignore + if (metadata.Contains(kv.first)) { + continue; + } + metadata.Append(kv.first, kv.second); + } + ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(metadata)); return Status::OK(); } @@ -442,6 +474,7 @@ struct SchemaExporter { ExportedSchemaPrivateData export_; int64_t flags_ = 0; + std::vector> additional_metadata_; std::unique_ptr dict_exporter_; std::vector child_exporters_; }; @@ -721,7 +754,13 @@ class FormatStringParser { size_t index_; }; -Result> DecodeMetadata(const char* metadata) { +struct DecodedMetadata { + std::shared_ptr metadata; + std::string extension_name; + std::string extension_serialized; +}; + +Result DecodeMetadata(const char* metadata) { auto read_int32 = [&](int32_t* out) -> Status { int32_t v; memcpy(&v, metadata, 4); @@ -744,21 +783,29 @@ Result> DecodeMetadata(const char* metadata) { return Status::OK(); }; + DecodedMetadata decoded; + if (metadata == nullptr) { - return nullptr; + return decoded; } int32_t npairs; RETURN_NOT_OK(read_int32(&npairs)); if (npairs == 0) { - return nullptr; + return decoded; } std::vector keys(npairs); std::vector values(npairs); for (int32_t i = 0; i < npairs; ++i) { RETURN_NOT_OK(read_string(&keys[i])); RETURN_NOT_OK(read_string(&values[i])); + if (keys[i] == kExtensionTypeKeyName) { + decoded.extension_name = values[i]; + } else if (keys[i] == kExtensionMetadataKeyName) { + decoded.extension_serialized = values[i]; + } } - return key_value_metadata(std::move(keys), std::move(values)); + decoded.metadata = key_value_metadata(std::move(keys), std::move(values)); + return decoded; } struct SchemaImporter { @@ -775,10 +822,9 @@ struct SchemaImporter { } Result> MakeField() const { - ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata)); const char* name = c_struct_->name ? c_struct_->name : ""; bool nullable = (c_struct_->flags & ARROW_FLAG_NULLABLE) != 0; - return field(name, type_, nullable, std::move(metadata)); + return field(name, type_, nullable, std::move(metadata_.metadata)); } Result> MakeSchema() const { @@ -787,8 +833,7 @@ struct SchemaImporter { "Cannot import schema: ArrowSchema describes non-struct type ", type_->ToString()); } - ARROW_ASSIGN_OR_RAISE(auto metadata, DecodeMetadata(c_struct_->metadata)); - return schema(type_->fields(), std::move(metadata)); + return schema(type_->fields(), std::move(metadata_.metadata)); } Result> MakeType() const { return type_; } @@ -836,6 +881,20 @@ struct SchemaImporter { bool ordered = (c_struct_->flags & ARROW_FLAG_DICTIONARY_ORDERED) != 0; type_ = dictionary(type_, dict_importer.type_, ordered); } + + // Import metadata + ARROW_ASSIGN_OR_RAISE(metadata_, DecodeMetadata(c_struct_->metadata)); + + // Detect extension type + if (!metadata_.extension_name.empty()) { + const auto registered_ext_type = GetExtensionType(metadata_.extension_name); + if (registered_ext_type) { + ARROW_ASSIGN_OR_RAISE( + type_, registered_ext_type->Deserialize(std::move(type_), + metadata_.extension_serialized)); + } + } + return Status::OK(); } @@ -1130,6 +1189,7 @@ struct SchemaImporter { int64_t recursion_level_; std::vector child_importers_; std::shared_ptr type_; + DecodedMetadata metadata_; }; } // namespace @@ -1255,8 +1315,15 @@ struct ArrayImporter { } Status DoImport() { + // Unwrap extension type + const DataType* storage_type = type_.get(); + if (storage_type->id() == Type::EXTENSION) { + storage_type = + checked_cast(*storage_type).storage_type().get(); + } + // First import children (required for reconstituting parent array data) - const auto& fields = type_->fields(); + const auto& fields = storage_type->fields(); if (c_struct_->n_children != static_cast(fields.size())) { return Status::Invalid("ArrowArray struct has ", c_struct_->n_children, " children, expected ", fields.size(), " for type ", @@ -1270,15 +1337,15 @@ struct ArrayImporter { } // Import main data - RETURN_NOT_OK(ImportMainData()); + RETURN_NOT_OK(VisitTypeInline(*storage_type, this)); - bool is_dict_type = (type_->id() == Type::DICTIONARY); + bool is_dict_type = (storage_type->id() == Type::DICTIONARY); if (c_struct_->dictionary != nullptr) { if (!is_dict_type) { return Status::Invalid("Import type is ", type_->ToString(), " but dictionary field in ArrowArray struct is not null"); } - const auto& dict_type = checked_cast(*type_); + const auto& dict_type = checked_cast(*storage_type); // Import dictionary values ArrayImporter dict_importer(dict_type.value_type()); RETURN_NOT_OK(dict_importer.ImportDict(this, c_struct_->dictionary)); @@ -1292,13 +1359,11 @@ struct ArrayImporter { return Status::OK(); } - Status ImportMainData() { return VisitTypeInline(*type_, this); } - Status Visit(const DataType& type) { return Status::NotImplemented("Cannot import array of type ", type_->ToString()); } - Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(); } + Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(type); } Status Visit(const NullType& type) { RETURN_NOT_OK(CheckNoChildren()); @@ -1352,16 +1417,15 @@ struct ArrayImporter { return Status::OK(); } - Status ImportFixedSizePrimitive() { - const auto& fw_type = checked_cast(*type_); + Status ImportFixedSizePrimitive(const FixedWidthType& type) { RETURN_NOT_OK(CheckNoChildren()); RETURN_NOT_OK(CheckNumBuffers(2)); RETURN_NOT_OK(AllocateArrayData()); RETURN_NOT_OK(ImportNullBitmap()); - if (BitUtil::IsMultipleOf8(fw_type.bit_width())) { - RETURN_NOT_OK(ImportFixedSizeBuffer(1, fw_type.bit_width() / 8)); + if (BitUtil::IsMultipleOf8(type.bit_width())) { + RETURN_NOT_OK(ImportFixedSizeBuffer(1, type.bit_width() / 8)); } else { - DCHECK_EQ(fw_type.bit_width(), 1); + DCHECK_EQ(type.bit_width(), 1); RETURN_NOT_OK(ImportBitsBuffer(1)); } return Status::OK(); diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 54ce0efcf9d..c51cb66c03b 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -31,8 +31,10 @@ #include "arrow/c/util_internal.h" #include "arrow/ipc/json_simple.h" #include "arrow/memory_pool.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/endian.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" @@ -45,6 +47,7 @@ using internal::ArrayExportGuard; using internal::ArrayExportTraits; using internal::ArrayStreamExportGuard; using internal::ArrayStreamExportTraits; +using internal::checked_cast; using internal::SchemaExportGuard; using internal::SchemaExportTraits; @@ -122,6 +125,10 @@ using ArrayReleaseCallback = ReleaseCallback; static const std::vector kMetadataKeys1{"key1", "key2"}; static const std::vector kMetadataValues1{"", "bar"}; + +static const std::vector kMetadataKeys2{"key"}; +static const std::vector kMetadataValues2{"abcde"}; + // clang-format off static const std::string kEncodedMetadata1{ // NOLINT: runtime/string #if ARROW_LITTLE_ENDIAN @@ -133,11 +140,7 @@ static const std::string kEncodedMetadata1{ // NOLINT: runtime/string 0, 0, 0, 4, 'k', 'e', 'y', '1', 0, 0, 0, 0, 0, 0, 0, 4, 'k', 'e', 'y', '2', 0, 0, 0, 3, 'b', 'a', 'r'}; #endif -// clang-format on -static const std::vector kMetadataKeys2{"key"}; -static const std::vector kMetadataValues2{"abcde"}; -// clang-format off static const std::string kEncodedMetadata2{ // NOLINT: runtime/string #if ARROW_LITTLE_ENDIAN 1, 0, 0, 0, @@ -146,6 +149,51 @@ static const std::string kEncodedMetadata2{ // NOLINT: runtime/string 0, 0, 0, 1, 0, 0, 0, 3, 'k', 'e', 'y', 0, 0, 0, 5, 'a', 'b', 'c', 'd', 'e'}; #endif + +static const std::string kEncodedUuidMetadata = // NOLINT: runtime/string +#if ARROW_LITTLE_ENDIAN + std::string {2, 0, 0, 0} + + std::string {20, 0, 0, 0} + kExtensionTypeKeyName + + std::string {4, 0, 0, 0} + "uuid" + + std::string {24, 0, 0, 0} + kExtensionMetadataKeyName + + std::string {15, 0, 0, 0} + "uuid-serialized"; +#else + std::string {0, 0, 0, 2} + + std::string {0, 0, 0, 20} + kExtensionTypeKeyName + + std::string {0, 0, 0, 4} + "uuid" + + std::string {0, 0, 0, 24} + kExtensionMetadataKeyName + + std::string {0, 0, 0, 15} + "uuid-serialized"; +#endif + +static const std::string kEncodedDictExtensionMetadata = // NOLINT: runtime/string +#if ARROW_LITTLE_ENDIAN + std::string {2, 0, 0, 0} + + std::string {20, 0, 0, 0} + kExtensionTypeKeyName + + std::string {14, 0, 0, 0} + "dict-extension" + + std::string {24, 0, 0, 0} + kExtensionMetadataKeyName + + std::string {25, 0, 0, 0} + "dict-extension-serialized"; +#else + std::string {0, 0, 0, 2} + + std::string {0, 0, 0, 20} + kExtensionTypeKeyName + + std::string {0, 0, 0, 14} + "dict-extension" + + std::string {0, 0, 0, 24} + kExtensionMetadataKeyName + + std::string {0, 0, 0, 25} + "dict-extension-serialized"; +#endif + +static const std::string kEncodedComplex128Metadata = // NOLINT: runtime/string +#if ARROW_LITTLE_ENDIAN + std::string {2, 0, 0, 0} + + std::string {20, 0, 0, 0} + kExtensionTypeKeyName + + std::string {10, 0, 0, 0} + "complex128" + + std::string {24, 0, 0, 0} + kExtensionMetadataKeyName + + std::string {21, 0, 0, 0} + "complex128-serialized"; +#else + std::string {0, 0, 0, 2} + + std::string {0, 0, 0, 20} + kExtensionTypeKeyName + + std::string {0, 0, 0, 10} + "complex128" + + std::string {0, 0, 0, 24} + kExtensionMetadataKeyName + + std::string {0, 0, 0, 21} + "complex128-serialized"; +#endif // clang-format on static constexpr int64_t kDefaultFlags = ARROW_FLAG_NULLABLE; @@ -404,6 +452,16 @@ TEST_F(TestSchemaExport, Dictionary) { } } +TEST_F(TestSchemaExport, Extension) { + TestPrimitive(uuid(), "w:16", "", kDefaultFlags, kEncodedUuidMetadata); + + TestNested(dict_extension_type(), {"c", "u"}, {"", ""}, {kDefaultFlags, kDefaultFlags}, + {kEncodedDictExtensionMetadata, ""}); + + TestNested(complex128(), {"+s", "g", "g"}, {"", "real", "imag"}, + {ARROW_FLAG_NULLABLE, 0, 0}, {kEncodedComplex128Metadata, "", ""}); +} + TEST_F(TestSchemaExport, ExportField) { TestPrimitive(field("thing", null()), "n", "thing", ARROW_FLAG_NULLABLE); // With nullable = false @@ -507,11 +565,9 @@ class TestArrayExport : public ::testing::Test { public: void SetUp() override { pool_ = default_memory_pool(); } - static std::function*)> JSONArrayFactory( + static std::function>()> JSONArrayFactory( std::shared_ptr type, const char* json) { - return [=](std::shared_ptr* out) -> Status { - return ::arrow::ipc::internal::json::ArrayFromJSON(type, json, out); - }; + return [=]() { return ArrayFromJSON(type, json); }; } template @@ -519,7 +575,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); const ArrayData& data = *arr->data(); // non-owning reference struct ArrowArray c_export; ASSERT_OK(ExportArray(*arr, &c_export)); @@ -562,7 +618,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); const ArrayData& data = *arr->data(); // non-owning reference struct ArrowArray c_export_temp, c_export_final; ASSERT_OK(ExportArray(*arr, &c_export_temp)); @@ -607,7 +663,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); struct ArrowArray c_export_parent, c_export_child; ASSERT_OK(ExportArray(*arr, &c_export_parent)); @@ -661,7 +717,7 @@ class TestArrayExport : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); std::shared_ptr arr; - ASSERT_OK(factory(&arr)); + ASSERT_OK_AND_ASSIGN(arr, ToResult(factory())); struct ArrowArray c_export_parent; ASSERT_OK(ExportArray(*arr, &c_export_parent)); @@ -752,10 +808,7 @@ TEST_F(TestArrayExport, Primitive) { } TEST_F(TestArrayExport, PrimitiveSliced) { - auto factory = [](std::shared_ptr* out) -> Status { - *out = ArrayFromJSON(int16(), "[1, 2, null, -3]")->Slice(1, 2); - return Status::OK(); - }; + auto factory = []() { return ArrayFromJSON(int16(), "[1, 2, null, -3]")->Slice(1, 2); }; TestPrimitive(factory); } @@ -802,18 +855,17 @@ TEST_F(TestArrayExport, List) { TEST_F(TestArrayExport, ListSliced) { { - auto factory = [](std::shared_ptr* out) -> Status { - *out = ArrayFromJSON(list(int8()), "[[1, 2], [3, null], [4, 5, 6], null]") - ->Slice(1, 2); - return Status::OK(); + auto factory = []() { + return ArrayFromJSON(list(int8()), "[[1, 2], [3, null], [4, 5, 6], null]") + ->Slice(1, 2); }; TestNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(int16(), "[1, 2, 3, 4, null, 5, 6, 7, 8]")->Slice(1, 6); auto offsets = ArrayFromJSON(int32(), "[0, 2, 3, 5, 6]")->Slice(2, 4); - return ListArray::FromArrays(*offsets, *values).Value(out); + return ListArray::FromArrays(*offsets, *values); }; TestNested(factory); } @@ -847,28 +899,25 @@ TEST_F(TestArrayExport, Union) { TEST_F(TestArrayExport, Dictionary) { { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])"); auto indices = ArrayFromJSON(uint16(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), - indices, values) - .Value(out); + indices, values); }; TestNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays( - dictionary(indices->type(), values->type(), /*ordered=*/true), indices, - values) - .Value(out); + dictionary(indices->type(), values->type(), /*ordered=*/true), indices, values); }; TestNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() -> Result> { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); ARROW_ASSIGN_OR_RAISE( @@ -876,13 +925,20 @@ TEST_F(TestArrayExport, Dictionary) { DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), indices, values)); auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]"); - RETURN_NOT_OK(LargeListArray::FromArrays(*offsets, *dict_array).Value(out)); - return (*out)->ValidateFull(); + ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array)); + RETURN_NOT_OK(arr->ValidateFull()); + return arr; }; TestNested(factory); } } +TEST_F(TestArrayExport, Extension) { + TestPrimitive(ExampleUuid); + TestPrimitive(ExampleSmallint); + TestPrimitive(ExampleComplex128); +} + TEST_F(TestArrayExport, MovePrimitive) { TestMovePrimitive(int8(), "[1, 2, null, -3]"); TestMovePrimitive(fixed_size_binary(3), R"(["foo", "bar", null])"); @@ -898,17 +954,16 @@ TEST_F(TestArrayExport, MoveNested) { TEST_F(TestArrayExport, MoveDictionary) { { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), - indices, values) - .Value(out); + indices, values); }; TestMoveNested(factory); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() -> Result> { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); ARROW_ASSIGN_OR_RAISE( @@ -916,8 +971,9 @@ TEST_F(TestArrayExport, MoveDictionary) { DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), indices, values)); auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]"); - RETURN_NOT_OK(LargeListArray::FromArrays(*offsets, *dict_array).Value(out)); - return (*out)->ValidateFull(); + ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array)); + RETURN_NOT_OK(arr->ValidateFull()); + return arr; }; TestMoveNested(factory); } @@ -934,7 +990,7 @@ TEST_F(TestArrayExport, MoveChild) { R"([[1, "foo"], [2, null]])", /*child_id=*/1); { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() -> Result> { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); ARROW_ASSIGN_OR_RAISE( @@ -942,8 +998,9 @@ TEST_F(TestArrayExport, MoveChild) { DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), indices, values)); auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]"); - RETURN_NOT_OK(LargeListArray::FromArrays(*offsets, *dict_array).Value(out)); - return (*out)->ValidateFull(); + ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array)); + RETURN_NOT_OK(arr->ValidateFull()); + return arr; }; TestMoveChild(factory, /*child_id=*/0); } @@ -1400,6 +1457,32 @@ TEST_F(TestSchemaImport, Dictionary) { CheckImport(expected); } +TEST_F(TestSchemaImport, UnregisteredExtension) { + FillPrimitive("w:16"); + c_struct_.metadata = kEncodedUuidMetadata.c_str(); + auto expected = fixed_size_binary(16); + CheckImport(expected); +} + +TEST_F(TestSchemaImport, RegisteredExtension) { + { + ExtensionTypeGuard guard(uuid()); + FillPrimitive("w:16"); + c_struct_.metadata = kEncodedUuidMetadata.c_str(); + auto expected = uuid(); + CheckImport(expected); + } + { + ExtensionTypeGuard guard(dict_extension_type()); + FillPrimitive(AddChild(), "u"); + FillPrimitive("c"); + FillDictionary(); + c_struct_.metadata = kEncodedDictExtensionMetadata.c_str(); + auto expected = dict_extension_type(); + CheckImport(expected); + } +} + TEST_F(TestSchemaImport, FormatStringError) { FillPrimitive(""); CheckImportError(); @@ -1481,6 +1564,22 @@ TEST_F(TestSchemaImport, DictionaryError) { CheckImportError(); } +TEST_F(TestSchemaImport, ExtensionError) { + ExtensionTypeGuard guard(uuid()); + + // Storage type doesn't match + FillPrimitive("w:15"); + c_struct_.metadata = kEncodedUuidMetadata.c_str(); + CheckImportError(); + + // Invalid serialization + std::string bogus_metadata = kEncodedUuidMetadata; + bogus_metadata[bogus_metadata.size() - 5] += 1; + FillPrimitive("w:16"); + c_struct_.metadata = bogus_metadata.c_str(); + CheckImportError(); +} + TEST_F(TestSchemaImport, RecursionError) { FillPrimitive(AddChild(), "c", "unused"); auto c = AddChild(); @@ -2163,21 +2262,44 @@ TEST_F(TestArrayImport, DictionaryWithOffset) { FillPrimitive(3, 0, 0, primitive_buffers_no_nulls4); FillDictionary(); - auto dict_values = ArrayFromJSON(utf8(), R"(["", "bar", "quux"])"); - auto indices = ArrayFromJSON(int8(), "[1, 2, 0]"); - ASSERT_OK_AND_ASSIGN( - auto expected, - DictionaryArray::FromArrays(dictionary(int8(), utf8()), indices, dict_values)); + auto expected = DictArrayFromJSON(dictionary(int8(), utf8()), "[1, 2, 0]", + R"(["", "bar", "quux"])"); CheckImport(expected); FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1); FillPrimitive(4, 0, 2, primitive_buffers_no_nulls4); FillDictionary(); - dict_values = ArrayFromJSON(utf8(), R"(["foo", "", "bar", "quux"])"); - indices = ArrayFromJSON(int8(), "[0, 1, 3, 0]"); - ASSERT_OK_AND_ASSIGN(expected, DictionaryArray::FromArrays(dictionary(int8(), utf8()), - indices, dict_values)); + expected = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, 3, 0]", + R"(["foo", "", "bar", "quux"])"); + CheckImport(expected); +} + +TEST_F(TestArrayImport, RegisteredExtension) { + ExtensionTypeGuard guard({smallint(), dict_extension_type(), complex128()}); + + // smallint + FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_16); + auto expected = + ExtensionType::WrapArray(smallint(), ArrayFromJSON(int16(), "[513, 1027, 1541]")); + CheckImport(expected); + + // dict_extension_type + FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1); + FillPrimitive(6, 0, 0, primitive_buffers_no_nulls4); + FillDictionary(); + + auto storage = DictArrayFromJSON(dictionary(int8(), utf8()), "[1, 2, 0, 1, 3, 0]", + R"(["foo", "", "bar", "quux"])"); + expected = ExtensionType::WrapArray(dict_extension_type(), storage); + CheckImport(expected); + + // complex128 + FillPrimitive(AddChild(), 3, 0, /*offset=*/0, primitive_buffers_no_nulls6); + FillPrimitive(AddChild(), 3, 0, /*offset=*/3, primitive_buffers_no_nulls6); + FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data); + expected = MakeComplex128(ArrayFromJSON(float64(), "[0.0, 1.5, -2.0]"), + ArrayFromJSON(float64(), "[3.0, 4.0, 5.0]")); CheckImport(expected); } @@ -2341,8 +2463,9 @@ class TestSchemaRoundtrip : public ::testing::Test { public: void SetUp() override { pool_ = default_memory_pool(); } - template - void TestWithTypeFactory(TypeFactory&& factory) { + template + void TestWithTypeFactory(TypeFactory&& factory, + ExpectedTypeFactory&& factory_expected) { std::shared_ptr type, actual; struct ArrowSchema c_schema {}; // zeroed SchemaExportGuard schema_guard(&c_schema); @@ -2359,7 +2482,7 @@ class TestSchemaRoundtrip : public ::testing::Test { // Recreate the type ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema)); - type = factory(); + type = factory_expected(); AssertTypeEqual(*type, *actual); type.reset(); actual.reset(); @@ -2367,6 +2490,11 @@ class TestSchemaRoundtrip : public ::testing::Test { ASSERT_EQ(pool_->bytes_allocated(), orig_bytes); } + template + void TestWithTypeFactory(TypeFactory&& factory) { + TestWithTypeFactory(factory, factory); + } + template void TestWithSchemaFactory(SchemaFactory&& factory) { std::shared_ptr schema, actual; @@ -2459,6 +2587,27 @@ TEST_F(TestSchemaRoundtrip, Dictionary) { } } +TEST_F(TestSchemaRoundtrip, UnregisteredExtension) { + TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); }); + TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); }); + + // Inside nested type + TestWithTypeFactory([]() { return list(dict_extension_type()); }, + []() { return list(dictionary(int8(), utf8())); }); +} + +TEST_F(TestSchemaRoundtrip, RegisteredExtension) { + ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()}); + TestWithTypeFactory(uuid); + TestWithTypeFactory(dict_extension_type); + TestWithTypeFactory(complex128); + + // Inside nested type + TestWithTypeFactory([]() { return list(uuid()); }); + TestWithTypeFactory([]() { return list(dict_extension_type()); }); + TestWithTypeFactory([]() { return list(complex128()); }); +} + TEST_F(TestSchemaRoundtrip, Map) { TestWithTypeFactory([&]() { return map(utf8(), int32()); }); TestWithTypeFactory([&]() { return map(list(utf8()), int32()); }); @@ -2482,28 +2631,30 @@ TEST_F(TestSchemaRoundtrip, Schema) { class TestArrayRoundtrip : public ::testing::Test { public: - using ArrayFactory = std::function*)>; + using ArrayFactory = std::function>()>; void SetUp() override { pool_ = default_memory_pool(); } static ArrayFactory JSONArrayFactory(std::shared_ptr type, const char* json) { - return [=](std::shared_ptr* out) -> Status { - return ::arrow::ipc::internal::json::ArrayFromJSON(type, json, out); - }; + return [=]() { return ArrayFromJSON(type, json); }; } static ArrayFactory SlicedArrayFactory(ArrayFactory factory) { - return [=](std::shared_ptr* out) -> Status { - std::shared_ptr arr; - RETURN_NOT_OK(factory(&arr)); + return [=]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, factory()); DCHECK_GE(arr->length(), 2); - *out = arr->Slice(1, arr->length() - 2); - return Status::OK(); + return arr->Slice(1, arr->length() - 2); }; } template void TestWithArrayFactory(ArrayFactory&& factory) { + TestWithArrayFactory(factory, factory); + } + + template + void TestWithArrayFactory(ArrayFactory&& factory, + ExpectedArrayFactory&& factory_expected) { std::shared_ptr array; struct ArrowArray c_array {}; struct ArrowSchema c_schema {}; @@ -2512,7 +2663,7 @@ class TestArrayRoundtrip : public ::testing::Test { auto orig_bytes = pool_->bytes_allocated(); - ASSERT_OK(factory(&array)); + ASSERT_OK_AND_ASSIGN(array, ToResult(factory())); ASSERT_OK(ExportType(*array->type(), &c_schema)); ASSERT_OK(ExportArray(*array, &c_array)); @@ -2539,7 +2690,7 @@ class TestArrayRoundtrip : public ::testing::Test { // Check value of imported array { std::shared_ptr expected; - ASSERT_OK(factory(&expected)); + ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected())); AssertTypeEqual(*expected->type(), *array->type()); AssertArraysEqual(*expected, *array, true); } @@ -2556,7 +2707,7 @@ class TestArrayRoundtrip : public ::testing::Test { SchemaExportGuard schema_guard(&c_schema); auto orig_bytes = pool_->bytes_allocated(); - ASSERT_OK(factory(&batch)); + ASSERT_OK_AND_ASSIGN(batch, ToResult(factory())); ASSERT_OK(ExportSchema(*batch->schema(), &c_schema)); ASSERT_OK(ExportRecordBatch(*batch, &c_array)); @@ -2579,7 +2730,7 @@ class TestArrayRoundtrip : public ::testing::Test { // Check value of imported record batch { std::shared_ptr expected; - ASSERT_OK(factory(&expected)); + ASSERT_OK_AND_ASSIGN(expected, ToResult(factory())); AssertSchemaEqual(*expected->schema(), *batch->schema()); AssertBatchesEqual(*expected, *batch); } @@ -2621,15 +2772,15 @@ TEST_F(TestArrayRoundtrip, Primitive) { } TEST_F(TestArrayRoundtrip, UnknownNullCount) { - TestWithArrayFactory([](std::shared_ptr* arr) -> Status { - *arr = ArrayFromJSON(int32(), "[0, 1, 2]"); - if ((*arr)->null_bitmap()) { + TestWithArrayFactory([]() -> Result> { + auto arr = ArrayFromJSON(int32(), "[0, 1, 2]"); + if (arr->null_bitmap()) { return Status::Invalid( "Failed precondition: " "the array shouldn't have a null bitmap."); } - (*arr)->data()->SetNullCount(kUnknownNullCount); - return Status::OK(); + arr->data()->SetNullCount(kUnknownNullCount); + return arr; }); } @@ -2670,30 +2821,62 @@ TEST_F(TestArrayRoundtrip, Nested) { TEST_F(TestArrayRoundtrip, Dictionary) { { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()), - indices, values) - .Value(out); + indices, values); }; TestWithArrayFactory(factory); TestWithArrayFactory(SlicedArrayFactory(factory)); } { - auto factory = [](std::shared_ptr* out) -> Status { + auto factory = []() { auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])"); auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]"); return DictionaryArray::FromArrays( - dictionary(indices->type(), values->type(), /*ordered=*/true), indices, - values) - .Value(out); + dictionary(indices->type(), values->type(), /*ordered=*/true), indices, values); }; TestWithArrayFactory(factory); TestWithArrayFactory(SlicedArrayFactory(factory)); } } +TEST_F(TestArrayRoundtrip, RegisteredExtension) { + ExtensionTypeGuard guard({smallint(), complex128(), dict_extension_type(), uuid()}); + + TestWithArrayFactory(ExampleSmallint); + TestWithArrayFactory(ExampleUuid); + TestWithArrayFactory(ExampleComplex128); + TestWithArrayFactory(ExampleDictExtension); + + // Nested inside outer array + auto NestedFactory = [](ArrayFactory factory) { + return [factory]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, ToResult(factory())); + return FixedSizeListArray::FromArrays(arr, /*list_size=*/1); + }; + }; + TestWithArrayFactory(NestedFactory(ExampleSmallint)); + TestWithArrayFactory(NestedFactory(ExampleUuid)); + TestWithArrayFactory(NestedFactory(ExampleComplex128)); + TestWithArrayFactory(NestedFactory(ExampleDictExtension)); +} + +TEST_F(TestArrayRoundtrip, UnregisteredExtension) { + auto StorageExtractor = [](ArrayFactory factory) { + return [factory]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, ToResult(factory())); + return checked_cast(*arr).storage(); + }; + }; + + TestWithArrayFactory(ExampleSmallint, StorageExtractor(ExampleSmallint)); + TestWithArrayFactory(ExampleUuid, StorageExtractor(ExampleUuid)); + TestWithArrayFactory(ExampleComplex128, StorageExtractor(ExampleComplex128)); + TestWithArrayFactory(ExampleDictExtension, StorageExtractor(ExampleDictExtension)); +} + TEST_F(TestArrayRoundtrip, RecordBatch) { auto schema = ::arrow::schema( {field("ints", int16()), field("bools", boolean(), /*nullable=*/false)}); @@ -2701,22 +2884,18 @@ TEST_F(TestArrayRoundtrip, RecordBatch) { auto arr1 = ArrayFromJSON(boolean(), "[false, true, false]"); { - auto factory = [&](std::shared_ptr* out) -> Status { - *out = RecordBatch::Make(schema, 3, {arr0, arr1}); - return Status::OK(); - }; + auto factory = [&]() { return RecordBatch::Make(schema, 3, {arr0, arr1}); }; TestWithBatchFactory(factory); } { // With schema and field metadata - auto factory = [&](std::shared_ptr* out) -> Status { + auto factory = [&]() { auto f0 = schema->field(0); auto f1 = schema->field(1); f1 = f1->WithMetadata(key_value_metadata(kMetadataKeys1, kMetadataValues1)); auto schema_with_md = ::arrow::schema({f0, f1}, key_value_metadata(kMetadataKeys2, kMetadataValues2)); - *out = RecordBatch::Make(schema_with_md, 3, {arr0, arr1}); - return Status::OK(); + return RecordBatch::Make(schema_with_md, 3, {arr0, arr1}); }; TestWithBatchFactory(factory); } diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc index 6d7bdfa6cf9..1216fe27d4e 100644 --- a/cpp/src/arrow/compute/api_aggregate.cc +++ b/cpp/src/arrow/compute/api_aggregate.cc @@ -85,18 +85,23 @@ static auto kScalarAggregateOptionsType = GetFunctionOptionsType(DataMember("mode", &CountOptions::mode)); -static auto kModeOptionsType = - GetFunctionOptionsType(DataMember("n", &ModeOptions::n)); +static auto kModeOptionsType = GetFunctionOptionsType( + DataMember("n", &ModeOptions::n), DataMember("skip_nulls", &ModeOptions::skip_nulls), + DataMember("min_count", &ModeOptions::min_count)); static auto kVarianceOptionsType = GetFunctionOptionsType( DataMember("ddof", &VarianceOptions::ddof), DataMember("skip_nulls", &VarianceOptions::skip_nulls), DataMember("min_count", &VarianceOptions::min_count)); static auto kQuantileOptionsType = GetFunctionOptionsType( DataMember("q", &QuantileOptions::q), - DataMember("interpolation", &QuantileOptions::interpolation)); + DataMember("interpolation", &QuantileOptions::interpolation), + DataMember("skip_nulls", &QuantileOptions::skip_nulls), + DataMember("min_count", &QuantileOptions::min_count)); static auto kTDigestOptionsType = GetFunctionOptionsType( DataMember("q", &TDigestOptions::q), DataMember("delta", &TDigestOptions::delta), - DataMember("buffer_size", &TDigestOptions::buffer_size)); + DataMember("buffer_size", &TDigestOptions::buffer_size), + DataMember("skip_nulls", &TDigestOptions::skip_nulls), + DataMember("min_count", &TDigestOptions::min_count)); static auto kIndexOptionsType = GetFunctionOptionsType(DataMember("value", &IndexOptions::value)); } // namespace @@ -112,7 +117,11 @@ CountOptions::CountOptions(CountMode mode) : FunctionOptions(internal::kCountOptionsType), mode(mode) {} constexpr char CountOptions::kTypeName[]; -ModeOptions::ModeOptions(int64_t n) : FunctionOptions(internal::kModeOptionsType), n(n) {} +ModeOptions::ModeOptions(int64_t n, bool skip_nulls, uint32_t min_count) + : FunctionOptions(internal::kModeOptionsType), + n{n}, + skip_nulls{skip_nulls}, + min_count{min_count} {} constexpr char ModeOptions::kTypeName[]; VarianceOptions::VarianceOptions(int ddof, bool skip_nulls, uint32_t min_count) @@ -122,27 +131,38 @@ VarianceOptions::VarianceOptions(int ddof, bool skip_nulls, uint32_t min_count) min_count(min_count) {} constexpr char VarianceOptions::kTypeName[]; -QuantileOptions::QuantileOptions(double q, enum Interpolation interpolation) +QuantileOptions::QuantileOptions(double q, enum Interpolation interpolation, + bool skip_nulls, uint32_t min_count) : FunctionOptions(internal::kQuantileOptionsType), q{q}, - interpolation{interpolation} {} -QuantileOptions::QuantileOptions(std::vector q, enum Interpolation interpolation) + interpolation{interpolation}, + skip_nulls{skip_nulls}, + min_count{min_count} {} +QuantileOptions::QuantileOptions(std::vector q, enum Interpolation interpolation, + bool skip_nulls, uint32_t min_count) : FunctionOptions(internal::kQuantileOptionsType), q{std::move(q)}, - interpolation{interpolation} {} + interpolation{interpolation}, + skip_nulls{skip_nulls}, + min_count{min_count} {} constexpr char QuantileOptions::kTypeName[]; -TDigestOptions::TDigestOptions(double q, uint32_t delta, uint32_t buffer_size) +TDigestOptions::TDigestOptions(double q, uint32_t delta, uint32_t buffer_size, + bool skip_nulls, uint32_t min_count) : FunctionOptions(internal::kTDigestOptionsType), q{q}, delta{delta}, - buffer_size{buffer_size} {} + buffer_size{buffer_size}, + skip_nulls{skip_nulls}, + min_count{min_count} {} TDigestOptions::TDigestOptions(std::vector q, uint32_t delta, - uint32_t buffer_size) + uint32_t buffer_size, bool skip_nulls, uint32_t min_count) : FunctionOptions(internal::kTDigestOptionsType), q{std::move(q)}, delta{delta}, - buffer_size{buffer_size} {} + buffer_size{buffer_size}, + skip_nulls{skip_nulls}, + min_count{min_count} {} constexpr char TDigestOptions::kTypeName[]; IndexOptions::IndexOptions(std::shared_ptr value) diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 8c27da49765..c8df81773d4 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -82,11 +82,16 @@ class ARROW_EXPORT CountOptions : public FunctionOptions { /// By default, returns the most common value and count. class ARROW_EXPORT ModeOptions : public FunctionOptions { public: - explicit ModeOptions(int64_t n = 1); + explicit ModeOptions(int64_t n = 1, bool skip_nulls = true, uint32_t min_count = 0); constexpr static char const kTypeName[] = "ModeOptions"; static ModeOptions Defaults() { return ModeOptions{}; } int64_t n = 1; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; }; /// \brief Control Delta Degrees of Freedom (ddof) of Variance and Stddev kernel @@ -121,10 +126,12 @@ class ARROW_EXPORT QuantileOptions : public FunctionOptions { MIDPOINT, }; - explicit QuantileOptions(double q = 0.5, enum Interpolation interpolation = LINEAR); + explicit QuantileOptions(double q = 0.5, enum Interpolation interpolation = LINEAR, + bool skip_nulls = true, uint32_t min_count = 0); explicit QuantileOptions(std::vector q, - enum Interpolation interpolation = LINEAR); + enum Interpolation interpolation = LINEAR, + bool skip_nulls = true, uint32_t min_count = 0); constexpr static char const kTypeName[] = "QuantileOptions"; static QuantileOptions Defaults() { return QuantileOptions{}; } @@ -132,6 +139,11 @@ class ARROW_EXPORT QuantileOptions : public FunctionOptions { /// quantile must be between 0 and 1 inclusive std::vector q; enum Interpolation interpolation; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; }; /// \brief Control TDigest approximate quantile kernel behavior @@ -140,9 +152,11 @@ class ARROW_EXPORT QuantileOptions : public FunctionOptions { class ARROW_EXPORT TDigestOptions : public FunctionOptions { public: explicit TDigestOptions(double q = 0.5, uint32_t delta = 100, - uint32_t buffer_size = 500); + uint32_t buffer_size = 500, bool skip_nulls = true, + uint32_t min_count = 0); explicit TDigestOptions(std::vector q, uint32_t delta = 100, - uint32_t buffer_size = 500); + uint32_t buffer_size = 500, bool skip_nulls = true, + uint32_t min_count = 0); constexpr static char const kTypeName[] = "TDigestOptions"; static TDigestOptions Defaults() { return TDigestOptions{}; } @@ -152,6 +166,11 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions { uint32_t delta; /// input buffer size, default 500 uint32_t buffer_size; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; }; /// \brief Control Index kernel behavior diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index b7287129cbc..83aaee5f0fe 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -53,6 +53,7 @@ struct EnumTraits return ""; } }; + template <> struct EnumTraits : BasicEnumTraits return ""; } }; + template <> struct EnumTraits : BasicEnumTraits< @@ -98,6 +100,80 @@ struct EnumTraits return ""; } }; +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "AssumeTimezoneOptions::Ambiguous"; } + static std::string value_name(compute::AssumeTimezoneOptions::Ambiguous value) { + switch (value) { + case compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_RAISE: + return "AMBIGUOUS_RAISE"; + case compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_EARLIEST: + return "AMBIGUOUS_EARLIEST"; + case compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_LATEST: + return "AMBIGUOUS_LATEST"; + } + return ""; + } +}; +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "AssumeTimezoneOptions::Nonexistent"; } + static std::string value_name(compute::AssumeTimezoneOptions::Nonexistent value) { + switch (value) { + case compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_RAISE: + return "NONEXISTENT_RAISE"; + case compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_EARLIEST: + return "NONEXISTENT_EARLIEST"; + case compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_LATEST: + return "NONEXISTENT_LATEST"; + } + return ""; + } +}; + +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "compute::RoundMode"; } + static std::string value_name(compute::RoundMode value) { + switch (value) { + case compute::RoundMode::DOWN: + return "DOWN"; + case compute::RoundMode::UP: + return "UP"; + case compute::RoundMode::TOWARDS_ZERO: + return "TOWARDS_ZERO"; + case compute::RoundMode::TOWARDS_INFINITY: + return "TOWARDS_INFINITY"; + case compute::RoundMode::HALF_DOWN: + return "HALF_DOWN"; + case compute::RoundMode::HALF_UP: + return "HALF_UP"; + case compute::RoundMode::HALF_TOWARDS_ZERO: + return "HALF_TOWARDS_ZERO"; + case compute::RoundMode::HALF_TOWARDS_INFINITY: + return "HALF_TOWARDS_INFINITY"; + case compute::RoundMode::HALF_TO_EVEN: + return "HALF_TO_EVEN"; + case compute::RoundMode::HALF_TO_ODD: + return "HALF_TO_ODD"; + } + return ""; + } +}; } // namespace internal namespace compute { @@ -115,6 +191,12 @@ static auto kArithmeticOptionsType = GetFunctionOptionsType( static auto kElementWiseAggregateOptionsType = GetFunctionOptionsType( DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls)); +static auto kRoundOptionsType = GetFunctionOptionsType( + DataMember("ndigits", &RoundOptions::ndigits), + DataMember("round_mode", &RoundOptions::round_mode)); +static auto kRoundToMultipleOptionsType = GetFunctionOptionsType( + DataMember("multiple", &RoundToMultipleOptions::multiple), + DataMember("round_mode", &RoundToMultipleOptions::round_mode)); static auto kJoinOptionsType = GetFunctionOptionsType( DataMember("null_handling", &JoinOptions::null_handling), DataMember("null_replacement", &JoinOptions::null_replacement)); @@ -147,6 +229,10 @@ static auto kStrptimeOptionsType = GetFunctionOptionsType( DataMember("unit", &StrptimeOptions::unit)); static auto kStrftimeOptionsType = GetFunctionOptionsType( DataMember("format", &StrftimeOptions::format)); +static auto kAssumeTimezoneOptionsType = GetFunctionOptionsType( + DataMember("timezone", &AssumeTimezoneOptions::timezone), + DataMember("ambiguous", &AssumeTimezoneOptions::ambiguous), + DataMember("nonexistent", &AssumeTimezoneOptions::nonexistent)); static auto kPadOptionsType = GetFunctionOptionsType( DataMember("width", &PadOptions::width), DataMember("padding", &PadOptions::padding)); static auto kTrimOptionsType = GetFunctionOptionsType( @@ -175,6 +261,30 @@ ElementWiseAggregateOptions::ElementWiseAggregateOptions(bool skip_nulls) skip_nulls(skip_nulls) {} constexpr char ElementWiseAggregateOptions::kTypeName[]; +RoundOptions::RoundOptions(int64_t ndigits, RoundMode round_mode) + : FunctionOptions(internal::kRoundOptionsType), + ndigits(ndigits), + round_mode(round_mode) { + static_assert(RoundMode::HALF_DOWN > RoundMode::DOWN && + RoundMode::HALF_DOWN > RoundMode::UP && + RoundMode::HALF_DOWN > RoundMode::TOWARDS_ZERO && + RoundMode::HALF_DOWN > RoundMode::TOWARDS_INFINITY && + RoundMode::HALF_DOWN < RoundMode::HALF_UP && + RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_ZERO && + RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_INFINITY && + RoundMode::HALF_DOWN < RoundMode::HALF_TO_EVEN && + RoundMode::HALF_DOWN < RoundMode::HALF_TO_ODD, + "Invalid order of round modes. Modes prefixed with HALF need to be " + "enumerated last with HALF_DOWN being the first among them."); +} +constexpr char RoundOptions::kTypeName[]; + +RoundToMultipleOptions::RoundToMultipleOptions(double multiple, RoundMode round_mode) + : FunctionOptions(internal::kRoundToMultipleOptionsType), + multiple(multiple), + round_mode(round_mode) {} +constexpr char RoundToMultipleOptions::kTypeName[]; + JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement) : FunctionOptions(internal::kJoinOptionsType), null_handling(null_handling), @@ -250,6 +360,15 @@ StrftimeOptions::StrftimeOptions() : StrftimeOptions(kDefaultFormat) {} constexpr char StrftimeOptions::kTypeName[]; constexpr const char* StrftimeOptions::kDefaultFormat; +AssumeTimezoneOptions::AssumeTimezoneOptions(std::string timezone, Ambiguous ambiguous, + Nonexistent nonexistent) + : FunctionOptions(internal::kAssumeTimezoneOptionsType), + timezone(std::move(timezone)), + ambiguous(ambiguous), + nonexistent(nonexistent) {} +AssumeTimezoneOptions::AssumeTimezoneOptions() : AssumeTimezoneOptions("UTC") {} +constexpr char AssumeTimezoneOptions::kTypeName[]; + PadOptions::PadOptions(int64_t width, std::string padding) : FunctionOptions(internal::kPadOptionsType), width(width), @@ -301,6 +420,8 @@ namespace internal { void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kArithmeticOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kRoundOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kRoundToMultipleOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSplitOptionsType)); @@ -311,6 +432,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kSetLookupOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kStrptimeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kStrftimeOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kAssumeTimezoneOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPadOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kTrimOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSliceOptionsType)); @@ -353,6 +475,15 @@ SCALAR_ARITHMETIC_UNARY(Log10, "log10", "log10_checked") SCALAR_ARITHMETIC_UNARY(Log2, "log2", "log2_checked") SCALAR_ARITHMETIC_UNARY(Log1p, "log1p", "log1p_checked") +Result Round(const Datum& arg, RoundOptions options, ExecContext* ctx) { + return CallFunction("round", {arg}, &options, ctx); +} + +Result RoundToMultiple(const Datum& arg, RoundToMultipleOptions options, + ExecContext* ctx) { + return CallFunction("round_to_multiple", {arg}, &options, ctx); +} + #define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \ Result NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \ ExecContext* ctx) { \ @@ -512,6 +643,11 @@ Result DayOfWeek(const Datum& arg, DayOfWeekOptions options, ExecContext* return CallFunction("day_of_week", {arg}, &options, ctx); } +Result AssumeTimezone(const Datum& arg, AssumeTimezoneOptions options, + ExecContext* ctx) { + return CallFunction("assume_timezone", {arg}, &options, ctx); +} + Result Strftime(const Datum& arg, StrftimeOptions options, ExecContext* ctx) { return CallFunction("strftime", {arg}, &options, ctx); } diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 2cbc0fde2b2..9f9a2931398 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -29,6 +29,7 @@ #include "arrow/result.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" +#include "arrow/vendored/datetime.h" namespace arrow { namespace compute { @@ -49,10 +50,58 @@ class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { explicit ElementWiseAggregateOptions(bool skip_nulls = true); constexpr static char const kTypeName[] = "ElementWiseAggregateOptions"; static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; } - bool skip_nulls; }; +/// Rounding and tie-breaking modes for round compute functions. +/// Additional details and examples are provided in compute.rst. +enum class RoundMode : int8_t { + /// Round to nearest integer less than or equal in magnitude (aka "floor") + DOWN, + /// Round to nearest integer greater than or equal in magnitude (aka "ceil") + UP, + /// Get the integral part without fractional digits (aka "trunc") + TOWARDS_ZERO, + /// Round negative values with DOWN rule and positive values with UP rule + TOWARDS_INFINITY, + /// Round ties with DOWN rule + HALF_DOWN, + /// Round ties with UP rule + HALF_UP, + /// Round ties with TOWARDS_ZERO rule + HALF_TOWARDS_ZERO, + /// Round ties with TOWARDS_INFINITY rule + HALF_TOWARDS_INFINITY, + /// Round ties to nearest even integer + HALF_TO_EVEN, + /// Round ties to nearest odd integer + HALF_TO_ODD, +}; + +class ARROW_EXPORT RoundOptions : public FunctionOptions { + public: + explicit RoundOptions(int64_t ndigits = 0, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); + constexpr static char const kTypeName[] = "RoundOptions"; + static RoundOptions Defaults() { return RoundOptions(); } + /// Rounding precision (number of digits to round to) + int64_t ndigits; + /// Rounding and tie-breaking mode + RoundMode round_mode; +}; + +class ARROW_EXPORT RoundToMultipleOptions : public FunctionOptions { + public: + explicit RoundToMultipleOptions(double multiple = 1.0, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); + constexpr static char const kTypeName[] = "RoundToMultipleOptions"; + static RoundToMultipleOptions Defaults() { return RoundToMultipleOptions(); } + /// Rounding scale (multiple to round to) + double multiple; + /// Rounding and tie-breaking mode + RoundMode round_mode; +}; + /// Options for var_args_join. class ARROW_EXPORT JoinOptions : public FunctionOptions { public: @@ -185,7 +234,7 @@ class ARROW_EXPORT StrftimeOptions : public FunctionOptions { constexpr static char const kTypeName[] = "StrftimeOptions"; - constexpr static const char* kDefaultFormat = "%Y-%m-%dT%H:%M:%SZ"; + constexpr static const char* kDefaultFormat = "%Y-%m-%dT%H:%M:%S"; /// The desired format string. std::string format; @@ -278,6 +327,40 @@ struct ARROW_EXPORT DayOfWeekOptions : public FunctionOptions { uint32_t week_start; }; +/// Used to control timestamp timezone conversion and handling ambiguous/nonexistent +/// times. +struct ARROW_EXPORT AssumeTimezoneOptions : public FunctionOptions { + public: + /// \brief How to interpret ambiguous local times that can be interpreted as + /// multiple instants (normally two) due to DST shifts. + /// + /// AMBIGUOUS_EARLIEST emits the earliest instant amongst possible interpretations. + /// AMBIGUOUS_LATEST emits the latest instant amongst possible interpretations. + enum Ambiguous { AMBIGUOUS_RAISE, AMBIGUOUS_EARLIEST, AMBIGUOUS_LATEST }; + + /// \brief How to handle local times that do not exist due to DST shifts. + /// + /// NONEXISTENT_EARLIEST emits the instant "just before" the DST shift instant + /// in the given timestamp precision (for example, for a nanoseconds precision + /// timestamp, this is one nanosecond before the DST shift instant). + /// NONEXISTENT_LATEST emits the DST shift instant. + enum Nonexistent { NONEXISTENT_RAISE, NONEXISTENT_EARLIEST, NONEXISTENT_LATEST }; + + explicit AssumeTimezoneOptions(std::string timezone, + Ambiguous ambiguous = AMBIGUOUS_RAISE, + Nonexistent nonexistent = NONEXISTENT_RAISE); + AssumeTimezoneOptions(); + constexpr static char const kTypeName[] = "AssumeTimezoneOptions"; + + /// Timezone to convert timestamps from + std::string timezone; + + /// How to interpret ambiguous local times (due to DST shifts) + Ambiguous ambiguous; + /// How to interpret non-existent local times (due to DST shifts) + Nonexistent nonexistent; +}; + /// @} /// \brief Get the absolute value of a value. @@ -524,8 +607,9 @@ Result Logb(const Datum& arg, const Datum& base, ExecContext* ctx = NULLPTR); /// \brief Round to the nearest integer less than or equal in magnitude to the -/// argument. Array values can be of arbitrary length. If argument is null the -/// result will be null. +/// argument. +/// +/// If argument is null the result will be null. /// /// \param[in] arg the value to round /// \param[in] ctx the function execution context, optional @@ -534,8 +618,9 @@ ARROW_EXPORT Result Floor(const Datum& arg, ExecContext* ctx = NULLPTR); /// \brief Round to the nearest integer greater than or equal in magnitude to the -/// argument. Array values can be of arbitrary length. If argument is null the -/// result will be null. +/// argument. +/// +/// If argument is null the result will be null. /// /// \param[in] arg the value to round /// \param[in] ctx the function execution context, optional @@ -543,8 +628,9 @@ Result Floor(const Datum& arg, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result Ceil(const Datum& arg, ExecContext* ctx = NULLPTR); -/// \brief Get the integral part without fractional digits. Array values can be -/// of arbitrary length. If argument is null the result will be null. +/// \brief Get the integral part without fractional digits. +/// +/// If argument is null the result will be null. /// /// \param[in] arg the value to truncate /// \param[in] ctx the function execution context, optional @@ -583,10 +669,35 @@ Result MinElementWise( /// /// \param[in] arg the value to extract sign from /// \param[in] ctx the function execution context, optional -/// \return the elementwise sign function +/// \return the element-wise sign function ARROW_EXPORT Result Sign(const Datum& arg, ExecContext* ctx = NULLPTR); +/// \brief Round a value to a given precision. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value rounded +/// \param[in] options rounding options (rounding mode and number of digits), optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +ARROW_EXPORT +Result Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Round a value to a given multiple. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value to round +/// \param[in] options rounding options (rounding mode and multiple), optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +ARROW_EXPORT +Result RoundToMultiple( + const Datum& arg, RoundToMultipleOptions options = RoundToMultipleOptions::Defaults(), + ExecContext* ctx = NULLPTR); + /// \brief Compare a numeric array with a scalar. /// /// \param[in] left datum to compare, must be an Array @@ -1025,5 +1136,21 @@ ARROW_EXPORT Result Subsecond(const Datum& values, ExecContext* ctx = NUL ARROW_EXPORT Result Strftime(const Datum& values, StrftimeOptions options, ExecContext* ctx = NULLPTR); +/// \brief Converts timestamps from local timestamp without a timezone to a timestamp with +/// timezone, interpreting the local timestamp as being in the specified timezone for each +/// element of `values` +/// +/// \param[in] values input to convert +/// \param[in] options for setting source timezone, exception and ambiguous timestamp +/// handling. +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 6.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result AssumeTimezone(const Datum& values, + AssumeTimezoneOptions options, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index d4c4a915999..34ee0599c3d 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -111,6 +111,9 @@ static auto kSortOptionsType = GetFunctionOptionsType(DataMember("sort_keys", &SortOptions::sort_keys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot)); +static auto kSelectKOptionsType = GetFunctionOptionsType( + DataMember("k", &SelectKOptions::k), + DataMember("sort_keys", &SelectKOptions::sort_keys)); } // namespace } // namespace internal @@ -140,6 +143,12 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot) : FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {} constexpr char PartitionNthOptions::kTypeName[]; +SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys) + : FunctionOptions(internal::kSelectKOptionsType), + k(k), + sort_keys(std::move(sort_keys)) {} +constexpr char SelectKOptions::kTypeName[]; + namespace internal { void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType)); @@ -148,6 +157,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); } } // namespace internal @@ -162,6 +172,14 @@ Result> NthToIndices(const Array& values, int64_t n, return result.make_array(); } +Result> SelectKUnstable(const Datum& datum, + const SelectKOptions& options, + ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("select_k_unstable", {datum}, &options, ctx)); + return result.make_array(); +} + Result ReplaceWithMask(const Datum& values, const Datum& mask, const Datum& replacements, ExecContext* ctx) { return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 5dc68fc5c83..a1c6f7959e1 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -120,6 +120,43 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { std::vector sort_keys; }; +/// \brief SelectK options +class ARROW_EXPORT SelectKOptions : public FunctionOptions { + public: + explicit SelectKOptions(int64_t k = -1, std::vector sort_keys = {}); + + constexpr static char const kTypeName[] = "SelectKOptions"; + + static SelectKOptions Defaults() { return SelectKOptions{-1, {}}; } + + static SelectKOptions TopKDefault(int64_t k, std::vector key_names = {}) { + std::vector keys; + for (const auto& name : key_names) { + keys.emplace_back(SortKey(name, SortOrder::Descending)); + } + if (key_names.empty()) { + keys.emplace_back(SortKey("not-used", SortOrder::Descending)); + } + return SelectKOptions{k, keys}; + } + static SelectKOptions BottomKDefault(int64_t k, + std::vector key_names = {}) { + std::vector keys; + for (const auto& name : key_names) { + keys.emplace_back(SortKey(name, SortOrder::Ascending)); + } + if (key_names.empty()) { + keys.emplace_back(SortKey("not-used", SortOrder::Ascending)); + } + return SelectKOptions{k, keys}; + } + + /// The number of `k` elements to keep. + int64_t k; + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; +}; + /// \brief Partitioning options for NthToIndices class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { public: @@ -252,6 +289,24 @@ ARROW_EXPORT Result> NthToIndices(const Array& values, int64_t n, ExecContext* ctx = NULLPTR); +/// \brief Returns the indices that would select the first `k` elements of the array in +/// the specified order. +/// +// Perform an indirect sort of the datum, keeping only the first `k` elements. The output +// array will contain indices such that the item indicated by the k-th index will be in +// the position it would be if the datum were sorted by `options.sort_keys`. However, +// indices of null values will not be part of the output. The sort is not guaranteed to be +// stable. +/// +/// \param[in] datum datum to be partitioned +/// \param[in] options options +/// \param[in] ctx the function execution context, optional +/// \return a datum with the same schema as the input +ARROW_EXPORT +Result> SelectKUnstable(const Datum& datum, + const SelectKOptions& options, + ExecContext* ctx = NULLPTR); + /// \brief Returns the indices that would sort an array in the /// specified order. /// diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index f5c55afe0f5..63f3315f7e0 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -70,6 +70,8 @@ class TempVectorStack { top_ = 0; buffer_size_ = size; ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(size, pool)); + // Ensure later operations don't accidentally read uninitialized memory. + std::memset(buffer->mutable_data(), 0xFF, size); buffer_ = std::move(buffer); return Status::OK(); } diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index 6434d5090f6..f08b50699a5 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -227,7 +227,7 @@ class ARROW_EXPORT Function { virtual Result Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const; - /// \brief Returns a the default options for this function. + /// \brief Returns the default options for this function. /// /// Whatever option semantics a Function has, implementations must guarantee /// that default_options() is valid to pass to Execute as options. diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index d7ebdf3de1d..183167490b6 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -15,17 +15,18 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/compute/function.h" + +#include + #include #include #include -#include - #include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" -#include "arrow/compute/function.h" #include "arrow/compute/kernel.h" #include "arrow/datum.h" #include "arrow/status.h" @@ -57,6 +58,12 @@ TEST(FunctionOptions, Equality) { options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "null"))); options.emplace_back(new ArithmeticOptions()); options.emplace_back(new ArithmeticOptions(/*check_overflow=*/true)); + options.emplace_back(new RoundOptions()); + options.emplace_back( + new RoundOptions(/*ndigits=*/2, /*round_mode=*/RoundMode::TOWARDS_INFINITY)); + options.emplace_back(new RoundToMultipleOptions()); + options.emplace_back(new RoundToMultipleOptions( + /*multiple=*/100, /*round_mode=*/RoundMode::TOWARDS_INFINITY)); options.emplace_back(new ElementWiseAggregateOptions()); options.emplace_back(new ElementWiseAggregateOptions(/*skip_nulls=*/false)); options.emplace_back(new JoinOptions()); @@ -80,6 +87,11 @@ TEST(FunctionOptions, Equality) { options.emplace_back(new StrptimeOptions("%Y", TimeUnit::type::MILLI)); options.emplace_back(new StrptimeOptions("%Y", TimeUnit::type::NANO)); options.emplace_back(new StrftimeOptions("%Y-%m-%dT%H:%M:%SZ", "C")); +#ifndef _WIN32 + options.emplace_back(new AssumeTimezoneOptions( + "Europe/Amsterdam", AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_RAISE, + AssumeTimezoneOptions::Nonexistent::NONEXISTENT_RAISE)); +#endif options.emplace_back(new PadOptions(5, " ")); options.emplace_back(new PadOptions(10, "A")); options.emplace_back(new TrimOptions(" ")); @@ -110,6 +122,8 @@ TEST(FunctionOptions, Equality) { {SortKey("key", SortOrder::Descending), SortKey("value", SortOrder::Descending)})); options.emplace_back(new PartitionNthOptions(/*pivot=*/0)); options.emplace_back(new PartitionNthOptions(/*pivot=*/42)); + options.emplace_back(new SelectKOptions(0, {})); + options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}})); for (size_t i = 0; i < options.size(); i++) { const size_t prev_i = i == 0 ? options.size() - 1 : i - 1; diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 099bd95bbf2..01750d1f359 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -290,9 +290,11 @@ class ARROW_EXPORT OutputType { enum ResolveKind { FIXED, COMPUTED }; /// Type resolution function. Given input types and shapes, return output - /// type and shape. This function SHOULD _not_ be used to check for arity, - /// that is to be performed one or more layers above. May make use of kernel - /// state to know what type to output in some cases. + /// type and shape. This function MAY may use the kernel state to decide + /// the output type based on the functionoptions. + /// + /// This function SHOULD _not_ be used to check for arity, that is to be + /// performed one or more layers above. using Resolver = std::function(KernelContext*, const std::vector&)>; @@ -304,7 +306,8 @@ class ARROW_EXPORT OutputType { /// \brief Output the exact type and shape provided by a ValueDescr OutputType(ValueDescr descr); // NOLINT implicit construction - explicit OutputType(Resolver resolver) + /// \brief Output a computed type depending on actual input types + OutputType(Resolver resolver) // NOLINT implicit construction : kind_(COMPUTED), resolver_(std::move(resolver)) {} OutputType(const OutputType& other) { diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 4096e497c0a..ce7a85f1557 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -50,11 +50,13 @@ add_arrow_compute_test(vector_test vector_replace_test.cc vector_selection_test.cc vector_sort_test.cc + select_k_test.cc test_util.cc) add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_topk_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_replace_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_selection_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 2952eade96b..b19536d33ab 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -549,6 +549,9 @@ struct IndexInit { static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { + if (!args.options) { + return Status::Invalid("Must provide IndexOptions for index kernel"); + } IndexInit visitor(ctx, static_cast(*args.options), *args.inputs[0].type); return visitor.Create(); diff --git a/cpp/src/arrow/compute/kernels/aggregate_mode.cc b/cpp/src/arrow/compute/kernels/aggregate_mode.cc index 6ad0eeb6456..f225f6bf569 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_mode.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_mode.cc @@ -130,6 +130,13 @@ struct CountModer { Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // count values in all chunks, ignore nulls const Datum& datum = batch[0]; + + const ModeOptions& options = ModeState::Get(ctx); + if ((!options.skip_nulls && datum.null_count() > 0) || + (datum.length() - datum.null_count() < options.min_count)) { + return PrepareOutput(/*n=*/0, ctx, out).status(); + } + CountValues(this->counts.data(), datum, this->min); // generator to emit next value:count pair @@ -154,9 +161,16 @@ struct CountModer { template <> struct CountModer { Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const Datum& datum = batch[0]; + + const ModeOptions& options = ModeState::Get(ctx); + if ((!options.skip_nulls && datum.null_count() > 0) || + (datum.length() - datum.null_count() < options.min_count)) { + return PrepareOutput(/*n=*/0, ctx, out).status(); + } + int64_t counts[2]{}; - const Datum& datum = batch[0]; for (const auto& array : datum.chunks()) { if (array->length() > array->null_count()) { const int64_t true_count = @@ -167,7 +181,6 @@ struct CountModer { } } - const ModeOptions& options = ModeState::Get(ctx); const int64_t distinct_values = (counts[0] != 0) + (counts[1] != 0); const int64_t n = std::min(options.n, distinct_values); @@ -198,12 +211,19 @@ struct SortModer { using Allocator = arrow::stl::allocator; Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const Datum& datum = batch[0]; + const int64_t in_length = datum.length() - datum.null_count(); + + const ModeOptions& options = ModeState::Get(ctx); + if ((!options.skip_nulls && datum.null_count() > 0) || + (in_length < options.min_count)) { + return PrepareOutput(/*n=*/0, ctx, out).status(); + } + // copy all chunks to a buffer, ignore nulls and nans std::vector in_buffer(Allocator(ctx->memory_pool())); uint64_t nan_count = 0; - const Datum& datum = batch[0]; - const int64_t in_length = datum.length() - datum.null_count(); if (in_length > 0) { in_buffer.resize(in_length); CopyNonNullValues(datum, in_buffer.data()); @@ -305,6 +325,13 @@ struct Moder::value>> { template Status ScalarMode(KernelContext* ctx, const Scalar& scalar, Datum* out) { using CType = typename T::c_type; + + const ModeOptions& options = ModeState::Get(ctx); + if ((!options.skip_nulls && !scalar.is_valid) || + (static_cast(scalar.is_valid) < options.min_count)) { + return PrepareOutput(/*n=*/0, ctx, out).status(); + } + if (scalar.is_valid) { bool called = false; return Finalize(ctx, out, [&]() { diff --git a/cpp/src/arrow/compute/kernels/aggregate_quantile.cc b/cpp/src/arrow/compute/kernels/aggregate_quantile.cc index 7d2ffe0770c..bfd97f813e5 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_quantile.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_quantile.cc @@ -79,12 +79,18 @@ struct SortQuantiler { Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const QuantileOptions& options = QuantileState::Get(ctx); + const Datum& datum = batch[0]; // copy all chunks to a buffer, ignore nulls and nans std::vector in_buffer(Allocator(ctx->memory_pool())); + int64_t in_length = 0; + if ((!options.skip_nulls && datum.null_count() > 0) || + (datum.length() - datum.null_count() < options.min_count)) { + in_length = 0; + } else { + in_length = datum.length() - datum.null_count(); + } - const Datum& datum = batch[0]; - const int64_t in_length = datum.length() - datum.null_count(); if (in_length > 0) { in_buffer.resize(in_length); CopyNonNullValues(datum, in_buffer.data()); @@ -232,7 +238,11 @@ struct CountQuantiler { // count values in all chunks, ignore nulls const Datum& datum = batch[0]; - int64_t in_length = CountValues(this->counts.data(), datum, this->min); + int64_t in_length = 0; + if ((options.skip_nulls || (!options.skip_nulls && datum.null_count() == 0)) && + (datum.length() - datum.null_count() >= options.min_count)) { + in_length = CountValues(this->counts.data(), datum, this->min); + } // prepare out array int64_t out_length = options.q.size(); @@ -394,7 +404,7 @@ Status ScalarQuantile(KernelContext* ctx, const QuantileOptions& options, const Scalar& scalar, Datum* out) { using CType = typename T::c_type; ArrayData* output = out->mutable_array(); - if (!scalar.is_valid) { + if (!scalar.is_valid || options.min_count > 1) { output->length = 0; output->null_count = 0; return Status::OK(); diff --git a/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc index be8d66c4c24..3b616c664a9 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc @@ -37,14 +37,23 @@ struct TDigestImpl : public ScalarAggregator { using CType = typename ArrowType::c_type; explicit TDigestImpl(const TDigestOptions& options) - : q{options.q}, tdigest{options.delta, options.buffer_size} {} + : options{options}, + tdigest{options.delta, options.buffer_size}, + count{0}, + all_valid{true} {} Status Consume(KernelContext*, const ExecBatch& batch) override { + if (!this->all_valid) return Status::OK(); + if (!options.skip_nulls && batch[0].null_count() > 0) { + this->all_valid = false; + return Status::OK(); + } if (batch[0].is_array()) { const ArrayData& data = *batch[0].array(); const CType* values = data.GetValues(1); if (data.length > data.GetNullCount()) { + this->count += data.length - data.GetNullCount(); VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length, [&](int64_t pos, int64_t len) { for (int64_t i = 0; i < len; ++i) { @@ -55,6 +64,7 @@ struct TDigestImpl : public ScalarAggregator { } else { const CType value = UnboxScalar::Unbox(*batch[0].scalar()); if (batch[0].scalar()->is_valid) { + this->count += 1; for (int64_t i = 0; i < batch.length; i++) { this->tdigest.NanAdd(value); } @@ -64,13 +74,21 @@ struct TDigestImpl : public ScalarAggregator { } Status MergeFrom(KernelContext*, KernelState&& src) override { - auto& other = checked_cast(src); + const auto& other = checked_cast(src); + if (!this->all_valid || !other.all_valid) { + this->all_valid = false; + return Status::OK(); + } this->tdigest.Merge(other.tdigest); + this->count += other.count; return Status::OK(); } Status Finalize(KernelContext* ctx, Datum* out) override { - const int64_t out_length = this->tdigest.is_empty() ? 0 : this->q.size(); + const int64_t out_length = + (this->tdigest.is_empty() || !this->all_valid || this->count < options.min_count) + ? 0 + : options.q.size(); auto out_data = ArrayData::Make(float64(), out_length, 0); out_data->buffers.resize(2, nullptr); @@ -79,7 +97,7 @@ struct TDigestImpl : public ScalarAggregator { ctx->Allocate(out_length * sizeof(double))); double* out_buffer = out_data->template GetMutableValues(1); for (int64_t i = 0; i < out_length; ++i) { - out_buffer[i] = this->tdigest.Quantile(this->q[i]); + out_buffer[i] = this->tdigest.Quantile(this->options.q[i]); } } @@ -87,8 +105,10 @@ struct TDigestImpl : public ScalarAggregator { return Status::OK(); } - const std::vector q; + const TDigestOptions options; TDigest tdigest; + int64_t count; + bool all_valid; }; struct TDigestInitState { diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index eb73e703b6e..fcf48e25a92 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -1853,6 +1853,10 @@ TYPED_TEST(TestNumericIndexKernel, Basics) { this->AssertIndexIs(chunked_input2, value, 4); this->AssertIndexIs(chunked_input3, value, -1); this->AssertIndexIs(chunked_input4, value, 5); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Must provide IndexOptions"), + CallFunction("index", {ArrayFromJSON(this->type_singleton(), "[0]")})); } TYPED_TEST(TestNumericIndexKernel, Random) { constexpr auto kChunks = 4; @@ -1954,10 +1958,10 @@ class TestPrimitiveModeKernel : public ::testing::Test { using Traits = TypeTraits; using CType = typename ArrowType::c_type; - void AssertModesAre(const Datum& array, const int n, + void AssertModesAre(const Datum& array, const ModeOptions options, const std::vector& expected_modes, const std::vector& expected_counts) { - ASSERT_OK_AND_ASSIGN(Datum out, Mode(array, ModeOptions{n})); + ASSERT_OK_AND_ASSIGN(Datum out, Mode(array, options)); ValidateOutput(out); const StructArray out_array(out.array()); ASSERT_EQ(out_array.length(), expected_modes.size()); @@ -1978,11 +1982,18 @@ class TestPrimitiveModeKernel : public ::testing::Test { const std::vector& expected_modes, const std::vector& expected_counts) { auto array = ArrayFromJSON(type_singleton(), json); - AssertModesAre(array, n, expected_modes, expected_counts); + AssertModesAre(array, ModeOptions(n), expected_modes, expected_counts); + } + + void AssertModesAre(const std::string& json, const ModeOptions options, + const std::vector& expected_modes, + const std::vector& expected_counts) { + auto array = ArrayFromJSON(type_singleton(), json); + AssertModesAre(array, options, expected_modes, expected_counts); } void AssertModeIs(const Datum& array, CType expected_mode, int64_t expected_count) { - AssertModesAre(array, 1, {expected_mode}, {expected_count}); + AssertModesAre(array, ModeOptions(1), {expected_mode}, {expected_count}); } void AssertModeIs(const std::string& json, CType expected_mode, @@ -1997,8 +2008,8 @@ class TestPrimitiveModeKernel : public ::testing::Test { AssertModeIs(chunked, expected_mode, expected_count); } - void AssertModesEmpty(const Datum& array, int n) { - ASSERT_OK_AND_ASSIGN(Datum out, Mode(array, ModeOptions{n})); + void AssertModesEmpty(const Datum& array, ModeOptions options) { + ASSERT_OK_AND_ASSIGN(Datum out, Mode(array, options)); auto out_array = out.make_array(); ValidateOutput(*out_array); ASSERT_EQ(out.array()->length, 0); @@ -2006,12 +2017,17 @@ class TestPrimitiveModeKernel : public ::testing::Test { void AssertModesEmpty(const std::string& json, int n = 1) { auto array = ArrayFromJSON(type_singleton(), json); - AssertModesEmpty(array, n); + AssertModesEmpty(array, ModeOptions(n)); } void AssertModesEmpty(const std::vector& json, int n = 1) { auto chunked = ChunkedArrayFromJSON(type_singleton(), json); - AssertModesEmpty(chunked, n); + AssertModesEmpty(chunked, ModeOptions(n)); + } + + void AssertModesEmpty(const std::string& json, ModeOptions options) { + auto array = ArrayFromJSON(type_singleton(), json); + AssertModesEmpty(array, options); } std::shared_ptr type_singleton() { return Traits::type_singleton(); } @@ -2049,13 +2065,37 @@ TEST_F(TestBooleanModeKernel, Basics) { {true, false}, {3, 2}); this->AssertModesEmpty({"[null, null]", "[]", "[null]"}, 4); - auto ty = struct_({field("mode", boolean()), field("count", int64())}); - Datum mode_true = ArrayFromJSON(ty, "[[true, 1]]"); - Datum mode_false = ArrayFromJSON(ty, "[[false, 1]]"); - Datum mode_empty = ArrayFromJSON(ty, "[]"); - EXPECT_THAT(Mode(Datum(true)), ResultWith(mode_true)); - EXPECT_THAT(Mode(Datum(false)), ResultWith(mode_false)); - EXPECT_THAT(Mode(MakeNullScalar(boolean())), ResultWith(mode_empty)); + auto in_ty = boolean(); + this->AssertModesAre("[true, false, false, null]", ModeOptions(/*n=*/1), {false}, {2}); + this->AssertModesEmpty("[true, false, false, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false)); + this->AssertModesAre("[true, false, false, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3), + {false}, {2}); + this->AssertModesEmpty("[false, false, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3)); + this->AssertModesAre("[true, false, false]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3), + {false}, {2}); + this->AssertModesEmpty("[true, false, false, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3)); + this->AssertModesEmpty("[true, false]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3)); + this->AssertModesAre(ScalarFromJSON(in_ty, "true"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false), {true}, {1}); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "true"), + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "true"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2)); + + this->AssertModesAre(ScalarFromJSON(in_ty, "true"), ModeOptions(/*n=*/1), {true}, {1}); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), ModeOptions(/*n=*/1)); } TYPED_TEST_SUITE(TestIntegerModeKernel, IntegralArrowTypes); @@ -2077,10 +2117,35 @@ TYPED_TEST(TestIntegerModeKernel, Basics) { this->AssertModesEmpty("[null, null, null]", 10); auto in_ty = this->type_singleton(); - auto ty = struct_({field("mode", in_ty), field("count", int64())}); - EXPECT_THAT(Mode(*MakeScalar(in_ty, 5)), - ResultWith(Datum(ArrayFromJSON(ty, "[[5, 1]]")))); - EXPECT_THAT(Mode(MakeNullScalar(in_ty)), ResultWith(Datum(ArrayFromJSON(ty, "[]")))); + + this->AssertModesAre("[1, 2, 2, null]", ModeOptions(/*n=*/1), {2}, {2}); + this->AssertModesEmpty("[1, 2, 2, null]", ModeOptions(/*n=*/1, /*skip_nulls=*/false)); + this->AssertModesAre("[1, 2, 2, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3), {2}, + {2}); + this->AssertModesEmpty("[2, 2, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3)); + this->AssertModesAre( + "[1, 2, 2]", ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3), {2}, {2}); + this->AssertModesEmpty("[1, 2, 2, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3)); + this->AssertModesEmpty("[1, 2]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3)); + this->AssertModesAre(ScalarFromJSON(in_ty, "1"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false), {1}, {1}); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"), + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2)); + + this->AssertModesAre(ScalarFromJSON(in_ty, "5"), ModeOptions(/*n=*/1), {5}, {1}); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), ModeOptions(/*n=*/1)); } TYPED_TEST_SUITE(TestFloatingModeKernel, RealArrowTypes); @@ -2108,10 +2173,35 @@ TYPED_TEST(TestFloatingModeKernel, Floats) { this->AssertModesAre("[NaN, NaN, 1, null, 1, 2, 2]", 3, {1, 2, NAN}, {2, 2, 2}); auto in_ty = this->type_singleton(); - auto ty = struct_({field("mode", in_ty), field("count", int64())}); - EXPECT_THAT(Mode(*MakeScalar(in_ty, 5.0)), - ResultWith(Datum(ArrayFromJSON(ty, "[[5.0, 1]]")))); - EXPECT_THAT(Mode(MakeNullScalar(in_ty)), ResultWith(Datum(ArrayFromJSON(ty, "[]")))); + + this->AssertModesAre("[1, 2, 2, null]", ModeOptions(/*n=*/1), {2}, {2}); + this->AssertModesEmpty("[1, 2, 2, null]", ModeOptions(/*n=*/1, /*skip_nulls=*/false)); + this->AssertModesAre("[1, 2, 2, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3), {2}, + {2}); + this->AssertModesEmpty("[2, 2, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3)); + this->AssertModesAre( + "[1, 2, 2]", ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3), {2}, {2}); + this->AssertModesEmpty("[1, 2, 2, null]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3)); + this->AssertModesEmpty("[1, 2]", + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3)); + this->AssertModesAre(ScalarFromJSON(in_ty, "1"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false), {1}, {1}); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"), + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2)); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), + ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2)); + + this->AssertModesAre(ScalarFromJSON(in_ty, "5"), ModeOptions(/*n=*/1), {5}, {1}); + this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), ModeOptions(/*n=*/1)); } TEST_F(TestInt8ModeKernelValueRange, Basics) { @@ -2672,6 +2762,36 @@ TYPED_TEST(TestIntegerQuantileKernel, Basics) { this->AssertQuantilesEmpty({"[null, null]", "[]", "[null]"}, {0.3, 0.4}); auto ty = this->type_singleton(); + + QuantileOptions keep_nulls(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/false, + /*min_count=*/0); + QuantileOptions min_count(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/true, + /*min_count=*/3); + QuantileOptions keep_nulls_min_count(/*q=*/0.5, QuantileOptions::LINEAR, + /*skip_nulls=*/false, /*min_count=*/3); + auto not_empty = ResultWith(ArrayFromJSON(float64(), "[3.0]")); + auto empty = ResultWith(ArrayFromJSON(float64(), "[]")); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls), not_empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), min_count), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), min_count), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), min_count), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), min_count), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls_min_count), + not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls_min_count), + empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls_min_count), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls_min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls_min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls_min_count), empty); + for (const auto interpolation : this->interpolations_) { QuantileOptions options({0.0, 0.5, 1.0}, interpolation); auto expected_ty = (interpolation == QuantileOptions::LINEAR || @@ -2718,6 +2838,36 @@ TYPED_TEST(TestFloatingQuantileKernel, Floats) { this->AssertQuantilesEmpty({"[NaN, NaN]", "[]", "[null]"}, {0.3, 0.4}); auto ty = this->type_singleton(); + + QuantileOptions keep_nulls(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/false, + /*min_count=*/0); + QuantileOptions min_count(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/true, + /*min_count=*/3); + QuantileOptions keep_nulls_min_count(/*q=*/0.5, QuantileOptions::LINEAR, + /*skip_nulls=*/false, /*min_count=*/3); + auto not_empty = ResultWith(ArrayFromJSON(float64(), "[3.0]")); + auto empty = ResultWith(ArrayFromJSON(float64(), "[]")); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls), not_empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), min_count), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), min_count), not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), min_count), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), min_count), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls_min_count), + not_empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls_min_count), + empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls_min_count), empty); + EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls_min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls_min_count), empty); + EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls_min_count), empty); + for (const auto interpolation : this->interpolations_) { QuantileOptions options({0.0, 0.5, 1.0}, interpolation); auto expected_ty = (interpolation == QuantileOptions::LINEAR || @@ -3015,5 +3165,44 @@ TEST(TestTDigestKernel, Scalar) { } } +TEST(TestTDigestKernel, Options) { + auto ty = float64(); + TDigestOptions keep_nulls(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500, + /*skip_nulls=*/false, /*min_count=*/0); + TDigestOptions min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500, + /*skip_nulls=*/true, /*min_count=*/3); + TDigestOptions keep_nulls_min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500, + /*skip_nulls=*/false, /*min_count=*/3); + + EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0]"), keep_nulls), + ResultWith(ArrayFromJSON(ty, "[2.0]"))); + EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), keep_nulls), + ResultWith(ArrayFromJSON(ty, "[]"))); + EXPECT_THAT(TDigest(ScalarFromJSON(ty, "1.0"), keep_nulls), + ResultWith(ArrayFromJSON(ty, "[1.0]"))); + EXPECT_THAT(TDigest(ScalarFromJSON(ty, "null"), keep_nulls), + ResultWith(ArrayFromJSON(ty, "[]"))); + + EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), min_count), + ResultWith(ArrayFromJSON(ty, "[2.0]"))); + EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, null]"), min_count), + ResultWith(ArrayFromJSON(ty, "[]"))); + EXPECT_THAT(TDigest(ScalarFromJSON(ty, "1.0"), min_count), + ResultWith(ArrayFromJSON(ty, "[]"))); + EXPECT_THAT(TDigest(ScalarFromJSON(ty, "null"), min_count), + ResultWith(ArrayFromJSON(ty, "[]"))); + + EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0]"), keep_nulls_min_count), + ResultWith(ArrayFromJSON(ty, "[2.0]"))); + EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0]"), keep_nulls_min_count), + ResultWith(ArrayFromJSON(ty, "[]"))); + EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), keep_nulls_min_count), + ResultWith(ArrayFromJSON(ty, "[]"))); + EXPECT_THAT(TDigest(ScalarFromJSON(ty, "1.0"), keep_nulls_min_count), + ResultWith(ArrayFromJSON(ty, "[]"))); + EXPECT_THAT(TDigest(ScalarFromJSON(ty, "null"), keep_nulls_min_count), + ResultWith(ArrayFromJSON(ty, "[]"))); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index f8b90085010..f230ca7ff73 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -17,6 +17,7 @@ #include "arrow/compute/kernels/codegen_internal.h" +#include #include #include #include @@ -145,12 +146,6 @@ const std::vector& DecimalTypeIds() { return g_decimal_type_ids; } -const std::vector& AllTimeUnits() { - static std::vector units = {TimeUnit::SECOND, TimeUnit::MILLI, - TimeUnit::MICRO, TimeUnit::NANO}; - return units; -} - const std::vector>& NumericTypes() { std::call_once(codegen_static_initialized, InitStaticData); return g_numeric_types; @@ -341,6 +336,91 @@ std::shared_ptr CommonBinary(const std::vector& descrs) { return large_binary(); } +Status CastBinaryDecimalArgs(DecimalPromotion promotion, + std::vector* descrs) { + auto& left_type = (*descrs)[0].type; + auto& right_type = (*descrs)[1].type; + DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); + + // decimal + float = float + if (is_floating(left_type->id())) { + right_type = left_type; + return Status::OK(); + } else if (is_floating(right_type->id())) { + left_type = right_type; + return Status::OK(); + } + + // precision, scale of left and right args + int32_t p1, s1, p2, s2; + + // decimal + integer = decimal + if (is_decimal(left_type->id())) { + auto decimal = checked_cast(left_type.get()); + p1 = decimal->precision(); + s1 = decimal->scale(); + } else { + DCHECK(is_integer(left_type->id())); + ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id())); + s1 = 0; + } + if (is_decimal(right_type->id())) { + auto decimal = checked_cast(right_type.get()); + p2 = decimal->precision(); + s2 = decimal->scale(); + } else { + DCHECK(is_integer(right_type->id())); + ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id())); + s2 = 0; + } + if (s1 < 0 || s2 < 0) { + return Status::NotImplemented("Decimals with negative scales not supported"); + } + + // decimal128 + decimal256 = decimal256 + Type::type casted_type_id = Type::DECIMAL128; + if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { + casted_type_id = Type::DECIMAL256; + } + + // decimal promotion rules compatible with amazon redshift + // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html + int32_t left_scaleup, right_scaleup; + + switch (promotion) { + case DecimalPromotion::kAdd: { + left_scaleup = std::max(s1, s2) - s1; + right_scaleup = std::max(s1, s2) - s2; + break; + } + case DecimalPromotion::kMultiply: { + left_scaleup = right_scaleup = 0; + break; + } + case DecimalPromotion::kDivide: { + left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; + right_scaleup = 0; + break; + } + default: + DCHECK(false) << "Invalid DecimalPromotion value " << static_cast(promotion); + } + ARROW_ASSIGN_OR_RAISE( + left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); + ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, + s2 + right_scaleup)); + return Status::OK(); +} + +bool HasDecimal(const std::vector& descrs) { + for (const auto& descr : descrs) { + if (is_decimal(descr.type->id())) { + return true; + } + } + return false; +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 9c8b2cef198..98ca835a14c 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -94,8 +94,8 @@ struct OptionsWrapper : public KernelState { /// KernelContext and the FunctionOptions as argument template struct KernelStateFromFunctionOptions : public KernelState { - explicit KernelStateFromFunctionOptions(KernelContext* ctx, OptionsType state) - : state(StateType(ctx, std::move(state))) {} + explicit KernelStateFromFunctionOptions(KernelContext* ctx, OptionsType options) + : state(StateType(ctx, std::move(options))) {} static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { @@ -415,9 +415,6 @@ const std::vector>& IntTypes(); const std::vector>& FloatingPointTypes(); const std::vector& DecimalTypeIds(); -ARROW_EXPORT -const std::vector& AllTimeUnits(); - // Returns a vector of example instances of parametric types such as // // * Decimal @@ -1313,6 +1310,19 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs) ARROW_EXPORT std::shared_ptr CommonBinary(const std::vector& descrs); +/// How to promote decimal precision/scale in CastBinaryDecimalArgs. +enum class DecimalPromotion : uint8_t { + kAdd, + kMultiply, + kDivide, +}; + +ARROW_EXPORT +Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector* descrs); + +ARROW_EXPORT +bool HasDecimal(const std::vector& descrs); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 3ea692857cf..23bb73f2a7f 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -1593,6 +1593,8 @@ struct GroupedTDigestImpl : public GroupedAggregator { options_ = *checked_cast(options); ctx_ = ctx; pool_ = ctx->memory_pool(); + counts_ = TypedBufferBuilder(pool_); + no_nulls_ = TypedBufferBuilder(pool_); return Status::OK(); } @@ -1602,12 +1604,21 @@ struct GroupedTDigestImpl : public GroupedAggregator { for (int64_t i = 0; i < added_groups; i++) { tdigests_.emplace_back(options_.delta, options_.buffer_size); } + RETURN_NOT_OK(counts_.Append(new_num_groups, 0)); + RETURN_NOT_OK(no_nulls_.Append(new_num_groups, true)); return Status::OK(); } Status Consume(const ExecBatch& batch) override { - VisitGroupedValuesNonNull( - batch, [&](uint32_t g, CType value) { tdigests_[g].NanAdd(value); }); + int64_t* counts = counts_.mutable_data(); + uint8_t* no_nulls = no_nulls_.mutable_data(); + VisitGroupedValues( + batch, + [&](uint32_t g, CType value) { + tdigests_[g].NanAdd(value); + counts[g]++; + }, + [&](uint32_t g) { BitUtil::SetBitTo(no_nulls, g, false); }); return Status::OK(); } @@ -1615,15 +1626,26 @@ struct GroupedTDigestImpl : public GroupedAggregator { const ArrayData& group_id_mapping) override { auto other = checked_cast(&raw_other); + int64_t* counts = counts_.mutable_data(); + uint8_t* no_nulls = no_nulls_.mutable_data(); + + const int64_t* other_counts = other->counts_.data(); + const uint8_t* other_no_nulls = no_nulls_.mutable_data(); + auto g = group_id_mapping.GetValues(1); for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { tdigests_[*g].Merge(other->tdigests_[other_g]); + counts[*g] += other_counts[other_g]; + BitUtil::SetBitTo( + no_nulls, *g, + BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g)); } return Status::OK(); } Result Finalize() override { + const int64_t* counts = counts_.data(); std::shared_ptr null_bitmap; ARROW_ASSIGN_OR_RAISE( std::shared_ptr values, @@ -1633,7 +1655,7 @@ struct GroupedTDigestImpl : public GroupedAggregator { double* results = reinterpret_cast(values->mutable_data()); for (int64_t i = 0; static_cast(i) < tdigests_.size(); ++i) { - if (!tdigests_[i].is_empty()) { + if (!tdigests_[i].is_empty() && counts[i] >= options_.min_count) { for (int64_t j = 0; j < slot_length; j++) { results[i * slot_length + j] = tdigests_[i].Quantile(options_.q[j]); } @@ -1649,6 +1671,18 @@ struct GroupedTDigestImpl : public GroupedAggregator { std::fill(&results[i * slot_length], &results[(i + 1) * slot_length], 0.0); } + if (!options_.skip_nulls) { + null_count = kUnknownNullCount; + if (null_bitmap) { + arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0, + no_nulls_.data(), /*right_offset=*/0, + static_cast(tdigests_.size()), + /*out_offset=*/0, null_bitmap->mutable_data()); + } else { + ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish()); + } + } + auto child = ArrayData::Make(float64(), tdigests_.size() * options_.q.size(), {nullptr, std::move(values)}, /*null_count=*/0); return ArrayData::Make(out_type(), tdigests_.size(), {std::move(null_bitmap)}, @@ -1661,6 +1695,8 @@ struct GroupedTDigestImpl : public GroupedAggregator { TDigestOptions options_; std::vector tdigests_; + TypedBufferBuilder counts_; + TypedBufferBuilder no_nulls_; ExecContext* ctx_; MemoryPool* pool_; }; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 32e8efa0ab8..df13bd569ea 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -1086,27 +1086,40 @@ TEST(GroupBy, VarianceAndStddev) { TEST(GroupBy, TDigest) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([ - [1, 1], - [null, 1], - [0, 2], - [null, 3], - [4, null], - [3, 1], - [0, 2], - [-1, 2], - [1, null], - [NaN, 3] + [1, 1], + [null, 1], + [0, 2], + [null, 3], + [1, 4], + [4, null], + [3, 1], + [0, 2], + [-1, 2], + [1, null], + [NaN, 3], + [1, 4], + [1, 4], + [null, 4] ])"); TDigestOptions options1(std::vector{0.5, 0.9, 0.99}); TDigestOptions options2(std::vector{0.5, 0.9, 0.99}, /*delta=*/50, /*buffer_size=*/1024); + TDigestOptions keep_nulls(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500, + /*skip_nulls=*/false, /*min_count=*/0); + TDigestOptions min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500, + /*skip_nulls=*/true, /*min_count=*/3); + TDigestOptions keep_nulls_min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500, + /*skip_nulls=*/false, /*min_count=*/3); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, internal::GroupBy( { batch->GetColumnByName("argument"), batch->GetColumnByName("argument"), batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), }, { batch->GetColumnByName("key"), @@ -1115,6 +1128,9 @@ TEST(GroupBy, TDigest) { {"hash_tdigest", nullptr}, {"hash_tdigest", &options1}, {"hash_tdigest", &options2}, + {"hash_tdigest", &keep_nulls}, + {"hash_tdigest", &min_count}, + {"hash_tdigest", &keep_nulls_min_count}, })); AssertDatumsApproxEqual( @@ -1122,13 +1138,17 @@ TEST(GroupBy, TDigest) { field("hash_tdigest", fixed_size_list(float64(), 1)), field("hash_tdigest", fixed_size_list(float64(), 3)), field("hash_tdigest", fixed_size_list(float64(), 3)), + field("hash_tdigest", fixed_size_list(float64(), 1)), + field("hash_tdigest", fixed_size_list(float64(), 1)), + field("hash_tdigest", fixed_size_list(float64(), 1)), field("key_0", int64()), }), R"([ - [[1.0], [1.0, 3.0, 3.0], [1.0, 3.0, 3.0], 1], - [[0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], 2], - [null, null, null, 3], - [[1.0], [1.0, 4.0, 4.0], [1.0, 4.0, 4.0], null] + [[1.0], [1.0, 3.0, 3.0], [1.0, 3.0, 3.0], null, null, null, 1], + [[0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0], [0.0], [0.0], 2], + [null, null, null, null, null, null, 3], + [[1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], null, [1.0], null, 4], + [[1.0], [1.0, 4.0, 4.0], [1.0, 4.0, 4.0], [1.0], null, null, null] ])"), aggregated_and_grouped, /*verbose=*/true); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 7692f037124..4a686ea6db5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -19,8 +19,10 @@ #include #include #include +#include -#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compare.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/type.h" @@ -78,24 +80,27 @@ using enable_if_decimal_value = struct AbsoluteValue { template - static constexpr enable_if_floating_point Call(KernelContext*, T arg, Status*) { + static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, + Status*) { return std::fabs(arg); } template - static constexpr enable_if_unsigned_c_integer Call(KernelContext*, T arg, Status*) { + static constexpr enable_if_unsigned_c_integer Call(KernelContext*, Arg arg, + Status*) { return arg; } template - static constexpr enable_if_signed_c_integer Call(KernelContext*, T arg, Status* st) { + static constexpr enable_if_signed_c_integer Call(KernelContext*, Arg arg, + Status* st) { return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg; } }; struct AbsoluteValueChecked { template - static enable_if_signed_c_integer Call(KernelContext*, Arg arg, Status* st) { + static enable_if_signed_c_integer Call(KernelContext*, Arg arg, Status* st) { static_assert(std::is_same::value, ""); if (arg == std::numeric_limits::min()) { *st = Status::Invalid("overflow"); @@ -105,13 +110,15 @@ struct AbsoluteValueChecked { } template - static enable_if_unsigned_c_integer Call(KernelContext* ctx, Arg arg, Status* st) { + static enable_if_unsigned_c_integer Call(KernelContext* ctx, Arg arg, + Status* st) { static_assert(std::is_same::value, ""); return arg; } template - static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status* st) { + static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, + Status* st) { static_assert(std::is_same::value, ""); return std::fabs(arg); } @@ -378,7 +385,7 @@ struct Negate { struct NegateChecked { template - static enable_if_signed_c_integer Call(KernelContext*, Arg arg, Status* st) { + static enable_if_signed_c_integer Call(KernelContext*, Arg arg, Status* st) { static_assert(std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(NegateWithOverflow(arg, &result))) { @@ -388,7 +395,8 @@ struct NegateChecked { } template - static enable_if_unsigned_c_integer Call(KernelContext* ctx, Arg arg, Status* st) { + static enable_if_unsigned_c_integer Call(KernelContext* ctx, Arg arg, + Status* st) { static_assert(std::is_same::value, ""); DCHECK(false) << "This is included only for the purposes of instantiability from the " "arithmetic kernel generator"; @@ -396,7 +404,8 @@ struct NegateChecked { } template - static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status* st) { + static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, + Status* st) { static_assert(std::is_same::value, ""); return -arg; } @@ -466,18 +475,20 @@ struct PowerChecked { struct Sign { template - static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status*) { + static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, + Status*) { return std::isnan(arg) ? arg : ((arg == 0) ? 0 : (std::signbit(arg) ? -1 : 1)); } template - static constexpr enable_if_unsigned_c_integer Call(KernelContext*, Arg arg, - Status*) { - return arg > 0; + static constexpr enable_if_unsigned_c_integer Call(KernelContext*, Arg arg, + Status*) { + return (arg > 0) ? 1 : 0; } template - static constexpr enable_if_signed_c_integer Call(KernelContext*, Arg arg, Status*) { + static constexpr enable_if_signed_c_integer Call(KernelContext*, Arg arg, + Status*) { return (arg > 0) ? 1 : ((arg == 0) ? 0 : -1); } }; @@ -852,24 +863,242 @@ struct LogbChecked { } }; +struct RoundUtil { + // Calculate powers of ten with arbitrary integer exponent + template + static enable_if_floating_point Pow10(int64_t power) { + static constexpr T lut[] = {1e0F, 1e1F, 1e2F, 1e3F, 1e4F, 1e5F, 1e6F, 1e7F, + 1e8F, 1e9F, 1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F}; + int64_t lut_size = (sizeof(lut) / sizeof(*lut)); + int64_t abs_power = std::abs(power); + auto pow10 = lut[std::min(abs_power, lut_size - 1)]; + while (abs_power-- >= lut_size) { + pow10 *= 1e1F; + } + return (power >= 0) ? pow10 : (1 / pow10); + } +}; + +// Specializations of rounding implementations for round kernels +template +struct RoundImpl; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::floor(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::ceil(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::trunc(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::signbit(val) ? std::floor(val) : std::ceil(val); + } +}; + +// NOTE: RoundImpl variants for the HALF_* rounding modes are only +// invoked when the fractional part is equal to 0.5 (std::round is invoked +// otherwise). + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::round(val * T(0.5)) * 2; + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::floor(val * T(0.5)) + std::ceil(val * T(0.5)); + } +}; + +// Specializations of kernel state for round kernels +template +struct RoundOptionsWrapper; + +template <> +struct RoundOptionsWrapper : public OptionsWrapper { + using OptionsType = RoundOptions; + using State = RoundOptionsWrapper; + double pow10; + + explicit RoundOptionsWrapper(OptionsType options) : OptionsWrapper(std::move(options)) { + // Only positive exponents for powers of 10 are used because combining + // multiply and division operations produced more stable rounding than + // using multiply-only. Refer to NumPy's round implementation: + // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589 + pow10 = RoundUtil::Pow10(std::abs(options.ndigits)); + } + + static Result> Init(KernelContext* ctx, + const KernelInitArgs& args) { + if (auto options = static_cast(args.options)) { + return ::arrow::internal::make_unique(*options); + } + return Status::Invalid( + "Attempted to initialize KernelState from null FunctionOptions"); + } +}; + +template <> +struct RoundOptionsWrapper + : public OptionsWrapper { + using OptionsType = RoundToMultipleOptions; + + static Result> Init(KernelContext* ctx, + const KernelInitArgs& args) { + ARROW_ASSIGN_OR_RAISE(auto state, OptionsWrapper::Init(ctx, args)); + auto options = Get(*state); + if (options.multiple <= 0) { + return Status::Invalid("Rounding multiple has to be a positive value"); + } + return std::move(state); + } +}; + +template +struct Round { + using State = RoundOptionsWrapper; + + template + static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) { + static_assert(std::is_same::value, ""); + // Do not process Inf or NaN because they will trigger the overflow error at end of + // function. + if (!std::isfinite(arg)) { + return arg; + } + auto state = static_cast(ctx->state()); + auto options = state->options; + auto pow10 = T(state->pow10); + auto round_val = (options.ndigits >= 0) ? (arg * pow10) : (arg / pow10); + auto frac = round_val - std::floor(round_val); + if (frac != T(0)) { + // Use std::round() if in tie-breaking mode and scaled value is not 0.5. + if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { + round_val = std::round(round_val); + } else { + round_val = RoundImpl::Round(round_val); + } + // Equality check is ommitted so that the common case of 10^0 (integer rounding) + // uses multiply-only + round_val = (options.ndigits > 0) ? (round_val / pow10) : (round_val * pow10); + if (!std::isfinite(round_val)) { + *st = Status::Invalid("overflow occurred during rounding"); + return arg; + } + } else { + // If scaled value is an integer, then no rounding is needed. + round_val = arg; + } + return round_val; + } +}; + +template +struct RoundToMultiple { + using State = RoundOptionsWrapper; + + template + static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) { + static_assert(std::is_same::value, ""); + // Do not process Inf or NaN because they will trigger the overflow error at end of + // function. + if (!std::isfinite(arg)) { + return arg; + } + auto options = State::Get(ctx); + auto round_val = arg / T(options.multiple); + auto frac = round_val - std::floor(round_val); + if (frac != T(0)) { + // Use std::round() if in tie-breaking mode and scaled value is not 0.5. + if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { + round_val = std::round(round_val); + } else { + round_val = RoundImpl::Round(round_val); + } + round_val *= T(options.multiple); + if (!std::isfinite(round_val)) { + *st = Status::Invalid("overflow occurred during rounding"); + return arg; + } + } else { + // If scaled value is an integer, then no rounding is needed. + round_val = arg; + } + return round_val; + } +}; + struct Floor { template - static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status*) { - return std::floor(arg); + static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, + Status*) { + static_assert(std::is_same::value, ""); + return RoundImpl::Round(arg); } }; struct Ceil { template - static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status*) { - return std::ceil(arg); + static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, + Status*) { + static_assert(std::is_same::value, ""); + return RoundImpl::Round(arg); } }; struct Trunc { template - static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status*) { - return std::trunc(arg); + static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, + Status*) { + static_assert(std::is_same::value, ""); + return RoundImpl::Round(arg); } }; @@ -965,78 +1194,6 @@ ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) { } } -Status CastBinaryDecimalArgs(const std::string& func_name, - std::vector* values) { - auto& left_type = (*values)[0].type; - auto& right_type = (*values)[1].type; - DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); - - // decimal + float = float - if (is_floating(left_type->id())) { - right_type = left_type; - return Status::OK(); - } else if (is_floating(right_type->id())) { - left_type = right_type; - return Status::OK(); - } - - // precision, scale of left and right args - int32_t p1, s1, p2, s2; - - // decimal + integer = decimal - if (is_decimal(left_type->id())) { - auto decimal = checked_cast(left_type.get()); - p1 = decimal->precision(); - s1 = decimal->scale(); - } else { - DCHECK(is_integer(left_type->id())); - p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id())))); - s1 = 0; - } - if (is_decimal(right_type->id())) { - auto decimal = checked_cast(right_type.get()); - p2 = decimal->precision(); - s2 = decimal->scale(); - } else { - DCHECK(is_integer(right_type->id())); - p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id())))); - s2 = 0; - } - if (s1 < 0 || s2 < 0) { - return Status::NotImplemented("Decimals with negative scales not supported"); - } - - // decimal128 + decimal256 = decimal256 - Type::type casted_type_id = Type::DECIMAL128; - if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { - casted_type_id = Type::DECIMAL256; - } - - // decimal promotion rules compatible with amazon redshift - // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html - int32_t left_scaleup, right_scaleup; - - // "add_checked" -> "add" - const std::string op = func_name.substr(0, func_name.find("_")); - if (op == "add" || op == "subtract") { - left_scaleup = std::max(s1, s2) - s1; - right_scaleup = std::max(s1, s2) - s2; - } else if (op == "multiply") { - left_scaleup = right_scaleup = 0; - } else if (op == "divide") { - left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; - right_scaleup = 0; - } else { - return Status::Invalid("Invalid decimal function: ", func_name); - } - - ARROW_ASSIGN_OR_RAISE( - left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); - ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, - s2 + right_scaleup)); - return Status::OK(); -} - // resolve decimal binary operation output type per *casted* args template Result ResolveDecimalBinaryOperationOutput( @@ -1166,17 +1323,21 @@ struct ArithmeticFunction : ScalarFunction { } Status CheckDecimals(std::vector* values) const { - bool has_decimal = false; - for (const auto& value : *values) { - if (is_decimal(value.type->id())) { - has_decimal = true; - break; - } - } - if (!has_decimal) return Status::OK(); + if (!HasDecimal(*values)) return Status::OK(); if (values->size() == 2) { - return CastBinaryDecimalArgs(name(), values); + // "add_checked" -> "add" + const auto func_name = name(); + const std::string op = func_name.substr(0, func_name.find("_")); + if (op == "add" || op == "subtract") { + return CastBinaryDecimalArgs(DecimalPromotion::kAdd, values); + } else if (op == "multiply") { + return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, values); + } else if (op == "divide") { + return CastBinaryDecimalArgs(DecimalPromotion::kDivide, values); + } else { + return Status::Invalid("Invalid decimal function: ", func_name); + } } return Status::OK(); } @@ -1276,6 +1437,65 @@ std::shared_ptr MakeUnaryArithmeticFunctionNotNull( return func; } +// Generate a kernel given an arithmetic rounding functor +template