Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ endif
# be JIT-compiled by the updated driver from the included PTX.
ifeq ($(USE_CUDA), 1)
ifeq ($(CUDA_ARCH),)
KNOWN_CUDA_ARCHS := 30 35 50 52 60 61 70 75
KNOWN_CUDA_ARCHS := 30 35 50 52 60 61 70 75 80
# Run nvcc on a zero-length file to check architecture-level support.
# Create args to include SASS in the fat binary for supported levels.
CUDA_ARCH := $(foreach arch,$(KNOWN_CUDA_ARCHS), \
Expand Down
360 changes: 232 additions & 128 deletions python/mxnet/test_utils.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,14 @@ void linalg_batch_det_backward_helper(const Tensor<xpu, 3, DType>& LU,
const DType zero_det,
const mxnet::OpContext& ctx);

#ifdef __CUDACC__
#if CUDA_VERSION < 11000
#define VERSION_ADJUSTED_TF32_MATH CUBLAS_DEFAULT_MATH
#else
#define VERSION_ADJUSTED_TF32_MATH CUBLAS_TF32_TENSOR_OP_MATH
#endif
#endif // __CUDACC__

#include "linalg_impl.h"

#endif // MXNET_OPERATOR_LINALG_H_
34 changes: 26 additions & 8 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,15 @@ inline void linalg_gemm<gpu, float>(const Tensor<gpu, 2, float>& A,
#else
cublasDataType_t full_datatype = CUBLAS_DATA_FULL;
#endif
auto handle = Stream<gpu>::GetBlasHandle(s);
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH);
CUBLAS_CALL(cublasSgemmEx(
Stream<gpu>::GetBlasHandle(s), (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(1), C.size(0),
(tB ? B.size(1) : B.size(0)), &alpha, B.dptr_, full_datatype, B.stride_,
A.dptr_, full_datatype, A.stride_, &beta, C.dptr_, full_datatype,
C.stride_))
C.stride_));
CUBLAS_CALL(cublasSetMathMode(handle, saved_math_mode));
}

#else
Expand All @@ -228,13 +231,16 @@ void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<g
using mshadow::gpu; \
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
auto handle = Stream<gpu>::GetBlasHandle(s); \
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH); \
CUBLAS_CALL(cublas##fname(handle, \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(2), C.size(0), (tB ? B.size(2) : B.size(0)), &alpha, \
B.dptr_, B.size(1)*B.stride_, B.stride_, \
A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \
C.dptr_, C.size(1)*C.stride_, C.stride_, A.size(1))) \
CUBLAS_CALL(cublasSetMathMode(handle, saved_math_mode)); \
}
LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float)
LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double)
Expand Down Expand Up @@ -342,13 +348,22 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB); \
using namespace mshadow::cuda; \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
auto handle = Stream<gpu>::GetBlasHandle(s); \
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH); \
CUBLAS_CALL(cublas##fname(handle, \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(2), C.size(1), (tB ? B.size(2) : B.size(1)), \
&alpha, B.dptr_, B.stride_, B.size(1) * B.stride_, \
A.dptr_, A.stride_, A.size(1) * A.stride_, \
&beta, C.dptr_, C.stride_, C.size(1) * C.stride_, A.size(0))) \
&alpha, \
B.dptr_, B.stride_, \
static_cast<int64_t>(B.size(1) * B.stride_), \
A.dptr_, A.stride_, \
static_cast<int64_t>(A.size(1) * A.stride_), \
&beta, \
C.dptr_, C.stride_, \
static_cast<int64_t>(C.size(1) * C.stride_), \
A.size(0))) \
CUBLAS_CALL(cublasSetMathMode(handle, saved_math_mode)); \
}

LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)
Expand All @@ -373,7 +388,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:

using namespace mshadow::cuda;
auto cublas_math_mode =
use_tensor_ops ? CUBLAS_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
use_tensor_ops ? CUBLAS_TENSOR_OP_MATH : VERSION_ADJUSTED_TF32_MATH;
auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);

