Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions {
std::string pattern;
};

struct ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions {
explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement)
: start(start), stop(stop), replacement(std::move(replacement)) {}

/// Index to start slicing at
int64_t start;
/// Index to stop slicing at
int64_t stop;
/// String to replace the slice with
std::string replacement;
};

struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
explicit ReplaceSubstringOptions(std::string pattern, std::string replacement,
int64_t max_replacements = -1)
Expand Down
160 changes: 160 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,165 @@ const FunctionDoc replace_substring_regex_doc(
{"strings"}, "ReplaceSubstringOptions");
#endif

// ----------------------------------------------------------------------
// Replace slice

struct ReplaceSliceTransformBase : public StringTransformBase {
using State = OptionsWrapper<ReplaceSliceOptions>;

const ReplaceSliceOptions* options;

explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options)
: options{&options} {}

int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
return ninputs * options->replacement.size() + input_ncodeunits;
}
};

struct BinaryReplaceSliceTransform : ReplaceSliceTransformBase {
using ReplaceSliceTransformBase::ReplaceSliceTransformBase;
int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
uint8_t* output) {
const auto& opts = *options;
int64_t before_slice = 0;
int64_t after_slice = 0;
uint8_t* output_start = output;

if (opts.start >= 0) {
// Count from left
before_slice = std::min<int64_t>(input_string_ncodeunits, opts.start);
} else {
// Count from right
before_slice = std::max<int64_t>(0, input_string_ncodeunits + opts.start);
}
// Mimic Pandas: if stop would be before start, treat as 0-length slice
if (opts.stop >= 0) {
// Count from left
after_slice =
std::min<int64_t>(input_string_ncodeunits, std::max(before_slice, opts.stop));
} else {
// Count from right
after_slice = std::max<int64_t>(before_slice, input_string_ncodeunits + opts.stop);
}
output = std::copy(input, input + before_slice, output);
output = std::copy(opts.replacement.begin(), opts.replacement.end(), output);
output = std::copy(input + after_slice, input + input_string_ncodeunits, output);
return output - output_start;
}
};

struct Utf8ReplaceSliceTransform : ReplaceSliceTransformBase {
using ReplaceSliceTransformBase::ReplaceSliceTransformBase;
int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
uint8_t* output) {
const auto& opts = *options;
const uint8_t* begin = input;
const uint8_t* end = input + input_string_ncodeunits;
const uint8_t *begin_sliced, *end_sliced;
uint8_t* output_start = output;

// Mimic Pandas: if stop would be before start, treat as 0-length slice
if (opts.start >= 0) {
// Count from left
if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opts.start)) {
return kTransformError;
}
if (opts.stop > options->start) {
// Continue counting from left
const int64_t length = opts.stop - options->start;
if (!arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)) {
return kTransformError;
}
} else if (opts.stop < 0) {
// Count from right
if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced,
-opts.stop)) {
return kTransformError;
}
} else {
// Zero-length slice
end_sliced = begin_sliced;
}
} else {
// Count from right
if (!arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced,
-opts.start)) {
return kTransformError;
}
if (opts.stop >= 0) {
// Restart counting from left
if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opts.stop)) {
return kTransformError;
}
if (end_sliced <= begin_sliced) {
// Zero-length slice
end_sliced = begin_sliced;
}
} else if ((opts.stop < 0) && (options->stop > options->start)) {
// Count from right
if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced,
-opts.stop)) {
return kTransformError;
}
} else {
// zero-length slice
end_sliced = begin_sliced;
}
}
output = std::copy(begin, begin_sliced, output);
output = std::copy(opts.replacement.begin(), options->replacement.end(), output);
output = std::copy(end_sliced, end, output);
return output - output_start;
}
};

template <typename Type>
using BinaryReplaceSlice =
StringTransformExecWithState<Type, BinaryReplaceSliceTransform>;
template <typename Type>
using Utf8ReplaceSlice = StringTransformExecWithState<Type, Utf8ReplaceSliceTransform>;

const FunctionDoc binary_replace_slice_doc(
"Replace a slice of a binary string with `replacement`",
("For each string in `strings`, replace a slice of the string defined by `start`"
"and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, "
"and both are measured in bytes.\n"
"Null values emit null."),
{"strings"}, "ReplaceSliceOptions");

const FunctionDoc utf8_replace_slice_doc(
"Replace a slice of a string with `replacement`",
("For each string in `strings`, replace a slice of the string defined by `start`"
"and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, "
"and both are measured in codeunits.\n"
"Null values emit null."),
{"strings"}, "ReplaceSliceOptions");

