diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 612be292e873..09583d066463 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -18,15 +18,15 @@ """Utility to invoke nvcc compiler in the system""" from __future__ import absolute_import as _abs -import subprocess import os +import subprocess import warnings import tvm._ffi from tvm.runtime import ndarray as nd -from . import utils from .._ffi.base import py_str +from . import utils def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None): @@ -78,6 +78,10 @@ def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None): else: cmd += ["-arch", arch] + # Fast math, pass in through options later + cmd += ["--use_fast_math"] + cmd += ["--ptxas-options=--fmad"] + if options: if isinstance(options, str): cmd += [options] diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 4a2917daa5ed..03d2fd04c28a 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -102,6 +102,8 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { compile_params.push_back(include_option); } + compile_params.push_back("--use_fast_math"); + for (const auto& string : compile_params) { param_cstrings.push_back(string.c_str()); }