Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions be/src/exprs/math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <iomanip>
#include <sstream>
#include <cmath>
#include <random>
#include <stdlib.h>

#include "common/compiler_util.h"
Expand Down Expand Up @@ -276,36 +277,42 @@ DoubleVal MathFunctions::pow(

void MathFunctions::rand_prepare(
FunctionContext* ctx, FunctionContext::FunctionStateScope scope) {
std::mt19937* generator = reinterpret_cast<std::mt19937*>(
ctx->allocate(sizeof(std::mt19937)));
if (UNLIKELY(generator == NULL)) {
LOG(ERROR) << "allocate random seed generator failed.";
return;
}
ctx->set_function_state(scope, generator);
new (generator) std::mt19937();
if (scope == FunctionContext::THREAD_LOCAL) {
uint32_t* seed = reinterpret_cast<uint32_t*>(ctx->allocate(sizeof(uint32_t)));
ctx->set_function_state(scope, seed);
if (ctx->get_num_args() == 1) {
uint32_t seed = 0;
// This is a call to RandSeed, initialize the seed
// TODO: should we support non-constant seed?
if (!ctx->is_arg_constant(0)) {
ctx->set_error("Seed argument to rand() must be constant");
return;
}
BigIntVal* seed_arg = static_cast<BigIntVal*>(ctx->get_constant_arg(0));
if (seed_arg->is_null) {
seed = NULL;
} else {
*seed = seed_arg->val;
if (!seed_arg->is_null) {
seed = seed_arg->val;
}
generator->seed(seed);
} else {
// This is a call to Rand, initialize seed to 0
// TODO: can we change this behavior? This is stupid.
*seed = 0;
generator->seed(std::random_device()());
}
}
}

DoubleVal MathFunctions::rand(FunctionContext* ctx) {
uint32_t* seed = reinterpret_cast<uint32_t*>(
std::mt19937* generator = reinterpret_cast<std::mt19937*>(
ctx->get_function_state(FunctionContext::THREAD_LOCAL));
*seed = ::rand_r(seed);
// Normalize to [0,1].
return DoubleVal(static_cast<double>(*seed) / RAND_MAX);
DCHECK(generator != nullptr);
static const double min = 0.0;
static const double max = 1.0;
std::uniform_real_distribution<double> distribution(min, max);
return DoubleVal(distribution(*generator));
}

DoubleVal MathFunctions::rand_seed(FunctionContext* ctx, const BigIntVal& seed) {
Expand All @@ -315,6 +322,16 @@ DoubleVal MathFunctions::rand_seed(FunctionContext* ctx, const BigIntVal& seed)
return rand(ctx);
}

void MathFunctions::rand_close(FunctionContext* ctx,
FunctionContext::FunctionStateScope scope) {
if (scope == FunctionContext::THREAD_LOCAL) {
uint8_t* generator = reinterpret_cast<uint8_t*>(
ctx->get_function_state(FunctionContext::THREAD_LOCAL));
ctx->free(generator);
ctx->set_function_state(FunctionContext::THREAD_LOCAL, nullptr);
}
}

StringVal MathFunctions::bin(FunctionContext* ctx, const BigIntVal& v) {
if (v.is_null) {
return StringVal::null();
Expand Down
4 changes: 3 additions & 1 deletion be/src/exprs/math_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,14 @@ class MathFunctions {
doris_udf::FunctionContext* ctx, const doris_udf::DoubleVal& base,
const doris_udf::DoubleVal& exp);

/// Used for both Rand() and RandSeed()
/// Used for both rand() and rand_seed()
static void rand_prepare(
doris_udf::FunctionContext*, doris_udf::FunctionContext::FunctionStateScope);
static doris_udf::DoubleVal rand(doris_udf::FunctionContext*);
static doris_udf::DoubleVal rand_seed(
doris_udf::FunctionContext*, const doris_udf::BigIntVal& seed);
static void rand_close(
FunctionContext* ctx, FunctionContext::FunctionStateScope scope);

static doris_udf::StringVal bin(
doris_udf::FunctionContext* ctx, const doris_udf::BigIntVal& v);
Expand Down
9 changes: 8 additions & 1 deletion be/src/testutil/function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "runtime/mem_pool.h"
#include "runtime/mem_tracker.h"
#include "udf/udf_internal.h"
#include "udf/udf.h"

namespace doris {

Expand All @@ -44,6 +43,14 @@ FunctionUtils::FunctionUtils(RuntimeState* state) {
_state, _memory_pool, return_type, arg_types, 0, false);
}

FunctionUtils::FunctionUtils(const doris_udf::FunctionContext::TypeDesc& return_type,
const std::vector<doris_udf::FunctionContext::TypeDesc>& arg_types, int varargs_buffer_size) {
_mem_tracker.reset(new MemTracker(-1, "function util"));
_memory_pool = new MemPool(_mem_tracker.get());
_fn_ctx = FunctionContextImpl::create_context(
_state, _memory_pool, return_type, arg_types, varargs_buffer_size, false);
}

FunctionUtils::~FunctionUtils() {
_fn_ctx->impl()->close();
delete _fn_ctx;
Expand Down
7 changes: 4 additions & 3 deletions be/src/testutil/function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
// under the License.

#include <memory>
#include <vector>

namespace doris_udf {
class FunctionContext;
}
#include "udf/udf.h"

namespace doris {

Expand All @@ -31,6 +30,8 @@ class FunctionUtils {
public:
FunctionUtils();
FunctionUtils(RuntimeState* state);
FunctionUtils(const doris_udf::FunctionContext::TypeDesc& return_type,
const std::vector<doris_udf::FunctionContext::TypeDesc>& arg_types, int varargs_buffer_size);
~FunctionUtils();

doris_udf::FunctionContext* get_fn_ctx() {
Expand Down
36 changes: 35 additions & 1 deletion be/test/exprs/math_functions_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,42 @@ TEST_F(MathFunctionsTest, abs) {
ASSERT_EQ(siv2, MathFunctions::abs(ctx, TinyIntVal(-1)));
ASSERT_EQ(siv3, MathFunctions::abs(ctx, TinyIntVal(INT8_MAX)));
ASSERT_EQ(siv4, MathFunctions::abs(ctx, TinyIntVal(INT8_MIN)));
}


TEST_F(MathFunctionsTest, rand) {
doris_udf::FunctionContext::TypeDesc type;
type.type = doris_udf::FunctionContext::TYPE_DOUBLE;
std::vector<doris_udf::FunctionContext::TypeDesc> arg_types;
doris_udf::FunctionContext::TypeDesc type1;
type1.type = doris_udf::FunctionContext::TYPE_BIGINT;
arg_types.push_back(type1);
FunctionUtils* utils1 = new FunctionUtils(type, arg_types, 8);
FunctionContext* ctx1 = utils1->get_fn_ctx();
std::vector<doris_udf::AnyVal*> constant_args;
BigIntVal bi(1);
constant_args.push_back(&bi);
ctx1->impl()->set_constant_args(constant_args);

MathFunctions::rand_prepare(ctx1, FunctionContext::THREAD_LOCAL);
DoubleVal dv1 = MathFunctions::rand_seed(ctx1, BigIntVal(0));
MathFunctions::rand_close(ctx1, FunctionContext::THREAD_LOCAL);

MathFunctions::rand_prepare(ctx1, FunctionContext::THREAD_LOCAL);
DoubleVal dv2 = MathFunctions::rand_seed(ctx1, BigIntVal(0));
MathFunctions::rand_close(ctx1, FunctionContext::THREAD_LOCAL);

ASSERT_EQ(dv1.val, dv2.val);
delete utils1;

MathFunctions::rand_prepare(ctx, FunctionContext::THREAD_LOCAL);
DoubleVal dv3 = MathFunctions::rand(ctx);
MathFunctions::rand_close(ctx, FunctionContext::THREAD_LOCAL);

MathFunctions::rand_prepare(ctx, FunctionContext::THREAD_LOCAL);
DoubleVal dv4 = MathFunctions::rand(ctx);
MathFunctions::rand_close(ctx, FunctionContext::THREAD_LOCAL);

ASSERT_NE(dv3.val, dv4.val);
}

} // namespace doris
Expand Down
4 changes: 4 additions & 0 deletions gensrc/script/doris_builtins_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,14 @@
[['rand', 'random'], 'DOUBLE', [],
'_ZN5doris13MathFunctions4randEPN9doris_udf15FunctionContextE',
'_ZN5doris13MathFunctions12rand_prepareEPN9doris_udf'
'15FunctionContextENS2_18FunctionStateScopeE',
'_ZN5doris13MathFunctions10rand_closeEPN9doris_udf'
'15FunctionContextENS2_18FunctionStateScopeE'],
[['rand', 'random'], 'DOUBLE', ['BIGINT'],
'_ZN5doris13MathFunctions9rand_seedEPN9doris_udf15FunctionContextERKNS1_9BigIntValE',
'_ZN5doris13MathFunctions12rand_prepareEPN9doris_udf'
'15FunctionContextENS2_18FunctionStateScopeE',
'_ZN5doris13MathFunctions10rand_closeEPN9doris_udf'
'15FunctionContextENS2_18FunctionStateScopeE'],

[['bin'], 'VARCHAR', ['BIGINT'],
Expand Down