Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
"""Utility to invoke nvcc compiler in the system"""
from __future__ import absolute_import as _abs

import subprocess
import os
import subprocess
import warnings

import tvm._ffi
from tvm.runtime import ndarray as nd

from . import utils
from .._ffi.base import py_str
from . import utils


def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None):
Expand Down Expand Up @@ -78,6 +78,10 @@ def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None):
else:
cmd += ["-arch", arch]

# Fast math, pass in through options later
cmd += ["--use_fast_math"]
cmd += ["--ptxas-options=--fmad"]

if options:
if isinstance(options, str):
cmd += [options]
Expand Down
2 changes: 2 additions & 0 deletions src/target/opt/build_cuda_on.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
compile_params.push_back(include_option);
}

compile_params.push_back("--use_fast_math");

for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
Expand Down