Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 0ac4a0b

Browse files
Handle 3D tensors in cuDNN legacy API
1 parent 8db4e88 commit 0ac4a0b

File tree

4 files changed

+31
-29
lines changed

4 files changed

+31
-29
lines changed

src/common/cuda/cudnn_cxx.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,6 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
112112
return ret;
113113
}
114114

115-
std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
116-
const std::vector<int64_t>& dims) {
117-
CHECK_EQ(order.size(), dims.size());
118-
std::vector<int64_t> ret(dims.size(), 1);
119-
for (size_t i = dims.size() - 1; i--;)
120-
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
121-
return ret;
122-
}
123-
124115
std::vector<Descriptor> GetPlans(cudnnBackendHeurMode_t h_mode,
125116
cudnnHandle_t handle,
126117
const Descriptor& op_graph,

src/common/cuda/cudnn_cxx.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,14 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
244244
cudnnBackendDescriptorType_t type);
245245

246246
// Order sets layout, as a permutation of dims, with N,C,<spacial dims> being identity.
247-
std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
248-
const std::vector<int64_t>& dims);
247+
template <typename T>
248+
std::vector<T> PackedStrides(const std::vector<size_t>& order, const std::vector<T>& dims) {
249+
CHECK_EQ(order.size(), dims.size());
250+
std::vector<T> ret(dims.size(), 1);
251+
for (size_t i = dims.size() - 1; i--;)
252+
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
253+
return ret;
254+
}
249255

250256
// Given an engine config's `notes`, return whether that config is compatible, i.e. does
251257
// the config have all of the required notes and none of the notes that are being excluded.

src/operator/cudnn_ops.cc

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
#include <dmlc/parameter.h>
3131

32-
#include <algorithm>
3332
#include <cstdlib>
3433
#include <iomanip>
3534
#include <iterator>
@@ -79,10 +78,6 @@ size_t LayoutInfo::ChannelIdx() const {
7978
return channel_last ? 1 + n_space_dims : 1;
8079
}
8180

82-
std::vector<int64_t> LayoutInfo::Strides(const std::vector<int64_t>& dims) const {
83-
return PackedStrides(Order(), dims);
84-
}
85-
8681
LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) {
8782
static std::unordered_map<mshadow::LayoutFlag, LayoutInfo> layout_map{
8883
{mshadow::kNCW, {1, false}},
@@ -165,14 +160,8 @@ Descriptor MakeTensorDesc(int64_t uid,
165160
for (size_t i = 0; i < dims.size(); ++i)
166161
dims[i] = blob.shape_[rev_order[i]];
167162
auto strides = li.Strides(dims);
168-
if (li.n_space_dims == 1 && expand_1d) {
169-
dims.insert(dims.begin() + 2, 1);
170-
std::vector<size_t> order(dims.size());
171-
std::iota(order.begin(), order.end(), 0);
172-
if (li.channel_last)
173-
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
174-
strides = PackedStrides(order, dims);
175-
}
163+
if (expand_1d)
164+
li.ExpandIf1d(&dims, &strides);
176165
return MakeTensorDesc(
177166
uid, CudnnType(static_cast<mshadow::TypeFlag>(blob.type_flag_)), dims, strides, is_virtual);
178167
}
@@ -803,9 +792,8 @@ void SetLegacyTensor(cudnnTensorDescriptor_t desc, const TBlob& blob, const Layo
803792
auto rev_order = ReverseOrder(li.Order());
804793
for (size_t i = 0; i < dims.size(); ++i)
805794
dims[i] = blob.shape_[rev_order[i]];
806-
auto strides64 = li.Strides(std::vector<int64_t>(dims.begin(), dims.end()));
807-
std::vector<int> strides(strides64.begin(), strides64.end());
808-
795+
auto strides = li.Strides(dims);
796+
li.ExpandIf1d(&dims, &strides);
809797
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
810798
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
811799
}
@@ -817,7 +805,7 @@ void SetLegacyCTensorExpandDims(cudnnTensorDescriptor_t desc,
817805
dims[1] = blob.shape_[0];
818806
std::vector<int> strides(dims.size(), 1);
819807
strides[0] = blob.shape_[0];
820-
808+
li.ExpandIf1d(&dims, &strides);
821809
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
822810
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
823811
}

src/operator/cudnn_ops.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include <mxnet/op_attr_types.h>
3131

32+
#include <algorithm>
3233
#include <mutex>
3334
#include <tuple>
3435
#include <unordered_map>
@@ -89,7 +90,23 @@ struct LayoutInfo {
8990

9091
std::vector<size_t> Order() const;
9192
size_t ChannelIdx() const;
92-
std::vector<int64_t> Strides(const std::vector<int64_t>& dims) const;
93+
94+
template <typename T>
95+
std::vector<T> Strides(const std::vector<T>& dims) const {
96+
return cudnn_cxx::PackedStrides(Order(), dims);
97+
}
98+
99+
template <typename T>
100+
void ExpandIf1d(std::vector<T>* dims, std::vector<T>* strides) const {
101+
if (n_space_dims != 1)
102+
return;
103+
dims->insert(dims->begin() + 2, 1);
104+
std::vector<size_t> order(dims->size());
105+
std::iota(order.begin(), order.end(), 0);
106+
if (channel_last)
107+
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
108+
*strides = cudnn_cxx::PackedStrides(order, *dims);
109+
}
93110
};
94111

95112
LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout);

0 commit comments

Comments
 (0)