Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def compute_capability_args(self, cross_compile_archs=None):

1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
2. If neither is set default compute capabilities will be used
3. Under `jit_mode` compute capabilities of all visible cards will be used.
3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX

Format:

Expand All @@ -243,6 +243,7 @@ def compute_capability_args(self, cross_compile_archs=None):
if cc not in ccs:
ccs.append(cc)
ccs = sorted(ccs)
ccs[-1] += '+PTX'
else:
# Cross-compile mode, compile for various architectures
# env override takes priority
Expand All @@ -260,8 +261,10 @@ def compute_capability_args(self, cross_compile_archs=None):

args = []
for cc in ccs:
cc = cc.replace('.', '')
args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}')
num = cc[0] + cc[2]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc.endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')

return args

Expand Down