From 524509e09be85dc07987335bf1bb5bf1e5af3501 Mon Sep 17 00:00:00 2001 From: chinakook Date: Wed, 7 Mar 2018 11:26:18 +0800 Subject: [PATCH 1/2] Bug Fix and performance optimized for rtc 1. "super().__init__()" bug is fixed in python 2. 2. Kernel is initialized in the stage of operator init. --- example/numpy-ops/custom_softmax_rtc.py | 131 +++++++++++++----------- 1 file changed, 72 insertions(+), 59 deletions(-) diff --git a/example/numpy-ops/custom_softmax_rtc.py b/example/numpy-ops/custom_softmax_rtc.py index 906cbbeac04c..1ce7e7346d98 100644 --- a/example/numpy-ops/custom_softmax_rtc.py +++ b/example/numpy-ops/custom_softmax_rtc.py @@ -23,51 +23,77 @@ class Softmax(mx.operator.CustomOp): def __init__(self): - self.fwd_kernel_mod = None - self.bwd_kernel_mod = None - super().__init__() + super(Softmax,self).__init__() + # Each thread processes a row (a sample in the batch). + fwd_src = r""" + template + __global__ void fwd(const DType* x, DType* y, const int row_size, const int req) { + const int offset = row_size * threadIdx.x; + DType max = x[offset]; + for(int i = 1; i < row_size; ++i) { + if(max < x[offset + i]) { + max = x[offset + i]; + } + } + DType sum = 0; + for(int i = 0; i < row_size; ++i) { + sum += exp(x[offset + i] - max); + } + switch(req) { + case 1: + for(int i = 0; i < row_size; ++i) { + y[offset + i] = exp(x[offset + i] - max) / sum; + } + break; + case 2: + for(int i = 0; i < row_size; ++i) { + y[offset + i] += exp(x[offset + i] - max) / sum; + } + break; + } + } + """ + + # Each block processes a row and each thread in a block calculate an element of `dx`. + bwd_src = r""" + template + __global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) { + const int z = static_cast(l[blockIdx.x]); + const int i = threadIdx.x + blockDim.x * blockIdx.x; + if(req == 1) { + dx[i] = threadIdx.x == z ? y[i] - 1 : y[i]; + } else { + dx[i] += threadIdx.x == z ? y[i] - 1 : y[i]; + } + } + """ + fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd", "fwd"]) + bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd", "bwd"]) + + fwd_kernel_float_signature = "const {0}*, const {0}*, const int, const int".format("float") + self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<{}>".format("float"), fwd_kernel_float_signature) + + bwd_kernel_float_signature = "const {0}*, const {0}*, {0}*, const int".format("float") + self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<{}>".format("float"), bwd_kernel_float_signature) + + fwd_kernel_double_signature = "const {0}*, const {0}*, const int, const int".format("double") + self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<{}>".format("double"), fwd_kernel_double_signature) + + bwd_kernel_double_signature = "const {0}*, const {0}*, {0}*, const int".format("double") + self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<{}>".format("double"), bwd_kernel_double_signature) def forward(self, is_train, req, in_data, out_data, aux): if req[0] == "null": return x = in_data[0] # input y = out_data[0] # output - if self.fwd_kernel_mod is None: - # Each thread processes a row (a sample in the batch). - src = r""" - template - __global__ void fwd(const DType* x, DType* y, const int row_size, const int req) { - const int offset = row_size * threadIdx.x; - DType max = x[offset]; - for(int i = 1; i < row_size; ++i) { - if(max < x[offset + i]) { - max = x[offset + i]; - } - } - DType sum = 0; - for(int i = 0; i < row_size; ++i) { - sum += exp(x[offset + i] - max); - } - switch(req) { - case 1: - for(int i = 0; i < row_size; ++i) { - y[offset + i] = exp(x[offset + i] - max) / sum; - } - break; - case 2: - for(int i = 0; i < row_size; ++i) { - y[offset + i] += exp(x[offset + i] - max) / sum; - } - break; - } - } - """ - self.fwd_kernel_mod = mx.rtc.CudaModule(src, exports=["fwd", "fwd"]) - dtype = "double" if y.dtype == np.float64 else "float" - kernel_signature = "const {0}*, const {0}*, const int, const int".format(dtype) - kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), kernel_signature) - # args, ctx, grid_shape, block_shape, shared_mem = 0 - kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1)) + + if y.dtype == np.float64: + # args, ctx, grid_shape, block_shape, shared_mem = 0 + self.fwd_double_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1)) + else: + # args, ctx, grid_shape, block_shape, shared_mem = 0 + self.fwd_float_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1)) def backward(self, req, out_grad, in_data, out_data, in_grad, aux): if req[0] == "null": @@ -75,26 +101,13 @@ def backward(self, req, out_grad, in_data, out_data, in_grad, aux): l = in_data[1] # label y = out_data[0] # output from the forward pass dx = in_grad[0] # the storage for the gradient - if self.bwd_kernel_mod is None: - # Each block processes a row and each thread in a block calculate an element of `dx`. - src = r""" - template - __global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) { - const int z = static_cast(l[blockIdx.x]); - const int i = threadIdx.x + blockDim.x * blockIdx.x; - if(req == 1) { - dx[i] = threadIdx.x == z ? y[i] - 1 : y[i]; - } else { - dx[i] += threadIdx.x == z ? y[i] - 1 : y[i]; - } - } - """ - self.bwd_kernel_mod = mx.rtc.CudaModule(src, exports=["bwd", "bwd"]) - dtype = "double" if dx.dtype == np.float64 else "float" - kernel_signature = "const {0}*, const {0}*, {0}*, const int".format(dtype) - kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), kernel_signature) - # args, ctx, grid_shape, block_shape, shared_mem = 0 - kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1)) + + if dx.dtype == np.float64: + # args, ctx, grid_shape, block_shape, shared_mem = 0 + self.bwd_double_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1)) + else: + # args, ctx, grid_shape, block_shape, shared_mem = 0 + self.bwd_float_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1)) def _reqCode(self, req): if(req == "write"): From 00a3643574e31b02eebc552951dd27f1975826f2 Mon Sep 17 00:00:00 2001 From: chinakook Date: Wed, 7 Mar 2018 15:26:46 +0800 Subject: [PATCH 2/2] Update custom_softmax_rtc.py fix unnessesary format --- example/numpy-ops/custom_softmax_rtc.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/example/numpy-ops/custom_softmax_rtc.py b/example/numpy-ops/custom_softmax_rtc.py index 1ce7e7346d98..d07041b002d3 100644 --- a/example/numpy-ops/custom_softmax_rtc.py +++ b/example/numpy-ops/custom_softmax_rtc.py @@ -70,17 +70,17 @@ def __init__(self): fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd", "fwd"]) bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd", "bwd"]) - fwd_kernel_float_signature = "const {0}*, const {0}*, const int, const int".format("float") - self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<{}>".format("float"), fwd_kernel_float_signature) + fwd_kernel_float_signature = "const float*, const float*, const int, const int" + self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd", fwd_kernel_float_signature) - bwd_kernel_float_signature = "const {0}*, const {0}*, {0}*, const int".format("float") - self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<{}>".format("float"), bwd_kernel_float_signature) + bwd_kernel_float_signature = "const float*, const float*, float*, const int" + self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd", bwd_kernel_float_signature) - fwd_kernel_double_signature = "const {0}*, const {0}*, const int, const int".format("double") - self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<{}>".format("double"), fwd_kernel_double_signature) + fwd_kernel_double_signature = "const double*, const double*, const int, const int" + self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd", fwd_kernel_double_signature) - bwd_kernel_double_signature = "const {0}*, const {0}*, {0}*, const int".format("double") - self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<{}>".format("double"), bwd_kernel_double_signature) + bwd_kernel_double_signature = "const double*, const double*, double*, const int" + self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd", bwd_kernel_double_signature) def forward(self, is_train, req, in_data, out_data, aux): if req[0] == "null":