diff --git a/benchmark/matmul.cpp b/benchmark/matmul.cpp index 10c32fecdaf..1a3e2c6343a 100644 --- a/benchmark/matmul.cpp +++ b/benchmark/matmul.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -231,8 +232,16 @@ static void SingleMatmulBase( cparams.enable_magic_zero = false; // Compile kernel + auto launch_constraints = LaunchParams(); FusionExecutor fe; - fe.compileFusion(fusion, args, LaunchParams(), cparams); + fe.compileFusion(fusion, args, launch_constraints, cparams); + auto properties = at::cuda::getDeviceProperties(inputs.first.get_device()); + if (properties->major >= 8 || + (properties->major == 7 && properties->minor >= 5)) { + TORCH_CHECK( + getBankConflictInfo(fe.kernel(), launch_constraints).empty(), + "Shared memory bank conflict not removed."); + } // Warm up run auto outputs = fe.runFusion({inputs.first, inputs.second}); diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index ae126efb40d..b9964df33af 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -559,21 +559,14 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) { // Handle inactive swizzles by just passing through index // and extend information. - TORCH_INTERNAL_ASSERT( - index_map_.count(in_x_id) == index_map_.count(in_y_id), - "input index should be either both defined or both undefined"); - if (index_map_.count(in_x_id)) { - // Only propagate original index through if - // the input index hasn't been computed. - // TODO: - // This part should be cleaner once we remove the - // second index traversal pass. - return; + if (!index_map_.count(in_x_id)) { + index_map_[in_x_id] = out_x_ind; + extent_map_[in_x_id] = getExtent(out_x_id); + } + if (!index_map_.count(in_y_id)) { + index_map_[in_y_id] = out_y_ind; + extent_map_[in_y_id] = getExtent(out_y_id); } - index_map_[in_x_id] = out_x_ind; - index_map_[in_y_id] = out_y_ind; - extent_map_[in_y_id] = getExtent(out_y_id); - extent_map_[in_x_id] = getExtent(out_x_id); } else { // Generate integer swizzle math if the // swizzle is activated. See also diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index c8eebe6ae55..e86eb3df484 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -17,6 +17,12 @@ namespace nvfuser { namespace { + +// Returns true if given number is power of 2 +constexpr bool isPowOf2(int x) { + return x > 1 && (x & (x - 1)) == 0; +} + // Move the broadcast axes to the left on the specified number of inner // dimensions e.g. (when number_of_inner_pos == 3): // [... I0, B, I1] -> [... B, I0, I1] @@ -49,6 +55,386 @@ void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) { tv->reorder(order_map); } +//! Automatically generates the shared memory swizzled data layout +//! for matmul mainloop. +//! The shared mem datalayout is always 2D currently, and this utility +//! function assumes that the innermost 2 dimensions on shared_mem_tv +//! are the ones begin swizzled. +void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { + // Check that the innermost 2 dimensions are concrete and static + // sized so that the swizzle function can be defined. + + // Utility to check concrete static size: + auto check_concrete_static_dim = [](IterDomain* id) { + TORCH_INTERNAL_ASSERT( + !id->isBroadcast() && !id->isReduction(), + "no support on reduction or broadcast dims, but get ", + id->toString()); + TORCH_INTERNAL_ASSERT( + id->extent()->isConstInt(), + "swizzled dimensions need to be statically, but get ", + id->toString()); + }; + + TORCH_INTERNAL_ASSERT( + shared_mem_tv->nDims() >= 2, + "At least 2D input needed for swizzling, but get ", + shared_mem_tv->toString()); + check_concrete_static_dim(shared_mem_tv->axis(-2)); + check_concrete_static_dim(shared_mem_tv->axis(-1)); + + // Extract the constant sizes of the swizzled tile + const auto tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt(); + const auto tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt(); + + if (isTuring(params.mma_op) || isAmpere(params.mma_op)) { + // TODO: right now, we are assuming ldmatrix access, which only supports + // sizeof(T) == 16bit (i.e. half/bfloat16) load according to offical doc: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix + // In the future, when we start adding support for tf32(different macro), + // fp32(ffma), double, int8, fp8, etc. we need to update this function. + TORCH_INTERNAL_ASSERT(dataTypeSize(*shared_mem_tv->getDataType()) == 2); + + // ldmatrix loads a ldmatrix_rows x ldmatrix_cols = 8 x 8 matrix each time, + constexpr int ldmatrix_rows = 8; + constexpr int ldmatrix_cols = 8; + + // Column size of the tile needs to be multiples of 8 for ldmatrix to work. + TORCH_INTERNAL_ASSERT( + tile_size_x >= ldmatrix_rows && tile_size_x % ldmatrix_rows == 0 && + tile_size_y >= ldmatrix_cols && tile_size_y % ldmatrix_cols == 0, + "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", + tile_size_x, + "x", + tile_size_y); + + /* Note [How to remove bank conflict for ldmatrix?] + * + * **This note is interleaved with code, I suggest reading this note like + * reading a jupyter notebook** + * + * Our task is to make sure different rows does not fall into the same + * bank of shared memory. + * + * Introduction to bank conflict can be found at page 54-72 of: + * https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf + * + * When we talk about bank conflict removal, we are talking about the + * following task: + * "there are 32 banks, and each bank contains one 4-byte word, we want to + * make sure different lanes in a warp does not access different word + * addresses in the same bank" + * For example, if thread 0 is accessing word address 1, and thread 1 is + * accessing word address 33, then these two threads will have a bank + * conflict because they are accessing different word addresses in the same + * bank. However, if thread 0 is accessing byte address 4 and thread 1 is + * accessing byte address 6 then there will be no bank conflict because 4 + * and 6 both belong to word 1. + */ + + constexpr int smem_bytes_per_word = 4; + constexpr int smem_banks = 32; + + /* but here, for our convenience, because ldmatrix always use vectorized + * access of 8 items = 16 bytes = 4 words, we further group words into + * units: we consider each 4 words as a "unit", and each 4 banks as a + * "megabank". So we can rephrase our task as: + * "there are 8 megabanks, and each megabanks contains one 4-word unit, we + * want to make sure different lanes in a warp does not access different + * unit addresses in the same megabank" + * In this terminology, matrices are in the row major format, each matrix + * has 8 rows, and each row has exactly one unit. + */ + + constexpr int items_per_unit = ldmatrix_cols; + constexpr int bytes_per_unit = + items_per_unit * primDataTypeSize(DataType::Half); + constexpr int words_per_unit = bytes_per_unit / smem_bytes_per_word; + constexpr int num_megabanks = smem_banks / words_per_unit; + + /* In the following example, each CTA tile contains 2 rows and 3 colums of + * matrices, each 8x8 size: + * +----------+----------+----------+ + * | matrix 0 | matrix 1 | matrix 2 | + * +----------+----------+----------+ + * | matrix 3 | matrix 4 | matrix 5 | + * +----------+----------+----------+ + * The addresses of different rows in the same matrix are offset by 3 units. + * In this perspective, loading a matrix is a strided memory access with the + * following stride (in units): + */ + + // number of units per row + int row_stride = tile_size_y / items_per_unit; + + /* So the bank conflicting problem is now converted to the following game: + * I have a clock that has one pointer and `num_megabanks` ticks. I start + * my game by making my pointer pointing to somewhere, and turn forward + * the pointer `ldmatrix_rows` times, each time by `row_stride` ticks. + * This problem can be well modeled by modular arithmetic in number theory + * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. + * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. + * Additions and multiplications are defined in a cyclic manner: + * 5 + 1 = 0, 5 + 2 = 1, 5 + 3 = 2, 5 + 4 = 3, ... + * 2 * 1 = 2, 2 * 2 = 4, 2 * 3 = 0, 2 * 4 = 2, ... + * With this definition, Z is mapped to Z/nZ naturally by i -> i % n [2] + * + * It worth mention that Z/nZ is a "commutative ring", that is, we can use + * addition and multiplication rules just like using normal integers: + * a + b = b + a, a * (b + c) = a * b + a * c, ... + * In short, we can reason about Z/nZ just like we are reasoning about + * integers, except that every number is automatically "% n". + * + * Reference: + * [1] https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n + * [2] The % is under Euclidean definition, that is -1 % 6 is 5 instead of + * -1, see [The Mathematics of Integer Arithmetic] for more detail. But + * we are only interested in non-negative numbers here, so there is no + * need to worry about this problem + */ + + // row_stride in Z/nZ, where n is num_megabanks: + // assert(row_stride >= 0); + // assert(num_megabanks >= 0); + int row_stride_znz = row_stride % num_megabanks; + + /* Consider the following function in Z/nZ: + * f(i; init) = init + i * stride + * where init is the initial position of the pointer in the clock when we + * start the game, and stride is the number of ticks we move forward each + * time, and i is the number of times we move forward. For a fixed init, we + * abbrivate f(i; init) as f(i). + * + * In our problem, f(i) is the megabank of the `i`th row of the matrix, and + * `init` is the megabank of the 0th row of the matrix. + * + * One very important property of f(i) is: + * - if f(i1) == f(i2), then for every j, f(i1 + j) = f(i2 + j) + * This property is true because: + * f(i1 + j) = f(i1) + j * stride = f(i2) + j * stride = f(i2 + j) + * + * The above property tells us, as we turn the clock forward: + * - initially, we will go to a never-visited tick in each turn, but, + * - at some point, we will return back to our original position, and, + * - after we return, we start repeat the pervious pattern again and again. + * + * As an example, consider f(i) where init = 0, stride = 6, under Z/8Z: + * i 0 1 2 3 4 5 6 7 + * f(i) 0 6 4 2 0 6 4 2 + * We can see that f(i) is repeating a pattern of four unique numbers + * "0 6 4 2" twice. In our bank conflict problem, this means we are using 4 + * different megabanks, and we have a 2-way conflict. + * + * The question of interest is, does the above observation generalize? That + * is, does f(i) always repeat a pattern of p unique numbers q times? Note + * that p and q must satisfy p * q = n. + * + * The answer to the above question is: yes! Consider the following + * equation: + * f(i1 + j) == f(i1) + * We want to know what is the smallest positive number j that makes the + * above equation true. Because this tells us in how many steps we will see + * repeat. This equation can be simplified as: + * f(i1 + j) == f(i1) + j * stride == f(i1) + * ==> j * stride == 0 + * + * An important tool to study this equation is multiplicative inverse: + * https://en.wikipedia.org/wiki/Modular_multiplicative_inverse + * A number i has multiplicative inverse `minv(i)` in Z/nZ if and only if it + * coprime with n. `minv(i)` is the number that `i * minv(i) == 1`. So in + * Z/nZ, the equation `ax = b` has solution `x = minv(a)*b` if a has + * multiplicative inverse. For example, in Z/15Z, `minv(2) = 8` because + * (2 * 8) % 15 = 1 + * + * stride has an multiplicative inverse if and only if stride coprime with + * n, that is, g := gcd(stride, n) == 1. In such case, the solution to our + * equation j * stride == 0 is j = minv(stride) * 0 = 0, that is: f(i) does + * not repeat, that is: there is no bank conflict. + */ + + int g = std::gcd(num_megabanks, row_stride_znz); + if (g == 1) { + return; // No need to swizzle in this case. + } + + /* For the case where stride does not coprime with n, we note that + * j * stride == 0 in Z/nZ is equivalent to (j * stride) % n = 0 in Z. We + * can write stride and n as: + * stride = s * g, n = m * g + * According to Theorem 4.13 in [The Mathematics of Integer Arithmetic], we + * have: + * (j * stride) % n = 0 + * ==> (j * s) % m * g = 0 + * ==> (j * s) % m = 0 + * which is equivalent to j * s == 0 in Z/mZ. Because s coprime with m, we + * further get: + * j == 0 (in Z/mZ) + * That is, j is a multiple of m in Z. So the smallest positive j that make + * the equation hold is n / g. + * + * That is: f(i) always repeat a pattern of n/g unique numbers g times. + * In other word: we are using n/g megabanks, and we have a g-way bank + * conflict. + * + * Let's use the word "pattern" to refer to the set of values of `f` at + * different `i`, that is: + * pattern k = { f(i; init=k) | i in Z/nZ } + * For the example of stride = 6 under Z/8Z, we have the following patterns + * f(i): 01234567 + * pattern 0: x_x_x_x_ + * pattern 1: _x_x_x_x + * (x => occupied, _ => unoccupied) + */ + + int repeated_pattern_size = num_megabanks / g; + + if (repeated_pattern_size >= ldmatrix_rows) { + return; // No need to swizzle in this case. + } + + /* Now we know that we have a g-way bank conflict. How do we remove this + * bank conflict? The answer is to mix the storage of different matrices. + * We first split the matrices along the row axis into g pieces, each piece + * has n/g rows. With this split, each piece occupies exactly one pattern. + * We want to use some non-traditional storage to let different pieces of + * the same matrix to occupy different patterns. + * + * Because Z/nZ has n items, each pattern has n/g different items, so we + * have in total g different patterns. We want to find the corresponding + * `init` values of these g different patterns. + * + * Consider two different init values `init1` and `init2`. When do they + * represent the same pattern? They represent the same pattern if and only + * if `f(0; init2)` falls on the pattern of `init1`, that is, there exist an + * i such that + * f(i; init1) == f(0; init2) + * which simplifies to + * init1 + i * stride == init2 + * ==> init2 - init1 == i * stride + * What values can `i * stride` be? It can be an arbitrary multiple of g: + * i * stride in Z/nZ is (i * stride) % n in Z. Let m = n/g, according to + * Theorem 4.13 in [The Mathematics of Integer Arithmetic] + * (i * stride) % n = (i * s) % m * g + * Because s coprime with m, we know that for an arbitrary value `j` in + * Z/mZ, we can take `i = minv(s) * j` to make `i * s == j`. + * + * That said, for init values that are off by a multiple of g they + * correspond to the same pattern, otherwise they belongs to different + * patterns. So, we can use + * init = 0, 1, ..., g - 1 + * to canonically represent g patterns. Let's call the above + * `init` values "pattern id". + * + * Now we have the idea about how to remove bank conflict: We can do an + * inner split of our row dimension by `repeated_pattern_size` to get + * (repeat, pattern), then different indices of the "repeat" dimension will + * be using the same megabank, and different indices of the "pattern" + * dimension will be using different megabank. We don't need to touch the + * "pattern" dimension, but we need to play with the "repeat" dimension to + * interleave it with matrice ids so that each matrix is distributed across + * different banks. + * + * For example, if we have repeated_pattern_size = 4, we would want to do + * something like below: + * +----------+----------+ + * 0| | | + * 1| matrix 0 | matrix 1 | + * 2| | | + * 3| | | + * +----------+----------+ + * 4| | | + * 5| matrix 1 | matrix 0 | + * 6| | | + * 7| | | + * +----------+----------+ + */ + + // -2 -1 + // [row, col] + TORCH_INTERNAL_ASSERT( + tile_size_x % ldmatrix_rows == 0, "Partial matrices not supported"); + shared_mem_tv->split(-2, ldmatrix_rows); + TORCH_INTERNAL_ASSERT( + tile_size_y % ldmatrix_cols == 0, "Partial matrices not supported"); + shared_mem_tv->split(-1, ldmatrix_cols); + // -4 -3 -2 -1 + // [matrix id, matrix, matrix id, matrix] + TORCH_INTERNAL_ASSERT( + ldmatrix_rows % repeated_pattern_size == 0, + "ldmatrix_rows is assumed to be a multiple of repeated_pattern_size"); + shared_mem_tv->split(-3, repeated_pattern_size); + // -5 -4 -3 -2 -1 + // [matrix id, repeat, pattern, matrix id, matrix] + int swizzle_period = ldmatrix_rows / repeated_pattern_size; + TORCH_INTERNAL_ASSERT( + tile_size_y % (swizzle_period * ldmatrix_cols) == 0, + "need aperiodic swizzle config for tile size ", + tile_size_x, + "x", + tile_size_y, + "with units ", + ldmatrix_rows, + "x", + ldmatrix_cols); + shared_mem_tv->split(-2, swizzle_period); + // -6 -5 -4 -3 -2 -1 + // [matrix id, repeat, pattern, matrix id outer, pattern id, matrix] + // swizzle repeat with pattern id to make repeat no longer repeat + if (isPowOf2(swizzle_period)) { + shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2); + } else { + shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2); + } + + // Merge back the tile for subsequent vectorization scheduling + // TODO: could potentially simplify away the merges + shared_mem_tv->merge(-6); + shared_mem_tv->merge(-5); + shared_mem_tv->merge(-3); + shared_mem_tv->merge(-2); + } else if (isVolta(params.mma_op)) { + // TODO: Volta is slightly more complex, and a fixed recipe would + // not scale. In a follow up this would be inferred from the mma + // macro layout themselves as we already have them registered in + // the utils. + return; + } else { + TORCH_INTERNAL_ASSERT(false, "Prolog swizzle: unsupported mma macro"); + } +} + +//! Generates the prolog schedule on the shared memory buffer +//! tensor. The scheduling performs two steps: +//! +//! 1. Swizzled the shared mem data layout. +//! 2. Coalesce and vectorize the read write schedule. +void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { + shared_mem_tv->setMemoryType(MemoryType::Shared); + + mma_utils::orderTiledConcreteIdAsRoot(shared_mem_tv); + + // Swizzle the shared memory data layout + prologSwizzle(shared_mem_tv, params); + + // Assuming we are always vectorizing smem write by 128b at the moment: + // TODO: would need a data-type and alignment dependent interface + // to support non-vectorizable shapes. + // The vectorizable width logic would be in a separate PR as the + // current effort tries to focus on generating swizzles. + shared_mem_tv->merge(-2); + mma_utils::scheduleContiguousVectorLoad( + shared_mem_tv, params.tile_sizes, 8, true); + + // Propagate prolog tensors + // propagate up the DAG, and propagate parallel type. + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + shared_mem_tv, + -1, + {}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); +} + } // namespace void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { @@ -235,33 +621,10 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { cc, -1, {acw_smem, bcw_smem}, {c}); // Schedule prolog: - // TODO: this section goes to a separate matmul util, - // and needs more configurability. + // TODO: this section needs more configurability. // ------------------------------------------------------------------ - mma_utils::orderTiledConcreteIdAsRoot(acw_smem); - // [... M, K] - acw_smem->merge(-2); - mma_utils::scheduleContiguousVectorLoad(acw_smem, gemm_tile, 8, false); - - // [... N, K] - mma_utils::orderTiledConcreteIdAsRoot(bcw_smem); - bcw_smem->merge(-2); - mma_utils::scheduleContiguousVectorLoad(bcw_smem, gemm_tile, 8, false); - - // Propagate prolog tensors - // propagate up the DAG, and propagate parallel type. - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - acw_smem, - -1, - {a}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType()); - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - bcw_smem, - -1, - {b}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType()); + scheduleProlog(acw_smem, params); + scheduleProlog(bcw_smem, params); // Set computeAt, setup the loop nesting structure on the kernel. // TODO: this section goes to a separate matmul util, @@ -310,19 +673,12 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { cc->applyMmaSwizzle( mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // Set memory type: - acw_smem->setMemoryType(MemoryType::Shared); - bcw_smem->setMemoryType(MemoryType::Shared); - // Set parallelization: // TODO: this section goes to a separate matmul util, // and needs more configurability. // ------------------------------------------------------------------ // Vectorize smem stores/loads: - acw_smem->axis(-1)->parallelize(ParallelType::Vectorize); - bcw_smem->axis(-1)->parallelize(ParallelType::Vectorize); - acr->axis(-1)->parallelize(ParallelType::Vectorize); bcr->axis(-1)->parallelize(ParallelType::Vectorize); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 30c5453810c..a43560ba135 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1618,7 +1618,7 @@ bool isFakeBoundaryTensorview( //! transform to by BoundedDirectionalTransformPropagator. std::unordered_set getDirectionalPropagatePathSet( TensorView* from_tv, - std::vector boundary_tvs, + const std::vector& boundary_tvs, BoundedDirectionalTransformPropagator::Options options, PropagateDirection direction) { // Prepare to collect all candidate tensorviews @@ -1730,9 +1730,9 @@ void BoundedDirectionalTransformPropagator::backward( if (!options.has_value()) { options = Options(); } - TORCH_INTERNAL_ASSERT( - !to.empty(), - "Propagation needs to be bounded, so no support for empty boundary."); + if (to.empty()) { + to = ir_utils::inputTvsOf(from); + } // Collect all tvs to included on the backward path as specified // by boundary and options. diff --git a/csrc/type.cpp b/csrc/type.cpp index 9637fad1f97..cb6e6dd7810 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -1108,34 +1108,7 @@ size_t dataTypeSize(DataType type) { [](auto&& dtype) -> size_t { using T = std::decay_t; if constexpr (std::is_same_v) { - switch (dtype) { - case DataType::Bool: - return sizeof(bool); - case DataType::ComplexDouble: - return sizeof(std::complex); - case DataType::ComplexFloat: - return sizeof(std::complex); - case DataType::Double: - return sizeof(double); - case DataType::Float: - return sizeof(float); - case DataType::Half: - return sizeof(at::Half); - case DataType::BFloat16: - return sizeof(at::BFloat16); - case DataType::Index: - TORCH_INTERNAL_ASSERT( - false, - "The actual type of Index is only known at compile time."); - case DataType::Int: - return sizeof(uint64_t); - case DataType::Int32: - return sizeof(uint32_t); - case DataType::SMemAddress: - return sizeof(unsigned); - default: - TORCH_INTERNAL_ASSERT(false, "Size undefined for data type."); - } + return primDataTypeSize(dtype); } else if constexpr (std::is_same_v) { return sizeof(void*); } else if constexpr (std::is_same_v) { diff --git a/csrc/type.h b/csrc/type.h index b0104bf0b7a..0d4facbcc4c 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -704,6 +704,36 @@ TORCH_CUDA_CU_API const char* load_store_type2string(LoadStoreOpType t); TORCH_CUDA_CU_API c10::optional cast_func_str( const std::pair&); +constexpr inline size_t primDataTypeSize(PrimDataType type) { + switch (type) { + case DataType::Bool: + return sizeof(bool); + case DataType::ComplexDouble: + return sizeof(std::complex); + case DataType::ComplexFloat: + return sizeof(std::complex); + case DataType::Double: + return sizeof(double); + case DataType::Float: + return sizeof(float); + case DataType::Half: + return sizeof(at::Half); + case DataType::BFloat16: + return sizeof(at::BFloat16); + case DataType::Index: + TORCH_INTERNAL_ASSERT( + false, "The actual type of Index is only known at compile time."); + case DataType::Int: + return sizeof(uint64_t); + case DataType::Int32: + return sizeof(uint32_t); + case DataType::SMemAddress: + return sizeof(unsigned); + default: + TORCH_INTERNAL_ASSERT(false, "Size undefined for data type."); + } +} + TORCH_CUDA_CU_API size_t dataTypeSize(DataType type); // If the index type is known it will be automatically used here diff --git a/test/test_gpu_matmul_sass.cpp b/test/test_gpu_matmul_sass.cpp index ff5370a3831..3eaed52f61e 100644 --- a/test/test_gpu_matmul_sass.cpp +++ b/test/test_gpu_matmul_sass.cpp @@ -69,6 +69,8 @@ sass::Container getSASSFor( params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); + fusion.printTransforms(); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -245,6 +247,20 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSASSModifiersCheck_CUDA) { } } +#if 0 + +TODO: With swizzle, the cuda code looks like: + +#pragma unroll +for(nvfuser_index_t i507 = 0; i507 < 8; ++i507) { + int i18439; + i18439 = i18438 + i507; + Turing::ldMatrixT (*reinterpret_cast*>(&T9[(4 * i507)]),((i18437 + (128 * (i18439 / 8))) + (16 * (i6455 ^ (i18439 % 8))))); +} + +where i6455 = (((nvfuser_index_t)threadIdx.x) % 16) % 8 so it no longer make sense to require the memory access pattern below. +We need to reinvestigate the test below to determine whether to change it or delete it. + // Check that all LDSM instructions has the following pattern: // LDSM.16.M88.2 R2, [R213] ; // LDSM.16.M88.2 R136, [R213+0x200] ; @@ -317,5 +333,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSASSRegisterUsageLDSM_CUDA) { } } } +#endif } // namespace nvfuser diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 9807980c312..68fbc0675b6 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -318,6 +318,9 @@ TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); + // prologSwizzle on Volta is not supported yet + // ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -367,6 +370,9 @@ TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) { params.double_buffer_options.double_buffer_smem_read = true; scheduleMatmul(&fusion, params); + // prologSwizzle on Volta is not supported yet + // ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -651,6 +657,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -706,6 +714,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { params.double_buffer_options.smem_double_buffer_stage = stage; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -772,6 +782,8 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -879,6 +891,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { params.double_buffer_options.double_buffer_smem_read = true; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -1802,6 +1816,8 @@ TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2818,6 +2834,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2866,6 +2884,8 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2921,6 +2941,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2986,6 +3008,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3045,6 +3069,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3097,6 +3123,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout);