From ec8e11dba9021d294e8da3ef2d7ad5c4f6fb4e39 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 25 Nov 2019 23:37:39 +0000 Subject: [PATCH 1/6] [Contrib] Add cudnn softmax --- python/tvm/contrib/cudnn.py | 10 +++ python/tvm/relay/op/nn/_nn.py | 4 +- python/tvm/relay/op/op_attrs.py | 5 ++ python/tvm/relay/op/strategy/cuda.py | 21 +++-- python/tvm/relay/op/strategy/generic.py | 21 +++-- python/tvm/relay/op/strategy/hls.py | 14 +-- python/tvm/relay/op/strategy/opengl.py | 14 +-- python/tvm/relay/op/strategy/x86.py | 14 +-- src/relay/op/nn/nn.cc | 16 ++-- src/runtime/contrib/cudnn/cudnn_utils.cc | 10 +++ src/runtime/contrib/cudnn/cudnn_utils.h | 8 ++ src/runtime/contrib/cudnn/softmax.cc | 104 +++++++++++++++++++++++ topi/python/topi/cuda/__init__.py | 2 +- topi/python/topi/cuda/softmax.py | 11 +++ topi/python/topi/nn/softmax.py | 1 - 15 files changed, 218 insertions(+), 37 deletions(-) create mode 100644 src/runtime/contrib/cudnn/softmax.cc diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index e62724512d49..03a94f9065f9 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -402,3 +402,13 @@ def conv_forward(x, ins[1], outs[0], conv_dtype), name="y") + +def softmax(x, axis=-1): + assert axis == -1 + return te.extern( + x.shape, [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.softmax.forward", + ins[0], + outs[0], + axis), name="y") diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index aa35fa2e8274..bc096fa5dc92 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -34,12 +34,12 @@ # softmax -reg.register_schedule("nn.softmax", strategy.schedule_softmax) +reg.register_strategy("nn.softmax", strategy.softmax_strategy) reg.register_pattern("nn.softmax", OpPattern.OPAQUE) # log_softmax -reg.register_schedule("nn.log_softmax", strategy.schedule_softmax) +#reg.register_schedule("nn.log_softmax", strategy.schedule_softmax) reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index edc2160e38bc..f2493be8fa7d 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -59,6 +59,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs") +class SoftmaxAttrs(Attrs): + """Attributes for nn.softmax""" + + @tvm._ffi.register_object("relay.attrs.FIFOBufferAttrs") class FIFOBufferAttrs(Attrs): """Attributes for nn.fifo_buffer""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 8ccd6bf51508..07f5a3dbc572 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -58,11 +58,22 @@ def schedule_adaptive_pool_cuda(attrs, outs, target): with target: return topi.cuda.schedule_adaptive_pool(outs) -@schedule_softmax.register(["cuda", "gpu"]) -def schedule_softmax_cuda(attrs, outs, target): - """schedule softmax for cuda""" - with target: - return topi.cuda.schedule_softmax(outs) +@softmax_strategy.register(["cuda", "gpu"]) +def softmax_strategy_cuda(attrs, inputs, out_type, target): + """softmax cuda strategy""" + axis = attrs.get_int("axis") + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="softmax.cuda") + if target.target_name == "cuda" and "cudnn" in target.libs and axis == -1: + strategy.add_implementation( + wrap_compute_softmax(topi.cuda.softmax_cudnn), + wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn), + name="softmax.cudnn", + plevel=15) + return strategy @schedule_lrn.register(["cuda", "gpu"]) def schedule_lrn_cuda(attrs, outs, target): diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 573df3675eee..f4e875102619 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -107,11 +107,22 @@ def schedule_adaptive_pool(attrs, outs, target): return topi.generic.schedule_adaptive_pool(outs) # softmax -@generic_func -def schedule_softmax(attrs, outs, target): - """Schedule softmax""" - with target: - return topi.generic.schedule_softmax(outs) +def wrap_compute_softmax(topi_compute): + """Wrap softmax topi compute""" + def _compute_softmax(attrs, inputs, out_type): + axis = attrs.get_int("axis") + return [topi_compute(inputs[0], axis)] + return _compute_softmax + +@override_native_generic_func("softmax_strategy") +def softmax_strategy(attrs, outs, target): + """softmax generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implemenation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.generic.schedule_softmax), + name="softmax.generic") + return strategy # lrn @generic_func diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py index 514902b86833..6b23456a80f5 100644 --- a/python/tvm/relay/op/strategy/hls.py +++ b/python/tvm/relay/op/strategy/hls.py @@ -50,11 +50,15 @@ def schedule_adaptive_pool_hls(attrs, outs, target): with target: return topi.hls.schedule_adaptive_pool(outs) -@schedule_softmax.register("hls") -def schedule_softmax_hls(attrs, outs, target): - """schedule softmax for hls""" - with target: - return topi.hls.schedule_softmax(outs) +@softmax_strategy.register("hls") +def softmax_strategy_hls(attrs, inputs, out_type, target): + """softmax hls strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.hls.schedule_softmax), + name="softmax.hls") + return strategy @override_native_generic_func("conv2d_strategy") def conv2d_strategy_hls(attrs, inputs, out_type, target): diff --git a/python/tvm/relay/op/strategy/opengl.py b/python/tvm/relay/op/strategy/opengl.py index 45e290c50e0f..e90a570459f8 100644 --- a/python/tvm/relay/op/strategy/opengl.py +++ b/python/tvm/relay/op/strategy/opengl.py @@ -44,11 +44,15 @@ def schedule_adaptive_pool_opengl(attrs, outs, target): with target: return topi.opengl.schedule_adaptive_pool(outs) -@schedule_softmax.register("opengl") -def schedule_softmax_opengl(attrs, outs, target): - """schedule softmax for opengl""" - with target: - return topi.opengl.schedule_softmax(outs) +@softmax_strategy.register("opengl") +def softmax_strategy_opengl(attrs, inputs, out_type, target): + """softmax opengl strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.opengl.schedule_softmax), + name="softmax.opengl") + return strategy @conv2d_strategy.register("opengl") def conv2d_strategy_opengl(attrs, inputs, out_type, target): diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index e35838c1c5e8..451cc9ddf4f3 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -55,11 +55,15 @@ def schedule_adaptive_pool_cpu(attrs, outs, target): with target: return topi.x86.schedule_adaptive_pool(outs) -@schedule_softmax.register("cpu") -def schedule_softmax_cpu(attrs, outs, target): - """schedule softmax for x86""" - with target: - return topi.x86.schedule_softmax(outs) +@softmax_strategy.register("cpu") +def softmax_strategy_cpu(attrs, inputs, out_type, target): + """softmax x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.x86.schedule_softmax), + name="softmax.x86") + return strategy @conv2d_strategy.register("cpu") def conv2d_strategy_cpu(attrs, inputs, out_type, target): diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 5203ffc39217..735cd3776306 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -347,14 +347,14 @@ RELAY_REGISTER_OP("nn.softmax") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - return Array{ topi::nn::softmax(inputs[0], param->axis) }; -}); +.add_type_rel("Identity", IdentityRel); +// .set_attr("FTVMCompute", [](const Attrs& attrs, +// const Array& inputs, +// const Type& out_type) { +// const auto* param = attrs.as(); +// CHECK(param != nullptr); +// return Array{ topi::nn::softmax(inputs[0], param->axis) }; +// }); // relay.nn.log_softmax diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index fa185e97d1f5..9c895c5b7e06 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -140,5 +140,15 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +// SoftmaxEntry + +SoftmaxEntry::SoftmaxEntry() { + CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); +} + +SoftmaxEntry::~SoftmaxEntry() { + CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); +} + } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 004224523ecd..950983047bb6 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -85,12 +85,20 @@ struct ConvEntry { void CleanWorkspace(); }; // ConvThreadEntry +struct SoftmaxEntry { + cudnnSoftmaxMode_t mode; + cudnnDataType_t data_type; + cudnnTensorDescriptor_t shape_desc; + SoftmaxEntry(); + ~SoftmaxEntry(); +}; // SoftmaxEntry struct CuDNNThreadEntry { CuDNNThreadEntry(); ~CuDNNThreadEntry(); cudnnHandle_t handle{nullptr}; ConvEntry conv_entry; + SoftmaxEntry softmax_entry; runtime::DeviceAPI *cuda_api{nullptr}; static CuDNNThreadEntry* ThreadLocal(); }; // CuDNNThreadEntry diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc new file mode 100644 index 000000000000..320ec1bfbdc4 --- /dev/null +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external cudnn utils function + */ +#include +#include +#include "cudnn_utils.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; + +/* + cudnnStatus_t cudnnSoftmaxForward( + cudnnHandle_t handle, + cudnnSoftmaxAlgorithm_t algorithm, + cudnnSoftmaxMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t xDesc, + const void *x, + const void *beta, + const cudnnTensorDescriptor_t yDesc, + void *y) + +2.62. cudnnSoftmaxAlgorithm_t + +CUDNN_SOFTMAX_FAST +This implementation applies the straightforward softmax operation. + +CUDNN_SOFTMAX_ACCURATE +This implementation scales each point of the softmax input domain by its maximum value to avoid potential floating point overflows in the softmax evaluation. + +CUDNN_SOFTMAX_LOG +This entry performs the log softmax operation, avoiding overflows by scaling each point in the input domain as in CUDNN_SOFTMAX_ACCURATE. + +2.63. cudnnSoftmaxMode_t +cudnnSoftmaxMode_t is used to select over which data the cudnnSoftmaxForward() and cudnnSoftmaxBackward() are computing their results. + +Values +CUDNN_SOFTMAX_MODE_INSTANCE +The softmax operation is computed per image (N) across the dimensions C,H,W. + +CUDNN_SOFTMAX_MODE_CHANNEL +The softmax operation is computed per spatial location (H,W) per image (N) across the dimension C. + +*/ +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + CHECK(axis >= 0 && axis < ndim); + CHECK(axis == ndim - 1); + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; + entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + // Set shape descriptor + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, + static_cast(N), + static_cast(shape[ndim - 1]), + 1, + 1)); + CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, + CUDNN_SOFTMAX_ACCURATE, + entry_ptr->softmax_entry.mode, + CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->softmax_entry.shape_desc, + x->data, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + entry_ptr->softmax_entry.shape_desc, + y->data)); +}); + +} // namespace contrib +} // namespace tvm diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 6e38318a0062..99d5d0efd189 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -32,7 +32,7 @@ from .deformable_conv2d import * from .conv3d import * from .reduction import schedule_reduce -from .softmax import schedule_softmax +from .softmax import * from .injective import schedule_injective, schedule_elemwise, schedule_broadcast from .dense import * from .pooling import * diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 54d5bfbae121..7382c3d21d26 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name, unused-variable, trailing-whitespace """Schedule for softmax operator""" from tvm import te +from tvm.contrib import cudnn +from .. import generic from .injective import schedule_injective_from_existing @@ -79,3 +81,12 @@ def schedule_softmax(outs): s[softmax].bind(tx, thread_x) return s + + +def softmax_cudnn(x, axis=-1): + assert axis == -1 + return cudnn.softmax(x, axis) + + +def schedule_softmax_cudnn(outs): + return generic.schedule_extern(outs) diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py index c414372ade93..fb513844dacd 100644 --- a/topi/python/topi/nn/softmax.py +++ b/topi/python/topi/nn/softmax.py @@ -77,7 +77,6 @@ def _normalize(exp, expsum, *indices): return te.compute(shape, lambda *indices: _normalize(exp, expsum, *indices), name='T_softmax_norm', attrs={"axis" : axis}) - @tvm.te.tag_scope(tag='log_softmax_output') def log_softmax(x): """Perform log softmax activation on the data From 9dcdb0f70e74dfe3542113e7ecb474aea0c7aa3d Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 1 Apr 2020 23:45:47 +0000 Subject: [PATCH 2/6] update --- python/tvm/contrib/cudnn.py | 17 ++++++++++++++++- python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/relay/op/strategy/cuda.py | 6 ++++++ python/tvm/relay/op/strategy/generic.py | 9 ++++++++- python/tvm/relay/op/strategy/hls.py | 6 ++++++ python/tvm/relay/op/strategy/opengl.py | 6 ++++++ python/tvm/relay/op/strategy/x86.py | 6 ++++++ src/runtime/contrib/cudnn/softmax.cc | 9 ++++++--- tests/python/contrib/test_cudnn.py | 21 +++++++++++++++++++++ 9 files changed, 76 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 03a94f9065f9..65ae7287b948 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -404,7 +404,22 @@ def conv_forward(x, conv_dtype), name="y") def softmax(x, axis=-1): - assert axis == -1 + """Compute softmax using CuDNN + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute the softmax + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + assert axis == -1 or axis == len(x.shape) - 1 return te.extern( x.shape, [x], lambda ins, outs: tvm.tir.call_packed( diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index bc096fa5dc92..bd7eb02f7e91 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -39,7 +39,7 @@ # log_softmax -#reg.register_schedule("nn.log_softmax", strategy.schedule_softmax) +reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax) reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 07f5a3dbc572..17c9a39f1030 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -75,6 +75,12 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target): plevel=15) return strategy +@schedule_log_softmax.register(["cuda", "gpu"]) +def schedule_log_softmax_cuda(attrs, outs, target): + """scheudle log_softmax for cuda""" + with target: + return topi.cuda.schedule_softmax(outs) + @schedule_lrn.register(["cuda", "gpu"]) def schedule_lrn_cuda(attrs, outs, target): """schedule LRN for cuda""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f4e875102619..89065a6cbe71 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -115,7 +115,7 @@ def _compute_softmax(attrs, inputs, out_type): return _compute_softmax @override_native_generic_func("softmax_strategy") -def softmax_strategy(attrs, outs, target): +def softmax_strategy(attrs, inputs, out_type, target): """softmax generic strategy""" strategy = _op.OpStrategy() strategy.add_implemenation( @@ -124,6 +124,13 @@ def softmax_strategy(attrs, outs, target): name="softmax.generic") return strategy +# log_softmax +@generic_func +def schedule_log_softmax(attrs, outs, target): + """Schedule log_softmax op""" + with target: + return topi.generic.schedule_softmax(outs) + # lrn @generic_func def schedule_lrn(attrs, outs, target): diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py index 6b23456a80f5..d41e85fc484c 100644 --- a/python/tvm/relay/op/strategy/hls.py +++ b/python/tvm/relay/op/strategy/hls.py @@ -60,6 +60,12 @@ def softmax_strategy_hls(attrs, inputs, out_type, target): name="softmax.hls") return strategy +@schedule_log_softmax.register("hls") +def schedule_log_softmax_hls(attrs, inputs, out_type, target): + """schedule log_softmax for hls""" + with target: + return topi.hls.schedule_softmax(outs) + @override_native_generic_func("conv2d_strategy") def conv2d_strategy_hls(attrs, inputs, out_type, target): """conv2d hls strategy""" diff --git a/python/tvm/relay/op/strategy/opengl.py b/python/tvm/relay/op/strategy/opengl.py index e90a570459f8..12c288c83b7e 100644 --- a/python/tvm/relay/op/strategy/opengl.py +++ b/python/tvm/relay/op/strategy/opengl.py @@ -54,6 +54,12 @@ def softmax_strategy_opengl(attrs, inputs, out_type, target): name="softmax.opengl") return strategy +@schedule_log_softmax.register("opengl") +def schedule_log_softmax_opengl(attrs, outs, target): + """schedule log_softmax for opengl""" + with target: + return topi.opengl.schedule_softmax(outs) + @conv2d_strategy.register("opengl") def conv2d_strategy_opengl(attrs, inputs, out_type, target): """conv2d opengl strategy""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 451cc9ddf4f3..c1f0244aa66c 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -65,6 +65,12 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): name="softmax.x86") return strategy +@schedule_log_softmax.register("cpu") +def schedule_log_softmax_cpu(attrs, outs, target): + """schedule log_softmax op for x86""" + with target: + return topi.x86.schedule_softmax(outs) + @conv2d_strategy.register("cpu") def conv2d_strategy_cpu(attrs, inputs, out_type, target): """conv2d x86 strategy""" diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index 320ec1bfbdc4..bdf86ff6aa89 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -72,7 +72,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") int64_t* shape = x->shape; if (axis < 0) axis += ndim; CHECK(axis >= 0 && axis < ndim); - CHECK(axis == ndim - 1); + CHECK(axis == ndim - 1) << "Currently only support axis=-1 for cudnn softmax"; int64_t N = 1; for (int i = 0; i < ndim - 1; ++i) { N *= shape[i]; @@ -81,6 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + // Set shape descriptor CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, @@ -89,13 +90,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") static_cast(shape[ndim - 1]), 1, 1)); + auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); + auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE, entry_ptr->softmax_entry.mode, - CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + alpha, entry_ptr->softmax_entry.shape_desc, x->data, - CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + beta, entry_ptr->softmax_entry.shape_desc, y->data)); }); diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 58e7b4905988..17d94614eb72 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -158,6 +158,27 @@ def verify(): def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0) + +def verify_softmax(shape, axis, dtype="float32"): + A = te.placeholder(shape, dtype=dtype, name='A') + B = cudnn.softmax(A, axis) + s = te.create_schedule([B.op]) + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = topi.testing.softmax_python(a_np) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) + +def test_softmax(): + verify_softmax((32, 10), -1) + verify_softmax((3, 4), -1) + verify_softmax((1, 5), -1, "float64") + if __name__ == "__main__": test_conv2d() test_conv3d() + test_softmax() From f03a2d4302edf55cad6d0822426bed7162182ff6 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 2 Apr 2020 00:48:05 +0000 Subject: [PATCH 3/6] support axis other than -1 --- python/tvm/contrib/cudnn.py | 2 +- python/tvm/relay/op/strategy/cuda.py | 2 +- src/runtime/contrib/cudnn/softmax.cc | 87 ++++++++++++---------------- tests/python/contrib/test_cudnn.py | 18 ++++++ topi/python/topi/cuda/softmax.py | 3 +- 5 files changed, 59 insertions(+), 53 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 65ae7287b948..7d885e8e6492 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -419,7 +419,7 @@ def softmax(x, axis=-1): ret : tvm.te.Tensor The result tensor """ - assert axis == -1 or axis == len(x.shape) - 1 + #assert axis == -1 or axis == len(x.shape) - 1 return te.extern( x.shape, [x], lambda ins, outs: tvm.tir.call_packed( diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 17c9a39f1030..1ab546b00dfb 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -67,7 +67,7 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_softmax(topi.nn.softmax), wrap_topi_schedule(topi.cuda.schedule_softmax), name="softmax.cuda") - if target.target_name == "cuda" and "cudnn" in target.libs and axis == -1: + if target.target_name == "cuda" and "cudnn" in target.libs: strategy.add_implementation( wrap_compute_softmax(topi.cuda.softmax_cudnn), wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn), diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index bdf86ff6aa89..fb6d8a6fdc56 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -18,7 +18,8 @@ */ /*! - * \file Use external cudnn utils function + * \file src/runtime/contrib/cudnn/softmax.cc + * \brief Use external cudnn softmax function */ #include #include @@ -29,40 +30,6 @@ namespace contrib { using namespace runtime; -/* - cudnnStatus_t cudnnSoftmaxForward( - cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algorithm, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) - -2.62. cudnnSoftmaxAlgorithm_t - -CUDNN_SOFTMAX_FAST -This implementation applies the straightforward softmax operation. - -CUDNN_SOFTMAX_ACCURATE -This implementation scales each point of the softmax input domain by its maximum value to avoid potential floating point overflows in the softmax evaluation. - -CUDNN_SOFTMAX_LOG -This entry performs the log softmax operation, avoiding overflows by scaling each point in the input domain as in CUDNN_SOFTMAX_ACCURATE. - -2.63. cudnnSoftmaxMode_t -cudnnSoftmaxMode_t is used to select over which data the cudnnSoftmaxForward() and cudnnSoftmaxBackward() are computing their results. - -Values -CUDNN_SOFTMAX_MODE_INSTANCE -The softmax operation is computed per image (N) across the dimensions C,H,W. - -CUDNN_SOFTMAX_MODE_CHANNEL -The softmax operation is computed per spatial location (H,W) per image (N) across the dimension C. - -*/ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") .set_body([](TVMArgs args, TVMRetValue *ret) { DLTensor* x = args[0]; @@ -72,24 +39,44 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") int64_t* shape = x->shape; if (axis < 0) axis += ndim; CHECK(axis >= 0 && axis < ndim); - CHECK(axis == ndim - 1) << "Currently only support axis=-1 for cudnn softmax"; - int64_t N = 1; - for (int i = 0; i < ndim - 1; ++i) { - N *= shape[i]; - } - + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Set shape descriptor - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, - static_cast(N), - static_cast(shape[ndim - 1]), - 1, - 1)); + // Set mode and shape descriptor + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, + static_cast(N), + static_cast(shape[ndim - 1]), + 1, + 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, + static_cast(pre_axis_dim), + static_cast(shape[axis]), + static_cast(post_axis_dim), + 1)); + } + auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 17d94614eb72..12f4f5ffd9eb 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -173,10 +173,28 @@ def verify_softmax(shape, axis, dtype="float32"): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) +def verify_softmax_4d(shape, dtype="float32"): + A = te.placeholder(shape, dtype=dtype, name='A') + B = cudnn.softmax(A, axis=1) + s = te.create_schedule([B.op]) + + ctx = tvm.gpu(0) + n, c, h, w = shape + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c)) + b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) + def test_softmax(): verify_softmax((32, 10), -1) verify_softmax((3, 4), -1) verify_softmax((1, 5), -1, "float64") + verify_softmax_4d((1, 16, 256, 256)) + verify_softmax_4d((1, 16, 256, 256), "float64") if __name__ == "__main__": test_conv2d() diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 7382c3d21d26..62c437ae96ac 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -84,9 +84,10 @@ def schedule_softmax(outs): def softmax_cudnn(x, axis=-1): - assert axis == -1 + """Perform softmax on the data using cudnn""" return cudnn.softmax(x, axis) def schedule_softmax_cudnn(outs): + """Schedule for softmax cudnn op""" return generic.schedule_extern(outs) From da19a21681e4c7afdc59d6bb0c863ac5caa14d1c Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 2 Apr 2020 00:50:28 +0000 Subject: [PATCH 4/6] clean up --- src/relay/op/nn/nn.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 735cd3776306..36e659e9266e 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -348,13 +348,6 @@ RELAY_REGISTER_OP("nn.softmax") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) .add_type_rel("Identity", IdentityRel); -// .set_attr("FTVMCompute", [](const Attrs& attrs, -// const Array& inputs, -// const Type& out_type) { -// const auto* param = attrs.as(); -// CHECK(param != nullptr); -// return Array{ topi::nn::softmax(inputs[0], param->axis) }; -// }); // relay.nn.log_softmax From c30d0074a549fe404b24b17d236737272b6f6662 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 2 Apr 2020 03:50:15 +0000 Subject: [PATCH 5/6] lint --- python/tvm/contrib/cudnn.py | 1 - python/tvm/relay/op/strategy/cuda.py | 1 - src/runtime/contrib/cudnn/cudnn_utils.h | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 7d885e8e6492..5043520ccf13 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -419,7 +419,6 @@ def softmax(x, axis=-1): ret : tvm.te.Tensor The result tensor """ - #assert axis == -1 or axis == len(x.shape) - 1 return te.extern( x.shape, [x], lambda ins, outs: tvm.tir.call_packed( diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1ab546b00dfb..4e7a15f94eff 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -61,7 +61,6 @@ def schedule_adaptive_pool_cuda(attrs, outs, target): @softmax_strategy.register(["cuda", "gpu"]) def softmax_strategy_cuda(attrs, inputs, out_type, target): """softmax cuda strategy""" - axis = attrs.get_int("axis") strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.softmax), diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 950983047bb6..ee6bb5089e38 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -91,7 +91,7 @@ struct SoftmaxEntry { cudnnTensorDescriptor_t shape_desc; SoftmaxEntry(); ~SoftmaxEntry(); -}; // SoftmaxEntry +}; // SoftmaxEntry struct CuDNNThreadEntry { CuDNNThreadEntry(); From 4359b75486ebe58f1a585395d47031b0c4bce8db Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 2 Apr 2020 05:43:53 +0000 Subject: [PATCH 6/6] fix test --- tests/python/contrib/test_cudnn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 12f4f5ffd9eb..5d1f100c1fc4 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -190,6 +190,13 @@ def verify_softmax_4d(shape, dtype="float32"): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) def test_softmax(): + if not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled...") + return + if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): + print("skip because cudnn is not enabled...") + return + verify_softmax((32, 10), -1) verify_softmax((3, 4), -1) verify_softmax((1, 5), -1, "float64")