Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 115 additions & 13 deletions be/src/vec/functions/function_convert_tz.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
#include "vec/runtime/vdatetime_value.h"
namespace doris::vectorized {

struct ConvertTzState {
bool use_state = false;
bool is_valid = false;
cctz::time_zone from_tz;
cctz::time_zone to_tz;
};

template <typename ArgDateType>
class FunctionConvertTZ : public IFunction {
using DateValueType = date_cast::TypeToValueTypeV<ArgDateType>;
Expand Down Expand Up @@ -88,8 +95,62 @@ class FunctionConvertTZ : public IFunction {
std::make_shared<DataTypeString>()};
}

Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
if (scope == FunctionContext::THREAD_LOCAL) {
return Status::OK();
}
std::shared_ptr<ConvertTzState> state = std::make_shared<ConvertTzState>();

context->set_function_state(scope, state);
DCHECK_EQ(context->get_num_args(), 3);
const auto* const_from_tz = context->get_constant_col(1);
const auto* const_to_tz = context->get_constant_col(2);

// ConvertTzState is used only when both the second and third parameters are constants
if (const_from_tz != nullptr && const_to_tz != nullptr) {
state->use_state = true;
init_convert_tz_state(state, const_from_tz, const_to_tz);
} else {
state->use_state = false;
}

return IFunction::open(context, scope);
}

void init_convert_tz_state(std::shared_ptr<ConvertTzState> state,
const ColumnPtrWrapper* const_from_tz,
const ColumnPtrWrapper* const_to_tz) {
auto const_data_from_tz = const_from_tz->column_ptr->get_data_at(0);
auto const_data_to_tz = const_to_tz->column_ptr->get_data_at(0);

// from_tz and to_tz must both be non-null.
if (const_data_from_tz.data == nullptr || const_data_to_tz.data == nullptr) {
state->is_valid = false;
return;
}

auto from_tz_name = const_data_from_tz.to_string();
auto to_tz_name = const_data_to_tz.to_string();

if (!TimezoneUtils::find_cctz_time_zone(from_tz_name, state->from_tz)) {
state->is_valid = false;
return;
}
if (!TimezoneUtils::find_cctz_time_zone(to_tz_name, state->to_tz)) {
state->is_valid = false;
return;
}
state->is_valid = true;
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
auto* convert_tz_state = reinterpret_cast<ConvertTzState*>(
context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
if (!convert_tz_state) {
return Status::RuntimeError(
"funciton context for function '{}' must have ConvertTzState;", get_name());
}
auto result_null_map_column = ColumnUInt8::create(input_rows_count, 0);

bool col_const[3];
Expand All @@ -106,12 +167,16 @@ class FunctionConvertTZ : public IFunction {

if (col_const[1] && col_const[2]) {
auto result_column = ColumnType::create();
execute_tz_const(context, assert_cast<const ColumnType*>(argument_columns[0].get()),
assert_cast<const ColumnString*>(argument_columns[1].get()),
assert_cast<const ColumnString*>(argument_columns[2].get()),
assert_cast<ReturnColumnType*>(result_column.get()),
assert_cast<ColumnUInt8*>(result_null_map_column.get())->get_data(),
input_rows_count);
if (convert_tz_state->use_state) {
execute_tz_const_with_state(
convert_tz_state, assert_cast<const ColumnType*>(argument_columns[0].get()),
assert_cast<ReturnColumnType*>(result_column.get()),
assert_cast<ColumnUInt8*>(result_null_map_column.get())->get_data(),
input_rows_count);
} else {
return Status::RuntimeError("ConvertTzState is not initialized in function {}",
get_name());
}
block.get_by_position(result).column = ColumnNullable::create(
std::move(result_column), std::move(result_null_map_column));
} else {
Expand Down Expand Up @@ -144,18 +209,55 @@ class FunctionConvertTZ : public IFunction {
}
}

static void execute_tz_const(FunctionContext* context, const ColumnType* date_column,
const ColumnString* from_tz_column,
const ColumnString* to_tz_column, ReturnColumnType* result_column,
NullMap& result_null_map, size_t input_rows_count) {
auto from_tz = from_tz_column->get_data_at(0).to_string();
auto to_tz = to_tz_column->get_data_at(0).to_string();
static void execute_tz_const_with_state(ConvertTzState* convert_tz_state,
const ColumnType* date_column,
ReturnColumnType* result_column,
NullMap& result_null_map, size_t input_rows_count) {
cctz::time_zone& from_tz = convert_tz_state->from_tz;
cctz::time_zone& to_tz = convert_tz_state->to_tz;
auto push_null = [&](int row) {
result_null_map[row] = true;
result_column->insert_default();
};
if (!convert_tz_state->is_valid) {
// If an invalid timezone is present, return null
for (size_t i = 0; i < input_rows_count; i++) {
push_null(i);
}
return;
}
for (size_t i = 0; i < input_rows_count; i++) {
if (result_null_map[i]) {
result_column->insert_default();
continue;
}
execute_inner_loop(date_column, from_tz, to_tz, result_column, result_null_map, i);

DateValueType ts_value =
binary_cast<NativeType, DateValueType>(date_column->get_element(i));
ReturnDateValueType ts_value2;

if constexpr (std::is_same_v<ArgDateType, DataTypeDateTimeV2>) {
std::pair<int64_t, int64_t> timestamp;
if (!ts_value.unix_timestamp(&timestamp, from_tz)) {
push_null(i);
continue;
}
ts_value2.from_unixtime(timestamp, to_tz);
} else {
int64_t timestamp;
if (!ts_value.unix_timestamp(&timestamp, from_tz)) {
push_null(i);
continue;
}
ts_value2.from_unixtime(timestamp, to_tz);
}

if (!ts_value2.is_valid_date()) [[unlikely]] {
push_null(i);
continue;
}

result_column->insert(binary_cast<ReturnDateValueType, ReturnNativeType>(ts_value2));
}
}

Expand Down