diff --git a/be/src/exprs/timestamp_functions.cpp b/be/src/exprs/timestamp_functions.cpp index ec968583e56ebc..3658ce21f5af41 100644 --- a/be/src/exprs/timestamp_functions.cpp +++ b/be/src/exprs/timestamp_functions.cpp @@ -884,39 +884,43 @@ void TimestampFunctions::convert_tz_prepare(doris_udf::FunctionContext* context, doris_udf::FunctionContext::FunctionStateScope scope) { if (scope != FunctionContext::FRAGMENT_LOCAL || context->get_num_args() != 3 || context->get_arg_type(1)->type != doris_udf::FunctionContext::Type::TYPE_VARCHAR || - context->get_arg_type(2)->type != doris_udf::FunctionContext::Type::TYPE_VARCHAR || - !context->is_arg_constant(1) || !context->is_arg_constant(2)) { + context->get_arg_type(2)->type != doris_udf::FunctionContext::Type::TYPE_VARCHAR) { return; } ConvertTzCtx* ctc = new ConvertTzCtx(); context->set_function_state(scope, ctc); - // find from timezone - StringVal* from = reinterpret_cast(context->get_constant_arg(1)); - if (UNLIKELY(from->is_null)) { - ctc->is_valid = false; - return; - } - if (!TimezoneUtils::find_cctz_time_zone(std::string((char*)from->ptr, from->len), - ctc->from_tz)) { - ctc->is_valid = false; - return; + if (context->is_arg_constant(1)) { + // find from timezone + StringVal* from = reinterpret_cast(context->get_constant_arg(1)); + if (UNLIKELY(from->is_null)) { + ctc->is_valid = false; + return; + } + if (!TimezoneUtils::find_cctz_time_zone(std::string((char*)from->ptr, from->len), + ctc->from_tz)) { + ctc->is_valid = false; + return; + } + ctc->constant_from = true; } - // find to timezone - StringVal* to = reinterpret_cast(context->get_constant_arg(2)); - if (UNLIKELY(to->is_null)) { - ctc->is_valid = false; - return; - } - if (!TimezoneUtils::find_cctz_time_zone(std::string((char*)to->ptr, to->len), ctc->to_tz)) { - ctc->is_valid = false; - return; + if (context->is_arg_constant(2)) { + // find to timezone + StringVal* to = reinterpret_cast(context->get_constant_arg(2)); + if (UNLIKELY(to->is_null)) { + ctc->is_valid = false; + return; + } + if (!TimezoneUtils::find_cctz_time_zone(std::string((char*)to->ptr, to->len), ctc->to_tz)) { + ctc->is_valid = false; + return; + } + ctc->constant_to = true; } ctc->is_valid = true; - return; } DateTimeVal TimestampFunctions::convert_tz(FunctionContext* ctx, const DateTimeVal& ts_val, @@ -944,12 +948,40 @@ DateTimeVal TimestampFunctions::convert_tz(FunctionContext* ctx, const DateTimeV } int64_t timestamp; - if (!ts_value.unix_timestamp(×tamp, ctc->from_tz)) { - return DateTimeVal::null(); + + if (ctc->constant_from) { + if (!ts_value.unix_timestamp(×tamp, ctc->from_tz)) { + return DateTimeVal::null(); + } + } else { + auto from_tz_string = from_tz.to_string(); + if (UNLIKELY(ctc->time_zone_cache.find(from_tz_string) == ctc->time_zone_cache.cend())) { + if (UNLIKELY(!TimezoneUtils::find_cctz_time_zone( + from_tz_string, ctc->time_zone_cache[from_tz_string]))) { + return DateTimeVal::null(); + } + } + if (!ts_value.unix_timestamp(×tamp, ctc->time_zone_cache[from_tz_string])) { + return DateTimeVal::null(); + } } + DateTimeValue ts_value2; - if (!ts_value2.from_unixtime(timestamp, ctc->to_tz)) { - return DateTimeVal::null(); + if (ctc->constant_to) { + if (!ts_value2.from_unixtime(timestamp, ctc->to_tz)) { + return DateTimeVal::null(); + } + } else { + auto to_tz_string = to_tz.to_string(); + if (UNLIKELY(ctc->time_zone_cache.find(to_tz_string) == ctc->time_zone_cache.cend())) { + if (UNLIKELY(!TimezoneUtils::find_cctz_time_zone(to_tz_string, + ctc->time_zone_cache[to_tz_string]))) { + return DateTimeVal::null(); + } + } + if (!ts_value2.from_unixtime(timestamp, ctc->time_zone_cache[to_tz_string])) { + return DateTimeVal::null(); + } } DateTimeVal return_val; diff --git a/be/src/exprs/timestamp_functions.h b/be/src/exprs/timestamp_functions.h index e558c8044c3cad..2391de5812984d 100644 --- a/be/src/exprs/timestamp_functions.h +++ b/be/src/exprs/timestamp_functions.h @@ -42,8 +42,11 @@ struct FormatCtx { struct ConvertTzCtx { // false means the format is invalid, and the function always return null bool is_valid = false; + bool constant_from = false; + bool constant_to = false; cctz::time_zone from_tz; cctz::time_zone to_tz; + std::map time_zone_cache; }; class TimestampFunctions { diff --git a/be/src/vec/functions/function_convert_tz.h b/be/src/vec/functions/function_convert_tz.h index f96031c58b5b0e..9d6f1e755d4419 100644 --- a/be/src/vec/functions/function_convert_tz.h +++ b/be/src/vec/functions/function_convert_tz.h @@ -25,6 +25,9 @@ namespace doris::vectorized { +struct ConvertTzCtx { + std::map time_zone_cache; +}; class FunctionConvertTZ : public IFunction { public: static constexpr auto name = "convert_tz"; @@ -42,6 +45,24 @@ class FunctionConvertTZ : public IFunction { bool use_default_implementation_for_constants() const override { return true; } bool use_default_implementation_for_nulls() const override { return false; } + Status prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { + if (scope != FunctionContext::THREAD_LOCAL) { + return Status::OK(); + } + context->set_function_state(scope, new ConvertTzCtx); + return Status::OK(); + } + + Status close(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { + if (scope == FunctionContext::THREAD_LOCAL) { + auto* convert_ctx = reinterpret_cast( + context->get_function_state(FunctionContext::THREAD_LOCAL)); + delete convert_ctx; + context->set_function_state(FunctionContext::THREAD_LOCAL, nullptr); + } + return Status::OK(); + } + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) override { auto result_column = ColumnDateTime::create(); @@ -79,6 +100,10 @@ class FunctionConvertTZ : public IFunction { const ColumnString* from_tz_column, const ColumnString* to_tz_column, ColumnDateTime* result_column, NullMap& result_null_map, size_t input_rows_count) { + auto convert_ctx = reinterpret_cast( + context->get_function_state(FunctionContext::FunctionStateScope::THREAD_LOCAL)); + std::map time_zone_cache_; + auto& time_zone_cache = convert_ctx ? convert_ctx->time_zone_cache : time_zone_cache_; for (size_t i = 0; i < input_rows_count; i++) { if (result_null_map[i]) { result_column->insert_default(); @@ -88,18 +113,36 @@ class FunctionConvertTZ : public IFunction { StringRef from_tz = from_tz_column->get_data_at(i); StringRef to_tz = to_tz_column->get_data_at(i); + if (time_zone_cache.find(from_tz) == time_zone_cache.cend()) { + if (!TimezoneUtils::find_cctz_time_zone(from_tz.to_string(), + time_zone_cache[from_tz])) { + result_null_map[i] = true; + result_column->insert_default(); + continue; + } + } + + if (time_zone_cache.find(to_tz) == time_zone_cache.cend()) { + if (!TimezoneUtils::find_cctz_time_zone(to_tz.to_string(), + time_zone_cache[to_tz])) { + result_null_map[i] = true; + result_column->insert_default(); + continue; + } + } + VecDateTimeValue ts_value = binary_cast(date_column->get_element(i)); int64_t timestamp; - if (!ts_value.unix_timestamp(×tamp, from_tz.to_string())) { + if (!ts_value.unix_timestamp(×tamp, time_zone_cache[from_tz])) { result_null_map[i] = true; result_column->insert_default(); continue; } VecDateTimeValue ts_value2; - if (!ts_value2.from_unixtime(timestamp, to_tz.to_string())) { + if (!ts_value2.from_unixtime(timestamp, time_zone_cache[to_tz])) { result_null_map[i] = true; result_column->insert_default(); continue;