Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ void FusionExecutor::debugCompileFusionFromStr(
}

std::tie(compiled_kernel_, last_compiler_log_, last_compiled_binary_) =
executor_utils::nvrtcCompile(c10::nullopt, code, name, fusion_id_);
executor_utils::nvrtcCompile(
c10::nullopt, code, name, fusion_id_, swizzle_factor);
TORCH_INTERNAL_ASSERT(
fusion_id_ > 0, "assign a fusion_id_ <= 0 is not accepted.");
}
Expand Down Expand Up @@ -236,6 +237,11 @@ void FusionExecutor::compileFusion(
// TODO: refactor the options_ passed through
options_.device = c10::Device(c10::DeviceType::CUDA, args.getDeviceIndex());

// Set if before compilation and launchParams
if (compile_params.swizzle_factor) {
swizzle_factor = std::max(1, *compile_params.swizzle_factor);
}

// Set the index type of compile params if not already set. If set,
// make sure the compile param type is valid with the given kernel
// arguments.
Expand Down Expand Up @@ -362,6 +368,7 @@ void FusionExecutor::compileFusion(
structured_code,
(kernelNamespace() + "::" + kernelName()).c_str(),
fusion_id_,
swizzle_factor,
block_size,
maxrregcount_high_water_mark,
save_compiled_binary_ || isDebugDumpEnabled(DebugDumpOption::Sass));
Expand Down Expand Up @@ -778,6 +785,12 @@ LaunchParams FusionExecutor::computeLaunchParams(

launch_params.setSmem(dynamic_smem_size);

const int gdimx = launch_params.gdimx();
const int gdimy = launch_params.gdimy();

launch_params.bind(gdimx * swizzle_factor, ParallelType::BIDx, true);
launch_params.bind(
(gdimy + swizzle_factor - 1) / swizzle_factor, ParallelType::BIDy, true);
return launch_params;
}

Expand Down Expand Up @@ -1161,8 +1174,21 @@ std::vector<at::Tensor> FusionExecutor::runFusion(

// Recompile the kernel if the number of threads in the block has increased
// or maxrregcount has changed
if (launch_params_.nThreads() > block_size_high_water_mark ||
compile_params.maxrregcount != maxrregcount_high_water_mark) {

bool need_to_recompile = false;

need_to_recompile |= launch_params_.nThreads() > block_size_high_water_mark;

need_to_recompile |=
compile_params.maxrregcount != maxrregcount_high_water_mark;

if (compile_params.swizzle_factor) {
need_to_recompile |=
swizzle_factor != std::max(1, *compile_params.swizzle_factor);
swizzle_factor = std::max(1, *compile_params.swizzle_factor);
}

if (need_to_recompile) {
const auto kernel = lowered_->kernel();
kernel_code_ = codegen::generateCudaKernel(kernel, kernelName());
const auto structured_code =
Expand All @@ -1176,6 +1202,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
structured_code,
(kernelNamespace() + "::" + kernelName()).c_str(),
fusion_id_,
swizzle_factor,
block_size_high_water_mark,
maxrregcount_high_water_mark,
save_compiled_binary_);
Expand Down Expand Up @@ -1446,7 +1473,8 @@ void FusionExecutor::compileRtc(
fusion_id_ = 1;

std::tie(compiled_kernel_, last_compiler_log_, last_compiled_binary_) =
executor_utils::nvrtcCompile(c10::nullopt, scode, name, fusion_id_);
executor_utils::nvrtcCompile(
c10::nullopt, scode, name, fusion_id_, swizzle_factor);
}

float FusionExecutor::runRtc(
Expand Down
1 change: 1 addition & 0 deletions csrc/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
// increases, recompile to adjust maxregister count.
int64_t block_size_high_water_mark = 1;
int maxrregcount_high_water_mark = 255;
int swizzle_factor = 1;

// lookup table to take short cut to retrieve recorded information in order to
// launch kernels without re-inference parameters.
Expand Down
14 changes: 7 additions & 7 deletions csrc/executor_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,25 @@ void LaunchParams::assertValid() {
gdimz());
}

void LaunchParams::bind(int64_t val, ParallelType p_type) {
void LaunchParams::bind(int64_t val, ParallelType p_type, bool allow_rebind) {
switch (p_type) {
case ParallelType::TIDx:
checkAndSet(val, bdimx_, "blockDim.x");
checkAndSet(val, bdimx_, "blockDim.x", allow_rebind);
break;
case ParallelType::BIDx:
checkAndSet(val, gdimx_, "gridDim.x");
checkAndSet(val, gdimx_, "gridDim.x", allow_rebind);
break;
case ParallelType::TIDy:
checkAndSet(val, bdimy_, "blockDim.y");
checkAndSet(val, bdimy_, "blockDim.y", allow_rebind);
break;
case ParallelType::BIDy:
checkAndSet(val, gdimy_, "gridDim.y");
checkAndSet(val, gdimy_, "gridDim.y", allow_rebind);
break;
case ParallelType::TIDz:
checkAndSet(val, bdimz_, "blockdim.z");
checkAndSet(val, bdimz_, "blockdim.z", allow_rebind);
break;
case ParallelType::BIDz:
checkAndSet(val, gdimz_, "gridDim.z");
checkAndSet(val, gdimz_, "gridDim.z", allow_rebind);
break;
default:
TORCH_INTERNAL_ASSERT(
Expand Down
37 changes: 21 additions & 16 deletions csrc/executor_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@ struct TORCH_CUDA_CU_API CompileParams {
std::optional<PrimDataType> index_type = std::nullopt;
int maxrregcount = 255;
bool enable_magic_zero = true;
std::optional<int> swizzle_factor = std::nullopt;

bool operator==(const CompileParams& other) const {
// Disallow comparison if the index type is nullopt
TORCH_INTERNAL_ASSERT(
index_type.has_value(),
index_type && swizzle_factor,
"cannot compare as the index type is not defined");
TORCH_INTERNAL_ASSERT(
other.index_type.has_value(),
other.index_type && other.swizzle_factor,
"cannot compare as the other index type is not defined");
return index_type == other.index_type &&
maxrregcount == other.maxrregcount &&
enable_magic_zero == other.enable_magic_zero;
enable_magic_zero == other.enable_magic_zero &&
swizzle_factor == other.swizzle_factor;
}

bool operator!=(const CompileParams& other) const {
Expand Down Expand Up @@ -100,32 +102,35 @@ class TORCH_CUDA_CU_API LaunchParams {
void checkAndSet(
const int64_t incoming_val,
int64_t& class_val,
std::string val) {
TORCH_INTERNAL_ASSERT(
class_val == UNINITIALIZED_VAL || incoming_val == class_val,
"Tried to set ",
val,
" from ",
class_val,
" to ",
incoming_val,
", but it was already set and new value does not match.",
" Thread dims all have to be bound to the same value.");
std::string val,
bool allow_rebind = false) {
if (!allow_rebind) {
TORCH_INTERNAL_ASSERT(
class_val == UNINITIALIZED_VAL || incoming_val == class_val,
"Tried to set ",
val,
" from ",
class_val,
" to ",
incoming_val,
", but it was already set and new value does not match.",
" Thread dims all have to be bound to the same value.");
}
TORCH_CHECK(
incoming_val > 0,
"Received a thread binding on ",
val,
" that is ",
incoming_val,
". Cannot create negative threads.");
if (class_val == UNINITIALIZED_VAL) {
if (class_val == UNINITIALIZED_VAL || allow_rebind) {
class_val = incoming_val;
}
assertValid();
}

// Binds dim assocaited with p_type to val
void bind(int64_t val, ParallelType p_type);
void bind(int64_t val, ParallelType p_type, bool allow_rebind = false);

// Adjusted value based on get functions above for each value
int64_t getDim(ParallelType p_type) const;
Expand Down
5 changes: 5 additions & 0 deletions csrc/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,7 @@ std::tuple<NvrtcFunction, std::string, std::vector<char>> nvrtcCompile(
const std::string& code,
const std::string& func_name,
int id,
const int swizzle_factor,
c10::optional<int> opt_block_size,
const int max_register_heuristic,
bool return_compiled_binary) {
Expand Down Expand Up @@ -1092,6 +1093,10 @@ std::tuple<NvrtcFunction, std::string, std::vector<char>> nvrtcCompile(
args.push_back("-DPYTORCH_NVFUSER_PROFILE_KERNEL");
}

const std::string swizzle_factor_str =
"-DSWIZZLE_FACTOR=" + std::to_string(std::max(swizzle_factor, 1));
args.push_back(swizzle_factor_str.c_str());

const char* ptxas_opt_level = getenv("PYTORCH_NVFUSER_JIT_OPT_LEVEL");
std::string jit_opt_level = "-O";

Expand Down
1 change: 1 addition & 0 deletions csrc/executor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ std::tuple<NvrtcFunction, std::string, std::vector<char>> nvrtcCompile(
const std::string& code,
const std::string& func_name,
int id,
const int swizzle_factor,
c10::optional<int> opt_block_size = c10::nullopt,
const int max_register_heuristic = 255,
bool return_compiled_binary = false);
Expand Down
38 changes: 20 additions & 18 deletions csrc/lower_bank_conflict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,32 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) {
return 32;
}

ParallelType getParallelType(const std::string& name) {
if (name == "threadIdx.x") {
return ParallelType::TIDx;
} else if (name == "threadIdx.y") {
return ParallelType::TIDy;
} else if (name == "threadIdx.z") {
return ParallelType::TIDz;
} else if (name == "getBlockIdX()") {
return ParallelType::BIDx;
} else if (name == "getBlockIdY()") {
return ParallelType::BIDy;
} else if (name == "getBlockIdZ()") {
return ParallelType::BIDz;
}
TORCH_INTERNAL_ASSERT(false, "Not a parallel type");
}

bool isThreadIdx(const std::string& name) {
return name == "threadIdx.x" || name == "threadIdx.y" ||
name == "threadIdx.z";
}

bool isBlockIdx(const std::string& name) {
return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z";
auto parallelType = getParallelType(name);
return parallelType == ParallelType::BIDx ||
parallelType == ParallelType::BIDy || parallelType == ParallelType::BIDz;
}

bool isBlockDim(const std::string& name) {
Expand All @@ -67,23 +86,6 @@ bool isGridDim(const std::string& name) {
return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z";
}

ParallelType getParallelType(const std::string& name) {
if (name == "threadIdx.x") {
return ParallelType::TIDx;
} else if (name == "threadIdx.y") {
return ParallelType::TIDy;
} else if (name == "threadIdx.z") {
return ParallelType::TIDz;
} else if (name == "blockIdx.x") {
return ParallelType::BIDx;
} else if (name == "blockIdx.y") {
return ParallelType::BIDy;
} else if (name == "blockIdx.z") {
return ParallelType::BIDz;
}
TORCH_INTERNAL_ASSERT(false, "Not a parallel type");
}

std::vector<int64_t> evaluateAddressesOnFirstPhase(
kir::TensorIndex* ti,
const std::vector<kir::ForLoop*>& for_loops,
Expand Down
11 changes: 11 additions & 0 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,17 @@ void scheduleMatmul(
// [... M,N,K]
scheduler_utils::matmul_utils::makeTile(cc, gemm_tile.cta_tile.toVector());

// int factor =
// getenv("SWIZZLE_FACTOR") != nullptr ?
// std::atoi(getenv("SWIZZLE_FACTOR")) : 1;
// if (factor != 1) {
// cc->split(1, factor, false); // outer split
// cc->split(0, factor, true); // inner
// // Mo, Mi, Ni, Mo
// cc->merge(1, 2);
// cc->merge(2, 1);
// }

// [Mo, No, Ko, Mi, Ni, Ki]
// Propagate tiling globally
scheduler_utils::transformPropagateToAllFrom(cc, -1);
Expand Down
6 changes: 3 additions & 3 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,11 @@ static const char* rng_op_type2string(RNGOpType t) {
static const char* parallel_type2string(ParallelType t) {
switch (t) {
case ParallelType::BIDz:
return "blockIdx.z";
return "getBlockIdZ()";
case ParallelType::BIDy:
return "blockIdx.y";
return "getBlockIdY()";
case ParallelType::BIDx:
return "blockIdx.x";
return "getBlockIdX()";
case ParallelType::TIDz:
return "threadIdx.z";
case ParallelType::TIDy:
Expand Down
24 changes: 24 additions & 0 deletions runtime/helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@
#include <assert.h>
#endif // __NVCC__

constexpr unsigned swizzle_factor = SWIZZLE_FACTOR;

__device__ unsigned getBlockIdX() {
static_assert(swizzle_factor >= 1);
if constexpr (swizzle_factor == 1) {
return blockIdx.x;
} else {
return blockIdx.x / swizzle_factor;
}
}

__device__ unsigned getBlockIdY() {
static_assert(swizzle_factor >= 1);
if constexpr (swizzle_factor == 1) {
return blockIdx.y;
} else {
return blockIdx.y * swizzle_factor + (blockIdx.x % swizzle_factor);
}
}

__device__ unsigned getBlockIdZ() {
return blockIdx.z;
}

__device__ constexpr int ceilDiv(int a, int b) {
return (a + b - 1) / b;
}
Expand Down
Loading