From 03111cef09b33a25792246785ab0ded68be1733c Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 8 Dec 2020 19:54:14 -0800 Subject: [PATCH 1/2] fix arch flags, add PTX --- op_builder/builder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index f44aee79637a..6fd13bf75781 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -221,7 +221,7 @@ def compute_capability_args(self, cross_compile_archs=None): 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. 2. If neither is set default compute capabilities will be used - 3. Under `jit_mode` compute capabilities of all visible cards will be used. + 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX Format: @@ -243,6 +243,7 @@ def compute_capability_args(self, cross_compile_archs=None): if cc not in ccs: ccs.append(cc) ccs = sorted(ccs) + ccs[-1] += '+PTX' else: # Cross-compile mode, compile for various architectures # env override takes priority @@ -261,7 +262,9 @@ def compute_capability_args(self, cross_compile_archs=None): args = [] for cc in ccs: cc = cc.replace('.', '') - args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') + args.append(f'-gencode=arch=compute_{cc},code=sm_{cc}') + if cc.endswith('+PTX'): + args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') return args From 2803061fd8126dda0950472122408a7ebf11d025 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 8 Dec 2020 21:21:48 -0800 Subject: [PATCH 2/2] bug fix --- op_builder/builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 6fd13bf75781..1f350065b4f6 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -261,10 +261,10 @@ def compute_capability_args(self, cross_compile_archs=None): args = [] for cc in ccs: - cc = cc.replace('.', '') - args.append(f'-gencode=arch=compute_{cc},code=sm_{cc}') + num = cc[0] + cc[2] + args.append(f'-gencode=arch=compute_{num},code=sm_{num}') if cc.endswith('+PTX'): - args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') + args.append(f'-gencode=arch=compute_{num},code=compute_{num}') return args