diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index 6f782c22063..37aee196883 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -92,21 +92,20 @@ jobs: continue-on-error: true run: archery docker push ubuntu-r - rstudio: - name: "rstudio/r-base:${{ matrix.r_version }}-${{ matrix.r_image }}" + bundled: + name: "${{ matrix.config.org }}/${{ matrix.config.image }}:${{ matrix.config.tag }}" runs-on: ubuntu-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} strategy: fail-fast: false matrix: - # See https://hub.docker.com/r/rstudio/r-base - r_version: ["4.0"] - r_image: - - centos7 + config: + - {org: 'rstudio', image: 'r-base', tag: '4.0-centos7'} + - {org: 'rhub', image: 'debian-gcc-devel', tag: 'latest'} env: - R_ORG: rstudio - R_IMAGE: r-base - R_TAG: ${{ matrix.r_version }}-${{ matrix.r_image }} + R_ORG: ${{ matrix.config.org }} + R_IMAGE: ${{ matrix.config.image }} + R_TAG: ${{ matrix.config.tag }} steps: - name: Checkout Arrow uses: actions/checkout@v2 @@ -120,8 +119,8 @@ jobs: uses: actions/cache@v1 with: path: .docker - key: ${{ matrix.r_image }}-r-${{ hashFiles('cpp/**') }} - restore-keys: ${{ matrix.r_image }}-r- + key: ${{ matrix.config.image }}-r-${{ hashFiles('cpp/**') }} + restore-keys: ${{ matrix.config.image }}-r- - name: Setup Python uses: actions/setup-python@v1 with: diff --git a/cpp/cmake_modules/SetupCxxFlags.cmake b/cpp/cmake_modules/SetupCxxFlags.cmake index 1606199c406..f812c96c2ad 100644 --- a/cpp/cmake_modules/SetupCxxFlags.cmake +++ b/cpp/cmake_modules/SetupCxxFlags.cmake @@ -18,6 +18,7 @@ # Check if the target architecture and compiler supports some special # instruction sets that would boost performance. include(CheckCXXCompilerFlag) +include(CheckCXXSourceCompiles) # Get cpu architecture message(STATUS "System processor: ${CMAKE_SYSTEM_PROCESSOR}") @@ -60,17 +61,36 @@ if(ARROW_CPU_FLAG STREQUAL "x86") # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65782 message(STATUS "Disable AVX512 support on MINGW for now") else() - check_cxx_compiler_flag(${ARROW_AVX512_FLAG} CXX_SUPPORTS_AVX512) + # Check for AVX512 support in the compiler. + set(OLD_CMAKE_REQURED_FLAGS ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${ARROW_AVX512_FLAG}") + check_cxx_source_compiles(" + #ifdef _MSC_VER + #include + #else + #include + #endif + + int main() { + __m512i mask = _mm512_set1_epi32(0x1); + char out[32]; + _mm512_storeu_si512(out, mask); + return 0; + }" CXX_SUPPORTS_AVX512) + set(CMAKE_REQUIRED_FLAGS ${OLD_CMAKE_REQURED_FLAGS}) endif() - # Runtime SIMD level it can get from compiler + # Runtime SIMD level it can get from compiler and ARROW_RUNTIME_SIMD_LEVEL if(CXX_SUPPORTS_SSE4_2 AND ARROW_RUNTIME_SIMD_LEVEL MATCHES "^(SSE4_2|AVX2|AVX512|MAX)$") + set(ARROW_HAVE_RUNTIME_SSE4_2 ON) add_definitions(-DARROW_HAVE_RUNTIME_SSE4_2) endif() if(CXX_SUPPORTS_AVX2 AND ARROW_RUNTIME_SIMD_LEVEL MATCHES "^(AVX2|AVX512|MAX)$") + set(ARROW_HAVE_RUNTIME_AVX2 ON) add_definitions(-DARROW_HAVE_RUNTIME_AVX2 -DARROW_HAVE_RUNTIME_BMI2) endif() if(CXX_SUPPORTS_AVX512 AND ARROW_RUNTIME_SIMD_LEVEL MATCHES "^(AVX512|MAX)$") + set(ARROW_HAVE_RUNTIME_AVX512 ON) add_definitions(-DARROW_HAVE_RUNTIME_AVX512 -DARROW_HAVE_RUNTIME_BMI2) endif() elseif(ARROW_CPU_FLAG STREQUAL "ppc") diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index bbeed8df292..dd17720595a 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -214,13 +214,13 @@ set(ARROW_SRCS vendored/double-conversion/diy-fp.cc vendored/double-conversion/strtod.cc) -if(CXX_SUPPORTS_AVX2) +if(ARROW_HAVE_RUNTIME_AVX2) list(APPEND ARROW_SRCS util/bpacking_avx2.cc) set_source_files_properties(util/bpacking_avx2.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON) set_source_files_properties(util/bpacking_avx2.cc PROPERTIES COMPILE_FLAGS ${ARROW_AVX2_FLAG}) endif() -if(CXX_SUPPORTS_AVX512) +if(ARROW_HAVE_RUNTIME_AVX512) list(APPEND ARROW_SRCS util/bpacking_avx512.cc) set_source_files_properties(util/bpacking_avx512.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON) @@ -387,14 +387,14 @@ if(ARROW_COMPUTE) compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc) - if(CXX_SUPPORTS_AVX2) + if(ARROW_HAVE_RUNTIME_AVX2) list(APPEND ARROW_SRCS compute/kernels/aggregate_basic_avx2.cc) set_source_files_properties(compute/kernels/aggregate_basic_avx2.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON) set_source_files_properties(compute/kernels/aggregate_basic_avx2.cc PROPERTIES COMPILE_FLAGS ${ARROW_AVX2_FLAG}) endif() - if(CXX_SUPPORTS_AVX512) + if(ARROW_HAVE_RUNTIME_AVX512) list(APPEND ARROW_SRCS compute/kernels/aggregate_basic_avx512.cc) set_source_files_properties(compute/kernels/aggregate_basic_avx512.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON) diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 39b3f8827fb..6d97e79a23f 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -953,14 +953,13 @@ class TestPrimitiveVarStdKernel : public ::testing::Test { using ScalarType = typename TypeTraits::ScalarType; void AssertVarStdIs(const Array& array, const VarianceOptions& options, - double expected_var, double diff = 0) { - AssertVarStdIsInternal(array, options, expected_var, diff); + double expected_var) { + AssertVarStdIsInternal(array, options, expected_var); } void AssertVarStdIs(const std::shared_ptr& array, - const VarianceOptions& options, double expected_var, - double diff = 0) { - AssertVarStdIsInternal(array, options, expected_var, diff); + const VarianceOptions& options, double expected_var) { + AssertVarStdIsInternal(array, options, expected_var); } void AssertVarStdIs(const std::string& json, const VarianceOptions& options, @@ -999,18 +998,14 @@ class TestPrimitiveVarStdKernel : public ::testing::Test { private: void AssertVarStdIsInternal(const Datum& array, const VarianceOptions& options, - double expected_var, double diff = 0) { + double expected_var) { ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options)); ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options)); auto var = checked_cast(out_var.scalar().get()); auto std = checked_cast(out_std.scalar().get()); ASSERT_TRUE(var->is_valid && std->is_valid); ASSERT_DOUBLE_EQ(std->value * std->value, var->value); - if (diff == 0) { - ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP - } else { - ASSERT_NEAR(var->value, expected_var, diff); - } + ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP } void AssertVarStdIsInvalidInternal(const Datum& array, const VarianceOptions& options) { @@ -1070,22 +1065,39 @@ TEST_F(TestVarStdKernelStability, Basics) { VarianceOptions options{1}; // ddof = 1 this->AssertVarStdIs("[100000004, 100000007, 100000013, 100000016]", options, 30.0); this->AssertVarStdIs("[1000000004, 1000000007, 1000000013, 1000000016]", options, 30.0); + +#ifndef __MINGW32__ // MinGW has precision issues + // This test is to make sure our variance combining method is stable. + // XXX: The reference value from numpy is actually wrong due to floating + // point limits. The correct result should equals variance(90, 0) = 4050. + std::vector chunks = {"[40000008000000490]", "[40000008000000400]"}; + this->AssertVarStdIs(chunks, options, 3904.0); +#endif +} + +// https://en.wikipedia.org/wiki/Kahan_summation_algorithm +void KahanSum(double& sum, double& adjust, double addend) { + double y = addend - adjust; + double t = sum + y; + adjust = (t - sum) - y; + sum = t; } -// Calculate reference variance with Welford's online algorithm +// Calculate reference variance with Welford's online algorithm + Kahan summation // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm std::pair WelfordVar(const Array& array) { const auto& array_numeric = reinterpret_cast(array); const auto values = array_numeric.raw_values(); internal::BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length()); double count = 0, mean = 0, m2 = 0; + double mean_adjust = 0, m2_adjust = 0; for (int64_t i = 0; i < array.length(); ++i) { if (reader.IsSet()) { ++count; double delta = values[i] - mean; - mean += delta / count; + KahanSum(mean, mean_adjust, delta / count); double delta2 = values[i] - mean; - m2 += delta * delta2; + KahanSum(m2, m2_adjust, delta * delta2); } reader.Next(); } @@ -1116,8 +1128,8 @@ TEST_F(TestVarStdKernelRandom, Basics) { double var_population, var_sample; std::tie(var_population, var_sample) = WelfordVar(*(array->Slice(0, total_size))); - this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population, 0.0001); - this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample, 0.0001); + this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population); + this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index e2b98bb38fc..327372ad486 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -53,32 +53,33 @@ struct VarStdState { []() {}); this->count = count; - this->sum = sum; + this->mean = mean; this->m2 = m2; } - // Combine `m2` from two chunks - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + // Combine `m2` from two chunks (m2 = n*s2) + // https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html void MergeFrom(const ThisType& state) { if (state.count == 0) { return; } if (this->count == 0) { this->count = state.count; - this->sum = state.sum; + this->mean = state.mean; this->m2 = state.m2; return; } - double delta = this->sum / this->count - state.sum / state.count; - this->m2 += state.m2 + - delta * delta * this->count * state.count / (this->count + state.count); + double mean = (this->mean * this->count + state.mean * state.count) / + (this->count + state.count); + this->m2 += state.m2 + this->count * (this->mean - mean) * (this->mean - mean) + + state.count * (state.mean - mean) * (state.mean - mean); this->count += state.count; - this->sum += state.sum; + this->mean = mean; } int64_t count = 0; - double sum = 0; - double m2 = 0; // sum((X-mean)^2) + double mean = 0; + double m2 = 0; // m2 = count*s2 = sum((X-mean)^2) }; enum class VarOrStd : bool { Var, Std }; diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index e9ea2539e89..f49103a585a 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -156,6 +156,9 @@ TEST_F(TestPartitioning, DiscoverSchema) { // fall back to string if any segment for field alpha is not parseable as int AssertInspect({"/0/1", "/hello/1"}, {Str("alpha"), Int("beta")}); + // If there are too many digits fall back to string + AssertInspect({"/3760212050/1"}, {Str("alpha"), Int("beta")}); + // missing segment for beta doesn't cause an error or fallback AssertInspect({"/0/1", "/hello"}, {Str("alpha"), Int("beta")}); } @@ -168,6 +171,9 @@ TEST_F(TestPartitioning, DictionaryInference) { // type is still int32 if possible AssertInspect({"/0/1"}, {DictInt("alpha"), DictInt("beta")}); + // If there are too many digits fall back to string + AssertInspect({"/3760212050/1"}, {DictStr("alpha"), DictInt("beta")}); + // successful dictionary inference AssertInspect({"/a/0"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/a/0", "/a/1"}, {DictStr("alpha"), DictInt("beta")}); @@ -256,6 +262,9 @@ TEST_F(TestPartitioning, DiscoverHiveSchema) { // (...so ensure your partitions are ordered the same for all paths) AssertInspect({"/alpha=0/beta=1", "/beta=2/alpha=3"}, {Int("alpha"), Int("beta")}); + // If there are too many digits fall back to string + AssertInspect({"/alpha=3760212050"}, {Str("alpha")}); + // missing path segments will not cause an error AssertInspect({"/alpha=0/beta=1", "/beta=2/alpha=3", "/gamma=what"}, {Int("alpha"), Int("beta"), Str("gamma")}); @@ -269,6 +278,9 @@ TEST_F(TestPartitioning, HiveDictionaryInference) { // type is still int32 if possible AssertInspect({"/alpha=0/beta=1"}, {DictInt("alpha"), DictInt("beta")}); + // If there are too many digits fall back to string + AssertInspect({"/alpha=3760212050"}, {DictStr("alpha")}); + // successful dictionary inference AssertInspect({"/alpha=a/beta=0"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/alpha=a/beta=0", "/alpha=a/1"}, {DictStr("alpha"), DictInt("beta")}); diff --git a/cpp/src/arrow/ipc/metadata_internal.cc b/cpp/src/arrow/ipc/metadata_internal.cc index a82aef328d6..8564b71ec20 100644 --- a/cpp/src/arrow/ipc/metadata_internal.cc +++ b/cpp/src/arrow/ipc/metadata_internal.cc @@ -427,8 +427,7 @@ static Status GetDictionaryEncoding(FBB& fbb, const std::shared_ptr& fiel const DictionaryType& type, int64_t dictionary_id, DictionaryOffset* out) { // We assume that the dictionary index type (as an integer) has already been - // validated elsewhere, and can safely assume we are dealing with signed - // integers + // validated elsewhere, and can safely assume we are dealing with integers const auto& index_type = checked_cast(*type.index_type()); auto index_type_offset = diff --git a/cpp/src/arrow/util/utf8.h b/cpp/src/arrow/util/utf8.h index d5875c4590b..c089fa7fff6 100644 --- a/cpp/src/arrow/util/utf8.h +++ b/cpp/src/arrow/util/utf8.h @@ -27,6 +27,7 @@ #include "arrow/util/macros.h" #include "arrow/util/simd.h" #include "arrow/util/string_view.h" +#include "arrow/util/ubsan.h" #include "arrow/util/visibility.h" namespace arrow { @@ -87,8 +88,9 @@ ARROW_EXPORT void InitializeUTF8(); inline bool ValidateUTF8(const uint8_t* data, int64_t size) { static constexpr uint64_t high_bits_64 = 0x8080808080808080ULL; - // For some reason, defining this variable outside the loop helps clang - uint64_t mask; + static constexpr uint32_t high_bits_32 = 0x80808080UL; + static constexpr uint16_t high_bits_16 = 0x8080U; + static constexpr uint8_t high_bits_8 = 0x80U; #ifndef NDEBUG internal::CheckUTF8Initialized(); @@ -98,8 +100,8 @@ inline bool ValidateUTF8(const uint8_t* data, int64_t size) { // XXX This is doing an unaligned access. Contemporary architectures // (x86-64, AArch64, PPC64) support it natively and often have good // performance nevertheless. - memcpy(&mask, data, 8); - if (ARROW_PREDICT_TRUE((mask & high_bits_64) == 0)) { + uint64_t mask64 = SafeLoadAs(data); + if (ARROW_PREDICT_TRUE((mask64 & high_bits_64) == 0)) { // 8 bytes of pure ASCII, move forward size -= 8; data += 8; @@ -154,13 +156,50 @@ inline bool ValidateUTF8(const uint8_t* data, int64_t size) { return false; } - // Validate string tail one byte at a time + // Check if string tail is full ASCII (common case, fast) + if (size >= 4) { + uint32_t tail_mask = SafeLoadAs(data + size - 4); + uint32_t head_mask = SafeLoadAs(data); + if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_32) == 0)) { + return true; + } + } else if (size >= 2) { + uint16_t tail_mask = SafeLoadAs(data + size - 2); + uint16_t head_mask = SafeLoadAs(data); + if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_16) == 0)) { + return true; + } + } else if (size == 1) { + if (ARROW_PREDICT_TRUE((*data & high_bits_8) == 0)) { + return true; + } + } else { + /* size == 0 */ + return true; + } + + // Fall back to UTF8 validation of tail string. // Note the state table is designed so that, once in the reject state, // we remain in that state until the end. So we needn't check for // rejection at each char (we don't gain much by short-circuiting here). uint16_t state = internal::kUTF8ValidateAccept; - while (size-- > 0) { - state = internal::ValidateOneUTF8Byte(*data++, state); + switch (size) { + case 7: + state = internal::ValidateOneUTF8Byte(data[size - 7], state); + case 6: + state = internal::ValidateOneUTF8Byte(data[size - 6], state); + case 5: + state = internal::ValidateOneUTF8Byte(data[size - 5], state); + case 4: + state = internal::ValidateOneUTF8Byte(data[size - 4], state); + case 3: + state = internal::ValidateOneUTF8Byte(data[size - 3], state); + case 2: + state = internal::ValidateOneUTF8Byte(data[size - 2], state); + case 1: + state = internal::ValidateOneUTF8Byte(data[size - 1], state); + default: + break; } return ARROW_PREDICT_TRUE(state == internal::kUTF8ValidateAccept); } diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 99c23f99cd9..0ae5a193f53 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -224,6 +224,7 @@ add_gandiva_test(internals-test like_holder_test.cc decimal_type_util_test.cc random_generator_holder_test.cc + gdv_function_stubs_test.cc EXTRA_DEPENDENCIES LLVM::LLVM_INTERFACE EXTRA_INCLUDES diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 2c71126aafe..ea3af5b45c9 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -16,6 +16,7 @@ // under the License. #include "gandiva/function_registry_string.h" + #include "gandiva/function_registry_common.h" namespace gandiva { @@ -61,17 +62,26 @@ std::vector GetStringFunctionRegistry() { UNARY_SAFE_NULL_NEVER_BOOL_FN(isnull, {}), UNARY_SAFE_NULL_NEVER_BOOL_FN(isnotnull, {}), - UNARY_UNSAFE_NULL_IF_NULL(castINT, {}, utf8, int32), - UNARY_UNSAFE_NULL_IF_NULL(castBIGINT, {}, utf8, int64), - UNARY_UNSAFE_NULL_IF_NULL(castFLOAT4, {}, utf8, float32), - UNARY_UNSAFE_NULL_IF_NULL(castFLOAT8, {}, utf8, float64), - NativeFunction("upper", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, "upper_utf8", NativeFunction::kNeedsContext), NativeFunction("lower", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, "lower_utf8", NativeFunction::kNeedsContext), + NativeFunction("castINT", {}, DataTypeVector{utf8()}, int32(), kResultNullIfNull, + "gdv_fn_castINT_utf8", NativeFunction::kNeedsContext), + + NativeFunction("castBIGINT", {}, DataTypeVector{utf8()}, int64(), kResultNullIfNull, + "gdv_fn_castBIGINT_utf8", NativeFunction::kNeedsContext), + + NativeFunction("castFLOAT4", {}, DataTypeVector{utf8()}, float32(), + kResultNullIfNull, "gdv_fn_castFLOAT4_utf8", + NativeFunction::kNeedsContext), + + NativeFunction("castFLOAT8", {}, DataTypeVector{utf8()}, float64(), + kResultNullIfNull, "gdv_fn_castFLOAT8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("castVARCHAR", {}, DataTypeVector{utf8(), int64()}, utf8(), kResultNullIfNull, "castVARCHAR_utf8_int64", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index ad3036f96b5..ad93ce8c412 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -20,6 +20,7 @@ #include #include +#include "arrow/util/value_parsing.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" #include "gandiva/in_holder.h" @@ -150,6 +151,37 @@ char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, memcpy(ret, dec_str.data(), *dec_str_len); return ret; } + +#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ + GANDIVA_EXPORT \ + OUT_TYPE gdv_fn_cast##TYPE_NAME##_utf8(int64_t context, const char* data, \ + int32_t len) { \ + OUT_TYPE val = 0; \ + /* trim leading and trailing spaces */ \ + int32_t trimmed_len; \ + int32_t start = 0, end = len - 1; \ + while (start <= end && data[start] == ' ') { \ + ++start; \ + } \ + while (end >= start && data[end] == ' ') { \ + --end; \ + } \ + trimmed_len = end - start + 1; \ + const char* trimmed_data = data + start; \ + if (!arrow::internal::ParseValue(trimmed_data, trimmed_len, &val)) { \ + std::string err = \ + "Failed to cast the string " + std::string(data, len) + " to " #OUT_TYPE; \ + gdv_fn_context_set_error_msg(context, err.c_str()); \ + } \ + return val; \ + } + +CAST_NUMERIC_FROM_STRING(int32_t, arrow::Int32Type, INT) +CAST_NUMERIC_FROM_STRING(int64_t, arrow::Int64Type, BIGINT) +CAST_NUMERIC_FROM_STRING(float, arrow::FloatType, FLOAT4) +CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8) + +#undef CAST_NUMERIC_FROM_STRING } namespace gandiva { @@ -277,6 +309,34 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { args = {types->i64_type(), types->i32_type(), types->i1_type()}; engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed", types->double_type(), args, reinterpret_cast(gdv_fn_random_with_seed)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castINT_utf8", types->i32_type(), args, + reinterpret_cast(gdv_fn_castINT_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_utf8", types->i64_type(), args, + reinterpret_cast(gdv_fn_castBIGINT_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_utf8", types->float_type(), args, + reinterpret_cast(gdv_fn_castFLOAT4_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_utf8", types->double_type(), args, + reinterpret_cast(gdv_fn_castFLOAT8_utf8)); } } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 4d66aa3e987..457f42511cc 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -19,6 +19,8 @@ #include +#include "gandiva/visibility.h" + /// Stub functions that can be accessed from LLVM. extern "C" { @@ -52,4 +54,16 @@ int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_lengt char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, int32_t x_scale, int32_t* dec_str_len); + +GANDIVA_EXPORT +int32_t gdv_fn_castINT_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +int64_t gdv_fn_castBIGINT_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +float gdv_fn_castFLOAT4_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +double gdv_fn_castFLOAT8_utf8(int64_t context, const char* data, int32_t data_len); } diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc new file mode 100644 index 00000000000..90ac1dfa540 --- /dev/null +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/gdv_function_stubs.h" + +#include +#include + +#include "gandiva/execution_context.h" + +namespace gandiva { + +TEST(TestGdvFnStubs, TestCastINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "2147483647", 10), 2147483647); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "02147483647", 11), 2147483647); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-2147483648", 11), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-02147483648", 12), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castINT_utf8(ctx_ptr, "2147483648", 10); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 2147483648 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "-2147483649", 11); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -2147483649 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int32")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastBIGINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775807", 19), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "09223372036854775807", 20), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775808", 20), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-009223372036854775808", 22), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775808", 19); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775809", 20); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int64")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastFloat4) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "-45.34", 6), -45.34f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "0", 1), 0.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "5", 1), 5.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, " 3.4 ", 5), 3.4f); + + gdv_fn_castFLOAT4_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to float")); + ctx.Reset(); + + gdv_fn_castFLOAT4_utf8(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to float")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastFloat8) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "-45.34", 6), -45.34); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "0", 1), 0.0); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "5", 1), 5.0); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, " 3.4 ", 5), 3.4); + + gdv_fn_castFLOAT8_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to double")); + ctx.Reset(); + + gdv_fn_castFLOAT8_utf8(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to double")); + ctx.Reset(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 34dd011ffb3..0432d6c761c 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -23,6 +23,7 @@ extern "C" { #include #include #include + #include "./types.h" FORCE_INLINE @@ -1439,27 +1440,4 @@ const char* binary_string(gdv_int64 context, const char* text, gdv_int32 text_le return ret; } -#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ - FORCE_INLINE \ - gdv_##OUT_TYPE cast##TYPE_NAME##_utf8(int64_t context, const char* data, \ - int32_t len) { \ - gdv_##OUT_TYPE val = 0; \ - int32_t trimmed_len; \ - data = btrim_utf8(context, data, len, &trimmed_len); \ - if (!arrow::internal::ParseValue(data, trimmed_len, &val)) { \ - std::string err = "Failed to cast the string " + std::string(data, trimmed_len) + \ - " to " #OUT_TYPE; \ - gdv_fn_context_set_error_msg(context, err.c_str()); \ - } \ - return val; \ - } - -CAST_NUMERIC_FROM_STRING(int32, arrow::Int32Type, INT) -CAST_NUMERIC_FROM_STRING(int64, arrow::Int64Type, BIGINT) -CAST_NUMERIC_FROM_STRING(float32, arrow::FloatType, FLOAT4) -CAST_NUMERIC_FROM_STRING(float64, arrow::DoubleType, FLOAT8) - -#undef CAST_INT_FROM_STRING -#undef CAST_FLOAT_FROM_STRING - } // extern "C" diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index 9bb44af9a1b..b1836d877ab 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -17,6 +17,7 @@ #include #include + #include "gandiva/execution_context.h" #include "gandiva/precompiled/types.h" @@ -1002,138 +1003,4 @@ TEST(TestStringOps, TestSplitPart) { EXPECT_EQ(std::string(out_str, out_len), "ååçåå"); } -TEST(TestArithmeticOps, TestCastINT) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castINT_utf8(ctx_ptr, "-45", 3), -45); - EXPECT_EQ(castINT_utf8(ctx_ptr, "0", 1), 0); - EXPECT_EQ(castINT_utf8(ctx_ptr, "2147483647", 10), 2147483647); - EXPECT_EQ(castINT_utf8(ctx_ptr, "02147483647", 11), 2147483647); - EXPECT_EQ(castINT_utf8(ctx_ptr, "-2147483648", 11), -2147483648LL); - EXPECT_EQ(castINT_utf8(ctx_ptr, "-02147483648", 12), -2147483648LL); - EXPECT_EQ(castINT_utf8(ctx_ptr, " 12 ", 4), 12); - - castINT_utf8(ctx_ptr, "2147483648", 10); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 2147483648 to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "-2147483649", 11); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string -2147483649 to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "12.34", 5); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 12.34 to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "abc", 3); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string abc to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "-", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string - to int32")); - ctx.Reset(); -} - -TEST(TestArithmeticOps, TestCastBIGINT) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "-45", 3), -45); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "0", 1), 0); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "9223372036854775807", 19), 9223372036854775807LL); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "09223372036854775807", 20), 9223372036854775807LL); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "-9223372036854775808", 20), - -9223372036854775807LL - 1); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "-009223372036854775808", 22), - -9223372036854775807LL - 1); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, " 12 ", 4), 12); - - castBIGINT_utf8(ctx_ptr, "9223372036854775808", 19); - EXPECT_THAT( - ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "-9223372036854775809", 20); - EXPECT_THAT( - ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "12.34", 5); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 12.34 to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "abc", 3); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string abc to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "-", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string - to int64")); - ctx.Reset(); -} - -TEST(TestArithmeticOps, TestCastFloat4) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, "-45.34", 6), -45.34f); - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, "0", 1), 0.0f); - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, "5", 1), 5.0f); - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, " 3.4 ", 5), 3.4f); - - castFLOAT4_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to float32")); - ctx.Reset(); - - castFLOAT4_utf8(ctx_ptr, "e", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string e to float32")); - ctx.Reset(); -} - -TEST(TestParseStringHolder, TestCastFloat8) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, "-45.34", 6), -45.34); - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, "0", 1), 0.0); - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, "5", 1), 5.0); - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, " 3.4 ", 5), 3.4); - - castFLOAT8_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to float64")); - ctx.Reset(); - - castFLOAT8_utf8(ctx_ptr, "e", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string e to float64")); - ctx.Reset(); -} - } // namespace gandiva diff --git a/cpp/src/parquet/CMakeLists.txt b/cpp/src/parquet/CMakeLists.txt index 22ad69219a3..a5e42f7be13 100644 --- a/cpp/src/parquet/CMakeLists.txt +++ b/cpp/src/parquet/CMakeLists.txt @@ -203,7 +203,7 @@ set(PARQUET_SRCS stream_writer.cc types.cc) -if(CXX_SUPPORTS_AVX2) +if(ARROW_HAVE_RUNTIME_AVX2) # AVX2 is used as a proxy for BMI2. list(APPEND PARQUET_SRCS level_comparison_avx2.cc level_conversion_bmi2.cc) set_source_files_properties(level_comparison_avx2.cc diff --git a/dev/archery/archery/bot.py b/dev/archery/archery/bot.py index baa5210130d..d222d1ef377 100644 --- a/dev/archery/archery/bot.py +++ b/dev/archery/archery/bot.py @@ -253,13 +253,15 @@ def crossbow(obj, crossbow): @crossbow.command() -@click.argument('task', nargs=-1, required=False) -@click.option('--group', '-g', multiple=True, +@click.argument('tasks', nargs=-1, required=False) +@click.option('--group', '-g', 'groups', multiple=True, help='Submit task groups as defined in tests.yml') +@click.option('--param', '-p', 'params', multiple=True, + help='Additional task parameters for rendering the CI templates') @click.option('--dry-run/--push', default=False, help='Just display the new changelog, don\'t write it') @click.pass_obj -def submit(obj, task, group, dry_run): +def submit(obj, tasks, groups, params, dry_run): """Submit crossbow testing tasks. See groups defined in arrow/dev/tasks/tests.yml @@ -273,9 +275,11 @@ def submit(obj, task, group, dry_run): if dry_run: args.append('--dry-run') - for g in group: + for p in params: + args.extend(['-p', p]) + for g in groups: args.extend(['-g', g]) - for t in task: + for t in tasks: args.append(t) # pygithub pull request object diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 6fd72ccc542..e0f5f0e4a90 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -213,7 +213,6 @@ setup_tempdir() { fi } - setup_miniconda() { # Setup short-lived miniconda for Python and integration tests if [ "$(uname)" == "Darwin" ]; then @@ -230,16 +229,18 @@ setup_miniconda() { bash miniconda.sh -b -p $MINICONDA rm -f miniconda.sh fi + echo "Installed miniconda at ${MINICONDA}" . $MINICONDA/etc/profile.d/conda.sh conda create -n arrow-test -y -q -c conda-forge \ - python=3.6 \ - nomkl \ - numpy \ - pandas \ - cython + python=3.6 \ + nomkl \ + numpy \ + pandas \ + cython conda activate arrow-test + echo "Using conda environment ${CONDA_PREFIX}" } # Build and test Java (Requires newer Maven -- I used 3.3.9) @@ -374,7 +375,7 @@ test_python() { fi python setup.py build_ext --inplace - py.test pyarrow -v --pdb + pytest pyarrow -v --pdb popd } @@ -778,15 +779,16 @@ cd ${ARROW_TMPDIR} if [ ${NEED_MINICONDA} -gt 0 ]; then setup_miniconda - echo "Using miniconda environment ${MINICONDA}" fi if [ "${ARTIFACT}" == "source" ]; then dist_name="apache-arrow-${VERSION}" if [ ${TEST_SOURCE} -gt 0 ]; then import_gpg_keys - fetch_archive ${dist_name} - tar xf ${dist_name}.tar.gz + if [ ! -d "${dist_name}" ]; then + fetch_archive ${dist_name} + tar xf ${dist_name}.tar.gz + fi else mkdir -p ${dist_name} if [ ! -f ${TEST_ARCHIVE} ]; then diff --git a/dev/tasks/crossbow.py b/dev/tasks/crossbow.py index 5981d56613e..a68794c3ac1 100755 --- a/dev/tasks/crossbow.py +++ b/dev/tasks/crossbow.py @@ -582,7 +582,8 @@ def put(self, job, prefix='build'): # adding CI's name to the end of the branch in order to use skip # patterns on travis and circleci task.branch = '{}-{}-{}'.format(job.branch, task.ci, task_name) - files = task.render_files(arrow=job.target, + files = task.render_files(**job.params, + arrow=job.target, queue_remote_url=self.remote_url) branch = self.create_branch(task.branch, files=files) self.create_tag(task.tag, branch.target) @@ -709,12 +710,12 @@ def __init__(self, ci, template, artifacts=None, params=None): self._status = None # status cache self._assets = None # assets cache - def render_files(self, **extra_params): + def render_files(self, **params): from jinja2 import Template, StrictUndefined from jinja2.exceptions import TemplateError path = CWD / self.template - params = toolz.merge(self.params, extra_params) + params = toolz.merge(self.params, params) template = Template(path.read_text(), undefined=StrictUndefined) try: rendered = template.render(task=self, **params) @@ -871,15 +872,21 @@ def uploaded_assets(self): class Job(Serializable): """Describes multiple tasks against a single target repository""" - def __init__(self, target, tasks): + def __init__(self, target, tasks, params=None): if not tasks: raise ValueError('no tasks were provided for the job') if not all(isinstance(task, Task) for task in tasks.values()): raise ValueError('each `tasks` mus be an instance of Task') if not isinstance(target, Target): raise ValueError('`target` must be an instance of Target') + if not isinstance(target, Target): + raise ValueError('`target` must be an instance of Target') + if not isinstance(params, dict): + raise ValueError('`params` must be an instance of dict') + self.target = target self.tasks = tasks + self.params = params or {} # additional parameters for the tasks self.branch = None # filled after adding to a queue self._queue = None # set by the queue object after put or get @@ -911,7 +918,7 @@ def date(self): return self.queue.date_of(self) @classmethod - def from_config(cls, config, target, tasks=None, groups=None): + def from_config(cls, config, target, tasks=None, groups=None, params=None): """ Intantiate a job from based on a config. @@ -923,9 +930,11 @@ def from_config(cls, config, target, tasks=None, groups=None): Describes target repository and revision the builds run against. tasks : Optional[List[str]], default None List of glob patterns for matching task names. - groups : tasks : Optional[List[str]], default None + groups : Optional[List[str]], default None List of exact group names matching predefined task sets in the config. + params : Optional[Dict[str, str]], default None + Additional rendering parameters for the task templates. Returns ------- @@ -948,7 +957,7 @@ def from_config(cls, config, target, tasks=None, groups=None): artifacts = [fn.format(**versions) for fn in artifacts] tasks[task_name] = Task(artifacts=artifacts, **task) - return cls(target=target, tasks=tasks) + return cls(target=target, tasks=tasks, params=params) def is_finished(self): for task in self.tasks.values(): @@ -1408,6 +1417,8 @@ def check_config(config_path): @click.argument('tasks', nargs=-1, required=False) @click.option('--group', '-g', 'groups', multiple=True, help='Submit task groups as defined in task.yml') +@click.option('--param', '-p', 'params', multiple=True, + help='Additional task parameters for rendering the CI templates') @click.option('--job-prefix', default='build', help='Arbitrary prefix for branch names, e.g. nightly') @click.option('--config-path', '-c', @@ -1429,7 +1440,7 @@ def check_config(config_path): help='Just display the rendered CI configurations without ' 'submitting them') @click.pass_obj -def submit(obj, tasks, groups, job_prefix, config_path, arrow_version, +def submit(obj, tasks, groups, params, job_prefix, config_path, arrow_version, arrow_remote, arrow_branch, arrow_sha, dry_run): output = obj['output'] queue, arrow = obj['queue'], obj['arrow'] @@ -1448,9 +1459,12 @@ def submit(obj, tasks, groups, job_prefix, config_path, arrow_version, target = Target.from_repo(arrow, remote=arrow_remote, branch=arrow_branch, head=arrow_sha, version=arrow_version) + # parse additional job parameters + params = dict([p.split("=") for p in params]) + # instantiate the job object job = Job.from_config(config=config, target=target, tasks=tasks, - groups=groups) + groups=groups, params=params) if dry_run: yaml.dump(job, output) diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 48823c4f6ea..d4fd68bd8ef 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1358,179 +1358,208 @@ tasks: verify-rc-binaries-binary: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_BINARY: 1 artifact: "binaries" - flag: "TEST_BINARY=1" verify-rc-binaries-apt: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_APT: 1 artifact: "binaries" - flag: "TEST_APT=1" verify-rc-binaries-yum: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_YUM: 1 artifact: "binaries" - flag: "TEST_YUM=1" verify-rc-wheels-linux: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 artifact: "wheels" - flag: "" verify-rc-wheels-macos: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + TEST_DEFAULT: 0 artifact: "wheels" - flag: "" verify-rc-source-macos-java: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + TEST_DEFAULT: 0 + TEST_JAVA: 1 artifact: "source" - flag: "TEST_JAVA=1" verify-rc-source-macos-csharp: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + TEST_DEFAULT: 0 + TEST_CSHARP: 1 artifact: "source" - flag: "TEST_CSHARP=1" verify-rc-source-macos-ruby: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + TEST_DEFAULT: 0 + TEST_RUBY: 1 artifact: "source" - flag: "TEST_RUBY=1" verify-rc-source-macos-python: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + TEST_DEFAULT: 0 + TEST_PYTHON: 1 + # https://stackoverflow.com/questions/56083725/macos-build-issues-lstdc-not-found-while-building-python-package + MACOSX_DEPLOYMENT_TARGET: "10.15" artifact: "source" - flag: "TEST_PYTHON=1" verify-rc-source-macos-js: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + INSTALL_NODE: 0 + TEST_DEFAULT: 0 + TEST_JS: 1 artifact: "source" - flag: "TEST_JS=1" verify-rc-source-macos-go: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + TEST_DEFAULT: 0 + TEST_GO: 1 artifact: "source" - flag: "TEST_GO=1" verify-rc-source-macos-rust: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + TEST_DEFAULT: 0 + TEST_RUST: 1 artifact: "source" - flag: "TEST_RUST=1" verify-rc-source-macos-integration: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.osx.yml params: - os: "macOS" + env: + INSTALL_NODE: 0 + TEST_DEFAULT: 0 + TEST_INTEGRATION: 1 artifact: "source" - flag: "TEST_INTEGRATION=1" verify-rc-source-linux-java: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_JAVA: 1 artifact: "source" - flag: "TEST_JAVA=1" verify-rc-source-linux-csharp: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_CSHARP: 1 artifact: "source" - flag: "TEST_CSHARP=1" verify-rc-source-linux-ruby: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_RUBY: 1 artifact: "source" - flag: "TEST_RUBY=1" verify-rc-source-linux-python: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_PYTHON: 1 artifact: "source" - flag: "TEST_PYTHON=1" verify-rc-source-linux-js: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + INSTALL_NODE: 0 + TEST_DEFAULT: 0 + TEST_JS: 1 artifact: "source" - flag: "TEST_JS=1" verify-rc-source-linux-go: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_GO: 1 artifact: "source" - flag: "TEST_GO=1" verify-rc-source-linux-rust: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + TEST_DEFAULT: 0 + TEST_RUST: 1 artifact: "source" - flag: "TEST_RUST=1" verify-rc-source-linux-integration: ci: github - template: verify-rc/github.nix.yml + template: verify-rc/github.linux.yml params: - os: "ubuntu" + env: + INSTALL_NODE: 0 + TEST_DEFAULT: 0 + TEST_INTEGRATION: 1 artifact: "source" - flag: "TEST_INTEGRATION=1" verify-rc-source-windows: ci: github - template: verify-rc/github.windows.source.yml + template: verify-rc/github.win.yml + params: + script: "verify-release-candidate.bat" verify-rc-wheels-windows: ci: github - template: verify-rc/github.windows.wheels.yml + template: verify-rc/github.win.yml + params: + script: "verify-release-candidate-wheels.bat" ############################## Docker tests ################################# diff --git a/dev/tasks/verify-rc/github.linux.yml b/dev/tasks/verify-rc/github.linux.yml new file mode 100644 index 00000000000..49d937ac6fa --- /dev/null +++ b/dev/tasks/verify-rc/github.linux.yml @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: must set "Crossbow" as name to have the badge links working in the +# github comment reports! +name: Crossbow + +on: + push: + branches: + - "*-github-*" + +jobs: + verify: + name: "Verify release candidate Ubuntu {{ artifact }}" + runs-on: ubuntu-latest + {%- if env is defined %} + env: + {%- for key, value in env.items() %} + {{ key }}: {{ value }} + {%- endfor %} + {%- endif %} + steps: + - name: Checkout Arrow + run: | + git clone --no-checkout {{ arrow.remote }} arrow + git -C arrow fetch -t {{ arrow.remote }} {{ arrow.branch }} + git -C arrow checkout FETCH_HEAD + git -C arrow submodule update --init --recursive + - name: Fetch Submodules and Tags + shell: bash + run: cd arrow && ci/scripts/util_checkout.sh + - name: Install System Dependencies + run: | + # TODO: don't require removing newer llvms + sudo apt-get --purge remove -y llvm-9 clang-9 + sudo apt-get install -y \ + wget curl libboost-all-dev jq \ + autoconf-archive gtk-doc-tools libgirepository1.0-dev flex bison + + if [ "$TEST_JAVA" = "1" ]; then + # Maven + MAVEN_VERSION=3.6.3 + wget https://downloads.apache.org/maven/maven-3/$MAVEN_VERSION/binaries/apache-maven-$MAVEN_VERSION-bin.zip + unzip apache-maven-$MAVEN_VERSION-bin.zip + mkdir -p $HOME/java + mv apache-maven-$MAVEN_VERSION $HOME/java + export PATH=$HOME/java/apache-maven-$MAVEN_VERSION/bin:$PATH + fi + + if [ "$TEST_RUBY" = "1" ]; then + ruby --version + sudo gem install bundler + fi + - uses: actions/setup-node@v2-beta + with: + node-version: '14' + - name: Run verification + shell: bash + run: | + arrow/dev/release/verify-release-candidate.sh \ + {{ artifact }} \ + {{ release|default("1.0.0") }} {{ rc|default("0") }} diff --git a/dev/tasks/verify-rc/github.nix.yml b/dev/tasks/verify-rc/github.nix.yml deleted file mode 100644 index 8482cdc97ca..00000000000 --- a/dev/tasks/verify-rc/github.nix.yml +++ /dev/null @@ -1,82 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# NOTE: must set "Crossbow" as name to have the badge links working in the -# github comment reports! -name: Crossbow - -on: - push: - branches: - - "*-github-*" - -jobs: - verify: - name: "Verify release candidate {{ os }} {{ artifact }} {{ flag }}" - runs-on: {{ os }}-latest - steps: - - name: Checkout Arrow - run: | - git clone --no-checkout {{ arrow.remote }} arrow - git -C arrow fetch -t {{ arrow.remote }} {{ arrow.branch }} - git -C arrow checkout FETCH_HEAD - git -C arrow submodule update --init --recursive - - name: Free Up Disk Space - shell: bash - run: arrow/ci/scripts/util_cleanup.sh - - name: Fetch Submodules and Tags - shell: bash - run: cd arrow && ci/scripts/util_checkout.sh - - name: Run verification - shell: bash - env: - INSTALL_NODE: 0 - run: | - set -e - - {{ flag }} - if [ $(uname) = "Darwin" ]; then - brew update - brew bundle --file=arrow/cpp/Brewfile - brew bundle --file=arrow/c_glib/Brewfile - if [ "$TEST_PYTHON" = "1" ]; then - # https://stackoverflow.com/questions/56083725/macos-build-issues-lstdc-not-found-while-building-python-package - export MACOSX_DEPLOYMENT_TARGET=10.9 - fi - else - # TODO: don't require removing newer llvms - sudo apt-get --purge remove -y llvm-9 clang-9 - sudo apt-get install -y \ - wget curl libboost-all-dev jq \ - autoconf-archive gtk-doc-tools libgirepository1.0-dev flex bison - if [ "$TEST_JAVA" = "1" ]; then - # Maven - MAVEN_VERSION=3.6.3 - wget https://downloads.apache.org/maven/maven-3/$MAVEN_VERSION/binaries/apache-maven-$MAVEN_VERSION-bin.zip - unzip apache-maven-$MAVEN_VERSION-bin.zip - mkdir -p $HOME/java - mv apache-maven-$MAVEN_VERSION $HOME/java - export PATH=$HOME/java/apache-maven-$MAVEN_VERSION/bin:$PATH - fi - if [ "$TEST_RUBY" = "1" ]; then - ruby --version - sudo gem install bundler - fi - fi - # TODO: put version and rc number in some separate file? - # If you edit the versions, be sure to edit the other workflow files in this directory too - TEST_DEFAULT=0 {{ flag }} arrow/dev/release/verify-release-candidate.sh {{ artifact }} 0.17.0 0 diff --git a/dev/tasks/verify-rc/github.windows.wheels.yml b/dev/tasks/verify-rc/github.osx.yml similarity index 67% rename from dev/tasks/verify-rc/github.windows.wheels.yml rename to dev/tasks/verify-rc/github.osx.yml index 082c2aa04ca..a0f6fc4af4e 100644 --- a/dev/tasks/verify-rc/github.windows.wheels.yml +++ b/dev/tasks/verify-rc/github.osx.yml @@ -26,8 +26,14 @@ on: jobs: verify: - name: "Verify release candidate Windows wheels" - runs-on: windows-latest + name: "Verify release candidate macOS {{ artifact }}" + runs-on: macos-latest + {%- if env is defined %} + env: + {%- for key, value in env.items() %} + {{ key }}: {{ value }} + {%- endfor %} + {%- endif %} steps: - name: Checkout Arrow run: | @@ -38,11 +44,18 @@ jobs: - name: Fetch Submodules and Tags shell: bash run: cd arrow && ci/scripts/util_checkout.sh - - uses: s-weigand/setup-conda@v1 + - name: Install System Dependencies + shell: bash + run: | + brew update + brew bundle --file=arrow/cpp/Brewfile + brew bundle --file=arrow/c_glib/Brewfile + - uses: actions/setup-node@v2-beta + with: + node-version: '14' - name: Run verification - shell: cmd + shell: bash run: | - choco install wget - cd arrow - # If you edit the versions, be sure to edit the other workflow files in this directory too - dev/release/verify-release-candidate-wheels.bat 0.17.0 0 + arrow/dev/release/verify-release-candidate.sh \ + {{ artifact }} \ + {{ release|default("1.0.0") }} {{ rc|default("0") }} diff --git a/dev/tasks/verify-rc/github.windows.source.yml b/dev/tasks/verify-rc/github.win.yml similarity index 83% rename from dev/tasks/verify-rc/github.windows.source.yml rename to dev/tasks/verify-rc/github.win.yml index d236bb0a2a5..fbe0ee26812 100644 --- a/dev/tasks/verify-rc/github.windows.source.yml +++ b/dev/tasks/verify-rc/github.win.yml @@ -27,7 +27,13 @@ on: jobs: verify: name: "Verify release candidate Windows source" - runs-on: windows-latest + runs-on: windows-2016 + {%- if env is defined %} + env: + {%- for key, value in env.items() %} + {{ key }}: {{ value }} + {%- endfor %} + {%- endif %} steps: - name: Checkout Arrow run: | @@ -39,11 +45,12 @@ jobs: shell: bash run: cd arrow && ci/scripts/util_checkout.sh - uses: s-weigand/setup-conda@v1 - - name: Run verification - shell: cmd + - name: Install System Dependencies run: | choco install boost-msvc-14.1 choco install wget + - name: Run verification + shell: cmd + run: | cd arrow - # If you edit the versions, be sure to edit the other workflow files in this directory too - dev/release/verify-release-candidate.bat 0.17.0 0 + dev/release/{{ script }} {{ release|default("1.0.0") }} {{ rc|default("0") }} diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index 753cdf6a10a..85ac83b42da 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -1741,4 +1741,191 @@ public void testCaseInsensitiveFunctions() throws Exception { releaseValueVectors(output); } + @Test + public void testCastInt() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode), + int32); + Field resultField = Field.nullable("result", int32); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castINTFn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "0", "123", "-123", "-1", "1" + }; + int[] expValues = + new int[] { + 0, 123, -123, -1, 1 + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + output.add(intVector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + IntVector intVector = (IntVector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(intVector.isNull(j)); + assertTrue(expValues[j] == intVector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test(expected = GandivaException.class) + public void testCastIntInvalidValue() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode), + int32); + Field resultField = Field.nullable("result", int32); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castINTFn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 1; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "abc" + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + output.add(intVector); + } + try { + eval.evaluate(batch, output); + } finally { + eval.close(); + releaseRecordBatch(batch); + releaseValueVectors(output); + } + } + + @Test + public void testCastFloat() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "0", + "111", + "12345.67" + }; + double[] expValues = + new double[] { + 2.3, -11.11, 0, 111, 12345.67 + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + Float8Vector float8Vector = (Float8Vector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(float8Vector.isNull(j)); + assertTrue(expValues[j] == float8Vector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test(expected = GandivaException.class) + public void testCastFloatInvalidValue() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "abc", + "111", + "12345.67" + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + try { + eval.evaluate(batch, output); + } finally { + eval.close(); + releaseRecordBatch(batch); + releaseValueVectors(output); + } + } } diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java index da93511b4f2..42dac7b8c60 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/Accountant.java @@ -140,7 +140,7 @@ private void updatePeak() { * @param size to increase * @return Whether the allocation fit within limits. */ - boolean forceAllocate(long size) { + public boolean forceAllocate(long size) { final AllocationOutcome.Status outcome = allocate(size, true, true, null); return outcome.isOk(); } @@ -220,7 +220,6 @@ public void releaseBytes(long size) { final long actualToReleaseToParent = Math.min(size, possibleAmountToReleaseToParent); parent.releaseBytes(actualToReleaseToParent); } - } public boolean isOverLimit() { diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/AllocationManager.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/AllocationManager.java index c61d041097e..9c7cfa9d90d 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/AllocationManager.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/AllocationManager.java @@ -47,11 +47,11 @@ public abstract class AllocationManager { private static final AtomicLong MANAGER_ID_GENERATOR = new AtomicLong(0); - private final RootAllocator root; + private final BufferAllocator root; private final long allocatorManagerId = MANAGER_ID_GENERATOR.incrementAndGet(); // ARROW-1627 Trying to minimize memory overhead caused by previously used IdentityHashMap // see JIRA for details - private final LowCostIdentityHashMap map = new LowCostIdentityHashMap<>(); + private final LowCostIdentityHashMap map = new LowCostIdentityHashMap<>(); private final long amCreationTime = System.nanoTime(); // The ReferenceManager created at the time of creation of this AllocationManager @@ -60,11 +60,11 @@ public abstract class AllocationManager { private volatile BufferLedger owningLedger; private volatile long amDestructionTime = 0; - protected AllocationManager(BaseAllocator accountingAllocator) { + protected AllocationManager(BufferAllocator accountingAllocator) { Preconditions.checkNotNull(accountingAllocator); accountingAllocator.assertOpen(); - this.root = accountingAllocator.root; + this.root = accountingAllocator.getRoot(); // we do a no retain association since our creator will want to retrieve the newly created // ledger and will create a reference count at that point @@ -87,13 +87,13 @@ void setOwningLedger(final BufferLedger ledger) { * @return The reference manager (new or existing) that associates the underlying * buffer to this new ledger. */ - BufferLedger associate(final BaseAllocator allocator) { + BufferLedger associate(final BufferAllocator allocator) { return associate(allocator, true); } - private BufferLedger associate(final BaseAllocator allocator, final boolean retain) { + private BufferLedger associate(final BufferAllocator allocator, final boolean retain) { allocator.assertOpen(); - Preconditions.checkState(root == allocator.root, + Preconditions.checkState(root == allocator.getRoot(), "A buffer can only be associated between two allocators that share the same root"); synchronized (this) { @@ -118,9 +118,11 @@ private BufferLedger associate(final BaseAllocator allocator, final boolean reta Preconditions.checkState(oldLedger == null, "Detected inconsistent state: A reference manager already exists for this allocator"); - // needed for debugging only: keep a pointer to reference manager inside allocator - // to dump state, verify allocator state etc - allocator.associateLedger(ledger); + if (allocator instanceof BaseAllocator) { + // needed for debugging only: keep a pointer to reference manager inside allocator + // to dump state, verify allocator state etc + ((BaseAllocator) allocator).associateLedger(ledger); + } return ledger; } } @@ -133,7 +135,7 @@ private BufferLedger associate(final BaseAllocator allocator, final boolean reta * calling ReferenceManager drops to 0. */ void release(final BufferLedger ledger) { - final BaseAllocator allocator = (BaseAllocator) ledger.getAllocator(); + final BufferAllocator allocator = ledger.getAllocator(); allocator.assertOpen(); // remove the mapping for the allocator @@ -142,9 +144,12 @@ void release(final BufferLedger ledger) { "Expecting a mapping for allocator and reference manager"); final BufferLedger oldLedger = map.remove(allocator); - // needed for debug only: tell the allocator that AllocationManager is removing a - // reference manager associated with this particular allocator - ((BaseAllocator) oldLedger.getAllocator()).dissociateLedger(oldLedger); + BufferAllocator oldAllocator = oldLedger.getAllocator(); + if (oldAllocator instanceof BaseAllocator) { + // needed for debug only: tell the allocator that AllocationManager is removing a + // reference manager associated with this particular allocator + ((BaseAllocator) oldAllocator).dissociateLedger(oldLedger); + } if (oldLedger == owningLedger) { // the release call was made by the owning reference manager @@ -152,10 +157,10 @@ void release(final BufferLedger ledger) { // the only mapping was for the owner // which now has been removed, it implies we can safely destroy the // underlying memory chunk as it is no longer being referenced - ((BaseAllocator) oldLedger.getAllocator()).releaseBytes(getSize()); + oldAllocator.releaseBytes(getSize()); // free the memory chunk associated with the allocation manager release0(); - ((BaseAllocator) oldLedger.getAllocator()).getListener().onRelease(getSize()); + oldAllocator.getListener().onRelease(getSize()); amDestructionTime = System.nanoTime(); owningLedger = null; } else { @@ -209,7 +214,7 @@ public interface Factory { * @param size Size (in bytes) of memory managed by the AllocationManager * @return The created AllocationManager used by this allocator */ - AllocationManager create(BaseAllocator accountingAllocator, long size); + AllocationManager create(BufferAllocator accountingAllocator, long size); ArrowBuf empty(); } diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java index 81f664985d5..246b2212e26 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BaseAllocator.java @@ -61,10 +61,10 @@ abstract class BaseAllocator extends Accountant implements BufferAllocator { public static final Config DEFAULT_CONFIG = ImmutableConfig.builder().build(); // Package exposed for sharing between AllocatorManger and BaseAllocator objects - final String name; - final RootAllocator root; + private final String name; + private final RootAllocator root; private final Object DEBUG_LOCK = DEBUG ? new Object() : null; - final AllocationListener listener; + private final AllocationListener listener; private final BaseAllocator parentAllocator; private final Map childAllocators; private final ArrowBuf empty; @@ -124,7 +124,8 @@ protected BaseAllocator( this.roundingPolicy = config.getRoundingPolicy(); } - AllocationListener getListener() { + @Override + public AllocationListener getListener() { return listener; } @@ -314,6 +315,11 @@ private AllocationManager newAllocationManager(BaseAllocator accountingAllocator return allocationManagerFactory.create(accountingAllocator, size); } + @Override + public BufferAllocator getRoot() { + return root; + } + @Override public BufferAllocator newChildAllocator( final String name, @@ -343,7 +349,7 @@ public BufferAllocator newChildAllocator( synchronized (DEBUG_LOCK) { childAllocators.put(childAllocator, childAllocator); historicalLog.recordEvent("allocator[%s] created new child allocator[%s]", name, - childAllocator.name); + childAllocator.getName()); } } else { childAllocators.put(childAllocator, childAllocator); diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferAllocator.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferAllocator.java index aa1f856c591..8fbf6f7b073 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferAllocator.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferAllocator.java @@ -49,6 +49,14 @@ public interface BufferAllocator extends AutoCloseable { */ ArrowBuf buffer(long size, BufferManager manager); + /** + * Get the root allocator of this allocator. If this allocator is already a root, return + * this directly. + * + * @return The root allocator + */ + BufferAllocator getRoot(); + /** * Create a new child allocator. * @@ -126,6 +134,30 @@ BufferAllocator newChildAllocator( */ long getHeadroom(); + /** + * Forcibly allocate bytes. Returns whether the allocation fit within limits. + * + * @param size to increase + * @return Whether the allocation fit within limits. + */ + boolean forceAllocate(long size); + + + /** + * Release bytes from this allocator. + * + * @param size to release + */ + void releaseBytes(long size); + + /** + * Returns the allocation listener used by this allocator. + * + * @return the {@link AllocationListener} instance. Or {@link AllocationListener#NOOP} by default if no listener + * is configured when this allocator was created. + */ + AllocationListener getListener(); + /** * Returns the parent allocator. * diff --git a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferLedger.java b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferLedger.java index 9fa4de71d8d..48b3e183d5a 100644 --- a/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferLedger.java +++ b/java/memory/memory-core/src/main/java/org/apache/arrow/memory/BufferLedger.java @@ -31,7 +31,7 @@ * ArrowBufs managed by this reference manager share a common * fate (same reference count). */ -public class BufferLedger implements ValueWithKeyIncluded, ReferenceManager { +public class BufferLedger implements ValueWithKeyIncluded, ReferenceManager { private final IdentityHashMap buffers = BaseAllocator.DEBUG ? new IdentityHashMap<>() : null; private static final AtomicLong LEDGER_ID_GENERATOR = new AtomicLong(0); @@ -41,14 +41,14 @@ public class BufferLedger implements ValueWithKeyIncluded, Refere // manage request for retain // correctly private final long lCreationTime = System.nanoTime(); - private final BaseAllocator allocator; + private final BufferAllocator allocator; private final AllocationManager allocationManager; private final HistoricalLog historicalLog = BaseAllocator.DEBUG ? new HistoricalLog(BaseAllocator.DEBUG_LOG_LENGTH, "BufferLedger[%d]", 1) : null; private volatile long lDestructionTime = 0; - BufferLedger(final BaseAllocator allocator, final AllocationManager allocationManager) { + BufferLedger(final BufferAllocator allocator, final AllocationManager allocationManager) { this.allocator = allocator; this.allocationManager = allocationManager; } @@ -57,7 +57,7 @@ boolean isOwningLedger() { return this == allocationManager.getOwningLedger(); } - public BaseAllocator getKey() { + public BufferAllocator getKey() { return allocator; } @@ -238,7 +238,7 @@ public ArrowBuf deriveBuffer(final ArrowBuf sourceBuffer, long index, long lengt "ArrowBuf(BufferLedger, BufferAllocator[%s], " + "UnsafeDirectLittleEndian[identityHashCode == " + "%d](%s)) => ledger hc == %d", - allocator.name, System.identityHashCode(derivedBuf), derivedBuf.toString(), + allocator.getName(), System.identityHashCode(derivedBuf), derivedBuf.toString(), System.identityHashCode(this)); synchronized (buffers) { @@ -275,7 +275,7 @@ ArrowBuf newArrowBuf(final long length, final BufferManager manager) { historicalLog.recordEvent( "ArrowBuf(BufferLedger, BufferAllocator[%s], " + "UnsafeDirectLittleEndian[identityHashCode == " + "%d](%s)) => ledger hc == %d", - allocator.name, System.identityHashCode(buf), buf.toString(), + allocator.getName(), System.identityHashCode(buf), buf.toString(), System.identityHashCode(this)); synchronized (buffers) { @@ -317,7 +317,7 @@ public ArrowBuf retain(final ArrowBuf srcBuffer, BufferAllocator target) { // alternatively, if there was already a mapping for in // allocation manager, the ref count of the new buffer will be targetrefmanager.refcount() + 1 // and this will be true for all the existing buffers currently managed by targetrefmanager - final BufferLedger targetRefManager = allocationManager.associate((BaseAllocator) target); + final BufferLedger targetRefManager = allocationManager.associate(target); // create a new ArrowBuf to associate with new allocator and target ref manager final long targetBufLength = srcBuffer.capacity(); ArrowBuf targetArrowBuf = targetRefManager.deriveBuffer(srcBuffer, 0, targetBufLength); @@ -336,8 +336,8 @@ public ArrowBuf retain(final ArrowBuf srcBuffer, BufferAllocator target) { boolean transferBalance(final ReferenceManager targetReferenceManager) { Preconditions.checkArgument(targetReferenceManager != null, "Expecting valid target reference manager"); - final BaseAllocator targetAllocator = (BaseAllocator) targetReferenceManager.getAllocator(); - Preconditions.checkArgument(allocator.root == targetAllocator.root, + final BufferAllocator targetAllocator = targetReferenceManager.getAllocator(); + Preconditions.checkArgument(allocator.getRoot() == targetAllocator.getRoot(), "You can only transfer between two allocators that share the same root."); allocator.assertOpen(); @@ -411,7 +411,7 @@ public TransferResult transferOwnership(final ArrowBuf srcBuffer, final BufferAl // alternatively, if there was already a mapping for in // allocation manager, the ref count of the new buffer will be targetrefmanager.refcount() + 1 // and this will be true for all the existing buffers currently managed by targetrefmanager - final BufferLedger targetRefManager = allocationManager.associate((BaseAllocator) target); + final BufferLedger targetRefManager = allocationManager.associate(target); // create a new ArrowBuf to associate with new allocator and target ref manager final long targetBufLength = srcBuffer.capacity(); final ArrowBuf targetArrowBuf = targetRefManager.deriveBuffer(srcBuffer, 0, targetBufLength); @@ -486,7 +486,7 @@ void print(StringBuilder sb, int indent, BaseAllocator.Verbosity verbosity) { .append("ledger[") .append(ledgerId) .append("] allocator: ") - .append(allocator.name) + .append(allocator.getName()) .append("), isOwning: ") .append(", size: ") .append(", references: ") diff --git a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java index e4553104715..bfe496532b1 100644 --- a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java +++ b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java @@ -34,7 +34,7 @@ public class DefaultAllocationManagerFactory implements AllocationManager.Factor MemoryUtil.UNSAFE.allocateMemory(0)); @Override - public AllocationManager create(BaseAllocator accountingAllocator, long size) { + public AllocationManager create(BufferAllocator accountingAllocator, long size) { return new AllocationManager(accountingAllocator) { private final long allocatedSize = size; private final long address = MemoryUtil.UNSAFE.allocateMemory(size); diff --git a/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java b/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java index 15651a38e4a..10cfb5c1648 100644 --- a/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java +++ b/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java @@ -26,7 +26,7 @@ public class DefaultAllocationManagerFactory implements AllocationManager.Factor public static final AllocationManager.Factory FACTORY = NettyAllocationManager.FACTORY; @Override - public AllocationManager create(BaseAllocator accountingAllocator, long size) { + public AllocationManager create(BufferAllocator accountingAllocator, long size) { return FACTORY.create(accountingAllocator, size); } diff --git a/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/NettyAllocationManager.java b/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/NettyAllocationManager.java index 45bd5d91347..20004778307 100644 --- a/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/NettyAllocationManager.java +++ b/java/memory/memory-netty/src/main/java/org/apache/arrow/memory/NettyAllocationManager.java @@ -30,7 +30,7 @@ public class NettyAllocationManager extends AllocationManager { public static final AllocationManager.Factory FACTORY = new AllocationManager.Factory() { @Override - public AllocationManager create(BaseAllocator accountingAllocator, long size) { + public AllocationManager create(BufferAllocator accountingAllocator, long size) { return new NettyAllocationManager(accountingAllocator, size); } @@ -65,7 +65,7 @@ public ArrowBuf empty() { */ private final int allocationCutOffValue; - NettyAllocationManager(BaseAllocator accountingAllocator, long requestedSize, int allocationCutOffValue) { + NettyAllocationManager(BufferAllocator accountingAllocator, long requestedSize, int allocationCutOffValue) { super(accountingAllocator); this.allocationCutOffValue = allocationCutOffValue; @@ -80,7 +80,7 @@ public ArrowBuf empty() { } } - NettyAllocationManager(BaseAllocator accountingAllocator, long requestedSize) { + NettyAllocationManager(BufferAllocator accountingAllocator, long requestedSize) { this(accountingAllocator, requestedSize, DEFAULT_ALLOCATION_CUTOFF_VALUE); } diff --git a/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestBaseAllocator.java b/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestBaseAllocator.java index a42e272a42e..ef49e41785f 100644 --- a/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestBaseAllocator.java +++ b/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestBaseAllocator.java @@ -393,7 +393,7 @@ private BaseAllocator createAllocatorWithCustomizedAllocationManager() { .maxAllocation(MAX_ALLOCATION) .allocationManagerFactory(new AllocationManager.Factory() { @Override - public AllocationManager create(BaseAllocator accountingAllocator, long requestedSize) { + public AllocationManager create(BufferAllocator accountingAllocator, long requestedSize) { return new AllocationManager(accountingAllocator) { private final Unsafe unsafe = getUnsafe(); private final long address = unsafe.allocateMemory(requestedSize); diff --git a/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestNettyAllocationManager.java b/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestNettyAllocationManager.java index f386ea66b2a..1b64cd73363 100644 --- a/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestNettyAllocationManager.java +++ b/java/memory/memory-netty/src/test/java/org/apache/arrow/memory/TestNettyAllocationManager.java @@ -35,7 +35,7 @@ private BaseAllocator createCustomizedAllocator() { return new RootAllocator(BaseAllocator.configBuilder() .allocationManagerFactory(new AllocationManager.Factory() { @Override - public AllocationManager create(BaseAllocator accountingAllocator, long size) { + public AllocationManager create(BufferAllocator accountingAllocator, long size) { return new NettyAllocationManager(accountingAllocator, size, CUSTOMIZED_ALLOCATION_CUTOFF_VALUE); } diff --git a/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java b/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java index 3963c1875d0..720c3d02d23 100644 --- a/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java +++ b/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/DefaultAllocationManagerFactory.java @@ -26,7 +26,7 @@ public class DefaultAllocationManagerFactory implements AllocationManager.Factor public static final AllocationManager.Factory FACTORY = UnsafeAllocationManager.FACTORY; @Override - public AllocationManager create(BaseAllocator accountingAllocator, long size) { + public AllocationManager create(BufferAllocator accountingAllocator, long size) { return FACTORY.create(accountingAllocator, size); } diff --git a/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/UnsafeAllocationManager.java b/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/UnsafeAllocationManager.java index f9756539c55..b10aba3598d 100644 --- a/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/UnsafeAllocationManager.java +++ b/java/memory/memory-unsafe/src/main/java/org/apache/arrow/memory/UnsafeAllocationManager.java @@ -32,7 +32,7 @@ public final class UnsafeAllocationManager extends AllocationManager { public static final AllocationManager.Factory FACTORY = new Factory() { @Override - public AllocationManager create(BaseAllocator accountingAllocator, long size) { + public AllocationManager create(BufferAllocator accountingAllocator, long size) { return new UnsafeAllocationManager(accountingAllocator, size); } @@ -46,7 +46,7 @@ public ArrowBuf empty() { private final long allocatedAddress; - UnsafeAllocationManager(BaseAllocator accountingAllocator, long requestedSize) { + UnsafeAllocationManager(BufferAllocator accountingAllocator, long requestedSize) { super(accountingAllocator); allocatedAddress = MemoryUtil.UNSAFE.allocateMemory(requestedSize); allocatedSize = requestedSize; diff --git a/java/vector/src/main/codegen/data/ValueVectorTypes.tdd b/java/vector/src/main/codegen/data/ValueVectorTypes.tdd index b9e052941ed..7612d3690b9 100644 --- a/java/vector/src/main/codegen/data/ValueVectorTypes.tdd +++ b/java/vector/src/main/codegen/data/ValueVectorTypes.tdd @@ -125,7 +125,7 @@ maxPrecisionDigits: 38, nDecimalDigits: 4, friendlyType: "BigDecimal", typeParams: [ {name: "scale", type: "int"}, { name: "precision", type: "int"}], arrowType: "org.apache.arrow.vector.types.pojo.ArrowType.Decimal", - fields: [{name: "start", type: "int"}, {name: "buffer", type: "ArrowBuf"}] + fields: [{name: "start", type: "long"}, {name: "buffer", type: "ArrowBuf"}] } ] }, diff --git a/java/vector/src/main/codegen/templates/ComplexWriters.java b/java/vector/src/main/codegen/templates/ComplexWriters.java index ab99ac38dcd..5f5025ff59e 100644 --- a/java/vector/src/main/codegen/templates/ComplexWriters.java +++ b/java/vector/src/main/codegen/templates/ComplexWriters.java @@ -139,12 +139,12 @@ public void write(NullableDecimalHolder h){ vector.setValueCount(idx() + 1); } - public void writeDecimal(int start, ArrowBuf buffer){ + public void writeDecimal(long start, ArrowBuf buffer){ vector.setSafe(idx(), 1, start, buffer); vector.setValueCount(idx() + 1); } - public void writeDecimal(int start, ArrowBuf buffer, ArrowType arrowType){ + public void writeDecimal(long start, ArrowBuf buffer, ArrowType arrowType){ DecimalUtility.checkPrecisionAndScale(((ArrowType.Decimal) arrowType).getPrecision(), ((ArrowType.Decimal) arrowType).getScale(), vector.getPrecision(), vector.getScale()); vector.setSafe(idx(), 1, start, buffer); diff --git a/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java b/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java index 0574dcf572d..94c7d8f6490 100644 --- a/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java @@ -189,7 +189,7 @@ public void writeNull() { writer.writeNull(); } - public void writeDecimal(int start, ArrowBuf buffer, ArrowType arrowType) { + public void writeDecimal(long start, ArrowBuf buffer, ArrowType arrowType) { if (writer.idx() >= (idx() + 1) * listSize) { throw new IllegalStateException(String.format("values at index %s is greater than listSize %s", idx(), listSize)); } diff --git a/java/vector/src/main/codegen/templates/UnionListWriter.java b/java/vector/src/main/codegen/templates/UnionListWriter.java index a2664436acc..bb0cff4e06c 100644 --- a/java/vector/src/main/codegen/templates/UnionListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionListWriter.java @@ -204,12 +204,12 @@ public void writeNull() { writer.writeNull(); } - public void writeDecimal(int start, ArrowBuf buffer, ArrowType arrowType) { + public void writeDecimal(long start, ArrowBuf buffer, ArrowType arrowType) { writer.writeDecimal(start, buffer, arrowType); writer.setPosition(writer.idx()+1); } - public void writeDecimal(int start, ArrowBuf buffer) { + public void writeDecimal(long start, ArrowBuf buffer) { writer.writeDecimal(start, buffer); writer.setPosition(writer.idx()+1); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index 554e174dc2b..04344c35e34 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -246,7 +246,7 @@ public void setBigEndian(int index, byte[] value) { * @param start start index of data in the buffer * @param buffer ArrowBuf containing decimal value. */ - public void set(int index, int start, ArrowBuf buffer) { + public void set(int index, long start, ArrowBuf buffer) { BitVectorHelper.setBit(validityBuffer, index); valueBuffer.setBytes((long) index * TYPE_WIDTH, buffer, start, TYPE_WIDTH); } @@ -258,7 +258,7 @@ public void set(int index, int start, ArrowBuf buffer) { * @param buffer contains the decimal in little endian bytes * @param length length of the value in the buffer */ - public void setSafe(int index, int start, ArrowBuf buffer, int length) { + public void setSafe(int index, long start, ArrowBuf buffer, int length) { handleSafe(index); BitVectorHelper.setBit(validityBuffer, index); @@ -285,7 +285,7 @@ public void setSafe(int index, int start, ArrowBuf buffer, int length) { * @param buffer contains the decimal in big endian bytes * @param length length of the value in the buffer */ - public void setBigEndianSafe(int index, int start, ArrowBuf buffer, int length) { + public void setBigEndianSafe(int index, long start, ArrowBuf buffer, int length) { handleSafe(index); BitVectorHelper.setBit(validityBuffer, index); @@ -394,7 +394,7 @@ public void setBigEndianSafe(int index, byte[] value) { * @param start start index of data in the buffer * @param buffer ArrowBuf containing decimal value. */ - public void setSafe(int index, int start, ArrowBuf buffer) { + public void setSafe(int index, long start, ArrowBuf buffer) { handleSafe(index); set(index, start, buffer); } @@ -460,7 +460,7 @@ public void setSafe(int index, DecimalHolder holder) { * @param start start position of the value in the buffer * @param buffer buffer containing the value to be stored in the vector */ - public void set(int index, int isSet, int start, ArrowBuf buffer) { + public void set(int index, int isSet, long start, ArrowBuf buffer) { if (isSet > 0) { set(index, start, buffer); } else { @@ -478,7 +478,7 @@ public void set(int index, int isSet, int start, ArrowBuf buffer) { * @param start start position of the value in the buffer * @param buffer buffer containing the value to be stored in the vector */ - public void setSafe(int index, int isSet, int start, ArrowBuf buffer) { + public void setSafe(int index, int isSet, long start, ArrowBuf buffer) { handleSafe(index); set(index, isSet, start, buffer); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index 6f40836e06b..51decee39fd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -320,7 +320,7 @@ public void write(DecimalHolder holder) { } @Override - public void writeDecimal(int start, ArrowBuf buffer, ArrowType arrowType) { + public void writeDecimal(long start, ArrowBuf buffer, ArrowType arrowType) { getWriter(MinorType.DECIMAL, new ArrowType.Decimal(MAX_DECIMAL_PRECISION, ((ArrowType.Decimal) arrowType).getScale())).writeDecimal(start, buffer, arrowType); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java index 711fa3b9cbf..36c988fac7e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -42,7 +42,7 @@ private DecimalUtility() {} public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, int scale) { byte[] value = new byte[DECIMAL_BYTE_LENGTH]; byte temp; - final int startIndex = index * DECIMAL_BYTE_LENGTH; + final long startIndex = (long) index * DECIMAL_BYTE_LENGTH; // Decimal stored as little endian, need to swap bytes to make BigDecimal bytebuf.getBytes(startIndex, value, 0, DECIMAL_BYTE_LENGTH); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java index 345fa592241..9592f3975ab 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java @@ -49,16 +49,13 @@ public static Field toMessageFormat(Field field, DictionaryProvider provider, Se return field; } DictionaryEncoding encoding = field.getDictionary(); - List children = field.getChildren(); + List children; - List updatedChildren = new ArrayList<>(children.size()); - for (Field child : children) { - updatedChildren.add(toMessageFormat(child, provider, dictionaryIdsUsed)); - } ArrowType type; if (encoding == null) { type = field.getType(); + children = field.getChildren(); } else { long id = encoding.getId(); Dictionary dictionary = provider.lookup(id); @@ -66,10 +63,16 @@ public static Field toMessageFormat(Field field, DictionaryProvider provider, Se throw new IllegalArgumentException("Could not find dictionary with ID " + id); } type = dictionary.getVectorType(); + children = dictionary.getVector().getField().getChildren(); dictionaryIdsUsed.add(id); } + final List updatedChildren = new ArrayList<>(children.size()); + for (Field child : children) { + updatedChildren.add(toMessageFormat(child, provider, dictionaryIdsUsed)); + } + return new Field(field.getName(), new FieldType(field.isNullable(), type, encoding, field.getMetadata()), updatedChildren); } @@ -115,8 +118,10 @@ public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map fieldChildren = null; if (encoding == null) { type = field.getType(); + fieldChildren = updatedChildren; } else { // re-type the field for in-memory format type = encoding.getIndexType(); @@ -127,13 +132,14 @@ public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map Integer.MAX_VALUE); + largeVec.set(0, holder); + + BigDecimal decimal = largeVec.getObject(0); + assertEquals(12345L, decimal.longValue()); + + logger.trace("Successfully setting values from large offsets"); } logger.trace("Successfully released the large vector."); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java index 15a19ed62d6..bddf8b86353 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java @@ -21,10 +21,12 @@ import static java.util.Arrays.asList; import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; import static org.apache.arrow.vector.TestUtils.newVarCharVector; +import static org.apache.arrow.vector.TestUtils.newVector; import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; @@ -41,6 +43,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import java.util.stream.Collectors; import org.apache.arrow.flatbuf.FieldNode; @@ -55,11 +58,16 @@ import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.NullVector; import org.apache.arrow.vector.TestUtils; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.compare.Range; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; +import org.apache.arrow.vector.compare.TypeEqualsVisitor; import org.apache.arrow.vector.compare.VectorEqualsVisitor; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -69,6 +77,7 @@ import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; @@ -87,10 +96,12 @@ public class TestArrowReaderWriter { private VarCharVector dictionaryVector1; private VarCharVector dictionaryVector2; private VarCharVector dictionaryVector3; + private StructVector dictionaryVector4; private Dictionary dictionary1; private Dictionary dictionary2; private Dictionary dictionary3; + private Dictionary dictionary4; private Schema schema; private Schema encodedSchema; @@ -119,6 +130,12 @@ public void init() { "aa".getBytes(StandardCharsets.UTF_8), "bb".getBytes(StandardCharsets.UTF_8), "cc".getBytes(StandardCharsets.UTF_8)); + + dictionaryVector4 = newVector(StructVector.class, "D4", MinorType.STRUCT, allocator); + final Map> dictionaryValues4 = new HashMap<>(); + dictionaryValues4.put("a", Arrays.asList(1, 2, 3)); + dictionaryValues4.put("b", Arrays.asList(4, 5, 6)); + setVector(dictionaryVector4, dictionaryValues4); dictionary1 = new Dictionary(dictionaryVector1, new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null)); @@ -126,6 +143,8 @@ public void init() { new DictionaryEncoding(/*id=*/2L, /*ordered=*/false, /*indexType=*/null)); dictionary3 = new Dictionary(dictionaryVector3, new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null)); + dictionary4 = new Dictionary(dictionaryVector4, + new DictionaryEncoding(/*id=*/3L, /*ordered=*/false, /*indexType=*/null)); } @After @@ -133,6 +152,7 @@ public void terminate() throws Exception { dictionaryVector1.close(); dictionaryVector2.close(); dictionaryVector3.close(); + dictionaryVector4.close(); allocator.close(); } @@ -305,6 +325,82 @@ public void testWriteReadWithDictionaries() throws IOException { } } + @Test + public void testWriteReadWithStructDictionaries() throws IOException { + DictionaryProvider.MapDictionaryProvider provider = + new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary4); + + try (final StructVector vector = + newVector(StructVector.class, "D4", MinorType.STRUCT, allocator)) { + final Map> values = new HashMap<>(); + // Index: 0, 2, 1, 2, 1, 0, 0 + values.put("a", Arrays.asList(1, 3, 2, 3, 2, 1, 1)); + values.put("b", Arrays.asList(4, 6, 5, 6, 5, 4, 4)); + setVector(vector, values); + FieldVector encodedVector = (FieldVector) DictionaryEncoder.encode(vector, dictionary4); + + List fields = Arrays.asList(encodedVector.getField()); + List vectors = Collections2.asImmutableList(encodedVector); + try ( + VectorSchemaRoot root = + new VectorSchemaRoot(fields, vectors, encodedVector.getValueCount()); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ArrowFileWriter writer = new ArrowFileWriter(root, provider, newChannel(out));) { + + writer.start(); + writer.writeBatch(); + writer.end(); + + try ( + SeekableReadChannel channel = new SeekableReadChannel( + new ByteArrayReadableSeekableByteChannel(out.toByteArray())); + ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { + final VectorSchemaRoot readRoot = reader.getVectorSchemaRoot(); + final Schema readSchema = readRoot.getSchema(); + assertEquals(root.getSchema(), readSchema); + assertEquals(1, reader.getDictionaryBlocks().size()); + assertEquals(1, reader.getRecordBlocks().size()); + + reader.loadNextBatch(); + assertEquals(1, readRoot.getFieldVectors().size()); + assertEquals(1, reader.getDictionaryVectors().size()); + + // Read the encoded vector and check it + final FieldVector readEncoded = readRoot.getVector(0); + assertEquals(encodedVector.getValueCount(), readEncoded.getValueCount()); + assertTrue(new RangeEqualsVisitor(encodedVector, readEncoded) + .rangeEquals(new Range(0, 0, encodedVector.getValueCount()))); + + // Read the dictionary + final Map readDictionaryMap = reader.getDictionaryVectors(); + final Dictionary readDictionary = + readDictionaryMap.get(readEncoded.getField().getDictionary().getId()); + assertNotNull(readDictionary); + + // Assert the dictionary vector is correct + final FieldVector readDictionaryVector = readDictionary.getVector(); + assertEquals(dictionaryVector4.getValueCount(), readDictionaryVector.getValueCount()); + final BiFunction typeComparatorIgnoreName = + (v1, v2) -> new TypeEqualsVisitor(v1, false, true).equals(v2); + assertTrue("Dictionary vectors are not equal", + new RangeEqualsVisitor(dictionaryVector4, readDictionaryVector, + typeComparatorIgnoreName) + .rangeEquals(new Range(0, 0, dictionaryVector4.getValueCount()))); + + // Assert the decoded vector is correct + try (final ValueVector readVector = + DictionaryEncoder.decode(readEncoded, readDictionary)) { + assertEquals(vector.getValueCount(), readVector.getValueCount()); + assertTrue("Decoded vectors are not equal", + new RangeEqualsVisitor(vector, readVector, typeComparatorIgnoreName) + .rangeEquals(new Range(0, 0, vector.getValueCount()))); + } + } + } + } + } + @Test public void testEmptyStreamInFileIPC() throws IOException { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java index 3d389d86515..15d6a5cf993 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java @@ -21,6 +21,8 @@ import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; @@ -60,8 +62,10 @@ import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.holders.IntervalDayHolder; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.FieldType; /** @@ -673,4 +677,32 @@ public static void setVector(FixedSizeListVector vector, List... values dataVector.setValueCount(curPos); vector.setValueCount(values.length); } + + /** + * Populate values for {@link StructVector}. + */ + public static void setVector(StructVector vector, Map> values) { + vector.allocateNewSafe(); + + int valueCount = 0; + for (final Entry> entry : values.entrySet()) { + // Add the child + final IntVector child = vector.addOrGet(entry.getKey(), + FieldType.nullable(MinorType.INT.getType()), IntVector.class); + + // Write the values to the child + child.allocateNew(); + final List v = entry.getValue(); + for (int i = 0; i < v.size(); i++) { + if (v.get(i) != null) { + child.set(i, v.get(i)); + vector.setIndexDefined(i); + } else { + child.setNull(i); + } + } + valueCount = Math.max(valueCount, v.size()); + } + vector.setValueCount(valueCount); + } } diff --git a/js/src/data.ts b/js/src/data.ts index 59d16b74b7b..47f644c0a4e 100644 --- a/js/src/data.ts +++ b/js/src/data.ts @@ -263,11 +263,11 @@ export class Data { } /** @nocollapse */ public static List(type: T, offset: number, length: number, nullCount: number, nullBitmap: NullBuffer, valueOffsets: ValueOffsetsBuffer, child: Data | Vector) { - return new Data(type, offset, length, nullCount, [toInt32Array(valueOffsets), undefined, toUint8Array(nullBitmap)], [child]); + return new Data(type, offset, length, nullCount, [toInt32Array(valueOffsets), undefined, toUint8Array(nullBitmap)], child ? [child] : []); } /** @nocollapse */ public static FixedSizeList(type: T, offset: number, length: number, nullCount: number, nullBitmap: NullBuffer, child: Data | Vector) { - return new Data(type, offset, length, nullCount, [undefined, undefined, toUint8Array(nullBitmap)], [child]); + return new Data(type, offset, length, nullCount, [undefined, undefined, toUint8Array(nullBitmap)], child ? [child] : []); } /** @nocollapse */ public static Struct(type: T, offset: number, length: number, nullCount: number, nullBitmap: NullBuffer, children: (Data | Vector)[]) { @@ -275,7 +275,7 @@ export class Data { } /** @nocollapse */ public static Map(type: T, offset: number, length: number, nullCount: number, nullBitmap: NullBuffer, valueOffsets: ValueOffsetsBuffer, child: (Data | Vector)) { - return new Data(type, offset, length, nullCount, [toInt32Array(valueOffsets), undefined, toUint8Array(nullBitmap)], [child]); + return new Data(type, offset, length, nullCount, [toInt32Array(valueOffsets), undefined, toUint8Array(nullBitmap)], child ? [child] : []); } public static Union(type: T, offset: number, length: number, nullCount: number, nullBitmap: NullBuffer, typeIds: TypeIdsBuffer, children: (Data | Vector)[], _?: any): Data; public static Union(type: T, offset: number, length: number, nullCount: number, nullBitmap: NullBuffer, typeIds: TypeIdsBuffer, valueOffsets: ValueOffsetsBuffer, children: (Data | Vector)[]): Data; diff --git a/r/tests/testthat/test-csv.R b/r/tests/testthat/test-csv.R index 94dd10b62b2..3de70b35471 100644 --- a/r/tests/testthat/test-csv.R +++ b/r/tests/testthat/test-csv.R @@ -212,11 +212,11 @@ test_that("read_csv_arrow() can read timestamps", { tf <- tempfile(); on.exit(unlink(tf)) write.csv(tbl, tf, row.names = FALSE) - df <- read_csv_arrow(tf, col_types = schema(time = timestamp())) + df <- read_csv_arrow(tf, col_types = schema(time = timestamp(timezone = "UTC"))) expect_equal(tbl, df) df <- read_csv_arrow(tf, col_types = "t", col_names = "time", skip = 1) - expect_equal(tbl, df) + expect_equal(tbl, df, check.tzone = FALSE) # col_types = "t" makes timezone-naive timestamp }) test_that("read_csv_arrow(timestamp_parsers=)", { diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs index f1e32c57d98..ec31686599f 100644 --- a/rust/arrow/src/array/data.rs +++ b/rust/arrow/src/array/data.rs @@ -29,7 +29,7 @@ use crate::util::bit_util; /// An generic representation of Arrow array data which encapsulates common attributes and /// operations for Arrow array. Specific operations for different arrays types (e.g., /// primitive, list, struct) are implemented in `Array`. -#[derive(PartialEq, Debug, Clone)] +#[derive(Debug, Clone)] pub struct ArrayData { /// The data type for this array data data_type: DataType, @@ -202,6 +202,61 @@ impl ArrayData { } } +impl PartialEq for ArrayData { + fn eq(&self, other: &Self) -> bool { + assert_eq!( + self.data_type(), + other.data_type(), + "Data types not the same" + ); + assert_eq!(self.len(), other.len(), "Lengths not the same"); + // TODO: when adding tests for this, test that we can compare with arrays that have offsets + assert_eq!(self.offset(), other.offset(), "Offsets not the same"); + assert_eq!(self.null_count(), other.null_count()); + // compare buffers excluding padding + let self_buffers = self.buffers(); + let other_buffers = other.buffers(); + assert_eq!(self_buffers.len(), other_buffers.len()); + self_buffers.iter().zip(other_buffers).for_each(|(s, o)| { + compare_buffer_regions( + s, + self.offset(), // TODO mul by data length + o, + other.offset(), // TODO mul by data len + ); + }); + // assert_eq!(self.buffers(), other.buffers()); + + assert_eq!(self.child_data(), other.child_data()); + // null arrays can skip the null bitmap, thus only compare if there are no nulls + if self.null_count() != 0 || other.null_count() != 0 { + compare_buffer_regions( + self.null_buffer().unwrap(), + self.offset(), + other.null_buffer().unwrap(), + other.offset(), + ) + } + true + } +} + +/// A helper to compare buffer regions of 2 buffers. +/// Compares the length of the shorter buffer. +fn compare_buffer_regions( + left: &Buffer, + left_offset: usize, + right: &Buffer, + right_offset: usize, +) { + // for convenience, we assume that the buffer lengths are only unequal if one has padding, + // so we take the shorter length so we can discard the padding from the longer length + let shorter_len = left.len().min(right.len()); + let s_sliced = left.bit_slice(left_offset, shorter_len); + let o_sliced = right.bit_slice(right_offset, shorter_len); + assert_eq!(s_sliced, o_sliced); +} + /// Builder for `ArrayData` type #[derive(Debug)] pub struct ArrayDataBuilder { diff --git a/rust/arrow/src/array/null.rs b/rust/arrow/src/array/null.rs index 190d2fa78fc..08c7cf1f21e 100644 --- a/rust/arrow/src/array/null.rs +++ b/rust/arrow/src/array/null.rs @@ -113,7 +113,7 @@ impl From for NullArray { impl fmt::Debug for NullArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "NullArray") + write!(f, "NullArray({})", self.len()) } } @@ -146,4 +146,10 @@ mod tests { assert_eq!(array2.null_count(), 16); assert_eq!(array2.offset(), 8); } + + #[test] + fn test_debug_null_array() { + let array = NullArray::new(1024 * 1024); + assert_eq!(format!("{:?}", array), "NullArray(1048576)"); + } } diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index 08c6a2b3042..7d04ba36c72 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -44,6 +44,167 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::{array::*, compute::take}; +/// Return true if a value of type `from_type` can be cast into a +/// value of `to_type`. Note that such as cast may be lossy. +/// +/// If this function returns true to stay consistent with the `cast` kernel below. +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + (Struct(_), _) => false, + (_, Struct(_)) => false, + (List(list_from), List(list_to)) => can_cast_types(list_from, list_to), + (List(_), _) => false, + (_, List(list_to)) => can_cast_types(from_type, list_to), + (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { + can_cast_types(from_value_type, to_value_type) + } + (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), + + (_, Boolean) => DataType::is_numeric(from_type), + (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8, + (Utf8, _) => DataType::is_numeric(to_type), + (_, Utf8) => DataType::is_numeric(from_type) || from_type == &Binary, + + // start numeric casts + (UInt8, UInt16) => true, + (UInt8, UInt32) => true, + (UInt8, UInt64) => true, + (UInt8, Int8) => true, + (UInt8, Int16) => true, + (UInt8, Int32) => true, + (UInt8, Int64) => true, + (UInt8, Float32) => true, + (UInt8, Float64) => true, + + (UInt16, UInt8) => true, + (UInt16, UInt32) => true, + (UInt16, UInt64) => true, + (UInt16, Int8) => true, + (UInt16, Int16) => true, + (UInt16, Int32) => true, + (UInt16, Int64) => true, + (UInt16, Float32) => true, + (UInt16, Float64) => true, + + (UInt32, UInt8) => true, + (UInt32, UInt16) => true, + (UInt32, UInt64) => true, + (UInt32, Int8) => true, + (UInt32, Int16) => true, + (UInt32, Int32) => true, + (UInt32, Int64) => true, + (UInt32, Float32) => true, + (UInt32, Float64) => true, + + (UInt64, UInt8) => true, + (UInt64, UInt16) => true, + (UInt64, UInt32) => true, + (UInt64, Int8) => true, + (UInt64, Int16) => true, + (UInt64, Int32) => true, + (UInt64, Int64) => true, + (UInt64, Float32) => true, + (UInt64, Float64) => true, + + (Int8, UInt8) => true, + (Int8, UInt16) => true, + (Int8, UInt32) => true, + (Int8, UInt64) => true, + (Int8, Int16) => true, + (Int8, Int32) => true, + (Int8, Int64) => true, + (Int8, Float32) => true, + (Int8, Float64) => true, + + (Int16, UInt8) => true, + (Int16, UInt16) => true, + (Int16, UInt32) => true, + (Int16, UInt64) => true, + (Int16, Int8) => true, + (Int16, Int32) => true, + (Int16, Int64) => true, + (Int16, Float32) => true, + (Int16, Float64) => true, + + (Int32, UInt8) => true, + (Int32, UInt16) => true, + (Int32, UInt32) => true, + (Int32, UInt64) => true, + (Int32, Int8) => true, + (Int32, Int16) => true, + (Int32, Int64) => true, + (Int32, Float32) => true, + (Int32, Float64) => true, + + (Int64, UInt8) => true, + (Int64, UInt16) => true, + (Int64, UInt32) => true, + (Int64, UInt64) => true, + (Int64, Int8) => true, + (Int64, Int16) => true, + (Int64, Int32) => true, + (Int64, Float32) => true, + (Int64, Float64) => true, + + (Float32, UInt8) => true, + (Float32, UInt16) => true, + (Float32, UInt32) => true, + (Float32, UInt64) => true, + (Float32, Int8) => true, + (Float32, Int16) => true, + (Float32, Int32) => true, + (Float32, Int64) => true, + (Float32, Float64) => true, + + (Float64, UInt8) => true, + (Float64, UInt16) => true, + (Float64, UInt32) => true, + (Float64, UInt64) => true, + (Float64, Int8) => true, + (Float64, Int16) => true, + (Float64, Int32) => true, + (Float64, Int64) => true, + (Float64, Float32) => true, + // end numeric casts + + // temporal casts + (Int32, Date32(_)) => true, + (Int32, Time32(_)) => true, + (Date32(_), Int32) => true, + (Time32(_), Int32) => true, + (Int64, Date64(_)) => true, + (Int64, Time64(_)) => true, + (Date64(_), Int64) => true, + (Time64(_), Int64) => true, + (Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => true, + (Date64(DateUnit::Millisecond), Date32(DateUnit::Day)) => true, + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true, + (Time32(_), Time64(_)) => true, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true, + (Time64(_), Time32(to_unit)) => match to_unit { + TimeUnit::Second => true, + TimeUnit::Millisecond => true, + _ => false, + }, + (Timestamp(_, _), Int64) => true, + (Int64, Timestamp(_, _)) => true, + (Timestamp(_, _), Timestamp(_, _)) => true, + (Timestamp(_, _), Date32(_)) => true, + (Timestamp(_, _), Date64(_)) => true, + // date64 to timestamp might not make sense, + (Null, Int32) => true, + (_, _) => false, + } +} + /// Cast `array` to the provided data type and return a new Array with /// type `to_type`, if possible. /// @@ -356,11 +517,24 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { // temporal casts (Int32, Date32(_)) => cast_array_data::(array, to_type.clone()), - (Int32, Time32(_)) => cast_array_data::(array, to_type.clone()), + (Int32, Time32(TimeUnit::Second)) => { + cast_array_data::(array, to_type.clone()) + } + (Int32, Time32(TimeUnit::Millisecond)) => { + cast_array_data::(array, to_type.clone()) + } + // No support for microsecond/nanosecond with i32 (Date32(_), Int32) => cast_array_data::(array, to_type.clone()), (Time32(_), Int32) => cast_array_data::(array, to_type.clone()), (Int64, Date64(_)) => cast_array_data::(array, to_type.clone()), - (Int64, Time64(_)) => cast_array_data::(array, to_type.clone()), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => { + cast_array_data::(array, to_type.clone()) + } + (Int64, Time64(TimeUnit::Nanosecond)) => { + cast_array_data::(array, to_type.clone()) + } + (Date64(_), Int64) => cast_array_data::(array, to_type.clone()), (Time64(_), Int64) => cast_array_data::(array, to_type.clone()), (Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => { @@ -549,19 +723,36 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { (Timestamp(from_unit, _), Date64(_)) => { let from_size = time_unit_multiple(&from_unit); let to_size = MILLISECONDS; - if from_size != to_size { - let time_array = Date64Array::from(array.data()); - Ok(Arc::new(divide( - &time_array, - &Date64Array::from(vec![from_size / to_size; array.len()]), - )?) as ArrayRef) - } else { - cast_array_data::(array, to_type.clone()) + + // Scale time_array by (to_size / from_size) using a + // single integer operation, but need to avoid integer + // math rounding down to zero + + match to_size.cmp(&from_size) { + std::cmp::Ordering::Less => { + let time_array = Date64Array::from(array.data()); + Ok(Arc::new(divide( + &time_array, + &Date64Array::from(vec![from_size / to_size; array.len()]), + )?) as ArrayRef) + } + std::cmp::Ordering::Equal => { + cast_array_data::(array, to_type.clone()) + } + std::cmp::Ordering::Greater => { + let time_array = Date64Array::from(array.data()); + Ok(Arc::new(multiply( + &time_array, + &Date64Array::from(vec![to_size / from_size; array.len()]), + )?) as ArrayRef) + } } } // date64 to timestamp might not make sense, - // end temporal casts + // null to primitive/flat types + (Null, Int32) => Ok(Arc::new(Int32Array::from(vec![None; array.len()]))), + (_, _) => Err(ArrowError::ComputeError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -2290,44 +2481,44 @@ mod tests { // Test casting TO StringArray let cast_type = Utf8; - let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); // Test casting TO Dictionary (with different index sizes) let cast_type = Dictionary(Box::new(Int16), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(Int32), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(Int64), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt16), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt32), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt64), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); } @@ -2412,11 +2603,11 @@ mod tests { let expected = vec!["1", "null", "3"]; // Test casting TO PrimitiveArray, different dictionary type - let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 succeeded"); + let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 failed"); assert_eq!(array_to_strings(&cast_array), expected); assert_eq!(cast_array.data_type(), &Utf8); - let cast_array = cast(&array, &Int64).expect("cast to int64 succeeded"); + let cast_array = cast(&array, &Int64).expect("cast to int64 failed"); assert_eq!(array_to_strings(&cast_array), expected); assert_eq!(cast_array.data_type(), &Int64); } @@ -2435,13 +2626,13 @@ mod tests { // Cast to a dictionary (same value type, Int32) let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); // Cast to a dictionary (different value type, Int8) let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); } @@ -2460,11 +2651,25 @@ mod tests { // Cast to a dictionary (same value type, Utf8) let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); } + #[test] + fn test_cast_null_array_to_int32() { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + + let expected = Int32Array::from(vec![None; 6]); + + // Cast to a dictionary (same value type, Utf8) + let cast_type = DataType::Int32; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = as_primitive_array::(&cast_array); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + /// Print the `DictionaryArray` `array` as a vector of strings fn array_to_strings(array: &ArrayRef) -> Vec { (0..array.len()) @@ -2477,4 +2682,290 @@ mod tests { }) .collect() } + + #[test] + fn test_can_cast_types() { + // this function attempts to ensure that can_cast_types stays + // in sync with cast. It simply tries all combinations of + // types and makes sure that if `can_cast_types` returns + // true, so does `cast` + + let all_types = get_all_types(); + + for array in get_arrays_of_all_types() { + for to_type in &all_types { + println!("Test casting {:?} --> {:?}", array.data_type(), to_type); + let cast_result = cast(&array, &to_type); + let reported_cast_ability = can_cast_types(array.data_type(), to_type); + + // check for mismatch + match (cast_result, reported_cast_ability) { + (Ok(_), false) => { + panic!("Was able to cast array from {:?} to {:?} but can_cast_types reported false", + array.data_type(), to_type) + }, + (Err(e), true) => { + panic!("Was not able to cast array from {:?} to {:?} but can_cast_types reported true. \ + Error was {:?}", + array.data_type(), to_type, e) + }, + // otherwise it was a match + _=> {}, + }; + } + } + } + + /// Create instances of arrays with varying types for cast tests + fn get_arrays_of_all_types() -> Vec { + let tz_name = Arc::new(String::from("America/New_York")); + let binary_data: Vec<&[u8]> = vec![b"foo", b"bar"]; + vec![ + Arc::new(BinaryArray::from(binary_data.clone())), + Arc::new(LargeBinaryArray::from(binary_data.clone())), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + Arc::new(make_list_array()), + Arc::new(make_large_list_array()), + Arc::new(make_fixed_size_list_array()), + Arc::new(make_fixed_size_binary_array()), + Arc::new(StructArray::from(vec![ + ( + Field::new("a", DataType::Boolean, false), + Arc::new(BooleanArray::from(vec![false, false, true, true])) + as Arc, + ), + ( + Field::new("b", DataType::Int32, false), + Arc::new(Int32Array::from(vec![42, 28, 19, 31])), + ), + ])), + //Arc::new(make_union_array()), + Arc::new(NullArray::new(10)), + Arc::new(StringArray::from(vec!["foo", "bar"])), + Arc::new(LargeStringArray::from(vec!["foo", "bar"])), + Arc::new(BooleanArray::from(vec![true, false])), + Arc::new(Int8Array::from(vec![1, 2])), + Arc::new(Int16Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(UInt8Array::from(vec![1, 2])), + Arc::new(UInt16Array::from(vec![1, 2])), + Arc::new(UInt32Array::from(vec![1, 2])), + Arc::new(UInt64Array::from(vec![1, 2])), + Arc::new(Float32Array::from(vec![1.0, 2.0])), + Arc::new(Float64Array::from(vec![1.0, 2.0])), + Arc::new(TimestampSecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampMillisecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampMicrosecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampNanosecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampSecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampMillisecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampMicrosecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampNanosecondArray::from_vec( + vec![1000, 2000], + Some(tz_name), + )), + Arc::new(Date32Array::from(vec![1000, 2000])), + Arc::new(Date64Array::from(vec![1000, 2000])), + Arc::new(Time32SecondArray::from(vec![1000, 2000])), + Arc::new(Time32MillisecondArray::from(vec![1000, 2000])), + Arc::new(Time64MicrosecondArray::from(vec![1000, 2000])), + Arc::new(Time64NanosecondArray::from(vec![1000, 2000])), + Arc::new(IntervalYearMonthArray::from(vec![1000, 2000])), + Arc::new(IntervalDayTimeArray::from(vec![1000, 2000])), + Arc::new(DurationSecondArray::from(vec![1000, 2000])), + Arc::new(DurationMillisecondArray::from(vec![1000, 2000])), + Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), + Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), + ] + } + + fn make_list_array() -> ListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + + // Construct a list array from the above two + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build(); + ListArray::from(list_data) + } + + fn make_large_list_array() -> LargeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from(&[0i64, 3, 6, 8].to_byte_slice()); + + // Construct a list array from the above two + let list_data_type = DataType::LargeList(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build(); + LargeListArray::from(list_data) + } + + fn make_fixed_size_list_array() -> FixedSizeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from( + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9].to_byte_slice(), + )) + .build(); + + // Construct a fixed size list array from the above two + let list_data_type = DataType::FixedSizeList(Box::new(DataType::Int32), 2); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data.clone()) + .build(); + FixedSizeListArray::from(list_data) + } + + fn make_fixed_size_binary_array() -> FixedSizeBinaryArray { + let values: [u8; 15] = *b"hellotherearrow"; + + let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) + .len(3) + .add_buffer(Buffer::from(&values[..])) + .build(); + FixedSizeBinaryArray::from(array_data) + } + + fn make_union_array() -> UnionArray { + let mut builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", false).unwrap(); + builder.build().unwrap() + } + + /// Creates a dictionary with primitive dictionary values, and keys of type K + fn make_dictionary_primitive() -> ArrayRef { + let keys_builder = PrimitiveBuilder::::new(2); + // Pick Int32 arbitrarily for dictionary values + let values_builder = PrimitiveBuilder::::new(2); + let mut b = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + b.append(1).unwrap(); + b.append(2).unwrap(); + Arc::new(b.finish()) + } + + /// Creates a dictionary with utf8 values, and keys of type K + fn make_dictionary_utf8() -> ArrayRef { + let keys_builder = PrimitiveBuilder::::new(2); + // Pick Int32 arbitrarily for dictionary values + let values_builder = StringBuilder::new(2); + let mut b = StringDictionaryBuilder::new(keys_builder, values_builder); + b.append("foo").unwrap(); + b.append("bar").unwrap(); + Arc::new(b.finish()) + } + + // Get a selection of datatypes to try and cast to + fn get_all_types() -> Vec { + use DataType::*; + let tz_name = Arc::new(String::from("America/New_York")); + + vec![ + Null, + Boolean, + Int8, + Int16, + Int32, + UInt64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + Timestamp(TimeUnit::Second, None), + Timestamp(TimeUnit::Millisecond, None), + Timestamp(TimeUnit::Microsecond, None), + Timestamp(TimeUnit::Nanosecond, None), + Timestamp(TimeUnit::Second, Some(tz_name.clone())), + Timestamp(TimeUnit::Millisecond, Some(tz_name.clone())), + Timestamp(TimeUnit::Microsecond, Some(tz_name.clone())), + Timestamp(TimeUnit::Nanosecond, Some(tz_name.clone())), + Date32(DateUnit::Day), + Date64(DateUnit::Day), + Date32(DateUnit::Millisecond), + Date64(DateUnit::Millisecond), + Time32(TimeUnit::Second), + Time32(TimeUnit::Millisecond), + Time64(TimeUnit::Microsecond), + Time64(TimeUnit::Nanosecond), + Duration(TimeUnit::Second), + Duration(TimeUnit::Millisecond), + Duration(TimeUnit::Microsecond), + Duration(TimeUnit::Nanosecond), + Interval(IntervalUnit::YearMonth), + Interval(IntervalUnit::DayTime), + Binary, + FixedSizeBinary(10), + LargeBinary, + Utf8, + LargeUtf8, + List(Box::new(DataType::Int8)), + List(Box::new(DataType::Utf8)), + FixedSizeList(Box::new(DataType::Int8), 10), + FixedSizeList(Box::new(DataType::Utf8), 10), + LargeList(Box::new(DataType::Int8)), + LargeList(Box::new(DataType::Utf8)), + Struct(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, true), + ]), + Union(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, true), + ]), + Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), + Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + ] + } } diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs index 0d05f826d37..0c30c625b8d 100644 --- a/rust/arrow/src/datatypes.rs +++ b/rust/arrow/src/datatypes.rs @@ -189,8 +189,8 @@ pub struct Field { name: String, data_type: DataType, nullable: bool, - dict_id: i64, - dict_is_ordered: bool, + pub(crate) dict_id: i64, + pub(crate) dict_is_ordered: bool, } pub trait ArrowNativeType: @@ -1129,6 +1129,16 @@ impl DataType { DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), } } + + /// Returns true if this type is numeric: (UInt*, Unit*, or Float*) + pub fn is_numeric(t: &DataType) -> bool { + use DataType::*; + match t { + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 + | Float64 => true, + _ => false, + } + } } impl Field { diff --git a/rust/arrow/src/ipc/convert.rs b/rust/arrow/src/ipc/convert.rs index 7a5795de91c..63d55f043c6 100644 --- a/rust/arrow/src/ipc/convert.rs +++ b/rust/arrow/src/ipc/convert.rs @@ -34,18 +34,8 @@ pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder { let mut fields = vec![]; for field in schema.fields() { - let fb_field_name = fbb.create_string(field.name().as_str()); - let field_type = get_fb_field_type(field.data_type(), &mut fbb); - let mut field_builder = ipc::FieldBuilder::new(&mut fbb); - field_builder.add_name(fb_field_name); - field_builder.add_type_type(field_type.type_type); - field_builder.add_nullable(field.is_nullable()); - match field_type.children { - None => {} - Some(children) => field_builder.add_children(children), - }; - field_builder.add_type_(field_type.type_); - fields.push(field_builder.finish()); + let fb_field = build_field(&mut fbb, field); + fields.push(fb_field); } let mut custom_metadata = vec![]; @@ -80,18 +70,8 @@ pub fn schema_to_fb_offset<'a: 'b, 'b>( ) -> WIPOffset> { let mut fields = vec![]; for field in schema.fields() { - let fb_field_name = fbb.create_string(field.name().as_str()); - let field_type = get_fb_field_type(field.data_type(), fbb); - let mut field_builder = ipc::FieldBuilder::new(fbb); - field_builder.add_name(fb_field_name); - field_builder.add_type_type(field_type.type_type); - field_builder.add_nullable(field.is_nullable()); - match field_type.children { - None => {} - Some(children) => field_builder.add_children(children), - }; - field_builder.add_type_(field_type.type_); - fields.push(field_builder.finish()); + let fb_field = build_field(fbb, field); + fields.push(fb_field); } let mut custom_metadata = vec![]; @@ -333,6 +313,40 @@ pub(crate) struct FBFieldType<'b> { pub(crate) children: Option>>>>, } +/// Create an IPC Field from an Arrow Field +pub(crate) fn build_field<'a: 'b, 'b>( + fbb: &mut FlatBufferBuilder<'a>, + field: &Field, +) -> WIPOffset> { + let fb_field_name = fbb.create_string(field.name().as_str()); + let field_type = get_fb_field_type(field.data_type(), fbb); + + let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() { + Some(get_fb_dictionary( + index_type, + field.dict_id, + field.dict_is_ordered, + fbb, + )) + } else { + None + }; + + let mut field_builder = ipc::FieldBuilder::new(fbb); + field_builder.add_name(fb_field_name); + if let Some(dictionary) = fb_dictionary { + field_builder.add_dictionary(dictionary) + } + field_builder.add_type_type(field_type.type_type); + field_builder.add_nullable(field.is_nullable()); + match field_type.children { + None => {} + Some(children) => field_builder.add_children(children), + }; + field_builder.add_type_(field_type.type_); + field_builder.finish() +} + /// Get the IPC type of a data type pub(crate) fn get_fb_field_type<'a: 'b, 'b>( data_type: &DataType, @@ -609,10 +623,51 @@ pub(crate) fn get_fb_field_type<'a: 'b, 'b>( children: Some(fbb.create_vector(&children[..])), } } + Dictionary(_, value_type) => { + // In this library, the dictionary "type" is a logical construct. Here we + // pass through to the value type, as we've already captured the index + // type in the DictionaryEncoding metadata in the parent field + get_fb_field_type(value_type, fbb) + } t => unimplemented!("Type {:?} not supported", t), } } +/// Create an IPC dictionary encoding +pub(crate) fn get_fb_dictionary<'a: 'b, 'b>( + index_type: &DataType, + dict_id: i64, + dict_is_ordered: bool, + fbb: &mut FlatBufferBuilder<'a>, +) -> WIPOffset> { + // We assume that the dictionary index type (as an integer) has already been + // validated elsewhere, and can safely assume we are dealing with integers + let mut index_builder = ipc::IntBuilder::new(fbb); + + match *index_type { + Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), + UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false), + _ => {} + } + + match *index_type { + Int8 | UInt8 => index_builder.add_bitWidth(8), + Int16 | UInt16 => index_builder.add_bitWidth(16), + Int32 | UInt32 => index_builder.add_bitWidth(32), + Int64 | UInt64 => index_builder.add_bitWidth(64), + _ => {} + } + + let index_builder = index_builder.finish(); + + let mut builder = ipc::DictionaryEncodingBuilder::new(fbb); + builder.add_id(dict_id); + builder.add_indexType(index_builder); + builder.add_isOrdered(dict_is_ordered); + + builder.finish() +} + #[cfg(test)] mod tests { use super::*; @@ -714,6 +769,26 @@ mod tests { false, ), Field::new("struct<>", DataType::Struct(vec![]), true), + Field::new_dict( + "dictionary", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 123, + true, + ), + Field::new_dict( + "dictionary", + DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::UInt32), + ), + true, + 123, + true, + ), ], md, ); diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index 53c422d481c..e4bb003d0bc 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -445,6 +445,69 @@ pub fn read_record_batch( RecordBatch::try_new(schema, arrays) } +/// Read the dictionary from the buffer and provided metadata, +/// updating the `dictionaries_by_field` with the resulting dictionary +fn read_dictionary( + buf: &[u8], + batch: ipc::DictionaryBatch, + ipc_schema: &ipc::Schema, + schema: &Schema, + dictionaries_by_field: &mut [Option], +) -> Result<()> { + if batch.isDelta() { + return Err(ArrowError::IoError( + "delta dictionary batches not supported".to_string(), + )); + } + + let id = batch.id(); + + // As the dictionary batch does not contain the type of the + // values array, we need to retrieve this from the schema. + let first_field = find_dictionary_field(ipc_schema, id).ok_or_else(|| { + ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) + })?; + + // Get an array representing this dictionary's values. + let dictionary_values: ArrayRef = match schema.field(first_field).data_type() { + DataType::Dictionary(_, ref value_type) => { + // Make a fake schema for the dictionary batch. + let schema = Schema { + fields: vec![Field::new("", value_type.as_ref().clone(), false)], + metadata: HashMap::new(), + }; + // Read a single column + let record_batch = read_record_batch( + &buf, + batch.data().unwrap(), + Arc::new(schema), + &dictionaries_by_field, + )?; + Some(record_batch.column(0).clone()) + } + _ => None, + } + .ok_or_else(|| { + ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) + })?; + + // for all fields with this dictionary id, update the dictionaries vector + // in the reader. Note that a dictionary batch may be shared between many fields. + // We don't currently record the isOrdered field. This could be general + // attributes of arrays. + let fields = ipc_schema.fields().unwrap(); + for (i, field) in fields.iter().enumerate() { + if let Some(dictionary) = field.dictionary() { + if dictionary.id() == id { + // Add (possibly multiple) array refs to the dictionaries array. + dictionaries_by_field[i] = Some(dictionary_values.clone()); + } + } + } + + Ok(()) +} + // Linear search for the first dictionary field with a dictionary id. fn find_dictionary_field(ipc_schema: &ipc::Schema, id: i64) -> Option { let fields = ipc_schema.fields().unwrap(); @@ -556,67 +619,13 @@ impl FileReader { ))?; reader.read_exact(&mut buf)?; - if batch.isDelta() { - return Err(ArrowError::IoError( - "delta dictionary batches not supported".to_string(), - )); - } - - let id = batch.id(); - - // As the dictionary batch does not contain the type of the - // values array, we need to retieve this from the schema. - let first_field = - find_dictionary_field(&ipc_schema, id).ok_or_else(|| { - ArrowError::InvalidArgumentError( - "dictionary id not found in schema".to_string(), - ) - })?; - - // Get an array representing this dictionary's values. - let dictionary_values: ArrayRef = - match schema.field(first_field).data_type() { - DataType::Dictionary(_, ref value_type) => { - // Make a fake schema for the dictionary batch. - let schema = Schema { - fields: vec![Field::new( - "", - value_type.as_ref().clone(), - false, - )], - metadata: HashMap::new(), - }; - // Read a single column - let record_batch = read_record_batch( - &buf, - batch.data().unwrap(), - Arc::new(schema), - &dictionaries_by_field, - )?; - Some(record_batch.column(0).clone()) - } - _ => None, - } - .ok_or_else(|| { - ArrowError::InvalidArgumentError( - "dictionary id not found in schema".to_string(), - ) - })?; - - // for all fields with this dictionary id, update the dictionaries vector - // in the reader. Note that a dictionary batch may be shared between many fields. - // We don't currently record the isOrdered field. This could be general - // attributes of arrays. - let fields = ipc_schema.fields().unwrap(); - for (i, field) in fields.iter().enumerate() { - if let Some(dictionary) = field.dictionary() { - if dictionary.id() == id { - // Add (possibly multiple) array refs to the dictionaries array. - dictionaries_by_field[i] = - Some(dictionary_values.clone()); - } - } - } + read_dictionary( + &buf, + batch, + &ipc_schema, + &schema, + &mut dictionaries_by_field, + )?; } _ => { return Err(ArrowError::IoError( @@ -747,17 +756,24 @@ impl RecordBatchReader for FileReader { pub struct StreamReader { /// Buffered stream reader reader: BufReader, + /// The schema that is read from the stream's first message schema: SchemaRef, - /// An indicator of whether the strewam is complete. + + /// The bytes of the IPC schema that is read from the stream's first message /// - /// This value is set to `true` the first time the reader's `next()` returns `None`. - finished: bool, + /// This is kept in order to interpret dictionary data + ipc_schema: Vec, /// Optional dictionaries for each schema field. /// /// Dictionaries may be appended to in the streaming format. dictionaries_by_field: Vec>, + + /// An indicator of whether the stream is complete. + /// + /// This value is set to `true` the first time the reader's `next()` returns `None`. + finished: bool, } impl StreamReader { @@ -783,8 +799,7 @@ impl StreamReader { let mut meta_buffer = vec![0; meta_len as usize]; reader.read_exact(&mut meta_buffer)?; - let vecs = &meta_buffer.to_vec(); - let message = ipc::get_root_as_message(vecs); + let message = ipc::get_root_as_message(meta_buffer.as_slice()); // message header is a Schema, so read it let ipc_schema: ipc::Schema = message.header_as_schema().ok_or_else(|| { ArrowError::IoError("Unable to read IPC message as schema".to_string()) @@ -797,6 +812,7 @@ impl StreamReader { Ok(Self { reader, schema: Arc::new(schema), + ipc_schema: meta_buffer, finished: false, dictionaries_by_field, }) @@ -871,6 +887,30 @@ impl StreamReader { read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field).map(Some) } + ipc::MessageHeader::DictionaryBatch => { + let batch = message.header_as_dictionary_batch().ok_or_else(|| { + ArrowError::IoError( + "Unable to read IPC message as dictionary batch".to_string(), + ) + })?; + // read the block that makes up the dictionary batch into a buffer + let mut buf = vec![0; message.bodyLength() as usize]; + self.reader.read_exact(&mut buf)?; + + let ipc_schema = ipc::get_root_as_message(&self.ipc_schema).header_as_schema() + .ok_or_else(|| { + ArrowError::IoError( + "Unable to read schema from stored message header".to_string(), + ) + })?; + + read_dictionary( + &buf, batch, &ipc_schema, &self.schema, &mut self.dictionaries_by_field + )?; + + // read the next message until we encounter a RecordBatch + self.maybe_next() + } ipc::MessageHeader::NONE => { Ok(None) } @@ -940,7 +980,7 @@ mod tests { let paths = vec![ "generated_interval", "generated_datetime", - // "generated_dictionary", + "generated_dictionary", "generated_nested", "generated_primitive_no_batches", "generated_primitive_zerolength", diff --git a/rust/arrow/src/util/display.rs b/rust/arrow/src/util/display.rs index bf0cade562f..87c18d26629 100644 --- a/rust/arrow/src/util/display.rs +++ b/rust/arrow/src/util/display.rs @@ -44,6 +44,22 @@ macro_rules! make_string { }}; } +macro_rules! make_string_from_list { + ($column: ident, $row: ident) => {{ + let list = $column + .as_any() + .downcast_ref::() + .ok_or(ArrowError::InvalidArgumentError(format!( + "Repl error: could not convert list column to list array." + )))? + .value($row); + let string_values = (0..list.len()) + .map(|i| array_value_to_string(&list.clone(), i)) + .collect::>>()?; + Ok(format!("[{}]", string_values.join(", "))) + }}; +} + /// Get the value at the given row in an array as a String. /// /// Note this function is quite inefficient and is unlikely to be @@ -89,6 +105,7 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result { make_string!(array::Time64NanosecondArray, column, row) } + DataType::List(_) => make_string_from_list!(column, row), DataType::Dictionary(index_type, _value_type) => match **index_type { DataType::Int8 => dict_array_value_to_string::(column, row), DataType::Int16 => dict_array_value_to_string::(column, row), diff --git a/rust/datafusion/benches/aggregate_query_sql.rs b/rust/datafusion/benches/aggregate_query_sql.rs index 547bf9e5d3c..bbb692d329e 100644 --- a/rust/datafusion/benches/aggregate_query_sql.rs +++ b/rust/datafusion/benches/aggregate_query_sql.rs @@ -22,6 +22,7 @@ use criterion::Criterion; use rand::seq::SliceRandom; use rand::Rng; use std::sync::{Arc, Mutex}; +use tokio::runtime::Runtime; extern crate arrow; extern crate datafusion; @@ -38,13 +39,12 @@ use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; -async fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { + let mut rt = Runtime::new().unwrap(); + // execute the query let df = ctx.lock().unwrap().sql(&sql).unwrap(); - let results = df.collect().await.unwrap(); - - // display the relation - for _batch in results {} + rt.block_on(df.collect()).unwrap(); } fn create_data(size: usize, null_density: f64) -> Vec> { @@ -116,8 +116,8 @@ fn create_context( } fn criterion_benchmark(c: &mut Criterion) { - let partitions_len = 4; - let array_len = 32768; // 2^15 + let partitions_len = 8; + let array_len = 32768 * 2; // 2^16 let batch_size = 2048; // 2^11 let ctx = create_context(partitions_len, array_len, batch_size).unwrap(); diff --git a/rust/datafusion/benches/math_query_sql.rs b/rust/datafusion/benches/math_query_sql.rs index b7e08106ff6..65f613b6cdd 100644 --- a/rust/datafusion/benches/math_query_sql.rs +++ b/rust/datafusion/benches/math_query_sql.rs @@ -21,6 +21,8 @@ use criterion::Criterion; use std::sync::{Arc, Mutex}; +use tokio::runtime::Runtime; + extern crate arrow; extern crate datafusion; @@ -34,13 +36,12 @@ use datafusion::error::Result; use datafusion::datasource::MemTable; use datafusion::execution::context::ExecutionContext; -async fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { + let mut rt = Runtime::new().unwrap(); + // execute the query let df = ctx.lock().unwrap().sql(&sql).unwrap(); - let results = df.collect().await.unwrap(); - - // display the relation - for _batch in results {} + rt.block_on(df.collect()).unwrap(); } fn create_context( @@ -77,24 +78,31 @@ fn create_context( } fn criterion_benchmark(c: &mut Criterion) { + let array_len = 1048576; // 2^20 + let batch_size = 512; // 2^9 + let ctx = create_context(array_len, batch_size).unwrap(); + c.bench_function("sqrt_20_9", |b| { + b.iter(|| query(ctx.clone(), "SELECT sqrt(f32) FROM t")) + }); + + let array_len = 1048576; // 2^20 + let batch_size = 4096; // 2^12 + let ctx = create_context(array_len, batch_size).unwrap(); c.bench_function("sqrt_20_12", |b| { - let array_len = 1048576; // 2^20 - let batch_size = 4096; // 2^12 - let ctx = create_context(array_len, batch_size).unwrap(); b.iter(|| query(ctx.clone(), "SELECT sqrt(f32) FROM t")) }); + let array_len = 4194304; // 2^22 + let batch_size = 4096; // 2^12 + let ctx = create_context(array_len, batch_size).unwrap(); c.bench_function("sqrt_22_12", |b| { - let array_len = 4194304; // 2^22 - let batch_size = 4096; // 2^12 - let ctx = create_context(array_len, batch_size).unwrap(); b.iter(|| query(ctx.clone(), "SELECT sqrt(f32) FROM t")) }); + let array_len = 4194304; // 2^22 + let batch_size = 16384; // 2^14 + let ctx = create_context(array_len, batch_size).unwrap(); c.bench_function("sqrt_22_14", |b| { - let array_len = 4194304; // 2^22 - let batch_size = 16384; // 2^14 - let ctx = create_context(array_len, batch_size).unwrap(); b.iter(|| query(ctx.clone(), "SELECT sqrt(f32) FROM t")) }); } diff --git a/rust/datafusion/benches/sort_limit_query_sql.rs b/rust/datafusion/benches/sort_limit_query_sql.rs index 1b2f1621c67..02440046b99 100644 --- a/rust/datafusion/benches/sort_limit_query_sql.rs +++ b/rust/datafusion/benches/sort_limit_query_sql.rs @@ -32,13 +32,12 @@ use datafusion::execution::context::ExecutionContext; use tokio::runtime::Runtime; -async fn run_query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { + let mut rt = Runtime::new().unwrap(); + // execute the query let df = ctx.lock().unwrap().sql(&sql).unwrap(); - let results = df.collect().await.unwrap(); - - // display the relation - for _batch in results {} + rt.block_on(df.collect()).unwrap(); } fn create_context() -> Arc> { @@ -90,7 +89,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("sort_and_limit_by_int", |b| { let ctx = create_context(); b.iter(|| { - run_query( + query( ctx.clone(), "SELECT c1, c13, c6, c10 \ FROM aggregate_test_100 \ @@ -103,7 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("sort_and_limit_by_float", |b| { let ctx = create_context(); b.iter(|| { - run_query( + query( ctx.clone(), "SELECT c1, c13, c12 \ FROM aggregate_test_100 \ @@ -116,7 +115,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("sort_and_limit_lex_by_int", |b| { let ctx = create_context(); b.iter(|| { - run_query( + query( ctx.clone(), "SELECT c1, c13, c6, c10 \ FROM aggregate_test_100 \ @@ -129,7 +128,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("sort_and_limit_lex_by_string", |b| { let ctx = create_context(); b.iter(|| { - run_query( + query( ctx.clone(), "SELECT c1, c13, c6, c10 \ FROM aggregate_test_100 \ diff --git a/rust/datafusion/examples/simple_udaf.rs b/rust/datafusion/examples/simple_udaf.rs index 4d3cc23696a..1f41f0db410 100644 --- a/rust/datafusion/examples/simple_udaf.rs +++ b/rust/datafusion/examples/simple_udaf.rs @@ -24,7 +24,7 @@ use arrow::{ use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator}; use datafusion::{prelude::*, scalar::ScalarValue}; -use std::{cell::RefCell, rc::Rc, sync::Arc}; +use std::sync::Arc; // create local execution context with an in-memory table fn create_context() -> Result { @@ -138,7 +138,7 @@ async fn main() -> Result<()> { // the return type; DataFusion expects this to match the type returned by `evaluate`. Arc::new(DataType::Float64), // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|| Ok(Rc::new(RefCell::new(GeometricMean::new())))), + Arc::new(|| Ok(Box::new(GeometricMean::new()))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index eabc779e49d..a2dd6c9887e 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -332,7 +332,7 @@ impl ExecutionContext { } _ => { // merge into a single partition - let plan = MergeExec::new(plan.clone(), self.state.config.concurrency); + let plan = MergeExec::new(plan.clone()); // MergeExec must produce a single partition assert_eq!(1, plan.output_partitioning().partition_count()); common::collect(plan.execute(0).await?) @@ -537,8 +537,8 @@ mod tests { ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray, }; use arrow::compute::add; + use std::fs::File; use std::thread::{self, JoinHandle}; - use std::{cell::RefCell, fs::File, rc::Rc}; use std::{io::prelude::*, sync::Mutex}; use tempfile::TempDir; use test::*; @@ -1371,11 +1371,7 @@ mod tests { "MY_AVG", DataType::Float64, Arc::new(DataType::Float64), - Arc::new(|| { - Ok(Rc::new(RefCell::new(AvgAccumulator::try_new( - &DataType::Float64, - )?))) - }), + Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index b8d0cc7fb82..6df92fe190e 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -25,7 +25,10 @@ use fmt::Debug; use std::{any::Any, collections::HashMap, collections::HashSet, fmt, sync::Arc}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, Field, Schema, SchemaRef}, +}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -37,8 +40,7 @@ use crate::{ }; use crate::{ physical_plan::{ - aggregates, expressions::binary_operator_data_type, functions, - type_coercion::can_coerce_from, udf::ScalarUDF, + aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, }, sql::parser::FileType, }; @@ -333,12 +335,13 @@ impl Expr { /// /// # Errors /// - /// This function errors when it is impossible to cast the expression to the target [arrow::datatypes::DataType]. + /// This function errors when it is impossible to cast the + /// expression to the target [arrow::datatypes::DataType]. pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self.clone()) - } else if can_coerce_from(cast_to_type, &this_type) { + } else if can_cast_types(&this_type, cast_to_type) { Ok(Expr::Cast { expr: Box::new(self.clone()), data_type: cast_to_type.clone(), diff --git a/rust/datafusion/src/physical_plan/aggregates.rs b/rust/datafusion/src/physical_plan/aggregates.rs index 40bb562b0e4..d417c41855d 100644 --- a/rust/datafusion/src/physical_plan/aggregates.rs +++ b/rust/datafusion/src/physical_plan/aggregates.rs @@ -36,11 +36,11 @@ use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Schema}; use expressions::{avg_return_type, sum_return_type}; -use std::{cell::RefCell, fmt, rc::Rc, str::FromStr, sync::Arc}; +use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = - Arc Result>> + Send + Sync>; + Arc Result> + Send + Sync>; /// This signature corresponds to which types an aggregator serializes /// its state, given its return datatype. diff --git a/rust/datafusion/src/physical_plan/distinct_expressions.rs b/rust/datafusion/src/physical_plan/distinct_expressions.rs index 2d2ab627d44..cc771078609 100644 --- a/rust/datafusion/src/physical_plan/distinct_expressions.rs +++ b/rust/datafusion/src/physical_plan/distinct_expressions.rs @@ -17,11 +17,9 @@ //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` -use std::cell::RefCell; use std::convert::TryFrom; use std::fmt::Debug; use std::hash::Hash; -use std::rc::Rc; use std::sync::Arc; use arrow::datatypes::{DataType, Field}; @@ -93,12 +91,12 @@ impl AggregateExpr for DistinctCount { self.exprs.clone() } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(DistinctCountAccumulator { + fn create_accumulator(&self) -> Result> { + Ok(Box::new(DistinctCountAccumulator { values: FnvHashSet::default(), data_types: self.input_data_types.clone(), count_data_type: self.data_type.clone(), - }))) + })) } } @@ -282,8 +280,7 @@ mod tests { DataType::UInt64, ); - let accum = agg.create_accumulator()?; - let mut accum = accum.borrow_mut(); + let mut accum = agg.create_accumulator()?; accum.update_batch(arrays)?; Ok((accum.state()?, accum.evaluate()?)) @@ -300,8 +297,7 @@ mod tests { DataType::UInt64, ); - let accum = agg.create_accumulator()?; - let mut accum = accum.borrow_mut(); + let mut accum = agg.create_accumulator()?; for row in rows.iter() { accum.update(row)? @@ -324,8 +320,7 @@ mod tests { DataType::UInt64, ); - let accum = agg.create_accumulator()?; - let mut accum = accum.borrow_mut(); + let mut accum = agg.create_accumulator()?; accum.merge_batch(arrays)?; Ok((accum.state()?, accum.evaluate()?)) diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 4c9029e7195..084f8186c5e 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -17,10 +17,9 @@ //! Defines physical expressions that can evaluated at runtime during query execution +use std::convert::TryFrom; use std::fmt; -use std::rc::Rc; use std::sync::Arc; -use std::{cell::RefCell, convert::TryFrom}; use crate::error::{ExecutionError, Result}; use crate::logical_plan::Operator; @@ -50,6 +49,7 @@ use arrow::{ }, datatypes::Field, }; +use compute::can_cast_types; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { @@ -162,10 +162,8 @@ impl AggregateExpr for Sum { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(SumAccumulator::try_new( - &self.data_type, - )?))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) } } @@ -391,11 +389,11 @@ impl AggregateExpr for Avg { ]) } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(AvgAccumulator::try_new( + fn create_accumulator(&self) -> Result> { + Ok(Box::new(AvgAccumulator::try_new( // avg is f64 &DataType::Float64, - )?))) + )?)) } fn expressions(&self) -> Vec> { @@ -521,10 +519,8 @@ impl AggregateExpr for Max { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(MaxAccumulator::try_new( - &self.data_type, - )?))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?)) } } @@ -774,10 +770,8 @@ impl AggregateExpr for Min { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(MinAccumulator::try_new( - &self.data_type, - )?))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) } } @@ -869,8 +863,8 @@ impl AggregateExpr for Count { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result>> { - Ok(Rc::new(RefCell::new(CountAccumulator::new()))) + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CountAccumulator::new())) } } @@ -1532,7 +1526,10 @@ impl PhysicalExpr for CastExpr { } } -/// Returns a cast operation, if casting needed. +/// Returns a physical cast operation that casts `expr` to `cast_type` +/// if casting is needed. +/// +/// Note that such casts may lose type information pub fn cast( expr: Arc, input_schema: &Schema, @@ -1540,19 +1537,12 @@ pub fn cast( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - return Ok(expr.clone()); - } - if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { - Ok(Arc::new(CastExpr { expr, cast_type })) - } else if is_numeric(&expr_type) - && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) - { + Ok(expr.clone()) + } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr { expr, cast_type })) } else { Err(ExecutionError::General(format!( - "Invalid CAST from {:?} to {:?}", + "Unsupported CAST from {:?} to {:?}", expr_type, cast_type ))) } @@ -1992,9 +1982,10 @@ mod tests { #[test] fn invalid_cast() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let result = cast(col("a"), &schema, DataType::Int32); - result.expect_err("Invalid CAST from Utf8 to Int32"); + // Ensure a useful error happens at plan time if invalid casts are used + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let result = cast(col("a"), &schema, DataType::LargeBinary); + result.expect_err("expected Invalid CAST"); Ok(()) } @@ -2476,13 +2467,12 @@ mod tests { batch: &RecordBatch, agg: Arc, ) -> Result { - let accum = agg.create_accumulator()?; + let mut accum = agg.create_accumulator()?; let expr = agg.expressions(); let values = expr .iter() .map(|e| e.evaluate(batch)) .collect::>>()?; - let mut accum = accum.borrow_mut(); accum.update_batch(&values)?; accum.evaluate() } diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 5f4fe9876b7..2860c3babe1 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -18,8 +18,6 @@ //! Defines the execution plan for the hash aggregate operation use std::any::Any; -use std::cell::RefCell; -use std::rc::Rc; use std::sync::Arc; use crate::error::{ExecutionError, Result}; @@ -278,9 +276,8 @@ fn group_aggregate_batch( .map(|(_, (accumulator_set, indices))| { // 2.2 accumulator_set - .iter() - .zip(&aggr_input_values) .into_iter() + .zip(&aggr_input_values) .map(|(accumulator, aggr_array)| { ( accumulator, @@ -300,12 +297,10 @@ fn group_aggregate_batch( }) // 2.4 .map(|(accumulator, values)| match mode { - AggregateMode::Partial => { - accumulator.borrow_mut().update_batch(&values) - } + AggregateMode::Partial => accumulator.update_batch(&values), AggregateMode::Final => { // note: the aggregation here is over states, not values, thus the merge - accumulator.borrow_mut().merge_batch(&values) + accumulator.merge_batch(&values) } }) .collect::>() @@ -335,7 +330,7 @@ impl GroupedHashAggregateIterator { } } -type AccumulatorSet = Vec>>; +type AccumulatorSet = Vec>; impl Iterator for GroupedHashAggregateIterator { type Item = ArrowResult; @@ -490,7 +485,7 @@ impl HashAggregateIterator { fn aggregate_batch( mode: &AggregateMode, batch: &RecordBatch, - accumulators: &AccumulatorSet, + accumulators: &mut AccumulatorSet, expressions: &Vec>>, ) -> Result<()> { // 1.1 iterate accumulators and respective expressions together @@ -499,7 +494,7 @@ fn aggregate_batch( // 1.1 accumulators - .iter() + .into_iter() .zip(expressions) .map(|(accum, expr)| { // 1.2 @@ -510,8 +505,8 @@ fn aggregate_batch( // 1.3 match mode { - AggregateMode::Partial => accum.borrow_mut().update_batch(values), - AggregateMode::Final => accum.borrow_mut().merge_batch(values), + AggregateMode::Partial => accum.update_batch(values), + AggregateMode::Final => accum.merge_batch(values), } }) .collect::>() @@ -528,7 +523,7 @@ impl Iterator for HashAggregateIterator { // return single batch self.finished = true; - let accumulators = match create_accumulators(&self.aggr_expr) { + let mut accumulators = match create_accumulators(&self.aggr_expr) { Ok(e) => e, Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))), }; @@ -547,7 +542,7 @@ impl Iterator for HashAggregateIterator { .as_mut() .into_iter() .map(|batch| { - aggregate_batch(&mode, &batch?, &accumulators, &expressions) + aggregate_batch(&mode, &batch?, &mut accumulators, &expressions) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>() @@ -655,7 +650,7 @@ fn finalize_aggregation( // build the vector of states let a = accumulators .iter() - .map(|accumulator| accumulator.borrow_mut().state()) + .map(|accumulator| accumulator.state()) .map(|value| { value.and_then(|e| { Ok(e.iter().map(|v| v.to_array()).collect::>()) @@ -668,12 +663,7 @@ fn finalize_aggregation( // merge the state to the final value accumulators .iter() - .map(|accumulator| { - accumulator - .borrow_mut() - .evaluate() - .and_then(|v| Ok(v.to_array())) - }) + .map(|accumulator| accumulator.evaluate().and_then(|v| Ok(v.to_array()))) .collect::>>() } } @@ -820,7 +810,7 @@ mod tests { .unwrap(); assert_eq!(*sums, Float64Array::from(vec![2.0, 7.0, 11.0])); - let merge = Arc::new(MergeExec::new(partial_aggregate, 2)); + let merge = Arc::new(MergeExec::new(partial_aggregate)); let final_group: Vec> = (0..groups.len()).map(|i| col(&groups[i].1)).collect(); diff --git a/rust/datafusion/src/physical_plan/limit.rs b/rust/datafusion/src/physical_plan/limit.rs index 8c0e563b031..753cbf7bdbf 100644 --- a/rust/datafusion/src/physical_plan/limit.rs +++ b/rust/datafusion/src/physical_plan/limit.rs @@ -243,8 +243,7 @@ mod tests { // input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); - let limit = - GlobalLimitExec::new(Arc::new(MergeExec::new(Arc::new(csv), 2)), 7, 2); + let limit = GlobalLimitExec::new(Arc::new(MergeExec::new(Arc::new(csv))), 7, 2); // the result should contain 4 batches (one per input partition) let iter = limit.execute(0).await?; diff --git a/rust/datafusion/src/physical_plan/merge.rs b/rust/datafusion/src/physical_plan/merge.rs index 02243bc7cc6..7ce737c9910 100644 --- a/rust/datafusion/src/physical_plan/merge.rs +++ b/rust/datafusion/src/physical_plan/merge.rs @@ -32,7 +32,7 @@ use arrow::record_batch::RecordBatch; use super::SendableRecordBatchReader; use async_trait::async_trait; -use tokio::task::{self, JoinHandle}; +use tokio; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -40,17 +40,12 @@ use tokio::task::{self, JoinHandle}; pub struct MergeExec { /// Input execution plan input: Arc, - /// Maximum number of concurrent threads - concurrency: usize, } impl MergeExec { /// Create a new MergeExec - pub fn new(input: Arc, max_concurrency: usize) -> Self { - MergeExec { - input, - concurrency: max_concurrency, - } + pub fn new(input: Arc) -> Self { + MergeExec { input } } } @@ -79,10 +74,7 @@ impl ExecutionPlan for MergeExec { children: Vec>, ) -> Result> { match children.len() { - 1 => Ok(Arc::new(MergeExec::new( - children[0].clone(), - self.concurrency, - ))), + 1 => Ok(Arc::new(MergeExec::new(children[0].clone()))), _ => Err(ExecutionError::General( "MergeExec wrong number of children".to_string(), )), @@ -108,35 +100,23 @@ impl ExecutionPlan for MergeExec { self.input.execute(0).await } _ => { - let partitions_per_thread = (input_partitions / self.concurrency).max(1); - let range: Vec = (0..input_partitions).collect(); - let chunks = range.chunks(partitions_per_thread); - - let mut tasks = vec![]; - for chunk in chunks { - let chunk = chunk.to_vec(); - let input = self.input.clone(); - let task: JoinHandle>>> = - task::spawn(async move { - let mut batches: Vec> = vec![]; - for partition in chunk { - let it = input.execute(partition).await?; - common::collect(it).iter().for_each(|b| { - b.iter() - .for_each(|b| batches.push(Arc::new(b.clone()))) - }); - } - Ok(batches) - }); - tasks.push(task); - } + let tasks = (0..input_partitions) + .map(|part_i| { + let input = self.input.clone(); + tokio::spawn(async move { + let it = input.execute(part_i).await?; + common::collect(it) + }) + }) + // this collect *is needed* so that the join below can + // switch between tasks + .collect::>(); - // combine the results from each thread let mut combined_results: Vec> = vec![]; for task in tasks { let result = task.await.unwrap()?; for batch in &result { - combined_results.push(batch.clone()); + combined_results.push(Arc::new(batch.clone())); } } @@ -171,7 +151,7 @@ mod tests { // input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); - let merge = MergeExec::new(Arc::new(csv), 2); + let merge = MergeExec::new(Arc::new(csv)); // output of MergeExec should have a single partition assert_eq!(merge.output_partitioning().partition_count(), 1); diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index ac33c67f6ac..1d6c46afe09 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -18,9 +18,7 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. use std::any::Any; -use std::cell::RefCell; use std::fmt::{Debug, Display}; -use std::rc::Rc; use std::sync::Arc; use crate::execution::context::ExecutionContextState; @@ -122,7 +120,7 @@ pub trait AggregateExpr: Send + Sync + Debug { /// the accumulator used to accumulate values from the expressions. /// the accumulator expects the same number of arguments as `expressions` and must /// return states with the same description as `state_fields` - fn create_accumulator(&self) -> Result>>; + fn create_accumulator(&self) -> Result>; /// the fields that encapsulate the Accumulator's state /// the number of fields here equals the number of states that the accumulator contains diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index bdaf79c7b2c..c4ae2dc6853 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -117,10 +117,7 @@ impl DefaultPhysicalPlanner { if child.output_partitioning().partition_count() == 1 { child.clone() } else { - Arc::new(MergeExec::new( - child.clone(), - ctx_state.config.concurrency, - )) + Arc::new(MergeExec::new(child.clone())) } }) .collect(), diff --git a/rust/datafusion/src/physical_plan/sort.rs b/rust/datafusion/src/physical_plan/sort.rs index 3ddfa183117..7c00cc5cb50 100644 --- a/rust/datafusion/src/physical_plan/sort.rs +++ b/rust/datafusion/src/physical_plan/sort.rs @@ -208,7 +208,7 @@ mod tests { options: SortOptions::default(), }, ], - Arc::new(MergeExec::new(Arc::new(csv), 2)), + Arc::new(MergeExec::new(Arc::new(csv))), 2, )?); diff --git a/rust/datafusion/src/physical_plan/udaf.rs b/rust/datafusion/src/physical_plan/udaf.rs index 933fd237c65..db86e1447ab 100644 --- a/rust/datafusion/src/physical_plan/udaf.rs +++ b/rust/datafusion/src/physical_plan/udaf.rs @@ -18,7 +18,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. use fmt::{Debug, Formatter}; -use std::{cell::RefCell, fmt, rc::Rc}; +use std::fmt; use arrow::{ datatypes::Field, @@ -150,7 +150,7 @@ impl AggregateExpr for AggregateFunctionExpr { Ok(Field::new(&self.name, self.data_type.clone(), true)) } - fn create_accumulator(&self) -> Result>> { + fn create_accumulator(&self) -> Result> { (self.fun.accumulator)() } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 52027a4080b..7322b63994d 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::convert::TryFrom; use std::env; use std::sync::Arc; extern crate arrow; extern crate datafusion; -use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::TimeUnit}; +use arrow::{datatypes::Int64Type, record_batch::RecordBatch}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, util::display::array_value_to_string, @@ -128,6 +129,100 @@ async fn parquet_single_nan_schema() { } } +#[tokio::test] +async fn parquet_list_columns() { + let mut ctx = ExecutionContext::new(); + let testdata = env::var("PARQUET_TEST_DATA").expect("PARQUET_TEST_DATA not defined"); + ctx.register_parquet( + "list_columns", + &format!("{}/list_columns.parquet", testdata), + ) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "int64_list", + DataType::List(Box::new(DataType::Int64)), + true, + ), + Field::new("utf8_list", DataType::List(Box::new(DataType::Utf8)), true), + ])); + + let sql = "SELECT int64_list, utf8_list FROM list_columns"; + let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan).unwrap(); + let results = ctx.collect(plan).await.unwrap(); + + // int64_list utf8_list + // 0 [1, 2, 3] [abc, efg, hij] + // 1 [None, 1] None + // 2 [4] [efg, None, hij, xyz] + + assert_eq!(1, results.len()); + let batch = &results[0]; + assert_eq!(3, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + assert_eq!(schema, batch.schema()); + + let int_list_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let utf8_list_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + int_list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) + ); + + assert_eq!( + utf8_list_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap(), + &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + ); + + assert_eq!( + int_list_array + .value(1) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![None, Some(1),]) + ); + + assert!(utf8_list_array.is_null(1)); + + assert_eq!( + int_list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(4),]) + ); + + let result = utf8_list_array.value(2); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), "efg"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "hij"); + assert_eq!(result.value(3), "xyz"); +} + #[tokio::test] async fn csv_count_star() -> Result<()> { let mut ctx = ExecutionContext::new(); diff --git a/rust/parquet/Cargo.toml b/rust/parquet/Cargo.toml index 50d7c34d341..60e43c93ffa 100644 --- a/rust/parquet/Cargo.toml +++ b/rust/parquet/Cargo.toml @@ -40,6 +40,7 @@ zstd = { version = "0.5", optional = true } chrono = "0.4" num-bigint = "0.3" arrow = { path = "../arrow", version = "2.0.0-SNAPSHOT", optional = true } +base64 = { version = "*", optional = true } [dev-dependencies] rand = "0.7" @@ -52,4 +53,4 @@ arrow = { path = "../arrow", version = "2.0.0-SNAPSHOT" } serde_json = { version = "1.0", features = ["preserve_order"] } [features] -default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd"] +default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64"] diff --git a/rust/parquet/src/arrow/array_reader.rs b/rust/parquet/src/arrow/array_reader.rs index 14bf7d287a3..77990cc1d86 100644 --- a/rust/parquet/src/arrow/array_reader.rs +++ b/rust/parquet/src/arrow/array_reader.rs @@ -25,19 +25,42 @@ use std::sync::Arc; use std::vec::Vec; use arrow::array::{ - ArrayDataBuilder, ArrayDataRef, ArrayRef, BooleanBufferBuilder, BufferBuilderTrait, - Int16BufferBuilder, StructArray, + Array, ArrayData, ArrayDataBuilder, ArrayDataRef, ArrayRef, BinaryArray, + BinaryBuilder, BooleanBufferBuilder, BufferBuilderTrait, FixedSizeBinaryArray, + FixedSizeBinaryBuilder, GenericListArray, Int16BufferBuilder, ListBuilder, + OffsetSizeTrait, PrimitiveArray, PrimitiveArrayOps, PrimitiveBuilder, StringArray, + StringBuilder, StructArray, }; use arrow::buffer::{Buffer, MutableBuffer}; -use arrow::datatypes::{DataType as ArrowType, DateUnit, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{ + BooleanType as ArrowBooleanType, DataType as ArrowType, + Date32Type as ArrowDate32Type, Date64Type as ArrowDate64Type, + DurationMicrosecondType as ArrowDurationMicrosecondType, + DurationMillisecondType as ArrowDurationMillisecondType, + DurationNanosecondType as ArrowDurationNanosecondType, + DurationSecondType as ArrowDurationSecondType, Field, + Float32Type as ArrowFloat32Type, Float64Type as ArrowFloat64Type, + Int16Type as ArrowInt16Type, Int32Type as ArrowInt32Type, + Int64Type as ArrowInt64Type, Int8Type as ArrowInt8Type, Schema, + Time32MillisecondType as ArrowTime32MillisecondType, + Time32SecondType as ArrowTime32SecondType, + Time64MicrosecondType as ArrowTime64MicrosecondType, + Time64NanosecondType as ArrowTime64NanosecondType, TimeUnit as ArrowTimeUnit, + TimestampMicrosecondType as ArrowTimestampMicrosecondType, + TimestampMillisecondType as ArrowTimestampMillisecondType, + TimestampNanosecondType as ArrowTimestampNanosecondType, + TimestampSecondType as ArrowTimestampSecondType, ToByteSlice, + UInt16Type as ArrowUInt16Type, UInt32Type as ArrowUInt32Type, + UInt64Type as ArrowUInt64Type, UInt8Type as ArrowUInt8Type, +}; +use arrow::util::bit_util; use crate::arrow::converter::{ BinaryArrayConverter, BinaryConverter, BoolConverter, BooleanArrayConverter, - Converter, Date32Converter, FixedLenBinaryConverter, FixedSizeArrayConverter, - Float32Converter, Float64Converter, Int16Converter, Int32Converter, Int64Converter, - Int8Converter, Int96ArrayConverter, Int96Converter, TimestampMicrosecondConverter, - TimestampMillisecondConverter, UInt16Converter, UInt32Converter, UInt64Converter, - UInt8Converter, Utf8ArrayConverter, Utf8Converter, + Converter, FixedLenBinaryConverter, FixedSizeArrayConverter, Float32Converter, + Float64Converter, Int32Converter, Int64Converter, Int96ArrayConverter, + Int96Converter, LargeBinaryArrayConverter, LargeBinaryConverter, + LargeUtf8ArrayConverter, LargeUtf8Converter, Utf8ArrayConverter, Utf8Converter, }; use crate::arrow::record_reader::RecordReader; use crate::arrow::schema::parquet_to_arrow_field; @@ -77,6 +100,97 @@ pub trait ArrayReader { fn get_rep_levels(&self) -> Option<&[i16]>; } +/// A NullArrayReader reads Parquet columns stored as null int32s with an Arrow +/// NullArray type. +pub struct NullArrayReader { + data_type: ArrowType, + pages: Box, + def_levels_buffer: Option, + rep_levels_buffer: Option, + column_desc: ColumnDescPtr, + record_reader: RecordReader, + _type_marker: PhantomData, +} + +impl NullArrayReader { + /// Construct null array reader. + pub fn new( + mut pages: Box, + column_desc: ColumnDescPtr, + ) -> Result { + let mut record_reader = RecordReader::::new(column_desc.clone()); + if let Some(page_reader) = pages.next() { + record_reader.set_page_reader(page_reader?)?; + } + + Ok(Self { + data_type: ArrowType::Null, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + column_desc, + record_reader, + _type_marker: PhantomData, + }) + } +} + +/// Implementation of primitive array reader. +impl ArrayReader for NullArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type of primitive array. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + /// Reads at most `batch_size` records into array. + fn next_batch(&mut self, batch_size: usize) -> Result { + let mut records_read = 0usize; + while records_read < batch_size { + let records_to_read = batch_size - records_read; + + // NB can be 0 if at end of page + let records_read_once = self.record_reader.read_records(records_to_read)?; + records_read += records_read_once; + + // Record reader exhausted + if records_read_once < records_to_read { + if let Some(page_reader) = self.pages.next() { + // Read from new page reader + self.record_reader.set_page_reader(page_reader?)?; + } else { + // Page reader also exhausted + break; + } + } + } + + // convert to arrays + let array = arrow::array::NullArray::new(records_read); + + // save definition and repetition buffers + self.def_levels_buffer = self.record_reader.consume_def_levels()?; + self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; + self.record_reader.reset(); + Ok(Arc::new(array)) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } +} + /// Primitive array readers are leaves of array reader tree. They accept page iterator /// and read them into primitive arrays. pub struct PrimitiveArrayReader { @@ -94,10 +208,15 @@ impl PrimitiveArrayReader { pub fn new( mut pages: Box, column_desc: ColumnDescPtr, + arrow_type: Option, ) -> Result { - let data_type = parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(); + // Check if Arrow type is specified, else create it from Parquet type + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; let mut record_reader = RecordReader::::new(column_desc.clone()); if let Some(page_reader) = pages.next() { @@ -149,75 +268,40 @@ impl ArrayReader for PrimitiveArrayReader { } } - // convert to arrays + // Convert to arrays by using the Parquet phyisical type. + // The physical types are then cast to Arrow types if necessary let array = - match (&self.data_type, T::get_physical_type()) { - (ArrowType::Boolean, PhysicalType::BOOLEAN) => { - BoolConverter::new(BooleanArrayConverter {}) - .convert(self.record_reader.cast::()) - } - (ArrowType::Int8, PhysicalType::INT32) => { - Int8Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int16, PhysicalType::INT32) => { - Int16Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int32, PhysicalType::INT32) => { + match T::get_physical_type() { + PhysicalType::BOOLEAN => BoolConverter::new(BooleanArrayConverter {}) + .convert(self.record_reader.cast::()), + PhysicalType::INT32 => { + // TODO: the cast is a no-op, but we should remove it Int32Converter::new().convert(self.record_reader.cast::()) } - (ArrowType::UInt8, PhysicalType::INT32) => { - UInt8Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt16, PhysicalType::INT32) => { - UInt16Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt32, PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int64, PhysicalType::INT64) => { + PhysicalType::INT64 => { Int64Converter::new().convert(self.record_reader.cast::()) } - (ArrowType::UInt64, PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Float32, PhysicalType::FLOAT) => Float32Converter::new() + PhysicalType::FLOAT => Float32Converter::new() .convert(self.record_reader.cast::()), - (ArrowType::Float64, PhysicalType::DOUBLE) => Float64Converter::new() + PhysicalType::DOUBLE => Float64Converter::new() .convert(self.record_reader.cast::()), - (ArrowType::Timestamp(unit, _), PhysicalType::INT64) => match unit { - TimeUnit::Millisecond => TimestampMillisecondConverter::new() - .convert(self.record_reader.cast::()), - TimeUnit::Microsecond => TimestampMicrosecondConverter::new() - .convert(self.record_reader.cast::()), - _ => Err(general_err!("No conversion from parquet type to arrow type for timestamp with unit {:?}", unit)), - }, - (ArrowType::Date32(unit), PhysicalType::INT32) => match unit { - DateUnit::Day => Date32Converter::new() - .convert(self.record_reader.cast::()), - _ => Err(general_err!("No conversion from parquet type to arrow type for date with unit {:?}", unit)), - } - (ArrowType::Time32(_), PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Time64(_), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) + PhysicalType::INT96 + | PhysicalType::BYTE_ARRAY + | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + // TODO: we could use unreachable!() as this is a private fn + Err(general_err!( + "Cannot read primitive array with a complex physical type" + )) } - (ArrowType::Interval(IntervalUnit::YearMonth), PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Interval(IntervalUnit::DayTime), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Duration(_), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (arrow_type, physical_type) => Err(general_err!( - "Reading {:?} type from parquet {:?} is not supported yet.", - arrow_type, - physical_type - )), }?; + // cast to Arrow type + // TODO: we need to check if it's fine for this to be fallible. + // My assumption is that we can't get to an illegal cast as we can only + // generate types that are supported, because we'd have gotten them from + // the metadata which was written to the Parquet sink + let array = arrow::compute::cast(&array, self.get_data_type())?; + // save definition and repetition buffers self.def_levels_buffer = self.record_reader.consume_def_levels()?; self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; @@ -369,7 +453,13 @@ where data_buffer.into_iter().map(Some).collect() }; - self.converter.convert(data) + // TODO: I did this quickly without thinking through it, there might be edge cases to consider + let array = self.converter.convert(data)?; + + Ok(match self.data_type { + ArrowType::Dictionary(_, _) => arrow::compute::cast(&array, &self.data_type)?, + _ => array, + }) } fn get_def_levels(&self) -> Option<&[i16]> { @@ -390,10 +480,14 @@ where pages: Box, column_desc: ColumnDescPtr, converter: C, + arrow_type: Option, ) -> Result { - let data_type = parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(); + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; Ok(Self { data_type, @@ -420,6 +514,400 @@ where } } +/// Implementation of list array reader. +pub struct ListArrayReader { + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + list_def_level: i16, + list_rep_level: i16, + def_level_buffer: Option, + rep_level_buffer: Option, + _marker: PhantomData, +} + +impl ListArrayReader { + /// Construct list array reader. + pub fn new( + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + def_level: i16, + rep_level: i16, + ) -> Self { + Self { + item_reader, + data_type, + item_type, + list_def_level: def_level, + list_rep_level: rep_level, + def_level_buffer: None, + rep_level_buffer: None, + _marker: PhantomData, + } + } +} + +macro_rules! build_empty_list_array_with_primitive_items { + ($item_type:ident) => {{ + let values_builder = PrimitiveBuilder::<$item_type>::new(0); + let mut builder = ListBuilder::new(values_builder); + let empty_list_array = builder.finish(); + Ok(Arc::new(empty_list_array)) + }}; +} + +macro_rules! build_empty_list_array_with_non_primitive_items { + ($builder:ident) => {{ + let values_builder = $builder::new(0); + let mut builder = ListBuilder::new(values_builder); + let empty_list_array = builder.finish(); + Ok(Arc::new(empty_list_array)) + }}; +} + +fn build_empty_list_array(item_type: ArrowType) -> Result { + match item_type { + ArrowType::UInt8 => build_empty_list_array_with_primitive_items!(ArrowUInt8Type), + ArrowType::UInt16 => { + build_empty_list_array_with_primitive_items!(ArrowUInt16Type) + } + ArrowType::UInt32 => { + build_empty_list_array_with_primitive_items!(ArrowUInt32Type) + } + ArrowType::UInt64 => { + build_empty_list_array_with_primitive_items!(ArrowUInt64Type) + } + ArrowType::Int8 => build_empty_list_array_with_primitive_items!(ArrowInt8Type), + ArrowType::Int16 => build_empty_list_array_with_primitive_items!(ArrowInt16Type), + ArrowType::Int32 => build_empty_list_array_with_primitive_items!(ArrowInt32Type), + ArrowType::Int64 => build_empty_list_array_with_primitive_items!(ArrowInt64Type), + ArrowType::Float32 => { + build_empty_list_array_with_primitive_items!(ArrowFloat32Type) + } + ArrowType::Float64 => { + build_empty_list_array_with_primitive_items!(ArrowFloat64Type) + } + ArrowType::Boolean => { + build_empty_list_array_with_primitive_items!(ArrowBooleanType) + } + ArrowType::Date32(_) => { + build_empty_list_array_with_primitive_items!(ArrowDate32Type) + } + ArrowType::Date64(_) => { + build_empty_list_array_with_primitive_items!(ArrowDate64Type) + } + ArrowType::Time32(ArrowTimeUnit::Second) => { + build_empty_list_array_with_primitive_items!(ArrowTime32SecondType) + } + ArrowType::Time32(ArrowTimeUnit::Millisecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime32MillisecondType) + } + ArrowType::Time64(ArrowTimeUnit::Microsecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime64MicrosecondType) + } + ArrowType::Time64(ArrowTimeUnit::Nanosecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime64NanosecondType) + } + ArrowType::Duration(ArrowTimeUnit::Second) => { + build_empty_list_array_with_primitive_items!(ArrowDurationSecondType) + } + ArrowType::Duration(ArrowTimeUnit::Millisecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationMillisecondType) + } + ArrowType::Duration(ArrowTimeUnit::Microsecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationMicrosecondType) + } + ArrowType::Duration(ArrowTimeUnit::Nanosecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationNanosecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Second, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampSecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Millisecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampMillisecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Microsecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampMicrosecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Nanosecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampNanosecondType) + } + ArrowType::Utf8 => { + build_empty_list_array_with_non_primitive_items!(StringBuilder) + } + ArrowType::Binary => { + build_empty_list_array_with_non_primitive_items!(BinaryBuilder) + } + _ => Err(ParquetError::General(format!( + "ListArray of type List({:?}) is not supported by array_reader", + item_type + ))), + } +} + +macro_rules! remove_primitive_array_indices { + ($arr: expr, $item_type:ty, $indices:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = PrimitiveBuilder::<$item_type>::new($arr.len()); + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! remove_array_indices_custom_builder { + ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::<$array_type>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = $item_builder::new(array_data.len()); + + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! remove_fixed_size_binary_array_indices { + ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr, $len:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::<$array_type>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = FixedSizeBinaryBuilder::new(array_data.len(), $len); + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +fn remove_indices( + arr: ArrayRef, + item_type: ArrowType, + indices: Vec, +) -> Result { + match item_type { + ArrowType::UInt8 => remove_primitive_array_indices!(arr, ArrowUInt8Type, indices), + ArrowType::UInt16 => { + remove_primitive_array_indices!(arr, ArrowUInt16Type, indices) + } + ArrowType::UInt32 => { + remove_primitive_array_indices!(arr, ArrowUInt32Type, indices) + } + ArrowType::UInt64 => { + remove_primitive_array_indices!(arr, ArrowUInt64Type, indices) + } + ArrowType::Int8 => remove_primitive_array_indices!(arr, ArrowInt8Type, indices), + ArrowType::Int16 => remove_primitive_array_indices!(arr, ArrowInt16Type, indices), + ArrowType::Int32 => remove_primitive_array_indices!(arr, ArrowInt32Type, indices), + ArrowType::Int64 => remove_primitive_array_indices!(arr, ArrowInt64Type, indices), + ArrowType::Float32 => { + remove_primitive_array_indices!(arr, ArrowFloat32Type, indices) + } + ArrowType::Float64 => { + remove_primitive_array_indices!(arr, ArrowFloat64Type, indices) + } + ArrowType::Boolean => { + remove_primitive_array_indices!(arr, ArrowBooleanType, indices) + } + ArrowType::Date32(_) => { + remove_primitive_array_indices!(arr, ArrowDate32Type, indices) + } + ArrowType::Date64(_) => { + remove_primitive_array_indices!(arr, ArrowDate64Type, indices) + } + ArrowType::Time32(ArrowTimeUnit::Second) => { + remove_primitive_array_indices!(arr, ArrowTime32SecondType, indices) + } + ArrowType::Time32(ArrowTimeUnit::Millisecond) => { + remove_primitive_array_indices!(arr, ArrowTime32MillisecondType, indices) + } + ArrowType::Time64(ArrowTimeUnit::Microsecond) => { + remove_primitive_array_indices!(arr, ArrowTime64MicrosecondType, indices) + } + ArrowType::Time64(ArrowTimeUnit::Nanosecond) => { + remove_primitive_array_indices!(arr, ArrowTime64NanosecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Second) => { + remove_primitive_array_indices!(arr, ArrowDurationSecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Millisecond) => { + remove_primitive_array_indices!(arr, ArrowDurationMillisecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Microsecond) => { + remove_primitive_array_indices!(arr, ArrowDurationMicrosecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Nanosecond) => { + remove_primitive_array_indices!(arr, ArrowDurationNanosecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Second, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampSecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Millisecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampMillisecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Microsecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampMicrosecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Nanosecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampNanosecondType, indices) + } + ArrowType::Utf8 => { + remove_array_indices_custom_builder!(arr, StringArray, StringBuilder, indices) + } + ArrowType::Binary => { + remove_array_indices_custom_builder!(arr, BinaryArray, BinaryBuilder, indices) + } + ArrowType::FixedSizeBinary(size) => remove_fixed_size_binary_array_indices!( + arr, + FixedSizeBinaryArray, + FixedSizeBinaryBuilder, + indices, + size + ), + _ => Err(ParquetError::General(format!( + "ListArray of type List({:?}) is not supported by array_reader", + item_type + ))), + } +} + +/// Implementation of ListArrayReader. Nested lists and lists of structs are not yet supported. +impl ArrayReader for ListArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type. + /// This must be a List. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + let next_batch_array = self.item_reader.next_batch(batch_size)?; + let item_type = self.item_reader.get_data_type().clone(); + + if next_batch_array.len() == 0 { + return build_empty_list_array(item_type); + } + let def_levels = self + .item_reader + .get_def_levels() + .ok_or_else(|| ArrowError("item_reader def levels are None.".to_string()))?; + let rep_levels = self + .item_reader + .get_rep_levels() + .ok_or_else(|| ArrowError("item_reader rep levels are None.".to_string()))?; + + if !((def_levels.len() == rep_levels.len()) + && (rep_levels.len() == next_batch_array.len())) + { + return Err(ArrowError( + "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), + )); + } + + // Need to remove from the values array the nulls that represent null lists rather than null items + // null lists have def_level = 0 + let mut null_list_indices: Vec = Vec::new(); + for i in 0..def_levels.len() { + if def_levels[i] == 0 { + null_list_indices.push(i); + } + } + let batch_values = match null_list_indices.len() { + 0 => next_batch_array.clone(), + _ => remove_indices(next_batch_array.clone(), item_type, null_list_indices)?, + }; + + // null list has def_level = 0 + // empty list has def_level = 1 + // null item in a list has def_level = 2 + // non-null item has def_level = 3 + // first item in each list has rep_level = 0, subsequent items have rep_level = 1 + + let mut offsets: Vec = Vec::new(); + let mut cur_offset = OffsetSize::zero(); + for i in 0..rep_levels.len() { + if rep_levels[i] == 0 { + offsets.push(cur_offset) + } + if def_levels[i] > 0 { + cur_offset = cur_offset + OffsetSize::one(); + } + } + offsets.push(cur_offset); + + let num_bytes = bit_util::ceil(offsets.len(), 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + let null_slice = null_buf.data_mut(); + let mut list_index = 0; + for i in 0..rep_levels.len() { + if rep_levels[i] == 0 && def_levels[i] != 0 { + bit_util::set_bit(null_slice, list_index); + } + if rep_levels[i] == 0 { + list_index += 1; + } + } + let value_offsets = Buffer::from(&offsets.to_byte_slice()); + + // null list has def_level = 0 + let null_count = def_levels.iter().filter(|x| x == &&0).count(); + + let list_data = ArrayData::builder(self.get_data_type().clone()) + .len(offsets.len() - 1) + .add_buffer(value_offsets) + .add_child_data(batch_values.data()) + .null_bit_buffer(null_buf.freeze()) + .null_count(null_count) + .offset(next_batch_array.offset()) + .build(); + + let result_array = GenericListArray::::from(list_data); + Ok(Arc::new(result_array)) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } +} + /// Implementation of struct array reader. pub struct StructArrayReader { children: Vec>, @@ -595,6 +1083,7 @@ impl ArrayReader for StructArrayReader { /// Create array reader from parquet schema, column indices, and parquet file reader. pub fn build_array_reader( parquet_schema: SchemaDescPtr, + arrow_schema: Schema, column_indices: T, file_reader: Rc, ) -> Result> @@ -633,13 +1122,19 @@ where fields: filtered_root_fields, }; - ArrayReaderBuilder::new(Rc::new(proj), Rc::new(leaves), file_reader) - .build_array_reader() + ArrayReaderBuilder::new( + Rc::new(proj), + Rc::new(arrow_schema), + Rc::new(leaves), + file_reader, + ) + .build_array_reader() } /// Used to build array reader. struct ArrayReaderBuilder { root_schema: TypePtr, + arrow_schema: Rc, // Key: columns that need to be included in final array builder // Value: column index in schema columns_included: Rc>, @@ -756,16 +1251,94 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext } /// Build array reader for list type. - /// Currently this is not supported. fn visit_list_with_item( &mut self, - _list_type: Rc, - _item_type: &Type, - _context: &'a ArrayReaderBuilderContext, + list_type: Rc, + item_type: Rc, + context: &'a ArrayReaderBuilderContext, ) -> Result>> { - Err(ArrowError( - "Reading parquet list array into arrow is not supported yet!".to_string(), - )) + let list_child = &list_type + .get_fields() + .first() + .ok_or_else(|| ArrowError("List field must have a child.".to_string()))?; + let mut new_context = context.clone(); + + new_context.path.append(vec![list_type.name().to_string()]); + + match list_type.get_basic_info().repetition() { + Repetition::REPEATED => { + new_context.def_level += 1; + new_context.rep_level += 1; + } + Repetition::OPTIONAL => { + new_context.def_level += 1; + } + _ => (), + } + + match list_child.get_basic_info().repetition() { + Repetition::REPEATED => { + new_context.def_level += 1; + new_context.rep_level += 1; + } + Repetition::OPTIONAL => { + new_context.def_level += 1; + } + _ => (), + } + + let item_reader = self + .dispatch(item_type.clone(), &new_context) + .unwrap() + .unwrap(); + + let item_reader_type = item_reader.get_data_type().clone(); + + match item_reader_type { + ArrowType::List(_) + | ArrowType::FixedSizeList(_, _) + | ArrowType::Struct(_) + | ArrowType::Dictionary(_, _) => Err(ArrowError(format!( + "reading List({:?}) into arrow not supported yet", + item_type + ))), + _ => { + let arrow_type = self + .arrow_schema + .field_with_name(list_type.name()) + .ok() + .map(|f| f.data_type().to_owned()) + .unwrap_or_else(|| { + ArrowType::List(Box::new(item_reader_type.clone())) + }); + + let list_array_reader: Box = match arrow_type { + ArrowType::List(_) => Box::new(ListArrayReader::::new( + item_reader, + arrow_type, + item_reader_type, + new_context.def_level, + new_context.rep_level, + )), + ArrowType::LargeList(_) => Box::new(ListArrayReader::::new( + item_reader, + arrow_type, + item_reader_type, + new_context.def_level, + new_context.rep_level, + )), + + _ => { + return Err(ArrowError(format!( + "creating ListArrayReader with type {:?} should be unreachable", + arrow_type + ))) + } + }; + + Ok(Some(list_array_reader)) + } + } } } @@ -773,11 +1346,13 @@ impl<'a> ArrayReaderBuilder { /// Construct array reader builder. fn new( root_schema: TypePtr, + arrow_schema: Rc, columns_included: Rc>, file_reader: Rc, ) -> Self { Self { root_schema, + arrow_schema, columns_included, file_reader, } @@ -818,18 +1393,37 @@ impl<'a> ArrayReaderBuilder { self.file_reader.clone(), )?); + let arrow_type = self + .arrow_schema + .field_with_name(cur_type.name()) + .ok() + .map(|f| f.data_type()) + .cloned(); + match cur_type.get_physical_type() { PhysicalType::BOOLEAN => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), - PhysicalType::INT32 => Ok(Box::new(PrimitiveArrayReader::::new( - page_iterator, - column_desc, - )?)), + PhysicalType::INT32 => { + if let Some(ArrowType::Null) = arrow_type { + Ok(Box::new(NullArrayReader::::new( + page_iterator, + column_desc, + )?)) + } else { + Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)) + } + } PhysicalType::INT64 => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), PhysicalType::INT96 => { let converter = Int96Converter::new(Int96ArrayConverter {}); @@ -837,24 +1431,61 @@ impl<'a> ArrayReaderBuilder { Int96Type, Int96Converter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } PhysicalType::FLOAT => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), - PhysicalType::DOUBLE => Ok(Box::new( - PrimitiveArrayReader::::new(page_iterator, column_desc)?, - )), + PhysicalType::DOUBLE => { + Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)) + } PhysicalType::BYTE_ARRAY => { if cur_type.get_basic_info().logical_type() == LogicalType::UTF8 { - let converter = Utf8Converter::new(Utf8ArrayConverter {}); + if let Some(ArrowType::LargeUtf8) = arrow_type { + let converter = + LargeUtf8Converter::new(LargeUtf8ArrayConverter {}); + Ok(Box::new(ComplexObjectArrayReader::< + ByteArrayType, + LargeUtf8Converter, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } else { + let converter = Utf8Converter::new(Utf8ArrayConverter {}); + Ok(Box::new(ComplexObjectArrayReader::< + ByteArrayType, + Utf8Converter, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } + } else if let Some(ArrowType::LargeBinary) = arrow_type { + let converter = + LargeBinaryConverter::new(LargeBinaryArrayConverter {}); Ok(Box::new(ComplexObjectArrayReader::< ByteArrayType, - Utf8Converter, + LargeBinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } else { let converter = BinaryConverter::new(BinaryArrayConverter {}); @@ -862,7 +1493,10 @@ impl<'a> ArrayReaderBuilder { ByteArrayType, BinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } } @@ -884,7 +1518,10 @@ impl<'a> ArrayReaderBuilder { FixedLenByteArrayType, FixedLenBinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } } @@ -901,11 +1538,15 @@ impl<'a> ArrayReaderBuilder { for child in cur_type.get_fields() { if let Some(child_reader) = self.dispatch(child.clone(), context)? { - fields.push(Field::new( - child.name(), - child_reader.get_data_type().clone(), - child.is_optional(), - )); + let field = match self.arrow_schema.field_with_name(child.name()) { + Ok(f) => f.to_owned(), + _ => Field::new( + child.name(), + child_reader.get_data_type().clone(), + child.is_optional(), + ), + }; + fields.push(field); children_reader.push(child_reader); } } @@ -928,6 +1569,7 @@ impl<'a> ArrayReaderBuilder { mod tests { use super::*; use crate::arrow::converter::Utf8Converter; + use crate::arrow::schema::parquet_to_arrow_schema; use crate::basic::{Encoding, Type as PhysicalType}; use crate::column::page::{Page, PageReader}; use crate::data_type::{ByteArray, DataType, Int32Type, Int64Type}; @@ -939,12 +1581,17 @@ mod tests { DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator, }; use crate::util::test_common::{get_test_file, make_pages}; - use arrow::array::{Array, ArrayRef, PrimitiveArray, StringArray, StructArray}; + use arrow::array::{ + Array, ArrayRef, LargeListArray, ListArray, PrimitiveArray, StringArray, + StructArray, + }; use arrow::datatypes::{ - DataType as ArrowType, Date32Type as ArrowDate32, Field, Int32Type as ArrowInt32, + ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field, + Int32Type as ArrowInt32, Int64Type as ArrowInt64, + Time32MillisecondType as ArrowTime32MillisecondArray, + Time64MicrosecondType as ArrowTime64MicrosecondArray, TimestampMicrosecondType as ArrowTimestampMicrosecondType, TimestampMillisecondType as ArrowTimestampMillisecondType, - UInt32Type as ArrowUInt32, UInt64Type as ArrowUInt64, }; use rand::distributions::uniform::SampleUniform; use rand::{thread_rng, Rng}; @@ -1011,9 +1658,12 @@ mod tests { let column_desc = schema.column(0); let page_iterator = EmptyPageIterator::new(schema); - let mut array_reader = - PrimitiveArrayReader::::new(Box::new(page_iterator), column_desc) - .unwrap(); + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); // expect no values to be read let array = array_reader.next_batch(50).unwrap(); @@ -1058,6 +1708,7 @@ mod tests { let mut array_reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, + None, ) .unwrap(); @@ -1101,7 +1752,7 @@ mod tests { } macro_rules! test_primitive_array_reader_one_type { - ($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_primitive_type:ty) => {{ + ($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_arrow_cast_type:ty, $result_primitive_type:ty) => {{ let message_type = format!( " message test_schema {{ @@ -1112,7 +1763,7 @@ mod tests { ); let schema = parse_message_type(&message_type) .map(|t| Rc::new(SchemaDescriptor::new(Rc::new(t)))) - .unwrap(); + .expect("Unable to parse message type into a schema descriptor"); let column_desc = schema.column(0); @@ -1141,25 +1792,50 @@ mod tests { let mut array_reader = PrimitiveArrayReader::<$arrow_parquet_type>::new( Box::new(page_iterator), column_desc.clone(), + None, ) - .unwrap(); + .expect("Unable to get array reader"); - let array = array_reader.next_batch(50).unwrap(); + let array = array_reader + .next_batch(50) + .expect("Unable to get batch from reader"); + let result_data_type = <$result_arrow_type>::get_data_type(); let array = array .as_any() .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::<$result_arrow_type>::from( - data[0..50] - .iter() - .map(|x| *x as $result_primitive_type) - .collect::>() - ), - array + .expect( + format!( + "Unable to downcast {:?} to {:?}", + array.data_type(), + result_data_type + ) + .as_str(), + ); + + // create expected array as primitive, and cast to result type + let expected = PrimitiveArray::<$result_arrow_cast_type>::from( + data[0..50] + .iter() + .map(|x| *x as $result_primitive_type) + .collect::>(), ); + let expected = Arc::new(expected) as ArrayRef; + let expected = arrow::compute::cast(&expected, &result_data_type) + .expect("Unable to cast expected array"); + assert_eq!(expected.data_type(), &result_data_type); + let expected = expected + .as_any() + .downcast_ref::>() + .expect( + format!( + "Unable to downcast expected {:?} to {:?}", + expected.data_type(), + result_data_type + ) + .as_str(), + ); + assert_eq!(expected, array); } }}; } @@ -1171,27 +1847,31 @@ mod tests { PhysicalType::INT32, "DATE", ArrowDate32, + ArrowInt32, i32 ); test_primitive_array_reader_one_type!( Int32Type, PhysicalType::INT32, "TIME_MILLIS", - ArrowUInt32, - u32 + ArrowTime32MillisecondArray, + ArrowInt32, + i32 ); test_primitive_array_reader_one_type!( Int64Type, PhysicalType::INT64, "TIME_MICROS", - ArrowUInt64, - u64 + ArrowTime64MicrosecondArray, + ArrowInt64, + i64 ); test_primitive_array_reader_one_type!( Int64Type, PhysicalType::INT64, "TIMESTAMP_MILLIS", ArrowTimestampMillisecondType, + ArrowInt64, i64 ); test_primitive_array_reader_one_type!( @@ -1199,6 +1879,7 @@ mod tests { PhysicalType::INT64, "TIMESTAMP_MICROS", ArrowTimestampMicrosecondType, + ArrowInt64, i64 ); } @@ -1245,6 +1926,7 @@ mod tests { let mut array_reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, + None, ) .unwrap(); @@ -1358,6 +2040,7 @@ mod tests { Box::new(page_iterator), column_desc, converter, + None, ) .unwrap(); @@ -1543,8 +2226,16 @@ mod tests { let file = get_test_file("nulls.snappy.parquet"); let file_reader = Rc::new(SerializedFileReader::new(file).unwrap()); + let file_metadata = file_reader.metadata().file_metadata(); + let arrow_schema = parquet_to_arrow_schema( + file_metadata.schema_descr(), + file_metadata.key_value_metadata(), + ) + .unwrap(); + let array_reader = build_array_reader( file_reader.metadata().file_metadata().schema_descr_ptr(), + arrow_schema, vec![0usize].into_iter(), file_reader, ) @@ -1559,4 +2250,113 @@ mod tests { assert_eq!(array_reader.get_data_type(), &arrow_type); } + + #[test] + fn test_list_array_reader() { + // [[1, null, 2], null, [3, 4]] + let array = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + ])); + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + array, + Some(vec![3, 2, 3, 0, 3, 3]), + Some(vec![0, 1, 1, 0, 0, 1]), + ); + + let mut list_array_reader = ListArrayReader::::new( + Box::new(item_array_reader), + ArrowType::List(Box::new(ArrowType::Int32)), + ArrowType::Int32, + 1, + 1, + ); + + let next_batch = list_array_reader.next_batch(1024).unwrap(); + let list_array = next_batch.as_any().downcast_ref::().unwrap(); + + assert_eq!(3, list_array.len()); + // This passes as I expect + assert_eq!(1, list_array.null_count()); + + assert_eq!( + list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), None, Some(2)]) + ); + + assert!(list_array.is_null(1)); + + assert_eq!( + list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(3), Some(4)]) + ); + } + + #[test] + fn test_large_list_array_reader() { + // [[1, null, 2], null, [3, 4]] + let array = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + ])); + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + array, + Some(vec![3, 2, 3, 0, 3, 3]), + Some(vec![0, 1, 1, 0, 0, 1]), + ); + + let mut list_array_reader = ListArrayReader::::new( + Box::new(item_array_reader), + ArrowType::LargeList(Box::new(ArrowType::Int32)), + ArrowType::Int32, + 1, + 1, + ); + + let next_batch = list_array_reader.next_batch(1024).unwrap(); + let list_array = next_batch + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(3, list_array.len()); + + assert_eq!( + list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), None, Some(2)]) + ); + + assert!(list_array.is_null(1)); + + assert_eq!( + list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(3), Some(4)]) + ); + } } diff --git a/rust/parquet/src/arrow/arrow_reader.rs b/rust/parquet/src/arrow/arrow_reader.rs index b654de1ad0a..88af583a3d4 100644 --- a/rust/parquet/src/arrow/arrow_reader.rs +++ b/rust/parquet/src/arrow/arrow_reader.rs @@ -19,7 +19,9 @@ use crate::arrow::array_reader::{build_array_reader, ArrayReader, StructArrayReader}; use crate::arrow::schema::parquet_to_arrow_schema; -use crate::arrow::schema::parquet_to_arrow_schema_by_columns; +use crate::arrow::schema::{ + parquet_to_arrow_schema_by_columns, parquet_to_arrow_schema_by_root_columns, +}; use crate::errors::{ParquetError, Result}; use crate::file::reader::FileReader; use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; @@ -40,7 +42,12 @@ pub trait ArrowReader { /// Read parquet schema and convert it into arrow schema. /// This schema only includes columns identified by `column_indices`. - fn get_schema_by_columns(&mut self, column_indices: T) -> Result + /// To select leaf columns (i.e. `a.b.c` instead of `a`), set `leaf_columns = true` + fn get_schema_by_columns( + &mut self, + column_indices: T, + leaf_columns: bool, + ) -> Result where T: IntoIterator; @@ -84,16 +91,28 @@ impl ArrowReader for ParquetFileArrowReader { ) } - fn get_schema_by_columns(&mut self, column_indices: T) -> Result + fn get_schema_by_columns( + &mut self, + column_indices: T, + leaf_columns: bool, + ) -> Result where T: IntoIterator, { let file_metadata = self.file_reader.metadata().file_metadata(); - parquet_to_arrow_schema_by_columns( - file_metadata.schema_descr(), - column_indices, - file_metadata.key_value_metadata(), - ) + if leaf_columns { + parquet_to_arrow_schema_by_columns( + file_metadata.schema_descr(), + column_indices, + file_metadata.key_value_metadata(), + ) + } else { + parquet_to_arrow_schema_by_root_columns( + file_metadata.schema_descr(), + column_indices, + file_metadata.key_value_metadata(), + ) + } } fn get_record_reader( @@ -123,6 +142,7 @@ impl ArrowReader for ParquetFileArrowReader { .metadata() .file_metadata() .schema_descr_ptr(), + self.get_schema()?, column_indices, self.file_reader.clone(), )?; diff --git a/rust/parquet/src/arrow/arrow_writer.rs b/rust/parquet/src/arrow/arrow_writer.rs new file mode 100644 index 00000000000..beac7957f9c --- /dev/null +++ b/rust/parquet/src/arrow/arrow_writer.rs @@ -0,0 +1,1367 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains writer which writes arrow data into parquet data. + +use std::rc::Rc; + +use arrow::array as arrow_array; +use arrow::datatypes::{DataType as ArrowDataType, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, PrimitiveArrayOps}; + +use super::schema::add_encoded_arrow_schema_to_metadata; +use crate::column::writer::{ColumnWriter, ColumnWriterImpl}; +use crate::errors::{ParquetError, Result}; +use crate::file::properties::WriterProperties; +use crate::{ + data_type::*, + file::writer::{FileWriter, ParquetWriter, RowGroupWriter, SerializedFileWriter}, +}; + +/// Arrow writer +/// +/// Writes Arrow `RecordBatch`es to a Parquet writer +pub struct ArrowWriter { + /// Underlying Parquet writer + writer: SerializedFileWriter, + /// A copy of the Arrow schema. + /// + /// The schema is used to verify that each record batch written has the correct schema + arrow_schema: SchemaRef, +} + +impl ArrowWriter { + /// Try to create a new Arrow writer + /// + /// The writer will fail if: + /// * a `SerializedFileWriter` cannot be created from the ParquetWriter + /// * the Arrow schema contains unsupported datatypes such as Unions + pub fn try_new( + writer: W, + arrow_schema: SchemaRef, + props: Option, + ) -> Result { + let schema = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; + // add serialized arrow schema + let mut props = props.unwrap_or_else(|| WriterProperties::builder().build()); + add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut props); + + let file_writer = SerializedFileWriter::new( + writer.try_clone()?, + schema.root_schema_ptr(), + Rc::new(props), + )?; + + Ok(Self { + writer: file_writer, + arrow_schema, + }) + } + + /// Write a RecordBatch to writer + /// + /// *NOTE:* The writer currently does not support all Arrow data types + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + // validate batch schema against writer's supplied schema + if self.arrow_schema != batch.schema() { + return Err(ParquetError::ArrowError( + "Record batch schema does not match writer schema".to_string(), + )); + } + // compute the definition and repetition levels of the batch + let mut levels = vec![]; + batch.columns().iter().for_each(|array| { + let mut array_levels = + get_levels(array, 0, &vec![1i16; batch.num_rows()][..], None); + levels.append(&mut array_levels); + }); + // reverse levels so we can use Vec::pop(&mut self) + levels.reverse(); + + let mut row_group_writer = self.writer.next_row_group()?; + + // write leaves + for column in batch.columns() { + write_leaves(&mut row_group_writer, column, &mut levels)?; + } + + self.writer.close_row_group(row_group_writer) + } + + /// Close and finalise the underlying Parquet writer + pub fn close(&mut self) -> Result<()> { + self.writer.close() + } +} + +/// Convenience method to get the next ColumnWriter from the RowGroupWriter +#[inline] +#[allow(clippy::borrowed_box)] +fn get_col_writer( + row_group_writer: &mut Box, +) -> Result { + let col_writer = row_group_writer + .next_column()? + .expect("Unable to get column writer"); + Ok(col_writer) +} + +#[allow(clippy::borrowed_box)] +fn write_leaves( + mut row_group_writer: &mut Box, + array: &arrow_array::ArrayRef, + mut levels: &mut Vec, +) -> Result<()> { + match array.data_type() { + ArrowDataType::Null + | ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) + | ArrowDataType::LargeBinary + | ArrowDataType::Binary + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 => { + let mut col_writer = get_col_writer(&mut row_group_writer)?; + write_leaf( + &mut col_writer, + array, + levels.pop().expect("Levels exhausted"), + )?; + row_group_writer.close_column(col_writer)?; + Ok(()) + } + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + // write the child list + let data = array.data(); + let child_array = arrow_array::make_array(data.child_data()[0].clone()); + write_leaves(&mut row_group_writer, &child_array, &mut levels)?; + Ok(()) + } + ArrowDataType::Struct(_) => { + let struct_array: &arrow_array::StructArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get struct array"); + for field in struct_array.columns() { + write_leaves(&mut row_group_writer, field, &mut levels)?; + } + Ok(()) + } + ArrowDataType::Dictionary(key_type, value_type) => { + use arrow_array::{ + Int16DictionaryArray, Int32DictionaryArray, Int64DictionaryArray, + Int8DictionaryArray, PrimitiveArray, StringArray, UInt16DictionaryArray, + UInt32DictionaryArray, UInt64DictionaryArray, UInt8DictionaryArray, + }; + use ArrowDataType::*; + use ColumnWriter::*; + + let array = &**array; + let mut col_writer = get_col_writer(&mut row_group_writer)?; + let levels = levels.pop().expect("Levels exhausted"); + + macro_rules! dispatch_dictionary { + ($($kt: pat, $vt: pat, $w: ident => $kat: ty, $vat: ty,)*) => ( + match (&**key_type, &**value_type, &mut col_writer) { + $(($kt, $vt, $w(writer)) => write_dict::<$kat, $vat, _>(array, writer, levels),)* + (kt, vt, _) => panic!("Don't know how to write dictionary of <{:?}, {:?}>", kt, vt), + } + ); + } + + match (&**key_type, &**value_type, &mut col_writer) { + (UInt8, UInt32, Int32ColumnWriter(writer)) => { + let typed_array = array + .as_any() + .downcast_ref::() + .expect("Unable to get dictionary array"); + + let keys = typed_array.keys(); + + let value_buffer = typed_array.values(); + let value_array = + arrow::compute::cast(&value_buffer, &ArrowDataType::Int32)?; + + let values = value_array + .as_any() + .downcast_ref::() + .unwrap(); + + use std::convert::TryFrom; + // This removes NULL values from the NullableIter, but + // they're encoded by the levels, so that's fine. + let materialized_values: Vec<_> = keys + .flatten() + .map(|key| { + usize::try_from(key).unwrap_or_else(|k| { + panic!("key {} does not fit in usize", k) + }) + }) + .map(|key| values.value(key)) + .collect(); + + let materialized_primitive_array = + PrimitiveArray::::from( + materialized_values, + ); + + writer.write_batch( + get_numeric_array_slice::( + &materialized_primitive_array, + ) + .as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )?; + row_group_writer.close_column(col_writer)?; + + return Ok(()); + } + _ => {} + } + + dispatch_dictionary!( + Int8, Utf8, ByteArrayColumnWriter => Int8DictionaryArray, StringArray, + Int16, Utf8, ByteArrayColumnWriter => Int16DictionaryArray, StringArray, + Int32, Utf8, ByteArrayColumnWriter => Int32DictionaryArray, StringArray, + Int64, Utf8, ByteArrayColumnWriter => Int64DictionaryArray, StringArray, + UInt8, Utf8, ByteArrayColumnWriter => UInt8DictionaryArray, StringArray, + UInt16, Utf8, ByteArrayColumnWriter => UInt16DictionaryArray, StringArray, + UInt32, Utf8, ByteArrayColumnWriter => UInt32DictionaryArray, StringArray, + UInt64, Utf8, ByteArrayColumnWriter => UInt64DictionaryArray, StringArray, + )?; + + row_group_writer.close_column(col_writer)?; + + Ok(()) + } + ArrowDataType::Float16 => Err(ParquetError::ArrowError( + "Float16 arrays not supported".to_string(), + )), + ArrowDataType::FixedSizeList(_, _) + | ArrowDataType::Boolean + | ArrowDataType::FixedSizeBinary(_) + | ArrowDataType::Union(_) => Err(ParquetError::NYI( + "Attempting to write an Arrow type that is not yet implemented".to_string(), + )), + } +} + +trait Materialize { + type Output; + + // Materialize the packed dictionary. The writer will later repack it. + fn materialize(&self) -> Vec; +} + +macro_rules! materialize_string { + ($($k:ty,)*) => { + $(impl Materialize<$k, arrow_array::StringArray> for dyn Array { + type Output = ByteArray; + + fn materialize(&self) -> Vec { + use std::convert::TryFrom; + + let typed_array = self.as_any() + .downcast_ref::<$k>() + .expect("Unable to get dictionary array"); + + let keys = typed_array.keys(); + + let value_buffer = typed_array.values(); + let values = value_buffer + .as_any() + .downcast_ref::() + .unwrap(); + + // This removes NULL values from the NullableIter, but + // they're encoded by the levels, so that's fine. + keys + .flatten() + .map(|key| usize::try_from(key).unwrap_or_else(|k| panic!("key {} does not fit in usize", k))) + .map(|key| values.value(key)) + .map(ByteArray::from) + .collect() + } + })* + }; +} + +materialize_string! { + arrow_array::Int8DictionaryArray, + arrow_array::Int16DictionaryArray, + arrow_array::Int32DictionaryArray, + arrow_array::Int64DictionaryArray, + arrow_array::UInt8DictionaryArray, + arrow_array::UInt16DictionaryArray, + arrow_array::UInt32DictionaryArray, + arrow_array::UInt64DictionaryArray, +} + +fn write_dict( + array: &(dyn Array + 'static), + writer: &mut ColumnWriterImpl, + levels: Levels, +) -> Result<()> +where + T: DataType, + dyn Array: Materialize, +{ + writer.write_batch( + &array.materialize(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )?; + + Ok(()) +} + +fn write_leaf( + writer: &mut ColumnWriter, + column: &arrow_array::ArrayRef, + levels: Levels, +) -> Result { + let written = match writer { + ColumnWriter::Int32ColumnWriter(ref mut typed) => { + let array = arrow::compute::cast(column, &ArrowDataType::Int32)?; + let array = array + .as_any() + .downcast_ref::() + .expect("Unable to get int32 array"); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::BoolColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + ColumnWriter::Int64ColumnWriter(ref mut typed) => { + let array = arrow_array::Int64Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::Int96ColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + ColumnWriter::FloatColumnWriter(ref mut typed) => { + let array = arrow_array::Float32Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::DoubleColumnWriter(ref mut typed) => { + let array = arrow_array::Float64Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match column.data_type() { + ArrowDataType::Binary | ArrowDataType::Utf8 => { + let array = arrow_array::BinaryArray::from(column.data()); + typed.write_batch( + get_binary_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ArrowDataType::LargeBinary | ArrowDataType::LargeUtf8 => { + let array = arrow_array::LargeBinaryArray::from(column.data()); + typed.write_batch( + get_large_binary_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + _ => unreachable!("Currently unreachable because data type not supported"), + }, + ColumnWriter::FixedLenByteArrayColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + }; + Ok(written as i64) +} + +/// A struct that represents definition and repetition levels. +/// Repetition levels are only populated if the parent or current leaf is repeated +#[derive(Debug)] +struct Levels { + definition: Vec, + repetition: Option>, +} + +/// Compute nested levels of the Arrow array, recursing into lists and structs +fn get_levels( + array: &arrow_array::ArrayRef, + level: i16, + parent_def_levels: &[i16], + parent_rep_levels: Option<&[i16]>, +) -> Vec { + match array.data_type() { + ArrowDataType::Null => vec![Levels { + definition: parent_def_levels.iter().map(|v| (v - 1).max(0)).collect(), + repetition: None, + }], + ArrowDataType::Boolean + | ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float16 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) + | ArrowDataType::Binary + | ArrowDataType::LargeBinary => vec![Levels { + definition: get_primitive_def_levels(array, parent_def_levels), + repetition: None, + }], + ArrowDataType::FixedSizeBinary(_) => unimplemented!(), + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + let array_data = array.data(); + let child_data = array_data.child_data().get(0).unwrap(); + // get offsets, accounting for large offsets if present + let offsets: Vec = { + if let ArrowDataType::LargeList(_) = array.data_type() { + unsafe { array_data.buffers()[0].typed_data::() }.to_vec() + } else { + let offsets = unsafe { array_data.buffers()[0].typed_data::() }; + offsets.to_vec().into_iter().map(|v| v as i64).collect() + } + }; + let child_array = arrow_array::make_array(child_data.clone()); + + let mut list_def_levels = Vec::with_capacity(child_array.len()); + let mut list_rep_levels = Vec::with_capacity(child_array.len()); + let rep_levels: Vec = parent_rep_levels + .map(|l| l.to_vec()) + .unwrap_or_else(|| vec![0i16; parent_def_levels.len()]); + parent_def_levels + .iter() + .zip(rep_levels) + .zip(offsets.windows(2)) + .for_each(|((parent_def_level, parent_rep_level), window)| { + if *parent_def_level == 0 { + // parent is null, list element must also be null + list_def_levels.push(0); + list_rep_levels.push(0); + } else { + // parent is not null, check if list is empty or null + let start = window[0]; + let end = window[1]; + let len = end - start; + if len == 0 { + list_def_levels.push(*parent_def_level - 1); + list_rep_levels.push(parent_rep_level); + } else { + list_def_levels.push(*parent_def_level); + list_rep_levels.push(parent_rep_level); + for _ in 1..len { + list_def_levels.push(*parent_def_level); + list_rep_levels.push(parent_rep_level + 1); + } + } + } + }); + + // if datatype is a primitive, we can construct levels of the child array + match child_array.data_type() { + // TODO: The behaviour of a > is untested + ArrowDataType::Null => vec![Levels { + definition: list_def_levels, + repetition: Some(list_rep_levels), + }], + ArrowDataType::Boolean => unimplemented!(), + ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float16 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) => { + let def_levels = + get_primitive_def_levels(&child_array, &list_def_levels[..]); + vec![Levels { + definition: def_levels, + repetition: Some(list_rep_levels), + }] + } + ArrowDataType::Binary + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 => unimplemented!(), + ArrowDataType::FixedSizeBinary(_) => unimplemented!(), + ArrowDataType::LargeBinary => unimplemented!(), + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + // nested list + unimplemented!() + } + ArrowDataType::FixedSizeList(_, _) => unimplemented!(), + ArrowDataType::Struct(_) => get_levels( + array, + level + 1, // indicates a nesting level of 2 (list + struct) + &list_def_levels[..], + Some(&list_rep_levels[..]), + ), + ArrowDataType::Union(_) => unimplemented!(), + ArrowDataType::Dictionary(_, _) => unimplemented!(), + } + } + ArrowDataType::FixedSizeList(_, _) => unimplemented!(), + ArrowDataType::Struct(_) => { + let struct_array: &arrow_array::StructArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get struct array"); + let mut struct_def_levels = Vec::with_capacity(struct_array.len()); + for i in 0..array.len() { + struct_def_levels.push(level + struct_array.is_valid(i) as i16); + } + // trying to create levels for struct's fields + let mut struct_levels = vec![]; + struct_array.columns().into_iter().for_each(|col| { + let mut levels = + get_levels(col, level + 1, &struct_def_levels[..], parent_rep_levels); + struct_levels.append(&mut levels); + }); + struct_levels + } + ArrowDataType::Union(_) => unimplemented!(), + ArrowDataType::Dictionary(_, _) => { + // Need to check for these cases not implemented in C++: + // - "Writing DictionaryArray with nested dictionary type not yet supported" + // - "Writing DictionaryArray with null encoded in dictionary type not yet supported" + vec![Levels { + definition: get_primitive_def_levels(array, parent_def_levels), + repetition: None, + }] + } + } +} + +/// Get the definition levels of the numeric array, with level 0 being null and 1 being not null +/// In the case where the array in question is a child of either a list or struct, the levels +/// are incremented in accordance with the `level` parameter. +/// Parent levels are either 0 or 1, and are used to higher (correct terminology?) leaves as null +fn get_primitive_def_levels( + array: &arrow_array::ArrayRef, + parent_def_levels: &[i16], +) -> Vec { + let mut array_index = 0; + let max_def_level = parent_def_levels.iter().max().unwrap(); + let mut primitive_def_levels = vec![]; + parent_def_levels.iter().for_each(|def_level| { + if def_level < max_def_level { + primitive_def_levels.push(*def_level); + } else { + primitive_def_levels.push(def_level - array.is_null(array_index) as i16); + array_index += 1; + } + }); + primitive_def_levels +} + +macro_rules! def_get_binary_array_fn { + ($name:ident, $ty:ty) => { + fn $name(array: &$ty) -> Vec { + let mut values = Vec::with_capacity(array.len() - array.null_count()); + for i in 0..array.len() { + if array.is_valid(i) { + let bytes = ByteArray::from(array.value(i).to_vec()); + values.push(bytes); + } + } + values + } + }; +} + +def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray); +def_get_binary_array_fn!(get_large_binary_array, arrow_array::LargeBinaryArray); + +/// Get the underlying numeric array slice, skipping any null values. +/// If there are no null values, it might be quicker to get the slice directly instead of +/// calling this function. +fn get_numeric_array_slice(array: &arrow_array::PrimitiveArray) -> Vec +where + T: DataType, + A: arrow::datatypes::ArrowNumericType, + T::T: From, +{ + let mut values = Vec::with_capacity(array.len() - array.null_count()); + for i in 0..array.len() { + if array.is_valid(i) { + values.push(array.value(i).into()) + } + } + values +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::Seek; + use std::sync::Arc; + + use arrow::array::*; + use arrow::datatypes::ToByteSlice; + use arrow::datatypes::{DataType, Field, Schema, UInt32Type, UInt8Type}; + use arrow::record_batch::RecordBatch; + + use crate::arrow::{ArrowReader, ParquetFileArrowReader}; + use crate::file::{metadata::KeyValue, reader::SerializedFileReader}; + use crate::util::test_common::get_temp_file; + + #[test] + fn arrow_writer() { + // define schema + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + ) + .unwrap(); + + let file = get_temp_file("test_arrow_writer.parquet", &[]); + let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[test] + #[ignore = "repetitions might be incorrect, will be addressed as part of ARROW-9728"] + fn arrow_writer_list() { + // define schema + let schema = Schema::new(vec![Field::new( + "a", + DataType::List(Box::new(DataType::Int32)), + false, + )]); + + // create some data + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // Construct a buffer for value offsets, for the nested array: + // [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]] + let a_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + + // Construct a list array from the above two + let a_list_data = ArrayData::builder(DataType::List(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + let a = ListArray::from(a_list_data); + + // build a record batch + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)]).unwrap(); + + // I think this setup is incorrect because this should pass + assert_eq!(batch.column(0).data().null_count(), 1); + + let file = get_temp_file("test_arrow_writer_list.parquet", &[]); + let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[test] + fn arrow_writer_binary() { + let string_field = Field::new("a", DataType::Utf8, false); + let binary_field = Field::new("b", DataType::Binary, false); + let schema = Schema::new(vec![string_field, binary_field]); + + let raw_string_values = vec!["foo", "bar", "baz", "quux"]; + let raw_binary_values = vec![ + b"foo".to_vec(), + b"bar".to_vec(), + b"baz".to_vec(), + b"quux".to_vec(), + ]; + let raw_binary_value_refs = raw_binary_values + .iter() + .map(|x| x.as_slice()) + .collect::>(); + + let string_values = StringArray::from(raw_string_values.clone()); + let binary_values = BinaryArray::from(raw_binary_value_refs); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(string_values), Arc::new(binary_values)], + ) + .unwrap(); + + let mut file = get_temp_file("test_arrow_writer_binary.parquet", &[]); + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema), None) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + file.seek(std::io::SeekFrom::Start(0)).unwrap(); + let file_reader = SerializedFileReader::new(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(file_reader)); + let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + + let batch = record_batch_reader.next().unwrap().unwrap(); + let string_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let binary_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + assert_eq!(string_col.value(i), raw_string_values[i]); + assert_eq!(binary_col.value(i), raw_binary_values[i].as_slice()); + } + } + + #[test] + fn arrow_writer_complex() { + // define schema + let struct_field_d = Field::new("d", DataType::Float64, true); + let struct_field_f = Field::new("f", DataType::Float32, true); + let struct_field_g = + Field::new("g", DataType::List(Box::new(DataType::Int16)), false); + let struct_field_e = Field::new( + "e", + DataType::Struct(vec![struct_field_f.clone(), struct_field_g.clone()]), + true, + ); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + Field::new( + "c", + DataType::Struct(vec![struct_field_d.clone(), struct_field_e.clone()]), + false, + ), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); + let d = Float64Array::from(vec![None, None, None, Some(1.0), None]); + let f = Float32Array::from(vec![Some(0.0), None, Some(333.3), None, Some(5.25)]); + + let g_value = Int16Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // Construct a buffer for value offsets, for the nested array: + // [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]] + let g_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + + // Construct a list array from the above two + let g_list_data = ArrayData::builder(struct_field_g.data_type().clone()) + .len(5) + .add_buffer(g_value_offsets) + .add_child_data(g_value.data()) + .build(); + let g = ListArray::from(g_list_data); + + let e = StructArray::from(vec![ + (struct_field_f, Arc::new(f) as ArrayRef), + (struct_field_g, Arc::new(g) as ArrayRef), + ]); + + let c = StructArray::from(vec![ + (struct_field_d, Arc::new(d) as ArrayRef), + (struct_field_e, Arc::new(e) as ArrayRef), + ]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b), Arc::new(c)], + ) + .unwrap(); + + let props = WriterProperties::builder() + .set_key_value_metadata(Some(vec![KeyValue { + key: "test_key".to_string(), + value: Some("test_value".to_string()), + }])) + .build(); + + let file = get_temp_file("test_arrow_writer_complex.parquet", &[]); + let mut writer = + ArrowWriter::try_new(file, Arc::new(schema), Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + const SMALL_SIZE: usize = 100; + + fn roundtrip(filename: &str, expected_batch: RecordBatch) { + let file = get_temp_file(filename, &[]); + + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + expected_batch.schema(), + None, + ) + .expect("Unable to write file"); + writer.write(&expected_batch).unwrap(); + writer.close().unwrap(); + + let reader = SerializedFileReader::new(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(reader)); + let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + + let actual_batch = record_batch_reader + .next() + .expect("No batch found") + .expect("Unable to get batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + for i in 0..expected_batch.num_columns() { + let expected_data = expected_batch.column(i).data(); + let actual_data = actual_batch.column(i).data(); + + assert_eq!(expected_data.data_type(), actual_data.data_type()); + assert_eq!(expected_data.len(), actual_data.len()); + assert_eq!(expected_data.null_count(), actual_data.null_count()); + assert_eq!(expected_data.offset(), actual_data.offset()); + assert_eq!(expected_data.buffers(), actual_data.buffers()); + assert_eq!(expected_data.child_data(), actual_data.child_data()); + // Null counts should be the same, not necessarily bitmaps + // A null bitmap is optional if an array has no nulls + if expected_data.null_count() != 0 { + assert_eq!(expected_data.null_bitmap(), actual_data.null_bitmap()); + } + } + } + + fn one_column_roundtrip(filename: &str, values: ArrayRef, nullable: bool) { + let schema = Schema::new(vec![Field::new( + "col", + values.data_type().clone(), + nullable, + )]); + let expected_batch = + RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); + + roundtrip(filename, expected_batch); + } + + fn values_required(iter: I, filename: &str) + where + A: From> + Array + 'static, + I: IntoIterator, + { + let raw_values: Vec<_> = iter.into_iter().collect(); + let values = Arc::new(A::from(raw_values)); + one_column_roundtrip(filename, values, false); + } + + fn values_optional(iter: I, filename: &str) + where + A: From>> + Array + 'static, + I: IntoIterator, + { + let optional_raw_values: Vec<_> = iter + .into_iter() + .enumerate() + .map(|(i, v)| if i % 2 == 0 { None } else { Some(v) }) + .collect(); + let optional_values = Arc::new(A::from(optional_raw_values)); + one_column_roundtrip(filename, optional_values, true); + } + + fn required_and_optional(iter: I, filename: &str) + where + A: From> + From>> + Array + 'static, + I: IntoIterator + Clone, + { + values_required::(iter.clone(), filename); + values_optional::(iter, filename); + } + + #[test] + fn all_null_primitive_single_column() { + let values = Arc::new(Int32Array::from(vec![None; SMALL_SIZE])); + one_column_roundtrip("all_null_primitive_single_column", values, true); + } + #[test] + fn null_single_column() { + let values = Arc::new(NullArray::new(SMALL_SIZE)); + one_column_roundtrip("null_single_column", values, true); + // null arrays are always nullable, a test with non-nullable nulls fails + } + + #[test] + #[should_panic( + expected = "Attempting to write an Arrow type that is not yet implemented" + )] + fn bool_single_column() { + required_and_optional::( + [true, false].iter().cycle().copied().take(SMALL_SIZE), + "bool_single_column", + ); + } + + #[test] + fn i8_single_column() { + required_and_optional::(0..SMALL_SIZE as i8, "i8_single_column"); + } + + #[test] + fn i16_single_column() { + required_and_optional::(0..SMALL_SIZE as i16, "i16_single_column"); + } + + #[test] + fn i32_single_column() { + required_and_optional::(0..SMALL_SIZE as i32, "i32_single_column"); + } + + #[test] + fn i64_single_column() { + required_and_optional::(0..SMALL_SIZE as i64, "i64_single_column"); + } + + #[test] + fn u8_single_column() { + required_and_optional::(0..SMALL_SIZE as u8, "u8_single_column"); + } + + #[test] + fn u16_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u16, + "u16_single_column", + ); + } + + #[test] + fn u32_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u32, + "u32_single_column", + ); + } + + #[test] + fn u64_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u64, + "u64_single_column", + ); + } + + #[test] + fn f32_single_column() { + required_and_optional::( + (0..SMALL_SIZE).map(|i| i as f32), + "f32_single_column", + ); + } + + #[test] + fn f64_single_column() { + required_and_optional::( + (0..SMALL_SIZE).map(|i| i as f64), + "f64_single_column", + ); + } + + // The timestamp array types don't implement From> because they need the timezone + // argument, and they also doesn't support building from a Vec>, so call + // one_column_roundtrip manually instead of calling required_and_optional for these tests. + + #[test] + #[ignore] // Timestamp support isn't correct yet + fn timestamp_second_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampSecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_second_single_column", values, false); + } + + #[test] + fn timestamp_millisecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampMillisecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_millisecond_single_column", values, false); + } + + #[test] + fn timestamp_microsecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampMicrosecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_microsecond_single_column", values, false); + } + + #[test] + #[ignore] // Timestamp support isn't correct yet + fn timestamp_nanosecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampNanosecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_nanosecond_single_column", values, false); + } + + #[test] + fn date32_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "date32_single_column", + ); + } + + #[test] + #[ignore] // Date support isn't correct yet + fn date64_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "date64_single_column", + ); + } + + #[test] + #[ignore] // DateUnit resolution mismatch + fn time32_second_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "time32_second_single_column", + ); + } + + #[test] + fn time32_millisecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "time32_millisecond_single_column", + ); + } + + #[test] + fn time64_microsecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "time64_microsecond_single_column", + ); + } + + #[test] + #[ignore] // DateUnit resolution mismatch + fn time64_nanosecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "time64_nanosecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_second_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_second_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_millisecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_millisecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_microsecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_microsecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_nanosecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_nanosecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Currently unreachable because data type not supported")] + fn interval_year_month_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "interval_year_month_single_column", + ); + } + + #[test] + #[should_panic(expected = "Currently unreachable because data type not supported")] + fn interval_day_time_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "interval_day_time_single_column", + ); + } + + #[test] + #[ignore] // Binary support isn't correct yet - buffers don't match + fn binary_single_column() { + let one_vec: Vec = (0..SMALL_SIZE as u8).collect(); + let many_vecs: Vec<_> = std::iter::repeat(one_vec).take(SMALL_SIZE).collect(); + let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); + + // BinaryArrays can't be built from Vec>, so only call `values_required` + values_required::(many_vecs_iter, "binary_single_column"); + } + + #[test] + #[ignore] // Large binary support isn't correct yet - buffers don't match + fn large_binary_single_column() { + let one_vec: Vec = (0..SMALL_SIZE as u8).collect(); + let many_vecs: Vec<_> = std::iter::repeat(one_vec).take(SMALL_SIZE).collect(); + let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); + + // LargeBinaryArrays can't be built from Vec>, so only call `values_required` + values_required::( + many_vecs_iter, + "large_binary_single_column", + ); + } + + #[test] + fn string_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); + let raw_strs = raw_values.iter().map(|s| s.as_str()); + + required_and_optional::(raw_strs, "string_single_column"); + } + + #[test] + fn large_string_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); + let raw_strs = raw_values.iter().map(|s| s.as_str()); + + required_and_optional::( + raw_strs, + "large_string_single_column", + ); + } + + #[test] + #[ignore = "repetitions might be incorrect, will be addressed as part of ARROW-9728"] + fn list_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let a_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + let a_list_data = ArrayData::builder(DataType::List(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + + // I think this setup is incorrect because this should pass + assert_eq!(a_list_data.null_count(), 1); + + let a = ListArray::from(a_list_data); + let values = Arc::new(a); + + one_column_roundtrip("list_single_column", values, false); + } + + #[test] + #[ignore = "repetitions might be incorrect, will be addressed as part of ARROW-9728"] + fn large_list_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let a_value_offsets = + arrow::buffer::Buffer::from(&[0i64, 1, 3, 3, 6, 10].to_byte_slice()); + let a_list_data = + ArrayData::builder(DataType::LargeList(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + + // I think this setup is incorrect because this should pass + assert_eq!(a_list_data.null_count(), 1); + + let a = LargeListArray::from(a_list_data); + let values = Arc::new(a); + + one_column_roundtrip("large_list_single_column", values, false); + } + + #[test] + fn struct_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let struct_field_a = Field::new("f", DataType::Int32, false); + let s = StructArray::from(vec![(struct_field_a, Arc::new(a_values) as ArrayRef)]); + + let values = Arc::new(s); + one_column_roundtrip("struct_single_column", values, false); + } + + #[test] + fn arrow_writer_string_dictionary() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 42, + true, + )])); + + // create some data + let d: Int32DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_string_dictionary.parquet", + expected_batch, + ); + } + + #[test] + fn arrow_writer_primitive_dictionary() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), + true, + 42, + true, + )])); + + // create some data + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(12345678).unwrap(); + builder.append_null().unwrap(); + builder.append(22345678).unwrap(); + builder.append(12345678).unwrap(); + let d = builder.finish(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_primitive_dictionary.parquet", + expected_batch, + ); + } + + #[test] + fn arrow_writer_string_dictionary_unsigned_index() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + true, + 42, + true, + )])); + + // create some data + let d: UInt8DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_string_dictionary_unsigned_index.parquet", + expected_batch, + ); + } +} diff --git a/rust/parquet/src/arrow/converter.rs b/rust/parquet/src/arrow/converter.rs index 9fbfa339168..80008ad2f3d 100644 --- a/rust/parquet/src/arrow/converter.rs +++ b/rust/parquet/src/arrow/converter.rs @@ -17,21 +17,32 @@ use crate::arrow::record_reader::RecordReader; use crate::data_type::{ByteArray, DataType, Int96}; -use arrow::array::{ - Array, ArrayRef, BinaryBuilder, BooleanArray, BooleanBufferBuilder, - BufferBuilderTrait, FixedSizeBinaryBuilder, StringBuilder, - TimestampNanosecondBuilder, +// TODO: clean up imports (best done when there are few moving parts) +use arrow::{ + array::{ + Array, ArrayRef, BinaryBuilder, BooleanArray, BooleanBufferBuilder, + BufferBuilderTrait, FixedSizeBinaryBuilder, LargeBinaryBuilder, + LargeStringBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder, StringBuilder, + StringDictionaryBuilder, TimestampNanosecondBuilder, + }, + datatypes::Time32MillisecondType, +}; +use arrow::{ + compute::cast, datatypes::Time32SecondType, datatypes::Time64MicrosecondType, + datatypes::Time64NanosecondType, }; -use arrow::compute::cast; use std::convert::From; use std::sync::Arc; use crate::errors::Result; -use arrow::datatypes::{ArrowPrimitiveType, DataType as ArrowDataType}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, ArrowPrimitiveType, DataType as ArrowDataType, +}; use arrow::array::ArrayDataBuilder; use arrow::array::{ - BinaryArray, FixedSizeBinaryArray, PrimitiveArray, StringArray, + BinaryArray, DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, + LargeStringArray, PrimitiveArray, PrimitiveArrayOps, StringArray, TimestampNanosecondArray, }; use std::marker::PhantomData; @@ -101,7 +112,9 @@ where let primitive_array: ArrayRef = Arc::new(PrimitiveArray::::from(array_data.build())); - Ok(cast(&primitive_array, &ArrowTargetType::get_data_type())?) + // TODO: We should make this cast redundant in favour of 1 cast to rule them all + // Ok(cast(&primitive_array, &ArrowTargetType::get_data_type())?) + Ok(primitive_array) } } @@ -193,6 +206,27 @@ impl Converter>, StringArray> for Utf8ArrayConverter { } } +pub struct LargeUtf8ArrayConverter {} + +impl Converter>, LargeStringArray> for LargeUtf8ArrayConverter { + fn convert(&self, source: Vec>) -> Result { + let data_size = source + .iter() + .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) + .sum(); + + let mut builder = LargeStringBuilder::with_capacity(source.len(), data_size); + for v in source { + match v { + Some(array) => builder.append_value(array.as_utf8()?), + None => builder.append_null(), + }? + } + + Ok(builder.finish()) + } +} + pub struct BinaryArrayConverter {} impl Converter>, BinaryArray> for BinaryArrayConverter { @@ -209,11 +243,114 @@ impl Converter>, BinaryArray> for BinaryArrayConverter { } } +pub struct LargeBinaryArrayConverter {} + +impl Converter>, LargeBinaryArray> for LargeBinaryArrayConverter { + fn convert(&self, source: Vec>) -> Result { + let mut builder = LargeBinaryBuilder::new(source.len()); + for v in source { + match v { + Some(array) => builder.append_value(array.data()), + None => builder.append_null(), + }? + } + + Ok(builder.finish()) + } +} + +pub struct StringDictionaryArrayConverter {} + +impl Converter>, DictionaryArray> + for StringDictionaryArrayConverter +{ + fn convert(&self, source: Vec>) -> Result> { + let data_size = source + .iter() + .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) + .sum(); + + let keys_builder = PrimitiveBuilder::::new(source.len()); + let values_builder = StringBuilder::with_capacity(source.len(), data_size); + + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + for v in source { + match v { + Some(array) => { + let _ = builder.append(array.as_utf8()?)?; + } + None => builder.append_null()?, + } + } + + Ok(builder.finish()) + } +} + +pub struct DictionaryArrayConverter +{ + _dict_value_source_marker: PhantomData, + _dict_value_target_marker: PhantomData, + _parquet_marker: PhantomData, +} + +impl + DictionaryArrayConverter +{ + pub fn new() -> Self { + Self { + _dict_value_source_marker: PhantomData, + _dict_value_target_marker: PhantomData, + _parquet_marker: PhantomData, + } + } +} + +impl + Converter::T>>, DictionaryArray> + for DictionaryArrayConverter +where + K: ArrowPrimitiveType, + DictValueSourceType: ArrowPrimitiveType, + DictValueTargetType: ArrowPrimitiveType, + ParquetType: DataType, + PrimitiveArray: From::T>>>, +{ + fn convert( + &self, + source: Vec::T>>, + ) -> Result> { + let keys_builder = PrimitiveBuilder::::new(source.len()); + let values_builder = PrimitiveBuilder::::new(source.len()); + + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + let source_array: Arc = + Arc::new(PrimitiveArray::::from(source)); + let target_array = cast(&source_array, &DictValueTargetType::get_data_type())?; + let target = target_array + .as_any() + .downcast_ref::>() + .unwrap(); + + for i in 0..target.len() { + if target.is_null(i) { + builder.append_null()?; + } else { + let _ = builder.append(target.value(i))?; + } + } + + Ok(builder.finish()) + } +} + pub type BoolConverter<'a> = ArrayRefConverter< &'a mut RecordReader, BooleanArray, BooleanArrayConverter, >; +// TODO: intuition tells me that removing many of these converters could help us consolidate where we cast pub type Int8Converter = CastConverter; pub type UInt8Converter = CastConverter; pub type Int16Converter = CastConverter; @@ -226,13 +363,44 @@ pub type TimestampMillisecondConverter = CastConverter; pub type TimestampMicrosecondConverter = CastConverter; +pub type Time32SecondConverter = + CastConverter; +pub type Time32MillisecondConverter = + CastConverter; +pub type Time64MicrosecondConverter = + CastConverter; +pub type Time64NanosecondConverter = + CastConverter; pub type UInt64Converter = CastConverter; pub type Float32Converter = CastConverter; pub type Float64Converter = CastConverter; pub type Utf8Converter = ArrayRefConverter>, StringArray, Utf8ArrayConverter>; +pub type LargeUtf8Converter = + ArrayRefConverter>, LargeStringArray, LargeUtf8ArrayConverter>; pub type BinaryConverter = ArrayRefConverter>, BinaryArray, BinaryArrayConverter>; +pub type LargeBinaryConverter = ArrayRefConverter< + Vec>, + LargeBinaryArray, + LargeBinaryArrayConverter, +>; +pub type StringDictionaryConverter = ArrayRefConverter< + Vec>, + DictionaryArray, + StringDictionaryArrayConverter, +>; +pub type DictionaryConverter = ArrayRefConverter< + Vec::T>>, + DictionaryArray, + DictionaryArrayConverter, +>; +pub type PrimitiveDictionaryConverter = ArrayRefConverter< + Vec::T>>, + DictionaryArray, + DictionaryArrayConverter, +>; + pub type Int96Converter = ArrayRefConverter>, TimestampNanosecondArray, Int96ArrayConverter>; pub type FixedLenBinaryConverter = ArrayRefConverter< @@ -351,7 +519,10 @@ mod tests { } #[test] + #[ignore = "We need to look at whether this is still relevant after we refactor out the casts"] fn test_converter_arrow_source_i16_target_i32() { + // TODO: this fails if we remove the cast here on converter. Is it still relevant? + // I'd favour removing these Parquet::PHYSICAL > Arrow::DataType, so we can do it in 1 pleace. let raw_data = vec![Some(1i16), None, Some(2i16), Some(3i16)]; converter_arrow_source_target!(raw_data, "INT32", Int16Type, Int16Converter) } diff --git a/rust/parquet/src/arrow/mod.rs b/rust/parquet/src/arrow/mod.rs index ef1544d65bb..979345722d2 100644 --- a/rust/parquet/src/arrow/mod.rs +++ b/rust/parquet/src/arrow/mod.rs @@ -35,7 +35,7 @@ //! //! println!("Converted arrow schema is: {}", arrow_reader.get_schema().unwrap()); //! println!("Arrow schema after projection is: {}", -//! arrow_reader.get_schema_by_columns(vec![2, 4, 6]).unwrap()); +//! arrow_reader.get_schema_by_columns(vec![2, 4, 6], true).unwrap()); //! //! let mut record_batch_reader = arrow_reader.get_record_reader(2048).unwrap(); //! @@ -51,10 +51,18 @@ pub(in crate::arrow) mod array_reader; pub mod arrow_reader; +pub mod arrow_writer; pub(in crate::arrow) mod converter; pub(in crate::arrow) mod record_reader; pub mod schema; pub use self::arrow_reader::ArrowReader; pub use self::arrow_reader::ParquetFileArrowReader; -pub use self::schema::{parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns}; +pub use self::arrow_writer::ArrowWriter; +pub use self::schema::{ + arrow_to_parquet_schema, parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns, + parquet_to_arrow_schema_by_root_columns, +}; + +/// Schema metadata key used to store serialized Arrow IPC schema +pub const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; diff --git a/rust/parquet/src/arrow/record_reader.rs b/rust/parquet/src/arrow/record_reader.rs index ccfdaf8f0e5..b30ab7760b2 100644 --- a/rust/parquet/src/arrow/record_reader.rs +++ b/rust/parquet/src/arrow/record_reader.rs @@ -86,6 +86,7 @@ impl<'a, T> FatPtr<'a, T> { self.ptr } + #[allow(clippy::wrong_self_convention)] fn to_slice_mut(&mut self) -> &mut [T] { self.ptr } diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index aebb9e776cc..10270fff464 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -26,27 +26,91 @@ use std::collections::{HashMap, HashSet}; use std::rc::Rc; +use arrow::datatypes::{DataType, DateUnit, Field, Schema, TimeUnit}; +use arrow::ipc::writer; + use crate::basic::{LogicalType, Repetition, Type as PhysicalType}; use crate::errors::{ParquetError::ArrowError, Result}; -use crate::file::metadata::KeyValue; +use crate::file::{metadata::KeyValue, properties::WriterProperties}; use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr}; -use arrow::datatypes::TimeUnit; -use arrow::datatypes::{DataType, DateUnit, Field, Schema}; - -/// Convert parquet schema to arrow schema including optional metadata. +/// Convert Parquet schema to Arrow schema including optional metadata. +/// Attempts to decode any existing Arrow shcema metadata, falling back +/// to converting the Parquet schema column-wise pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, - metadata: &Option>, + key_value_metadata: &Option>, ) -> Result { - parquet_to_arrow_schema_by_columns( - parquet_schema, - 0..parquet_schema.columns().len(), - metadata, - ) + let mut metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); + let arrow_schema_metadata = metadata + .remove(super::ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)); + + match arrow_schema_metadata { + Some(Some(schema)) => Ok(schema), + _ => parquet_to_arrow_schema_by_columns( + parquet_schema, + 0..parquet_schema.columns().len(), + key_value_metadata, + ), + } +} + +/// Convert parquet schema to arrow schema including optional metadata, +/// only preserving some root columns. +/// This is useful if we have columns `a.b`, `a.c.e` and `a.d`, +/// and want `a` with all its child fields +pub fn parquet_to_arrow_schema_by_root_columns( + parquet_schema: &SchemaDescriptor, + column_indices: T, + key_value_metadata: &Option>, +) -> Result +where + T: IntoIterator, +{ + // Reconstruct the index ranges of the parent columns + // An Arrow struct gets represented by 1+ columns based on how many child fields the + // struct has. This means that getting fields 1 and 2 might return the struct twice, + // if field 1 is the struct having say 3 fields, and field 2 is a primitive. + // + // The below gets the parent columns, and counts the number of child fields in each parent, + // such that we would end up with: + // - field 1 - columns: [0, 1, 2] + // - field 2 - columns: [3] + let mut parent_columns = vec![]; + let mut curr_name = ""; + let mut prev_name = ""; + let mut indices = vec![]; + (0..(parquet_schema.num_columns())).for_each(|i| { + let p_type = parquet_schema.get_column_root(i); + curr_name = p_type.get_basic_info().name(); + if prev_name == "" { + // first index + indices.push(i); + prev_name = curr_name; + } else if curr_name != prev_name { + prev_name = curr_name; + parent_columns.push((curr_name.to_string(), indices.clone())); + indices = vec![i]; + } else { + indices.push(i); + } + }); + // push the last column if indices has values + if !indices.is_empty() { + parent_columns.push((curr_name.to_string(), indices)); + } + + // gather the required leaf columns + let leaf_columns = column_indices + .into_iter() + .flat_map(|i| parent_columns[i].1.clone()); + + parquet_to_arrow_schema_by_columns(parquet_schema, leaf_columns, key_value_metadata) } -/// Convert parquet schema to arrow schema including optional metadata, only preserving some leaf columns. +/// Convert parquet schema to arrow schema including optional metadata, +/// only preserving some leaf columns. pub fn parquet_to_arrow_schema_by_columns( parquet_schema: &SchemaDescriptor, column_indices: T, @@ -55,32 +119,136 @@ pub fn parquet_to_arrow_schema_by_columns( where T: IntoIterator, { + let mut metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); + let arrow_schema_metadata = metadata + .remove(super::ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)) + .unwrap_or_default(); + + // add the Arrow metadata to the Parquet metadata + if let Some(arrow_schema) = &arrow_schema_metadata { + arrow_schema.metadata().iter().for_each(|(k, v)| { + metadata.insert(k.clone(), v.clone()); + }); + } + let mut base_nodes = Vec::new(); let mut base_nodes_set = HashSet::new(); let mut leaves = HashSet::new(); + enum FieldType<'a> { + Parquet(&'a Type), + Arrow(Field), + } + for c in column_indices { - let column = parquet_schema.column(c).self_type() as *const Type; - let root = parquet_schema.get_column_root(c); - let root_raw_ptr = root as *const Type; - - leaves.insert(column); - if !base_nodes_set.contains(&root_raw_ptr) { - base_nodes.push(root); - base_nodes_set.insert(root_raw_ptr); + let column = parquet_schema.column(c); + let name = column.name(); + + if let Some(field) = arrow_schema_metadata + .as_ref() + .and_then(|schema| schema.field_with_name(name).ok().cloned()) + { + base_nodes.push(FieldType::Arrow(field)); + } else { + let column = column.self_type() as *const Type; + let root = parquet_schema.get_column_root(c); + let root_raw_ptr = root as *const Type; + + leaves.insert(column); + if !base_nodes_set.contains(&root_raw_ptr) { + base_nodes.push(FieldType::Parquet(root)); + base_nodes_set.insert(root_raw_ptr); + } } } - let metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); - base_nodes .into_iter() - .map(|t| ParquetTypeConverter::new(t, &leaves).to_field()) + .map(|t| match t { + FieldType::Parquet(t) => ParquetTypeConverter::new(t, &leaves).to_field(), + FieldType::Arrow(f) => Ok(Some(f)), + }) .collect::>>>() .map(|result| result.into_iter().filter_map(|f| f).collect::>()) .map(|fields| Schema::new_with_metadata(fields, metadata)) } +/// Try to convert Arrow schema metadata into a schema +fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Option { + let decoded = base64::decode(encoded_meta); + match decoded { + Ok(bytes) => { + let slice = if bytes[0..4] == [255u8; 4] { + &bytes[8..] + } else { + bytes.as_slice() + }; + let message = arrow::ipc::get_root_as_message(slice); + message + .header_as_schema() + .map(arrow::ipc::convert::fb_to_schema) + } + Err(err) => { + // The C++ implementation returns an error if the schema can't be parsed. + // To prevent this, we explicitly log this, then compute the schema without the metadata + eprintln!( + "Unable to decode the encoded schema stored in {}, {:?}", + super::ARROW_SCHEMA_META_KEY, + err + ); + None + } + } +} + +/// Encodes the Arrow schema into the IPC format, and base64 encodes it +fn encode_arrow_schema(schema: &Schema) -> String { + let options = writer::IpcWriteOptions::default(); + let mut serialized_schema = arrow::ipc::writer::schema_to_bytes(&schema, &options); + + // manually prepending the length to the schema as arrow uses the legacy IPC format + // TODO: change after addressing ARROW-9777 + let schema_len = serialized_schema.ipc_message.len(); + let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); + len_prefix_schema.append(&mut vec![255u8, 255, 255, 255]); + len_prefix_schema.append((schema_len as u32).to_le_bytes().to_vec().as_mut()); + len_prefix_schema.append(&mut serialized_schema.ipc_message); + + base64::encode(&len_prefix_schema) +} + +/// Mutates writer metadata by storing the encoded Arrow schema. +/// If there is an existing Arrow schema metadata, it is replaced. +pub(crate) fn add_encoded_arrow_schema_to_metadata( + schema: &Schema, + props: &mut WriterProperties, +) { + let encoded = encode_arrow_schema(schema); + + let schema_kv = KeyValue { + key: super::ARROW_SCHEMA_META_KEY.to_string(), + value: Some(encoded), + }; + + let mut meta = props.key_value_metadata.clone().unwrap_or_default(); + // check if ARROW:schema exists, and overwrite it + let schema_meta = meta + .iter() + .enumerate() + .find(|(_, kv)| kv.key.as_str() == super::ARROW_SCHEMA_META_KEY); + match schema_meta { + Some((i, _)) => { + meta.remove(i); + meta.push(schema_kv); + } + None => { + meta.push(schema_kv); + } + } + props.key_value_metadata = Some(meta); +} + /// Convert arrow schema to parquet schema pub fn arrow_to_parquet_schema(schema: &Schema) -> Result { let fields: Result> = schema @@ -140,7 +308,10 @@ fn arrow_to_parquet_type(field: &Field) -> Result { }; // create type from field match field.data_type() { - DataType::Null => Err(ArrowError("Null arrays not supported".to_string())), + DataType::Null => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::NONE) + .with_repetition(repetition) + .build(), DataType::Boolean => Type::primitive_type_builder(name, PhysicalType::BOOLEAN) .with_repetition(repetition) .build(), @@ -215,42 +386,48 @@ fn arrow_to_parquet_type(field: &Field) -> Result { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_logical_type(LogicalType::INTERVAL) .with_repetition(repetition) - .with_length(3) + .with_length(12) + .build() + } + DataType::Binary | DataType::LargeBinary => { + Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_repetition(repetition) .build() } - DataType::Binary => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_repetition(repetition) - .build(), DataType::FixedSizeBinary(length) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_repetition(repetition) .with_length(*length) .build() } - DataType::Utf8 => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(LogicalType::UTF8) - .with_repetition(repetition) - .build(), - DataType::List(dtype) | DataType::FixedSizeList(dtype, _) => { - Type::group_type_builder(name) - .with_fields(&mut vec![Rc::new( - Type::group_type_builder("list") - .with_fields(&mut vec![Rc::new({ - let list_field = Field::new( - "element", - *dtype.clone(), - field.is_nullable(), - ); - arrow_to_parquet_type(&list_field)? - })]) - .with_repetition(Repetition::REPEATED) - .build()?, - )]) - .with_logical_type(LogicalType::LIST) - .with_repetition(Repetition::REQUIRED) + DataType::Utf8 | DataType::LargeUtf8 => { + Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_logical_type(LogicalType::UTF8) + .with_repetition(repetition) .build() } + DataType::List(dtype) + | DataType::FixedSizeList(dtype, _) + | DataType::LargeList(dtype) => Type::group_type_builder(name) + .with_fields(&mut vec![Rc::new( + Type::group_type_builder("list") + .with_fields(&mut vec![Rc::new({ + let list_field = + Field::new("element", *dtype.clone(), field.is_nullable()); + arrow_to_parquet_type(&list_field)? + })]) + .with_repetition(Repetition::REPEATED) + .build()?, + )]) + .with_logical_type(LogicalType::LIST) + .with_repetition(Repetition::REQUIRED) + .build(), DataType::Struct(fields) => { + if fields.is_empty() { + return Err(ArrowError( + "Parquet does not support writing empty structs".to_string(), + )); + } // recursively convert children to types/nodes let fields: Result> = fields .iter() @@ -267,9 +444,6 @@ fn arrow_to_parquet_type(field: &Field) -> Result { let dict_field = Field::new(name, *value.clone(), field.is_nullable()); arrow_to_parquet_type(&dict_field) } - DataType::LargeUtf8 | DataType::LargeBinary | DataType::LargeList(_) => { - Err(ArrowError("Large arrays not supported".to_string())) - } } } /// This struct is used to group methods and data structures used to convert parquet @@ -555,12 +729,16 @@ impl ParquetTypeConverter<'_> { mod tests { use super::*; - use std::collections::HashMap; + use std::{collections::HashMap, convert::TryFrom, sync::Arc}; - use arrow::datatypes::{DataType, DateUnit, Field, TimeUnit}; + use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, TimeUnit}; - use crate::file::metadata::KeyValue; - use crate::schema::{parser::parse_message_type, types::SchemaDescriptor}; + use crate::file::{metadata::KeyValue, reader::SerializedFileReader}; + use crate::{ + arrow::{ArrowReader, ArrowWriter, ParquetFileArrowReader}, + schema::{parser::parse_message_type, types::SchemaDescriptor}, + util::test_common::get_temp_file, + }; #[test] fn test_flat_primitives() { @@ -1194,6 +1372,17 @@ mod tests { }); } + #[test] + #[should_panic(expected = "Parquet does not support writing empty structs")] + fn test_empty_struct_field() { + let arrow_fields = vec![Field::new("struct", DataType::Struct(vec![]), false)]; + let arrow_schema = Schema::new(arrow_fields); + let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema); + + assert!(converted_arrow_schema.is_err()); + converted_arrow_schema.unwrap(); + } + #[test] fn test_metadata() { let message_type = " @@ -1216,4 +1405,184 @@ mod tests { assert_eq!(converted_arrow_schema.metadata(), &expected_metadata); } + + #[test] + fn test_arrow_schema_roundtrip() -> Result<()> { + // This tests the roundtrip of an Arrow schema + // Fields that are commented out fail roundtrip tests or are unsupported by the writer + let metadata: HashMap = + [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Binary, false), + Field::new("c3", DataType::FixedSizeBinary(3), false), + Field::new("c4", DataType::Boolean, false), + Field::new("c5", DataType::Date32(DateUnit::Day), false), + Field::new("c6", DataType::Date64(DateUnit::Millisecond), false), + Field::new("c7", DataType::Time32(TimeUnit::Second), false), + Field::new("c8", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("c13", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("c14", DataType::Time64(TimeUnit::Nanosecond), false), + Field::new("c15", DataType::Timestamp(TimeUnit::Second, None), false), + Field::new( + "c16", + DataType::Timestamp( + TimeUnit::Millisecond, + Some(Arc::new("UTC".to_string())), + ), + false, + ), + Field::new( + "c17", + DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::new("Africa/Johannesburg".to_string())), + ), + false, + ), + Field::new( + "c18", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("c19", DataType::Interval(IntervalUnit::DayTime), false), + Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), + Field::new("c21", DataType::List(Box::new(DataType::Boolean)), false), + // Field::new( + // "c22", + // DataType::FixedSizeList(Box::new(DataType::Boolean), 5), + // false, + // ), + // Field::new( + // "c23", + // DataType::List(Box::new(DataType::LargeList(Box::new( + // DataType::Struct(vec![ + // Field::new("a", DataType::Int16, true), + // Field::new("b", DataType::Float64, false), + // ]), + // )))), + // true, + // ), + Field::new( + "c24", + DataType::Struct(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::UInt16, false), + ]), + false, + ), + Field::new("c25", DataType::Interval(IntervalUnit::YearMonth), true), + Field::new("c26", DataType::Interval(IntervalUnit::DayTime), true), + // Field::new("c27", DataType::Duration(TimeUnit::Second), false), + // Field::new("c28", DataType::Duration(TimeUnit::Millisecond), false), + // Field::new("c29", DataType::Duration(TimeUnit::Microsecond), false), + // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), + Field::new_dict( + "c31", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 123, + true, + ), + Field::new("c32", DataType::LargeBinary, true), + Field::new("c33", DataType::LargeUtf8, true), + // Field::new( + // "c34", + // DataType::LargeList(Box::new(DataType::List(Box::new( + // DataType::Struct(vec![ + // Field::new("a", DataType::Int16, true), + // Field::new("b", DataType::Float64, true), + // ]), + // )))), + // true, + // ), + Field::new("c35", DataType::Null, true), + ], + metadata, + ); + + // write to an empty parquet file so that schema is serialized + let file = get_temp_file("test_arrow_schema_roundtrip.parquet", &[]); + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; + writer.close()?; + + // read file back + let parquet_reader = SerializedFileReader::try_from(file)?; + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(parquet_reader)); + let read_schema = arrow_reader.get_schema()?; + assert_eq!(schema, read_schema); + + // read all fields by columns + let partial_read_schema = + arrow_reader.get_schema_by_columns(0..(schema.fields().len()), false)?; + assert_eq!(schema, partial_read_schema); + + Ok(()) + } + + #[test] + #[ignore = "Roundtrip of lists currently fails because we don't check their types correctly in the Arrow schema"] + fn test_arrow_schema_roundtrip_lists() -> Result<()> { + let metadata: HashMap = + [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("c21", DataType::List(Box::new(DataType::Boolean)), false), + Field::new( + "c22", + DataType::FixedSizeList(Box::new(DataType::Boolean), 5), + false, + ), + Field::new( + "c23", + DataType::List(Box::new(DataType::LargeList(Box::new( + DataType::Struct(vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Float64, false), + ]), + )))), + true, + ), + ], + metadata, + ); + + // write to an empty parquet file so that schema is serialized + let file = get_temp_file("test_arrow_schema_roundtrip_lists.parquet", &[]); + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; + writer.close()?; + + // read file back + let parquet_reader = SerializedFileReader::try_from(file)?; + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(parquet_reader)); + let read_schema = arrow_reader.get_schema()?; + assert_eq!(schema, read_schema); + + // read all fields by columns + let partial_read_schema = + arrow_reader.get_schema_by_columns(0..(schema.fields().len()), false)?; + assert_eq!(schema, partial_read_schema); + + Ok(()) + } } diff --git a/rust/parquet/src/file/properties.rs b/rust/parquet/src/file/properties.rs index 188d6ec3c9e..b62ce7bbc38 100644 --- a/rust/parquet/src/file/properties.rs +++ b/rust/parquet/src/file/properties.rs @@ -89,8 +89,8 @@ pub type WriterPropertiesPtr = Rc; /// Writer properties. /// -/// It is created as an immutable data structure, use [`WriterPropertiesBuilder`] to -/// assemble the properties. +/// All properties except the key-value metadata are immutable, +/// use [`WriterPropertiesBuilder`] to assemble these properties. #[derive(Debug, Clone)] pub struct WriterProperties { data_pagesize_limit: usize, @@ -99,7 +99,7 @@ pub struct WriterProperties { max_row_group_size: usize, writer_version: WriterVersion, created_by: String, - key_value_metadata: Option>, + pub(crate) key_value_metadata: Option>, default_column_properties: ColumnProperties, column_properties: HashMap, } diff --git a/rust/parquet/src/schema/types.rs b/rust/parquet/src/schema/types.rs index 416073af035..57999050ab3 100644 --- a/rust/parquet/src/schema/types.rs +++ b/rust/parquet/src/schema/types.rs @@ -788,7 +788,7 @@ impl SchemaDescriptor { result.clone() } - fn column_root_of(&self, i: usize) -> &Rc { + fn column_root_of(&self, i: usize) -> &TypePtr { assert!( i < self.leaves.len(), "Index out of bound: {} not in [0, {})", @@ -810,6 +810,10 @@ impl SchemaDescriptor { self.schema.as_ref() } + pub fn root_schema_ptr(&self) -> TypePtr { + self.schema.clone() + } + /// Returns schema name. pub fn name(&self) -> &str { self.schema.name() diff --git a/rust/parquet/src/schema/visitor.rs b/rust/parquet/src/schema/visitor.rs index 6d712ce441f..a1866fb1471 100644 --- a/rust/parquet/src/schema/visitor.rs +++ b/rust/parquet/src/schema/visitor.rs @@ -50,7 +50,7 @@ pub trait TypeVisitor { { self.visit_list_with_item( list_type.clone(), - list_item, + list_item.clone(), context, ) } else { @@ -70,13 +70,13 @@ pub trait TypeVisitor { { self.visit_list_with_item( list_type.clone(), - fields.first().unwrap(), + fields.first().unwrap().clone(), context, ) } else { self.visit_list_with_item( list_type.clone(), - list_item, + list_item.clone(), context, ) } @@ -114,7 +114,7 @@ pub trait TypeVisitor { fn visit_list_with_item( &mut self, list_type: TypePtr, - item_type: &Type, + item_type: TypePtr, context: C, ) -> Result; } @@ -125,7 +125,7 @@ mod tests { use crate::basic::Type as PhysicalType; use crate::errors::Result; use crate::schema::parser::parse_message_type; - use crate::schema::types::{Type, TypePtr}; + use crate::schema::types::TypePtr; use std::rc::Rc; struct TestVisitorContext {} @@ -174,7 +174,7 @@ mod tests { fn visit_list_with_item( &mut self, list_type: TypePtr, - item_type: &Type, + item_type: TypePtr, _context: TestVisitorContext, ) -> Result { assert_eq!(