Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 75 additions & 39 deletions be/src/vec/functions/round.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
#pragma once

#include <cstddef>
#include <memory>

#include "common/exception.h"
#include "common/status.h"
#include "vec/columns/column_const.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/functions/function.h"
#if defined(__SSE4_1__) || defined(__aarch64__)
#include "util/sse_util.hpp"
Expand Down Expand Up @@ -430,7 +434,10 @@ struct Dispatcher {
FloatRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>>;

static ColumnPtr apply_vec_const(const IColumn* col_general, Int16 scale_arg) {
// scale_arg: scale for function computation
// result_scale: scale for result decimal, this scale is got from planner
static ColumnPtr apply_vec_const(const IColumn* col_general, const Int16 scale_arg,
[[maybe_unused]] Int16 result_scale) {
if constexpr (IsNumber<T>) {
const auto* const col = check_and_get_column<ColumnVector<T>>(col_general);
auto col_res = ColumnVector<T>::create();
Expand All @@ -457,17 +464,35 @@ struct Dispatcher {
} else if constexpr (IsDecimalNumber<T>) {
const auto* const decimal_col = check_and_get_column<ColumnDecimal<T>>(col_general);
const auto& vec_src = decimal_col->get_data();

UInt32 result_scale =
std::min(static_cast<UInt32>(std::max(scale_arg, static_cast<Int16>(0))),
decimal_col->get_scale());
const size_t input_rows_count = vec_src.size();
auto col_res = ColumnDecimal<T>::create(vec_src.size(), result_scale);
auto& vec_res = col_res->get_data();

if (!vec_res.empty()) {
FunctionRoundingImpl<ScaleMode::Negative>::apply(
decimal_col->get_data(), decimal_col->get_scale(), vec_res, scale_arg);
}
// We need to always make sure result decimal's scale is as expected as its in plan
// So we need to append enough zero to result.

// Case 0: scale_arg <= -(integer part digits count)
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count)
// decimal parts has been erased, so add them back by multiply 10^(result_scale)
// Case 2: scale_arg > 0 && scale_arg < result_scale
// decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing

if (scale_arg <= 0) {
for (size_t i = 0; i < input_rows_count; ++i) {
vec_res[i].value *= int_exp10(result_scale);
}
} else if (scale_arg > 0 && scale_arg < result_scale) {
for (size_t i = 0; i < input_rows_count; ++i) {
vec_res[i].value *= int_exp10(result_scale - scale_arg);
}
}

return col_res;
} else {
Expand All @@ -477,7 +502,9 @@ struct Dispatcher {
}
}

static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale) {
// result_scale: scale for result decimal, this scale is got from planner
static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale,
[[maybe_unused]] Int16 result_scale) {
const auto& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale);
const size_t input_row_count = col_scale_i32.size();
for (size_t i = 0; i < input_row_count; ++i) {
Expand Down Expand Up @@ -515,10 +542,8 @@ struct Dispatcher {
return col_res;
} else if constexpr (IsDecimalNumber<T>) {
const auto* decimal_col = assert_cast<const ColumnDecimal<T>*>(col_general);

// ALWAYS use SAME scale with source Decimal column
const Int32 input_scale = decimal_col->get_scale();
auto col_res = ColumnDecimal<T>::create(input_row_count, input_scale);
auto col_res = ColumnDecimal<T>::create(input_row_count, result_scale);

for (size_t i = 0; i < input_row_count; ++i) {
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(
Expand All @@ -534,15 +559,15 @@ struct Dispatcher {
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count)
// decimal parts has been erased, so add them back by multiply 10^(scale_arg)
// Case 2: scale_arg > 0 && scale_arg < decimal part digits count
// decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg)
// Case 2: scale_arg > 0 && scale_arg < result_scale
// decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing
const Int32 scale_arg = col_scale_i32.get_data()[i];
if (scale_arg <= 0) {
col_res->get_element(i).value *= int_exp10(input_scale);
} else if (scale_arg > 0 && scale_arg < input_scale) {
col_res->get_element(i).value *= int_exp10(input_scale - scale_arg);
col_res->get_element(i).value *= int_exp10(result_scale);
} else if (scale_arg > 0 && scale_arg < result_scale) {
col_res->get_element(i).value *= int_exp10(result_scale - scale_arg);
}
}

Expand All @@ -554,8 +579,9 @@ struct Dispatcher {
}
}

static ColumnPtr apply_const_vec(const ColumnConst* const_col_general,
const IColumn* col_scale) {
// result_scale: scale for result decimal, this scale is got from planner
static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, const IColumn* col_scale,
[[maybe_unused]] Int16 result_scale) {
const auto& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale);
const size_t input_rows_count = col_scale->size();

Expand All @@ -575,8 +601,7 @@ struct Dispatcher {
assert_cast<const ColumnDecimal<T>&>(const_col_general->get_data_column());
const T& general_val = data_col_general.get_data()[0];
Int32 input_scale = data_col_general.get_scale();

auto col_res = ColumnDecimal<T>::create(input_rows_count, input_scale);
auto col_res = ColumnDecimal<T>::create(input_rows_count, result_scale);

for (size_t i = 0; i < input_rows_count; ++i) {
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(
Expand All @@ -592,15 +617,15 @@ struct Dispatcher {
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count)
// decimal parts has been erased, so add them back by multiply 10^(scale_arg)
// Case 2: scale_arg > 0 && scale_arg < decimal part digits count
// decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg)
// Case 2: scale_arg > 0 && scale_arg < result_scale
// decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing
const Int32 scale_arg = col_scale_i32.get_data()[i];
if (scale_arg <= 0) {
col_res->get_element(i).value *= int_exp10(input_scale);
} else if (scale_arg > 0 && scale_arg < input_scale) {
col_res->get_element(i).value *= int_exp10(input_scale - scale_arg);
col_res->get_element(i).value *= int_exp10(result_scale);
} else if (scale_arg > 0 && scale_arg < result_scale) {
col_res->get_element(i).value *= int_exp10(result_scale - scale_arg);
}
}

Expand Down Expand Up @@ -679,33 +704,47 @@ class FunctionRounding : public IFunction {
return Status::OK();
}

/// SELECT number, truncate(123.345, 1) FROM number("numbers"="10")
/// should NOT behave like two column arguments, so we can not use const column default implementation
bool use_default_implementation_for_constants() const override { return false; }
bool use_default_implementation_for_constants() const override { return true; }

//// We moved and optimized the execute_impl logic of function_truncate.h from PR#32746,
//// as well as make it suitable for all functions.
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]);
ColumnWithTypeAndName& column_result = block.get_by_position(result);
const DataTypePtr result_type = block.get_by_position(result).type;
const bool is_col_general_const = is_column_const(*column_general.column);
const auto* col_general = is_col_general_const
? assert_cast<const ColumnConst&>(*column_general.column)
.get_data_column_ptr()
: column_general.column.get();

ColumnPtr res;

/// potential argument types:
/// if the SECOND argument is MISSING(would be considered as ZERO const) or CONST, then we have the following type:
/// 1. func(Column), func(ColumnConst), func(Column, ColumnConst), func(ColumnConst, ColumnConst)
/// 1. func(Column), func(Column, ColumnConst)
/// otherwise, the SECOND arugment is COLUMN, we have another type:
/// 2. func(Column, Column), func(ColumnConst, Column)

auto call = [&](const auto& types) -> bool {
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType;

// For decimal, we will always make sure result Decimal has exactly same precision and scale with
// arguments from query plan.
Int16 result_scale = 0;
if constexpr (IsDataTypeDecimal<DataType>) {
if (column_result.type->get_type_id() == TypeIndex::Nullable) {
if (auto nullable_type = std::dynamic_pointer_cast<const DataTypeNullable>(
column_result.type)) {
result_scale = nullable_type->get_nested_type()->get_scale();
} else {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"Illegal nullable column");
}
} else {
result_scale = column_result.type->get_scale();
}
}

if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) {
using FieldType = typename DataType::FieldType;
if (arguments.size() == 1 ||
Expand All @@ -718,23 +757,20 @@ class FunctionRounding : public IFunction {
}

res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply_vec_const(
col_general, scale_arg);

if (is_col_general_const) {
// Important, make sure the result column has the same size as the input column
res = ColumnConst::create(std::move(res), input_rows_count);
}
col_general, scale_arg, result_scale);
} else {
// the SECOND arugment is COLUMN
if (is_col_general_const) {
res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::
apply_const_vec(
&assert_cast<const ColumnConst&>(*column_general.column),
block.get_by_position(arguments[1]).column.get());
block.get_by_position(arguments[1]).column.get(),
result_scale);
} else {
res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::
apply_vec_vec(col_general,
block.get_by_position(arguments[1]).column.get());
block.get_by_position(arguments[1]).column.get(),
result_scale);
}
}
return true;
Expand All @@ -758,7 +794,7 @@ class FunctionRounding : public IFunction {
column_general.type->get_name(), name);
}

