Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ void TopKImpl(const RunContext &ctx,
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
index_t k = 0;
size_t alignment = std::max(sizeof(DType), sizeof(index_t));
int k = 0;
size_t alignment = std::max(sizeof(DType), sizeof(int));
mxnet::TShape target_shape;
ParseTopKParam(src.shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
Expand All @@ -395,31 +395,31 @@ void TopKImpl(const RunContext &ctx,
<< "The total element_num is " << element_num << ", but the selected IDType can only represent "
<< mxnet::common::MaxIntegerValue<IDType>() << " elements";
Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
// Temp space needed by the full sorts.
size_t temp_size = std::max(
mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));

size_t temp_size = 0;
// Temp space needed by the gpu-based full sorts.
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
// Additional temp space for gpu full sorts for batch ids.
temp_size += PadBytes(sizeof(int) * src.Size(), alignment);
// Temp space for cpu sorts.
temp_size = std::max(temp_size, sizeof(DType) * src.Size());

temp_size = std::max(temp_size, static_cast<size_t>(sizeof(DType) * src.Size()));
size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
+ PadBytes(sizeof(int) * src.Size(), alignment);
if (param.ret_typ == topk_enum::kReturnMask) {
workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment);
}
workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
char* workspace_curr_ptr = workspace.dptr_;
sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
Shape1(src.Size()), s); // contain sorted dat
Shape1(src.Size()), s); // contain sorted dat
workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment);

if (param.ret_typ == topk_enum::kReturnMask) {
sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Expand Down