// cublasGemmStridedBatchedEx is only supported for GPU with architecture
Expand Down Expand Up @@ -414,6 +429,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \
auto handle = Stream<gpu>::GetBlasHandle(s); \
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH); \
for (index_t i = 0; i < A.size(2); ++i) { \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
Expand All @@ -423,6 +440,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
A.dptr_+i*A.stride_, A.size(2) * A.stride_, A.size(1)*A.size(2)*A.stride_, &beta, \
C.dptr_+i*C.stride_, C.size(2) * C.stride_, C.size(1)*C.size(2)*C.stride_, A.size(0))) \
}\
SetCublasMathMode(handle, saved_math_mode); \
}

LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float)
Expand Down
19 changes: 11 additions & 8 deletions src/operator/numpy/np_true_divide-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs,
});
} else {
#ifndef _WIN32
CHECK_EQ(outputs[0].type_flag_, kFloat32) << "true_divide only supports float32 output "
"when input's dtype is "
<< type_string(inputs[0].type_flag_);
MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
Kernel<op_with_req<OP, Req>, xpu>::Launch(
s, data.Size(), out.dptr<float>(), data.dptr<DType>(),
static_cast<float>(alpha));
CHECK(out.type_flag_ == mshadow::kFloat32 || out.type_flag_ == mshadow::kFloat64)
<< "true_divide only supports float32 and float64"
" output when input's dtype is "
<< type_string(inputs[0].type_flag_);
MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, ODType, {
MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
Kernel<op_with_req<OP, Req>, xpu>::Launch(
s, data.Size(), out.dptr<ODType>(), data.dptr<DType>(),
static_cast<ODType>(alpha));
});
});
});
#else
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_consistency(dump=False):
ctx_list = [{'ctx': mx.gpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}},
{'ctx': mx.cpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}}]
gt = check_consistency(sym, ctx_list, arg_params=arg_params, aux_params=aux_params,
tol=1e-3, grad_req='null', raise_on_err=False, ground_truth=gt)
rtol=1e-3, atol=1e-3, grad_req='null', raise_on_err=False, ground_truth=gt)
if dump:
np.savez('data/inception-v3-dump.npz', **{n: a.asnumpy() for n, a in gt.items()})

Expand Down
16 changes: 11 additions & 5 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def check_rnn_layer(layer):
states = layer.begin_state(16)
co, cs = layer(x, states)

# atol of 1e-6 required, as exposed by seed 2124685726
assert_almost_equal(go, co, rtol=1e-2, atol=1e-6)
assert_almost_equal(go, co)
for g, c in zip(gs, cs):
assert_almost_equal(g, c, rtol=1e-2, atol=1e-6)
assert_almost_equal(g, c)


@with_seed()
Expand All @@ -70,9 +69,9 @@ def check_rnn_layer_w_rand_inputs(layer):
states = layer.begin_state(16)
co, cs = layer(x, states)

assert_almost_equal(go, co, rtol=1e-2, atol=1e-6)
assert_almost_equal(go, co)
for g, c in zip(gs, cs):
assert_almost_equal(g, c, rtol=1e-2, atol=1e-6)
assert_almost_equal(g, c)


@with_seed()
Expand Down Expand Up @@ -481,6 +480,13 @@ def tensor_size(big_tensor_bytes):
# This in the past has given cudnnFind() trouble when it needed to allocate similar I/O's
# from the area carved out by the MXNET_GPU_MEM_POOL_RESERVE setting (by default 5%).
(free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id)
# This test needs to be 'qualified' for use with each new larger memory size
largest_supported_total_mem_GB = 32
if (total_mem_bytes > largest_supported_total_mem_GB * 1024 * 1024 * 1024):
sys.stderr.write(
' bypassing test due to too-large global memory of size {} ... '.format(total_mem_bytes))
return

start_size = tensor_size(0.20 * total_mem_bytes)
num_trials = 10
sys.stderr.write(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_gluon_model_zoo_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_inference():
max_val = np.max(np.abs(cpu_out.asnumpy()))
gpu_max_val = np.max(np.abs(gpu_out.asnumpy()))
eprint(model_name + ": CPU " + str(max_val) + ", GPU " + str(gpu_max_val))
assert_almost_equal(cpu_out / max_val, gpu_out / gpu_max_val, rtol=1e-3, atol=1e-3)
assert_almost_equal(cpu_out / max_val, gpu_out / gpu_max_val)

def get_nn_model(name):
if "densenet" in name:
Expand Down
Loading