diff --git a/r/src/array_from_vector.cpp b/r/src/array_from_vector.cpp index 19322309a6b..63f1307f1c0 100644 --- a/r/src/array_from_vector.cpp +++ b/r/src/array_from_vector.cpp @@ -351,14 +351,19 @@ std::shared_ptr MakeFactorArrayImpl(Rcpp::IntegerVector_ factor, std::shared_ptr MakeFactorArray(Rcpp::IntegerVector_ factor, const std::shared_ptr& type) { - SEXP levels = factor.attr("levels"); - int n = Rf_length(levels); - if (n < 128) { - return MakeFactorArrayImpl(factor, type); - } else if (n < 32768) { - return MakeFactorArrayImpl(factor, type); - } else { - return MakeFactorArrayImpl(factor, type); + const auto& dict_type = checked_cast(*type); + switch (dict_type.index_type()->id()) { + case Type::INT8: + return MakeFactorArrayImpl(factor, type); + case Type::INT16: + return MakeFactorArrayImpl(factor, type); + case Type::INT32: + return MakeFactorArrayImpl(factor, type); + case Type::INT64: + return MakeFactorArrayImpl(factor, type); + default: + Rcpp::stop(tfm::format("Cannot convert to dictionary with index_type %s", + dict_type.index_type()->ToString())); } } @@ -1297,8 +1302,8 @@ bool CheckCompatibleFactor(SEXP obj, const std::shared_ptr& typ return false; } - auto* dict_type = checked_cast(type.get()); - return dict_type->value_type()->Equals(utf8()); + const auto& dict_type = checked_cast(*type); + return dict_type.value_type()->Equals(utf8()); } arrow::Status CheckCompatibleStruct(SEXP obj, diff --git a/r/src/array_to_vector.cpp b/r/src/array_to_vector.cpp index 9586c839970..df6f3b34ed0 100644 --- a/r/src/array_to_vector.cpp +++ b/r/src/array_to_vector.cpp @@ -357,6 +357,7 @@ class Converter_Dictionary : public Converter { case Type::UINT16: case Type::INT16: case Type::INT32: + // TODO: also add int64, uint32, uint64 downcasts, if possible break; default: Rcpp::stop("Cannot convert Dictionary Array of type `%s` to R", diff --git a/r/tests/testthat/test-Table.R b/r/tests/testthat/test-Table.R index 726fa96afee..aa23789d031 100644 --- a/r/tests/testthat/test-Table.R +++ b/r/tests/testthat/test-Table.R @@ -321,3 +321,19 @@ test_that("Table handles null type (ARROW-7064)", { tab <- Table$create(a = 1:10, n = vctrs::unspecified(10)) expect_equivalent(tab$schema, schema(a = int32(), n = null())) }) + +test_that("Can create table with specific dictionary types", { + fact <- example_data[,"fct"] + int_types <- c(int8(), int16(), int32(), int64()) + # TODO: test uint types when format allows + # uint_types <- c(uint8(), uint16(), uint32(), uint64()) + for (i in int_types) { + sch <- schema(fct = dictionary(i, utf8())) + tab <- Table$create(fact, schema = sch) + expect_equal(sch, tab$schema) + if (i != int64()) { + # TODO: same downcast to int32 as we do for int64() type elsewhere + expect_identical(as.data.frame(tab), fact) + } + } +})