void AddReplaceSlice(FunctionRegistry* registry) {
{
auto func = std::make_shared<ScalarFunction>("binary_replace_slice", Arity::Unary(),
&binary_replace_slice_doc);
for (const auto& ty : BaseBinaryTypes()) {
DCHECK_OK(func->AddKernel({ty}, ty,
GenerateTypeAgnosticVarBinaryBase<BinaryReplaceSlice>(ty),
ReplaceSliceTransformBase::State::Init));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<ScalarFunction>("utf8_replace_slice", Arity::Unary(),
&utf8_replace_slice_doc);
DCHECK_OK(func->AddKernel({utf8()}, utf8(), Utf8ReplaceSlice<StringType>::Exec,
ReplaceSliceTransformBase::State::Init));
DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(),
Utf8ReplaceSlice<LargeStringType>::Exec,
ReplaceSliceTransformBase::State::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

// ----------------------------------------------------------------------
// Extract with regex

Expand Down Expand Up @@ -3434,6 +3593,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
MemAllocation::NO_PREALLOCATE);
AddExtractRegex(registry);
#endif
AddReplaceSlice(registry);
AddSlice(registry);
AddSplit(registry);
AddStrptime(registry);
Expand Down
100 changes: 100 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,56 @@ TYPED_TEST(TestBinaryKernels, CountSubstring) {
// TODO: case-insensitive
}

TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) {
ReplaceSliceOptions options{0, 1, "XX"};
this->CheckUnary("binary_replace_slice", "[]", this->type(), "[]", &options);
this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc"])",
this->type(), R"([null, "XX", "XX", "XXb", "XXbc"])", &options);

ReplaceSliceOptions options_whole{0, 5, "XX"};
this->CheckUnary("binary_replace_slice",
R"([null, "", "a", "ab", "abc", "abcde", "abcdef"])", this->type(),
R"([null, "XX", "XX", "XX", "XX", "XX", "XXf"])", &options_whole);

ReplaceSliceOptions options_middle{2, 4, "XX"};
this->CheckUnary("binary_replace_slice",
R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
R"([null, "XX", "aXX", "abXX", "abXX", "abXX", "abXXe"])",
&options_middle);

ReplaceSliceOptions options_neg_start{-3, -2, "XX"};
this->CheckUnary("binary_replace_slice",
R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
R"([null, "XX", "XXa", "XXab", "XXbc", "aXXcd", "abXXde"])",
&options_neg_start);

ReplaceSliceOptions options_neg_end{2, -2, "XX"};
this->CheckUnary("binary_replace_slice",
R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXde"])",
&options_neg_end);

ReplaceSliceOptions options_neg_pos{-1, 2, "XX"};
this->CheckUnary("binary_replace_slice",
R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
R"([null, "XX", "XX", "aXX", "abXXc", "abcXXd", "abcdXXe"])",
&options_neg_pos);

// Effectively the same as [2, 2)
ReplaceSliceOptions options_flip{2, 0, "XX"};
this->CheckUnary("binary_replace_slice",
R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXcde"])",
&options_flip);

// Effectively the same as [-3, -3)
ReplaceSliceOptions options_neg_flip{-3, -5, "XX"};
this->CheckUnary("binary_replace_slice",
R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
R"([null, "XX", "XXa", "XXab", "XXabc", "aXXbcd", "abXXcde"])",
&options_neg_flip);
}

template <typename TestType>
class TestStringKernels : public BaseTestStringKernels<TestType> {};

Expand Down Expand Up @@ -745,6 +795,56 @@ TYPED_TEST(TestStringKernels, SplitRegexReverse) {
}
#endif

TYPED_TEST(TestStringKernels, Utf8ReplaceSlice) {
ReplaceSliceOptions options{0, 1, "χχ"};
this->CheckUnary("utf8_replace_slice", "[]", this->type(), "[]", &options);
this->CheckUnary("utf8_replace_slice", R"([null, "", "π", "πb", "πbθ"])", this->type(),
R"([null, "χχ", "χχ", "χχb", "χχbθ"])", &options);

ReplaceSliceOptions options_whole{0, 5, "χχ"};
this->CheckUnary("utf8_replace_slice",
R"([null, "", "π", "πb", "πbθ", "πbθde", "πbθdef"])", this->type(),
R"([null, "χχ", "χχ", "χχ", "χχ", "χχ", "χχf"])", &options_whole);

ReplaceSliceOptions options_middle{2, 4, "χχ"};
this->CheckUnary("utf8_replace_slice",
R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
R"([null, "χχ", "πχχ", "πbχχ", "πbχχ", "πbχχ", "πbχχe"])",
&options_middle);

ReplaceSliceOptions options_neg_start{-3, -2, "χχ"};
this->CheckUnary("utf8_replace_slice",
R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
R"([null, "χχ", "χχπ", "χχπb", "χχbθ", "πχχθd", "πbχχde"])",
&options_neg_start);

ReplaceSliceOptions options_neg_end{2, -2, "χχ"};
this->CheckUnary("utf8_replace_slice",
R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχde"])",
&options_neg_end);

ReplaceSliceOptions options_neg_pos{-1, 2, "χχ"};
this->CheckUnary("utf8_replace_slice",
R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
R"([null, "χχ", "χχ", "πχχ", "πbχχθ", "πbθχχd", "πbθdχχe"])",
&options_neg_pos);

// Effectively the same as [2, 2)
ReplaceSliceOptions options_flip{2, 0, "χχ"};
this->CheckUnary("utf8_replace_slice",
R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχθde"])",
&options_flip);

// Effectively the same as [-3, -3)
ReplaceSliceOptions options_neg_flip{-3, -5, "χχ"};
this->CheckUnary("utf8_replace_slice",
R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
R"([null, "χχ", "χχπ", "χχπb", "χχπbθ", "πχχbθd", "πbχχθde"])",
&options_neg_flip);
}

TYPED_TEST(TestStringKernels, ReplaceSubstring) {
ReplaceSubstringOptions options{"foo", "bazz"};
this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
Expand Down
Loading