Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7ef15e0
[MatMul] Prolog build out, adding automatic swizzle generator
zasdfgbnm Mar 20, 2023
6afb375
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 20, 2023
0ee396c
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 22, 2023
a28293a
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 23, 2023
0700c9e
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 23, 2023
87c16d4
Merge branch 'main' into matmul_swizzle_gen
xwang233 Mar 24, 2023
f7be625
Merge branch 'main' into matmul_swizzle_gen
xwang233 Mar 24, 2023
4abdc49
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_gen
zasdfgbnm Mar 29, 2023
98cfadc
Merge branch 'matmul_swizzle_gen' of github.com:NVIDIA/Fuser into mat…
zasdfgbnm Mar 29, 2023
4fda654
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 29, 2023
91852d2
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 31, 2023
6589e06
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_gen
zasdfgbnm Apr 3, 2023
7b8e75a
fix
zasdfgbnm Apr 3, 2023
2a84c52
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Apr 5, 2023
7a83199
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Apr 5, 2023
e38a9a6
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Apr 8, 2023
80e0c2e
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_gen
zasdfgbnm Apr 10, 2023
ed2205f
Merge branch 'matmul_swizzle_gen' of github.com:NVIDIA/Fuser into mat…
zasdfgbnm Apr 10, 2023
9ac4f51
test bank conflict
zasdfgbnm Apr 10, 2023
d1c15f1
cleanup
zasdfgbnm Apr 10, 2023
e084d4b
Matmul prolog swizzle new algo
zasdfgbnm Apr 10, 2023
566d26e
save
zasdfgbnm Apr 11, 2023
9152319
save
zasdfgbnm Apr 11, 2023
a6e0878
save
zasdfgbnm Apr 11, 2023
278261e
save
zasdfgbnm Apr 11, 2023
0fb2d83
save
zasdfgbnm Apr 11, 2023
348b5d2
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_ge…
zasdfgbnm Apr 11, 2023
de64355
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_ge…
zasdfgbnm Apr 12, 2023
1266943
save
zasdfgbnm Apr 12, 2023
b39a04f
update
zasdfgbnm Apr 13, 2023
de1f058
another update
zasdfgbnm Apr 13, 2023
f4ac9c8
guard assert
zasdfgbnm Apr 13, 2023
f5848b6
save
zasdfgbnm Apr 13, 2023
d1256b1
update
zasdfgbnm Apr 13, 2023
5610465
simplify the case with g != 1
zasdfgbnm Apr 14, 2023
7879c97
Merge branch 'main' into matmul_swizzle_gen_new_algo
zasdfgbnm Apr 14, 2023
77cef35
k->m
zasdfgbnm Apr 14, 2023
d54a76f
fix indexing error
zasdfgbnm Apr 14, 2023
35be768
update doc
zasdfgbnm Apr 14, 2023
0259bbf
Merge branch 'main' into matmul_swizzle_gen_new_algo
zasdfgbnm Apr 14, 2023
9d8db7f
minv
zasdfgbnm Apr 14, 2023
1da2441
disable FusionAmpereMatmulSASSRegisterUsageLDSM_CUDA
zasdfgbnm Apr 14, 2023
e85f61d
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_ge…
zasdfgbnm Apr 17, 2023
abe3328
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_ge…
zasdfgbnm Apr 17, 2023
0a60fbf
Merge branch 'main' into matmul_swizzle_gen_new_algo
zasdfgbnm Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion benchmark/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <fusion.h>
#include <ir_all_nodes.h>
#include <ir_utils.h>
#include <lower_bank_conflict.h>
#include <ops/all_ops.h>
#include <scheduler/all_schedulers.h>
#include <scheduler/matmul.h>
Expand Down Expand Up @@ -231,8 +232,16 @@ static void SingleMatmulBase(
cparams.enable_magic_zero = false;

// Compile kernel
auto launch_constraints = LaunchParams();
FusionExecutor fe;
fe.compileFusion(fusion, args, LaunchParams(), cparams);
fe.compileFusion(fusion, args, launch_constraints, cparams);
auto properties = at::cuda::getDeviceProperties(inputs.first.get_device());
if (properties->major >= 8 ||
(properties->major == 7 && properties->minor >= 5)) {
TORCH_CHECK(
getBankConflictInfo(fe.kernel(), launch_constraints).empty(),
"Shared memory bank conflict not removed.");
}

// Warm up run
auto outputs = fe.runFusion({inputs.first, inputs.second});
Expand Down
21 changes: 7 additions & 14 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,21 +559,14 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) {
// Handle inactive swizzles by just passing through index
// and extend information.

TORCH_INTERNAL_ASSERT(
index_map_.count(in_x_id) == index_map_.count(in_y_id),
"input index should be either both defined or both undefined");
if (index_map_.count(in_x_id)) {
// Only propagate original index through if
// the input index hasn't been computed.
// TODO:
// This part should be cleaner once we remove the
// second index traversal pass.
return;
if (!index_map_.count(in_x_id)) {
index_map_[in_x_id] = out_x_ind;
extent_map_[in_x_id] = getExtent(out_x_id);
}
if (!index_map_.count(in_y_id)) {
index_map_[in_y_id] = out_y_ind;
extent_map_[in_y_id] = getExtent(out_y_id);
}
index_map_[in_x_id] = out_x_ind;
index_map_[in_y_id] = out_y_ind;
extent_map_[in_y_id] = getExtent(out_y_id);
extent_map_[in_x_id] = getExtent(out_x_id);
} else {
// Generate integer swizzle math if the
// swizzle is activated. See also
Expand Down
422 changes: 389 additions & 33 deletions csrc/scheduler/matmul.cpp

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1618,7 +1618,7 @@ bool isFakeBoundaryTensorview(
//! transform to by BoundedDirectionalTransformPropagator.
std::unordered_set<TensorView*> getDirectionalPropagatePathSet(
TensorView* from_tv,
std::vector<TensorView*> boundary_tvs,
const std::vector<TensorView*>& boundary_tvs,
BoundedDirectionalTransformPropagator::Options options,
PropagateDirection direction) {
// Prepare to collect all candidate tensorviews
Expand Down Expand Up @@ -1730,9 +1730,9 @@ void BoundedDirectionalTransformPropagator::backward(
if (!options.has_value()) {
options = Options();
}
TORCH_INTERNAL_ASSERT(
!to.empty(),
"Propagation needs to be bounded, so no support for empty boundary.");
if (to.empty()) {
to = ir_utils::inputTvsOf(from);
}

// Collect all tvs to included on the backward path as specified
// by boundary and options.
Expand Down
29 changes: 1 addition & 28 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,34 +1108,7 @@ size_t dataTypeSize(DataType type) {
[](auto&& dtype) -> size_t {
using T = std::decay_t<decltype(dtype)>;
if constexpr (std::is_same_v<T, PrimDataType>) {
switch (dtype) {
case DataType::Bool:
return sizeof(bool);
case DataType::ComplexDouble:
return sizeof(std::complex<double>);
case DataType::ComplexFloat:
return sizeof(std::complex<float>);
case DataType::Double:
return sizeof(double);
case DataType::Float:
return sizeof(float);
case DataType::Half:
return sizeof(at::Half);
case DataType::BFloat16:
return sizeof(at::BFloat16);
case DataType::Index:
TORCH_INTERNAL_ASSERT(
false,
"The actual type of Index is only known at compile time.");
case DataType::Int:
return sizeof(uint64_t);
case DataType::Int32:
return sizeof(uint32_t);
case DataType::SMemAddress:
return sizeof(unsigned);
default:
TORCH_INTERNAL_ASSERT(false, "Size undefined for data type.");
}
return primDataTypeSize(dtype);
} else if constexpr (std::is_same_v<T, PointerOf>) {
return sizeof(void*);
} else if constexpr (std::is_same_v<T, ArrayOf>) {
Expand Down
30 changes: 30 additions & 0 deletions csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,36 @@ TORCH_CUDA_CU_API const char* load_store_type2string(LoadStoreOpType t);
TORCH_CUDA_CU_API c10::optional<std::string> cast_func_str(
const std::pair<DataType, DataType>&);

constexpr inline size_t primDataTypeSize(PrimDataType type) {
switch (type) {
case DataType::Bool:
return sizeof(bool);
case DataType::ComplexDouble:
return sizeof(std::complex<double>);
case DataType::ComplexFloat:
return sizeof(std::complex<float>);
case DataType::Double:
return sizeof(double);
case DataType::Float:
return sizeof(float);
case DataType::Half:
return sizeof(at::Half);
case DataType::BFloat16:
return sizeof(at::BFloat16);
case DataType::Index:
TORCH_INTERNAL_ASSERT(
false, "The actual type of Index is only known at compile time.");
case DataType::Int:
return sizeof(uint64_t);
case DataType::Int32:
return sizeof(uint32_t);
case DataType::SMemAddress:
return sizeof(unsigned);
default:
TORCH_INTERNAL_ASSERT(false, "Size undefined for data type.");
}
}

TORCH_CUDA_CU_API size_t dataTypeSize(DataType type);

// If the index type is known it will be automatically used here
Expand Down
17 changes: 17 additions & 0 deletions test/test_gpu_matmul_sass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ sass::Container getSASSFor(
params.double_buffer_options.smem_double_buffer_stage = 4;
scheduleMatmul(&fusion, params);

fusion.printTransforms();

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -245,6 +247,20 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSASSModifiersCheck_CUDA) {
}
}

#if 0

TODO: With swizzle, the cuda code looks like:

#pragma unroll
for(nvfuser_index_t i507 = 0; i507 < 8; ++i507) {
int i18439;
i18439 = i18438 + i507;
Turing::ldMatrixT (*reinterpret_cast<Array<__half,4,4>*>(&T9[(4 * i507)]),((i18437 + (128 * (i18439 / 8))) + (16 * (i6455 ^ (i18439 % 8)))));
}

where i6455 = (((nvfuser_index_t)threadIdx.x) % 16) % 8 so it no longer make sense to require the memory access pattern below.
We need to reinvestigate the test below to determine whether to change it or delete it.

// Check that all LDSM instructions has the following pattern:
// LDSM.16.M88.2 R2, [R213] ;
// LDSM.16.M88.2 R136, [R213+0x200] ;
Expand Down Expand Up @@ -317,5 +333,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSASSRegisterUsageLDSM_CUDA) {
}
}
}
#endif

} // namespace nvfuser
28 changes: 28 additions & 0 deletions test/test_gpu_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) {
params.tile_sizes = gemm_tile;
scheduleMatmul(&fusion, params);

// prologSwizzle on Volta is not supported yet
// ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -367,6 +370,9 @@ TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) {
params.double_buffer_options.double_buffer_smem_read = true;
scheduleMatmul(&fusion, params);

// prologSwizzle on Volta is not supported yet
// ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -651,6 +657,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) {
params.double_buffer_options.smem_double_buffer_stage = 4;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -706,6 +714,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) {
params.double_buffer_options.smem_double_buffer_stage = stage;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -772,6 +782,8 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) {

scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -879,6 +891,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) {
params.double_buffer_options.double_buffer_smem_read = true;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -1802,6 +1816,8 @@ TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) {
params.tile_sizes = gemm_tile;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -2818,6 +2834,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) {
params.double_buffer_options.smem_double_buffer_stage = 3;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -2866,6 +2884,8 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) {
params.tile_sizes = gemm_tile;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -2921,6 +2941,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) {
params.double_buffer_options.double_buffer_smem_write = true;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -2986,6 +3008,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) {

scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -3045,6 +3069,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) {

scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down Expand Up @@ -3097,6 +3123,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) {
params.double_buffer_options.smem_double_buffer_stage = 3;
scheduleMatmul(&fusion, params);

ASSERT_TRUE(fusion.bankConflictInfo().empty());

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

Expand Down