Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 122 additions & 120 deletions cpp/src/arrow/python/builtin_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <algorithm>
#include <limits>
#include <map>
#include <sstream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -49,9 +50,11 @@ Status InvalidConversion(PyObject* obj, const std::string& expected_types,
return Status::OK();
}

class ScalarVisitor {
class TypeInferrer {
// A type inference visitor for Python values

public:
ScalarVisitor()
TypeInferrer()
: total_count_(0),
none_count_(0),
bool_count_(0),
Expand All @@ -62,14 +65,46 @@ class ScalarVisitor {
binary_count_(0),
unicode_count_(0),
decimal_count_(0),
list_count_(0),
struct_count_(0),
max_decimal_metadata_(std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::min()),
decimal_type_() {
PyAcquireGIL lock;
Status status = internal::ImportDecimalType(&decimal_type_);
DCHECK_OK(status);
}

// Infer value type from a sequence of values
Status VisitSequence(PyObject* obj) {
// Loop through a sequence
if (PyArray_Check(obj)) {
Py_ssize_t size = PySequence_Size(obj);
OwnedRef value_ref;

for (Py_ssize_t i = 0; i < size; ++i) {
auto array = reinterpret_cast<PyArrayObject*>(obj);
auto ptr = reinterpret_cast<const char*>(PyArray_GETPTR1(array, i));

value_ref.reset(PyArray_GETITEM(array, ptr));
RETURN_IF_PYERROR();
RETURN_NOT_OK(Visit(value_ref.obj()));
}
} else if (PySequence_Check(obj)) {
OwnedRef seq_ref(PySequence_Fast(obj, "Object is not a sequence or iterable"));
RETURN_IF_PYERROR();
PyObject* seq = seq_ref.obj();

Py_ssize_t size = PySequence_Fast_GET_SIZE(seq);
for (Py_ssize_t i = 0; i < size; ++i) {
PyObject* value = PySequence_Fast_GET_ITEM(seq, i);
RETURN_NOT_OK(Visit(value));
}
} else {
return Status::TypeError("Object is not a sequence or iterable");
}
return Status::OK();
}

Status Visit(PyObject* obj) {
++total_count_;
if (obj == Py_None || internal::PyFloat_IsNaN(obj)) {
Expand Down Expand Up @@ -103,6 +138,10 @@ class ScalarVisitor {
ss << type->ToString();
return Status::Invalid(ss.str());
}
} else if (PyList_Check(obj) || PyArray_Check(obj)) {
return VisitList(obj);
} else if (PyDict_Check(obj)) {
return VisitDict(obj);
} else if (PyObject_IsInstance(obj, decimal_type_.obj())) {
RETURN_NOT_OK(max_decimal_metadata_.Update(obj));
++decimal_count_;
Expand All @@ -118,14 +157,36 @@ class ScalarVisitor {
return Status::OK();
}

std::shared_ptr<DataType> GetType() {
Status Validate() const {
if (list_count_ > 0) {
if (list_count_ + none_count_ != total_count_) {
return Status::Invalid("cannot mix list and non-list, non-null values");
}
RETURN_NOT_OK(list_inferrer_->Validate());
} else if (struct_count_ > 0) {
if (struct_count_ + none_count_ != total_count_) {
return Status::Invalid("cannot mix struct and non-struct, non-null values");
}
for (const auto& it : struct_inferrers_) {
RETURN_NOT_OK(it.second.Validate());
}
}
return Status::OK();
}

std::shared_ptr<DataType> GetType() const {
// TODO(wesm): handling mixed-type cases
if (decimal_count_) {
if (list_count_) {
auto value_type = list_inferrer_->GetType();
DCHECK(value_type != nullptr);
return list(value_type);
} else if (struct_count_) {
return GetStructType();
} else if (decimal_count_) {
return decimal(max_decimal_metadata_.precision(), max_decimal_metadata_.scale());
} else if (float_count_) {
return float64();
} else if (int_count_) {
// TODO(wesm): tighter type later
return int64();
} else if (date_count_) {
return date64();
Expand All @@ -144,6 +205,53 @@ class ScalarVisitor {

int64_t total_count() const { return total_count_; }

protected:
Status VisitList(PyObject* obj) {
if (!list_inferrer_) {
list_inferrer_.reset(new TypeInferrer);
}
++list_count_;
return list_inferrer_->VisitSequence(obj);
}

Status VisitDict(PyObject* obj) {
PyObject* key_obj;
PyObject* value_obj;
Py_ssize_t pos = 0;

while (PyDict_Next(obj, &pos, &key_obj, &value_obj)) {
std::string key;
if (PyUnicode_Check(key_obj)) {
RETURN_NOT_OK(internal::PyUnicode_AsStdString(key_obj, &key));
} else if (PyBytes_Check(key_obj)) {
key = internal::PyBytes_AsStdString(key_obj);
} else {
std::stringstream ss;
ss << "Expected dict key of type str or bytes, got '" << Py_TYPE(key_obj)->tp_name
<< "'";
return Status::TypeError(ss.str());
}
// Get or create visitor for this key
auto it = struct_inferrers_.find(key);
if (it == struct_inferrers_.end()) {
it = struct_inferrers_.insert(std::make_pair(key, TypeInferrer())).first;
}
TypeInferrer* visitor = &it->second;
RETURN_NOT_OK(visitor->Visit(value_obj));
}
++struct_count_;
return Status::OK();
}

std::shared_ptr<DataType> GetStructType() const {
std::vector<std::shared_ptr<Field>> fields;
for (const auto& it : struct_inferrers_) {
const auto struct_field = field(it.first, it.second.GetType());
fields.emplace_back(struct_field);
}
return struct_(fields);
}

private:
int64_t total_count_;
int64_t none_count_;
Expand All @@ -155,6 +263,10 @@ class ScalarVisitor {
int64_t binary_count_;
int64_t unicode_count_;
int64_t decimal_count_;
int64_t list_count_;
std::unique_ptr<TypeInferrer> list_inferrer_;
int64_t struct_count_;
std::map<std::string, TypeInferrer> struct_inferrers_;

internal::DecimalMetadata max_decimal_metadata_;

Expand All @@ -163,116 +275,6 @@ class ScalarVisitor {
OwnedRefNoGIL decimal_type_;
};

static constexpr int MAX_NESTING_LEVELS = 32;

// SeqVisitor is used to infer the type.
class SeqVisitor {
public:
SeqVisitor() : max_nesting_level_(0), max_observed_level_(0), nesting_histogram_() {
std::fill(nesting_histogram_, nesting_histogram_ + MAX_NESTING_LEVELS, 0);
}

// co-recursive with VisitElem
Status Visit(PyObject* obj, int level = 0) {
max_nesting_level_ = std::max(max_nesting_level_, level);

// Loop through a sequence
if (!PySequence_Check(obj))
return Status::TypeError("Object is not a sequence or iterable");

Py_ssize_t size = PySequence_Size(obj);
for (int64_t i = 0; i < size; ++i) {
OwnedRef ref;
if (PyArray_Check(obj)) {
auto array = reinterpret_cast<PyArrayObject*>(obj);
auto ptr = reinterpret_cast<const char*>(PyArray_GETPTR1(array, i));

ref.reset(PyArray_GETITEM(array, ptr));
RETURN_IF_PYERROR();

RETURN_NOT_OK(VisitElem(ref, level));
} else {
ref.reset(PySequence_GetItem(obj, i));
RETURN_IF_PYERROR();
RETURN_NOT_OK(VisitElem(ref, level));
}
}
return Status::OK();
}

std::shared_ptr<DataType> GetType() {
// If all the non-list inputs were null (or there were no inputs)
std::shared_ptr<DataType> result;
if (scalars_.total_count() == 0) {
// Lists of Lists of NULL
result = null();
} else {
// Lists of Lists of [X]
result = scalars_.GetType();
}
for (int i = 0; i < max_nesting_level_; ++i) {
result = std::make_shared<ListType>(result);
}
return result;
}

Status Validate() const {
if (scalars_.total_count() > 0) {
if (num_nesting_levels() > 1) {
return Status::Invalid("Mixed nesting levels not supported");
// If the nesting goes deeper than the deepest scalar
} else if (max_observed_level_ < max_nesting_level_) {
return Status::Invalid("Mixed nesting levels not supported");
}
}
return Status::OK();
}

// Returns the number of nesting levels which have scalar elements.
int num_nesting_levels() const {
int result = 0;
for (int i = 0; i < MAX_NESTING_LEVELS; ++i) {
if (nesting_histogram_[i] > 0) {
++result;
}
}
return result;
}

private:
ScalarVisitor scalars_;

// Track observed
// Deapest nesting level (irregardless of scalars)
int max_nesting_level_;
int max_observed_level_;

// Number of scalar elements at each nesting level.
// (TOOD: We really only need to know if a scalar is present, not the count).
int nesting_histogram_[MAX_NESTING_LEVELS];

// Visits a specific element (inner part of the loop).
Status VisitElem(const OwnedRef& item_ref, int level) {
DCHECK_NE(item_ref.obj(), NULLPTR);
if (PyList_Check(item_ref.obj()) || PyArray_Check(item_ref.obj())) {
RETURN_NOT_OK(Visit(item_ref.obj(), level + 1));
} else if (PyDict_Check(item_ref.obj())) {
return Status::NotImplemented("No type inference for dicts");
} else {
// We permit nulls at any level of nesting, but they aren't treated like
// other scalar values as far as the checking for mixed nesting structure
if (item_ref.obj() != Py_None) {
++nesting_histogram_[level];
}
if (level > max_observed_level_) {
max_observed_level_ = level;
}
return scalars_.Visit(item_ref.obj());
}
return Status::OK();
}
};

// Convert *obj* to a sequence if necessary
// Fill *size* to its length. If >= 0 on entry, *size* is an upper size
// bound that may lead to truncation.
Expand Down Expand Up @@ -319,11 +321,11 @@ Status ConvertToSequenceAndInferSize(PyObject* obj, PyObject** seq, int64_t* siz
// Non-exhaustive type inference
Status InferArrowType(PyObject* obj, std::shared_ptr<DataType>* out_type) {
PyDateTime_IMPORT;
SeqVisitor seq_visitor;
RETURN_NOT_OK(seq_visitor.Visit(obj));
RETURN_NOT_OK(seq_visitor.Validate());
TypeInferrer inferrer;
RETURN_NOT_OK(inferrer.VisitSequence(obj));
RETURN_NOT_OK(inferrer.Validate());

*out_type = seq_visitor.GetType();
*out_type = inferrer.GetType();
if (*out_type == nullptr) {
return Status::TypeError("Unable to determine data type");
}
Expand Down
3 changes: 1 addition & 2 deletions python/benchmarks/convert_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class InferPyListToArray(object):
"""
size = 10 ** 5
types = ('int64', 'float64', 'bool', 'decimal', 'binary', 'ascii',
'unicode', 'int64 list')
# TODO add 'struct' when supported
'unicode', 'int64 list', 'struct')

param_names = ['type']
params = [types]
Expand Down
Loading