diff --git a/be/src/exprs/math_functions.cpp b/be/src/exprs/math_functions.cpp index ba4fd7f2c13962..6fab6485d3bbea 100644 --- a/be/src/exprs/math_functions.cpp +++ b/be/src/exprs/math_functions.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "common/compiler_util.h" @@ -276,10 +277,17 @@ DoubleVal MathFunctions::pow( void MathFunctions::rand_prepare( FunctionContext* ctx, FunctionContext::FunctionStateScope scope) { + std::mt19937* generator = reinterpret_cast( + ctx->allocate(sizeof(std::mt19937))); + if (UNLIKELY(generator == NULL)) { + LOG(ERROR) << "allocate random seed generator failed."; + return; + } + ctx->set_function_state(scope, generator); + new (generator) std::mt19937(); if (scope == FunctionContext::THREAD_LOCAL) { - uint32_t* seed = reinterpret_cast(ctx->allocate(sizeof(uint32_t))); - ctx->set_function_state(scope, seed); if (ctx->get_num_args() == 1) { + uint32_t seed = 0; // This is a call to RandSeed, initialize the seed // TODO: should we support non-constant seed? if (!ctx->is_arg_constant(0)) { @@ -287,25 +295,24 @@ void MathFunctions::rand_prepare( return; } BigIntVal* seed_arg = static_cast(ctx->get_constant_arg(0)); - if (seed_arg->is_null) { - seed = NULL; - } else { - *seed = seed_arg->val; + if (!seed_arg->is_null) { + seed = seed_arg->val; } + generator->seed(seed); } else { - // This is a call to Rand, initialize seed to 0 - // TODO: can we change this behavior? This is stupid. - *seed = 0; + generator->seed(std::random_device()()); } } } DoubleVal MathFunctions::rand(FunctionContext* ctx) { - uint32_t* seed = reinterpret_cast( + std::mt19937* generator = reinterpret_cast( ctx->get_function_state(FunctionContext::THREAD_LOCAL)); - *seed = ::rand_r(seed); - // Normalize to [0,1]. - return DoubleVal(static_cast(*seed) / RAND_MAX); + DCHECK(generator != nullptr); + static const double min = 0.0; + static const double max = 1.0; + std::uniform_real_distribution distribution(min, max); + return DoubleVal(distribution(*generator)); } DoubleVal MathFunctions::rand_seed(FunctionContext* ctx, const BigIntVal& seed) { @@ -315,6 +322,16 @@ DoubleVal MathFunctions::rand_seed(FunctionContext* ctx, const BigIntVal& seed) return rand(ctx); } +void MathFunctions::rand_close(FunctionContext* ctx, + FunctionContext::FunctionStateScope scope) { + if (scope == FunctionContext::THREAD_LOCAL) { + uint8_t* generator = reinterpret_cast( + ctx->get_function_state(FunctionContext::THREAD_LOCAL)); + ctx->free(generator); + ctx->set_function_state(FunctionContext::THREAD_LOCAL, nullptr); + } +} + StringVal MathFunctions::bin(FunctionContext* ctx, const BigIntVal& v) { if (v.is_null) { return StringVal::null(); diff --git a/be/src/exprs/math_functions.h b/be/src/exprs/math_functions.h index 47a919e0f98094..6d729b364632eb 100644 --- a/be/src/exprs/math_functions.h +++ b/be/src/exprs/math_functions.h @@ -91,12 +91,14 @@ class MathFunctions { doris_udf::FunctionContext* ctx, const doris_udf::DoubleVal& base, const doris_udf::DoubleVal& exp); - /// Used for both Rand() and RandSeed() + /// Used for both rand() and rand_seed() static void rand_prepare( doris_udf::FunctionContext*, doris_udf::FunctionContext::FunctionStateScope); static doris_udf::DoubleVal rand(doris_udf::FunctionContext*); static doris_udf::DoubleVal rand_seed( doris_udf::FunctionContext*, const doris_udf::BigIntVal& seed); + static void rand_close( + FunctionContext* ctx, FunctionContext::FunctionStateScope scope); static doris_udf::StringVal bin( doris_udf::FunctionContext* ctx, const doris_udf::BigIntVal& v); diff --git a/be/src/testutil/function_utils.cpp b/be/src/testutil/function_utils.cpp index 0506dcc462a511..e7b3108b26f2c0 100644 --- a/be/src/testutil/function_utils.cpp +++ b/be/src/testutil/function_utils.cpp @@ -22,7 +22,6 @@ #include "runtime/mem_pool.h" #include "runtime/mem_tracker.h" #include "udf/udf_internal.h" -#include "udf/udf.h" namespace doris { @@ -44,6 +43,14 @@ FunctionUtils::FunctionUtils(RuntimeState* state) { _state, _memory_pool, return_type, arg_types, 0, false); } +FunctionUtils::FunctionUtils(const doris_udf::FunctionContext::TypeDesc& return_type, + const std::vector& arg_types, int varargs_buffer_size) { + _mem_tracker.reset(new MemTracker(-1, "function util")); + _memory_pool = new MemPool(_mem_tracker.get()); + _fn_ctx = FunctionContextImpl::create_context( + _state, _memory_pool, return_type, arg_types, varargs_buffer_size, false); +} + FunctionUtils::~FunctionUtils() { _fn_ctx->impl()->close(); delete _fn_ctx; diff --git a/be/src/testutil/function_utils.h b/be/src/testutil/function_utils.h index 041f6b2b557e5b..b3c8af6e235a47 100644 --- a/be/src/testutil/function_utils.h +++ b/be/src/testutil/function_utils.h @@ -16,10 +16,9 @@ // under the License. #include +#include -namespace doris_udf { -class FunctionContext; -} +#include "udf/udf.h" namespace doris { @@ -31,6 +30,8 @@ class FunctionUtils { public: FunctionUtils(); FunctionUtils(RuntimeState* state); + FunctionUtils(const doris_udf::FunctionContext::TypeDesc& return_type, + const std::vector& arg_types, int varargs_buffer_size); ~FunctionUtils(); doris_udf::FunctionContext* get_fn_ctx() { diff --git a/be/test/exprs/math_functions_test.cpp b/be/test/exprs/math_functions_test.cpp index d4dbdbb909fa17..49a5d1de2d0b6b 100644 --- a/be/test/exprs/math_functions_test.cpp +++ b/be/test/exprs/math_functions_test.cpp @@ -127,8 +127,42 @@ TEST_F(MathFunctionsTest, abs) { ASSERT_EQ(siv2, MathFunctions::abs(ctx, TinyIntVal(-1))); ASSERT_EQ(siv3, MathFunctions::abs(ctx, TinyIntVal(INT8_MAX))); ASSERT_EQ(siv4, MathFunctions::abs(ctx, TinyIntVal(INT8_MIN))); +} - +TEST_F(MathFunctionsTest, rand) { + doris_udf::FunctionContext::TypeDesc type; + type.type = doris_udf::FunctionContext::TYPE_DOUBLE; + std::vector arg_types; + doris_udf::FunctionContext::TypeDesc type1; + type1.type = doris_udf::FunctionContext::TYPE_BIGINT; + arg_types.push_back(type1); + FunctionUtils* utils1 = new FunctionUtils(type, arg_types, 8); + FunctionContext* ctx1 = utils1->get_fn_ctx(); + std::vector constant_args; + BigIntVal bi(1); + constant_args.push_back(&bi); + ctx1->impl()->set_constant_args(constant_args); + + MathFunctions::rand_prepare(ctx1, FunctionContext::THREAD_LOCAL); + DoubleVal dv1 = MathFunctions::rand_seed(ctx1, BigIntVal(0)); + MathFunctions::rand_close(ctx1, FunctionContext::THREAD_LOCAL); + + MathFunctions::rand_prepare(ctx1, FunctionContext::THREAD_LOCAL); + DoubleVal dv2 = MathFunctions::rand_seed(ctx1, BigIntVal(0)); + MathFunctions::rand_close(ctx1, FunctionContext::THREAD_LOCAL); + + ASSERT_EQ(dv1.val, dv2.val); + delete utils1; + + MathFunctions::rand_prepare(ctx, FunctionContext::THREAD_LOCAL); + DoubleVal dv3 = MathFunctions::rand(ctx); + MathFunctions::rand_close(ctx, FunctionContext::THREAD_LOCAL); + + MathFunctions::rand_prepare(ctx, FunctionContext::THREAD_LOCAL); + DoubleVal dv4 = MathFunctions::rand(ctx); + MathFunctions::rand_close(ctx, FunctionContext::THREAD_LOCAL); + + ASSERT_NE(dv3.val, dv4.val); } } // namespace doris diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index bd44884a32092c..e8c9aa485f0fd5 100755 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -338,10 +338,14 @@ [['rand', 'random'], 'DOUBLE', [], '_ZN5doris13MathFunctions4randEPN9doris_udf15FunctionContextE', '_ZN5doris13MathFunctions12rand_prepareEPN9doris_udf' + '15FunctionContextENS2_18FunctionStateScopeE', + '_ZN5doris13MathFunctions10rand_closeEPN9doris_udf' '15FunctionContextENS2_18FunctionStateScopeE'], [['rand', 'random'], 'DOUBLE', ['BIGINT'], '_ZN5doris13MathFunctions9rand_seedEPN9doris_udf15FunctionContextERKNS1_9BigIntValE', '_ZN5doris13MathFunctions12rand_prepareEPN9doris_udf' + '15FunctionContextENS2_18FunctionStateScopeE', + '_ZN5doris13MathFunctions10rand_closeEPN9doris_udf' '15FunctionContextENS2_18FunctionStateScopeE'], [['bin'], 'VARCHAR', ['BIGINT'],