block.replace_by_position(result, std::move(res));
column_result.column = std::move(res);
return Status::OK();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ default FunctionSignature computePrecision(FunctionSignature signature) {
Expression floatLength = getArgument(1);
int scale;

if (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral()
// If scale arg is an integer literal, or it is a cast(Integer as Integer)
// then we will try to use its value as result scale
// In any other cases, we will make sure result decimal has same scale with input.
if ((floatLength.isLiteral() && floatLength.getDataType() instanceof Int32OrLessType)
|| (floatLength instanceof Cast && floatLength.child(0).isLiteral()
&& floatLength.child(0).getDataType() instanceof Int32OrLessType)) {
// Scale argument is a literal or cast from other literal
if (floatLength instanceof Cast) {
scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue();
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,115 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select --
123.100

-- !select --
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100

-- !select --
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000

-- !select --
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100

-- !select --
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000

-- !select --
123.200
123.200
123.200
123.200
123.200
123.200
123.200
123.200
123.200
123.200

-- !select --
130.000
130.000
130.000
130.000
130.000
130.000
130.000
130.000
130.000
130.000

-- !select --
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100

-- !select --
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000

-- !select --
4434.41

-- !select --
0

-- !select --
false \N 4434

-- !select --
0

-- !select --
10

Expand Down Expand Up @@ -97,6 +208,18 @@
-- !select --
16.025 16.02500 16.02500

-- !select_fix --
16.025 16.02500 16.02500

-- !select_fix --
16.025 16.02500 16.02500

-- !select_fix --
16.025 16.02500 16.02500

-- !select_fix --
16.025 16.02500 16.02500

-- !nereids_round_arg1 --
10

Expand Down
Loading