From 5967e80fe66bb05d0479b1bafecc5160242915da Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 19 Aug 2019 14:45:28 -0700 Subject: [PATCH 1/7] fix alignment --- src/operator/tensor/ordering_op-inl.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 880acf1f4cae..fd124c0e2a63 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -404,7 +404,7 @@ void TopKImpl(const RunContext &ctx, bool do_transpose = false; bool is_ascend = false; index_t k = 0; - size_t alignment = std::max(sizeof(DType), sizeof(int)); + size_t alignment = std::max(sizeof(DType), sizeof(index_t)); mxnet::TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); @@ -417,11 +417,11 @@ void TopKImpl(const RunContext &ctx, size_t temp_size = 0; // Temp space needed by the gpu-based full sorts. temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); + mxnet::op::SortByKeyWorkspaceSize(src.Size())); temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); + mxnet::op::SortByKeyWorkspaceSize(src.Size())); temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); + mxnet::op::SortByKeyWorkspaceSize(src.Size())); // Additional temp space for gpu full sorts for batch ids. temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment); // Temp space for cpu sorts. @@ -429,15 +429,15 @@ void TopKImpl(const RunContext &ctx, size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) + PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { - workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment); + workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment); } workspace = resource.get_space_typed(Shape1(workspace_size), s); char* workspace_curr_ptr = workspace.dptr_; sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), - Shape1(src.Size()), s); // contain sorted dat + Shape1(src.Size()), s); // contain sorted dat workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment); indices = Tensor(reinterpret_cast(workspace_curr_ptr), - Shape1(src.Size()), s); // indices in the original matrix + Shape1(src.Size()), s); // indices in the original matrix workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { From b4f9793d44093042a70a4b7553f4c2c9283e7c84 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 19 Aug 2019 16:24:28 -0700 Subject: [PATCH 2/7] use correct type for shape index --- 3rdparty/mshadow/mshadow/tensor.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index 0d662621aa4d..df3678e26fa5 100755 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -69,7 +69,7 @@ struct Shape { * \param idx dimension index * \return the corresponding dimension size */ - MSHADOW_XINLINE index_t &operator[](index_t idx) { + MSHADOW_XINLINE index_t &operator[](int idx) { return shape_[idx]; } /*! @@ -77,7 +77,7 @@ struct Shape { * \param idx dimension index * \return the corresponding dimension size */ - MSHADOW_XINLINE const index_t &operator[](index_t idx) const { + MSHADOW_XINLINE const index_t &operator[](int idx) const { return shape_[idx]; } /*! From 38e41a2b06426283881619413cba41e81dfc13d7 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 19 Aug 2019 22:47:03 -0700 Subject: [PATCH 3/7] clean up unnecessary space in topk --- src/operator/tensor/ordering_op-inl.h | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index fd124c0e2a63..3889b5953569 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -414,18 +414,12 @@ void TopKImpl(const RunContext &ctx, << element_num << ", but the selected index_t can only represent " << mxnet::common::MaxIntegerValue() << " elements"; Tensor dat = src.FlatTo3D(axis, axis, s); - size_t temp_size = 0; - // Temp space needed by the gpu-based full sorts. - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - // Additional temp space for gpu full sorts for batch ids. - temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment); - // Temp space for cpu sorts. - temp_size = std::max(temp_size, static_cast(sizeof(DType) * src.Size())); + // Temp space needed by the full sorts. + size_t temp_size = std::max( + mxnet::op::SortByKeyWorkspaceSize(src.Size()), + mxnet::op::SortByKeyWorkspaceSize(src.Size()) + ); + size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) + PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { From 8d76ed18e2d6d04e2b67dfd5ceb53e91301f4cc6 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 20 Aug 2019 00:28:03 -0700 Subject: [PATCH 4/7] fix lint --- src/operator/tensor/ordering_op-inl.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 3889b5953569..c0265827fdab 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -417,8 +417,7 @@ void TopKImpl(const RunContext &ctx, // Temp space needed by the full sorts. size_t temp_size = std::max( mxnet::op::SortByKeyWorkspaceSize(src.Size()), - mxnet::op::SortByKeyWorkspaceSize(src.Size()) - ); + mxnet::op::SortByKeyWorkspaceSize(src.Size())); size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) + PadBytes(sizeof(index_t) * src.Size(), alignment); From 316443bbd70f9737252dd80a652e17e51444a417 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 20 Aug 2019 14:20:06 -0700 Subject: [PATCH 5/7] add additional temp space --- src/operator/tensor/ordering_op-inl.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index c0265827fdab..796acb4c5f36 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -419,6 +419,11 @@ void TopKImpl(const RunContext &ctx, mxnet::op::SortByKeyWorkspaceSize(src.Size()), mxnet::op::SortByKeyWorkspaceSize(src.Size())); + // Additional temp space for gpu full sorts for batch ids. + temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment); + // Temp space for cpu sorts. + temp_size = std::max(temp_size, sizeof(DType) * src.Size()); + size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) + PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { From ec603711a11fd119f68662b7fcd617ab07819b53 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Wed, 21 Aug 2019 14:05:01 -0700 Subject: [PATCH 6/7] address reviewer comment --- 3rdparty/mshadow/mshadow/tensor.h | 4 ++-- src/operator/tensor/ordering_op-inl.h | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index df3678e26fa5..c6847d9cbc65 100755 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -484,7 +484,7 @@ struct Tensor: public TRValue, * \param idx the dimension count from the highest dimensin * \return the size */ - MSHADOW_XINLINE index_t size(index_t idx) const { + MSHADOW_XINLINE index_t size(int idx) const { return shape_[idx]; } /*! @@ -506,7 +506,7 @@ struct Tensor: public TRValue, * \param idx index * \return the result tensor */ - MSHADOW_XINLINE Tensor operator[](index_t idx) const { + MSHADOW_XINLINE Tensor operator[](int idx) const { return Tensor(dptr_ + this->MemSize<1>() * idx, shape_.SubShape(), stride_, stream_); } diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 796acb4c5f36..b36d79acfc7b 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -419,6 +419,8 @@ void TopKImpl(const RunContext &ctx, mxnet::op::SortByKeyWorkspaceSize(src.Size()), mxnet::op::SortByKeyWorkspaceSize(src.Size())); + temp_size = std::max(temp_size, + mxnet::op::SortByKeyWorkspaceSize(src.Size())); // Additional temp space for gpu full sorts for batch ids. temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment); // Temp space for cpu sorts. From f6f0750c984bb2c8abea742f180c790deda0f464 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Wed, 21 Aug 2019 14:13:27 -0700 Subject: [PATCH 7/7] fix incorrect nidex type --- 3rdparty/mshadow/mshadow/tensor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index c6847d9cbc65..ad29e751a050 100755 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -506,7 +506,7 @@ struct Tensor: public TRValue, * \param idx index * \return the result tensor */ - MSHADOW_XINLINE Tensor operator[](int idx) const { + MSHADOW_XINLINE Tensor operator[](index_t idx) const { return Tensor(dptr_ + this->MemSize<1>() * idx, shape_.SubShape(), stride_, stream_); }