Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 123 additions & 123 deletions be/src/vec/aggregate_functions/aggregate_function_java_udaf.h

Large diffs are not rendered by default.

44 changes: 21 additions & 23 deletions be/src/vec/functions/function_java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,27 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
auto& key_null_map_data =
assert_cast<ColumnVector<UInt8>*>(key_data_column_null_map.get())->get_data();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(key_null_map_data.data());
ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(value_null_map_data.data());
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of map column
for (int i = 0; i < num_rows; ++i) {
jobject obj = env->GetObjectArrayElement(result_obj, i);
if (obj == nullptr) {
continue;
}
element_size = element_size + env->CallIntMethod(obj, map_size);
env->DeleteLocalRef(obj);
}
map_key_column_nullable.resize(element_size);
memset(key_null_map_data.data(), 0, element_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data(), 0, element_size);
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(key_data_column.get());
Expand All @@ -358,16 +379,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}

ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(value_null_map_data.data());
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (value_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(value_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
Expand All @@ -379,20 +391,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of map column
for (int i = 0; i < num_rows; ++i) {
jobject obj = env->GetObjectArrayElement(result_obj, i);
if (obj == nullptr) {
continue;
}
element_size = element_size + env->CallIntMethod(obj, map_size);
env->DeleteLocalRef(obj);
}
map_key_column_nullable.resize(element_size);
memset(key_null_map_data.data(), 0, element_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data(), 0, element_size);
env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_result_map_batch_id, result_nullable,
num_rows, result_obj, nullmap_address, offset_address,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -163,43 +163,6 @@ public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset,
}
}

/**
* invoke add function, add row in loop [rowStart, rowEnd).
*/
public void add(boolean isSinglePlace, long rowStart, long rowEnd) throws UdfRuntimeException {
try {
long idx = rowStart;
do {
Long curPlace = null;
if (isSinglePlace) {
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr));
} else {
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
}
Object[] inputArgs = new Object[argTypes.length + 1];
Object state = stateObjMap.get(curPlace);
if (state != null) {
inputArgs[0] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
inputArgs[0] = newState;
}
do {
Object[] inputObjects = allocateInputObjects(idx, 1);
for (int i = 0; i < argTypes.length; ++i) {
inputArgs[i + 1] = inputObjects[i];
}
allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs);
idx++;
} while (isSinglePlace && idx < rowEnd);
} while (idx < rowEnd);
} catch (Exception e) {
LOG.warn("invoke add function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to add: ", e);
}
}

/**
* invoke user create function to get obj.
*/
Expand Down Expand Up @@ -292,40 +255,71 @@ public void merge(long place, byte[] data) throws UdfRuntimeException {
/**
* invoke getValue to return finally result.
*/
public boolean getValue(long row, long place) throws UdfRuntimeException {

public Object getValue(long place) throws UdfRuntimeException {
try {
if (stateObjMap.get(place) == null) {
stateObjMap.put(place, createAggState());
}
return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)),
row, retClass);
return allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place));
} catch (Exception e) {
LOG.warn("invoke getValue function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to result", e);
}
}

@Override
protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException {
if (obj == null) {
// If result is null, return true directly when row == 0 as we have already inserted default value.
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
public void copyTupleBasicResult(Object result, int row, long outputNullMapPtr, long outputBufferBase,
long charsAddress,
long offsetsAddr) throws UdfRuntimeException {
if (result == null) {
// put null obj
if (outputNullMapPtr == -1) {
throw new UdfRuntimeException("UDAF failed to store null data to not null column");
} else {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 1);
}
return true;
return;
}
try {
if (outputNullMapPtr != -1) {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 0);
}
copyTupleBasicResult(result, row, retClass, outputBufferBase, charsAddress,
offsetsAddr, retType);
} catch (UdfRuntimeException e) {
LOG.info(e.toString());
}
return super.storeUdfResult(obj, row, retClass);
}

@Override
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
if (isArrayType) {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1)));
} else {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
public void copyTupleArrayResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result, nullMapAddr, offsetsAddr, nestedNullMapAddr,
dataAddr, strOffsetAddr, retType.getItemType().getPrimitiveType());
}

public void copyTupleMapResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr,
long offsetsAddr,
long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
PrimitiveType keyType = retType.getKeyType().getPrimitiveType();
PrimitiveType valueType = retType.getValueType().getPrimitiveType();
Object[] keyCol = new Object[1];
Object[] valueCol = new Object[1];
Object[] resultArr = new Object[1];
resultArr[0] = result;
buildArrayListFromHashMap(resultArr, keyType, valueType, keyCol, valueCol);
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row,
valueCol[0], nullMapAddr, offsetsAddr,
valueNsestedNullMapAddr, valueDataAddr, valueStrOffsetAddr, valueType);
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, keyCol[0], nullMapAddr, offsetsAddr,
keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType);
}

@Override
Expand Down
Loading