Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ endif ()
set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR})
set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc")
if (PROJECT_IS_TOP_LEVEL)
message(STATUS "top-level build")
find_package(Torch REQUIRED)
find_package(Python REQUIRED Development Interpreter)
find_package(pybind11 REQUIRED)
# need this since the pytorch execution uses a different name
set(PYTHON_EXECUTABLE ${Python_EXECUTABLE})
set(ATEN_CUDA_ROOT "${TORCH_INSTALL_PREFIX}/include/ATen")
# CXX flags is necessary since https://github.com/pytorch/pytorch/issues/98093
string(APPEND CMAKE_CXX_FLAGS ${TORCH_CXX_FLAGS})
if(BUILD_NVFUSER_BENCHMARK)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/googletest)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/benchmark)
Expand Down Expand Up @@ -151,10 +152,11 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp
${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp
${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp
${NVFUSER_SRCS_DIR}/scheduler/matmul_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp
${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry.cpp
${NVFUSER_SRCS_DIR}/scheduler/utils.cpp
Expand Down Expand Up @@ -198,8 +200,6 @@ target_include_directories(${NVFUSER_CODEGEN} PUBLIC
)
set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 17)

message(STATUS "install_lib_dir: ${TORCH_INSTALL_LIB_DIR}")
message(STATUS "cmake_install_includedir: ${CMAKE_INSTALL_INCLUDEDIR}")
if (PROJECT_IS_TOP_LEVEL)
target_link_libraries(${NVFUSER_CODEGEN} PRIVATE torch ${TORCH_LIBRARIES})
# TODO: setup header and lib installation
Expand Down Expand Up @@ -272,7 +272,6 @@ if(BUILD_PYTHON)
target_compile_definitions(${NVFUSER} PRIVATE EXTENSION_NAME=_C)

if (PROJECT_IS_TOP_LEVEL)
message(STATUS "skipping the rest of the libs")
target_compile_options(${NVFUSER} PRIVATE -Wall -Wno-unused-function)
target_link_libraries(${NVFUSER} PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(${NVFUSER} PRIVATE "${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so")
Expand Down Expand Up @@ -359,7 +358,6 @@ if(BUILD_TEST)
target_include_directories(${NVFUSER_TESTS} PRIVATE "${NVFUSER_ROOT}")

if (PROJECT_IS_TOP_LEVEL)
message(STATUS "skipping the rest of the libs")
target_compile_options(${NVFUSER_TESTS} PRIVATE -Wall -Wno-unused-function)
target_link_libraries(${NVFUSER_TESTS} PRIVATE ${TORCH_LIBRARIES})
else() # PROJECT_IS_TOP_LEVEL
Expand Down
18 changes: 8 additions & 10 deletions benchmark/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ops/all_ops.h>
#include <scheduler/all_schedulers.h>
#include <scheduler/matmul.h>
#include <scheduler/matmul_heuristic.h>

#include <benchmark/benchmark.h>

Expand Down Expand Up @@ -128,7 +129,7 @@ std::pair<at::Tensor, at::Tensor> fp16MatmulAtInput(

// TODO: separate compute and schedule definition once the can schedule
// logic and pattern matching is ready.
void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParam params) {
void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParams params) {
// Only hgemm on the initial setup
auto a = makeContigTensor(2, DataType::Half);
auto b = makeContigTensor(2, DataType::Half);
Expand All @@ -139,7 +140,7 @@ void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParam params) {
fusion->addInput(b);
fusion->addOutput(c);

scheduleMatmul(c, a, b, params);
scheduleMatmul(fusion, params);
}

void checkMatch(at::Tensor expect, at::Tensor result, int64_t k) {
Expand Down Expand Up @@ -197,7 +198,7 @@ void checkMatch(at::Tensor expect, at::Tensor result, int64_t k) {
static void SingleMatmulBase(
benchmark::State& benchmark_state,
MatmulLayout layout,
MatmulParam params) {
MatmulParams params) {
std::vector<int64_t> input_mnk{
benchmark_state.range(0),
benchmark_state.range(1),
Expand Down Expand Up @@ -288,7 +289,7 @@ size_t getSmemSize(GemmTile cta_tile, int stage_number) {
}

// TODO: this part eventually will be automated by heuristics
MatmulParam getMatmulParams(
MatmulParams getMatmulParams(
GemmTile cta_tile,
int stage_number,
MatmulLayout layout) {
Expand All @@ -298,12 +299,9 @@ MatmulParam getMatmulParams(
gemm_tile.warp_tile = GemmTile(64, 64, cta_tile.k);
gemm_tile.instruction_tile = GemmTile(16, 16, 16);

// Collect mma swizzle info
auto mma_builder =
MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
.layout(layout);

MatmulParam params(mma_builder);
MatmulParams params;
params.mma_op = MmaOptions::MacroType::Ampere_16_16_16;
params.layout = layout;
params.tile_sizes = gemm_tile;
params.async_gmem_load_operands = true;
params.double_buffer_options.double_buffer_smem_write = true;
Expand Down
11 changes: 11 additions & 0 deletions csrc/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,17 @@ MmaOp::MmaOp(
in_b->getValType().value() == ValType::TensorIndex,
in_b->getValType().value());

const auto isBroadcastIn = [](const Val* val) {
if (val->getValType().value() == ValType::TensorView) {
const auto* tv = val->as<TensorView>();
return tv->hasBroadcast();
}
return true;
};

TORCH_INTERNAL_ASSERT(isBroadcastIn(in_a));
TORCH_INTERNAL_ASSERT(isBroadcastIn(in_b));
Comment on lines +1354 to +1363
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for the future:

I think we should not only check the input has a broadcast, but also to add many more consistency check. For this PR, I think it is fine, but let's create a separate followup PR for it:

I think we should check:

  • Inputs has two concrete IDs and one broadcast ID
  • The broadcast's axis in different inputs are different
  • Output has two concrete IDs and one reduction ID
  • The axis of the output reduction ID must correspond to a concrete ID in both inputs

And with this check here, we can remove corresponding checks in the scheduler.


addOutput(out);
addInput(in_a);
addInput(in_b);
Expand Down
11 changes: 11 additions & 0 deletions csrc/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,17 @@ std::vector<SelectOp*> getSelectOps(Fusion* fusion) {
return select_ops;
}

std::vector<MmaOp*> getMmaOps(Fusion* fusion) {
std::vector<MmaOp*> mma_ops;
for (auto expr : fusion->exprs()) {
if (expr->isA<MmaOp>()) {
mma_ops.push_back(expr->as<MmaOp>());
}
}

return mma_ops;
}

namespace {

class ValReplacementMutator : private OptOutMutator {
Expand Down
2 changes: 2 additions & 0 deletions csrc/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ TORCH_CUDA_CU_API std::vector<IndexSelectOp*> getIndexSelectOps(Fusion* fusion);

TORCH_CUDA_CU_API std::vector<TorchGatherOp*> getTorchGatherOps(Fusion* fusion);

TORCH_CUDA_CU_API std::vector<MmaOp*> getMmaOps(Fusion* fusion);

TORCH_CUDA_CU_API std::vector<SelectOp*> getSelectOps(Fusion* fusion);

// Returns the initialization value of tv or nullptr if not initialized.
Expand Down
16 changes: 8 additions & 8 deletions csrc/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
// registered outputs.
void traverse(Fusion* fusion);

// Same as traverse put it traverses every edge, meaning it will traverse
// Same as traverse but it traverses every edge, meaning it will traverse
// values more than once.
void traverseAllPaths(Fusion* fusion);

Expand All @@ -147,8 +147,8 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
};

/*
* Backward visitor IterVisitor calls handle in reverse order from outputs
* to inputs It would be really nice to unify this with IterVisitor, however,
* Backward visitor calls handle in reverse order from outputs to inputs.
* It would be really nice to unify this with IterVisitor, however,
* the challenge there is that we specify traversal from outputs towards inputs
* because it implicitly provides DCE. However, if users are not careful, they
* could miss necessary outputs to do a backward traversal.
Expand All @@ -163,10 +163,10 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
* outputs of some exprs, example being the `N` output of welford ops.
* `must_cover_all_expr_outputs` is added to disable the check, and in
* this case the visitor pass need be aware
* 1. Exprs with any output that has a use chain that ends with a final
* consumer in the `from` list `will be` visited.
* 2. Vals that doesn't have a use chain that ends with a final
* consumer in the `from` list `will not be` visited, even though its
* 1. Exprs in the `from` list with any output that has a use chain that
* ends with a final consumer `will be` visited.
* 2. Vals in the `from` list that doesn't have a use chain that ends with
* a final consumer `will not be` visited, even though its
* definition expr might be visited. An example is if the `N` output
* of an welford op is unused, but other outputs are, the welford op
* will be visited but the `N` output will not.
Expand Down Expand Up @@ -302,7 +302,7 @@ class StmtSort : public IterVisitor {
bool traverse_members = false,
bool traverse_attributes = false);

// Returns ordered Statements required to produce from, including from.
// Returns ordered Statements required to produce 'to', including 'to'.
static std::vector<Statement*> getStmts(
Fusion* fusion,
const std::vector<Val*>& to,
Expand Down
83 changes: 77 additions & 6 deletions csrc/mma_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <fusion.h>
#include <ir_all_nodes.h>
#include <mma_type.h>
#include <functional>

namespace nvfuser {

Expand Down Expand Up @@ -124,7 +125,8 @@ bool isTuring(MmaOptions::MacroType macro) {
}

bool isAmpere(MmaOptions::MacroType macro) {
return macro == MmaOptions::MacroType::Ampere_16_8_16 ||
return macro == MmaOptions::MacroType::Ampere_16_8_8 ||
macro == MmaOptions::MacroType::Ampere_16_8_16 ||
macro == MmaOptions::MacroType::Ampere_16_16_16;
}

Expand All @@ -134,11 +136,9 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Ampere_16_16_16:
case MmaOptions::MacroType::Turing_16_16_16:
return 8;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
break;
default:
TORCH_INTERNAL_ASSERT(false, "unknown macro");
break;
Expand All @@ -150,13 +150,11 @@ int getInputARegisterSize(MmaOptions::MacroType macro) {
switch (macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Ampere_16_16_16:
return 8;
break;
default:
TORCH_INTERNAL_ASSERT(false, "unknown macro");
break;
Expand All @@ -168,7 +166,6 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) {
switch (macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
Expand Down Expand Up @@ -196,6 +193,25 @@ bool isOperandTransposed(MmaOptions options) {
return false;
}

GemmTile getMmaOpShape(MmaOptions::MacroType macro) {
switch (macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return {16, 16, 4};
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return {16, 8, 16};
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
return {16, 16, 16};
case MmaOptions::MacroType::Ampere_16_8_8:
return {16, 8, 8};
case MmaOptions::MacroType::NoMMA:
return {1, 1, 1};
}

TORCH_INTERNAL_ASSERT(false, "unknown MMA macro");
}

std::string toString(MmaOptions::MmaInputLayout input_layout) {
std::stringstream ss;
switch (input_layout) {
Expand Down Expand Up @@ -238,4 +254,59 @@ std::string toString(MmaOptions::MacroType mt) {
return ss.str();
}

std::string toString(const GemmTile& tile) {
std::stringstream ss;
ss << "[" << tile.m << ", " << tile.n << ", " << tile.k << "]";
return ss.str();
}

std::string toString(const MatMulTileOptions& opts) {
std::stringstream ss;
ss << "MatMulTileOptions: "
<< "instruction tile " << toString(opts.instruction_tile) << ", "
<< "warp tile " << toString(opts.warp_tile) << ", "
<< "CTA tile " << toString(opts.cta_tile);
return ss.str();
}

std::string toString(MmaOptions::MacroType mt, bool) {
switch (mt) {
case MmaOptions::MacroType::Ampere_16_8_8:
return "Ampere_16_8_8";
case MmaOptions::MacroType::Ampere_16_8_16:
return "Ampere_16_8_16";
case MmaOptions::MacroType::Ampere_16_16_16:
return "Ampere_16_16_16";
case MmaOptions::MacroType::NoMMA:
return "NoOp";
case MmaOptions::MacroType::Turing_16_8_16:
return "Turing_16_8_16";
case MmaOptions::MacroType::Turing_16_16_16:
return "Turing_16_16_16";
case MmaOptions::MacroType::Volta_16_16_4:
return "Volta_16_16_4";
}
TORCH_INTERNAL_ASSERT(false, "Unsupported mma type");
return "Unsupported";
}

size_t hash(MmaOptions::MacroType macro) {
return std::hash<size_t>{}(static_cast<size_t>(macro));
}

size_t hash(MmaOptions::MmaInputLayout input_layout) {
return std::hash<size_t>{}(static_cast<size_t>(input_layout));
}

size_t hash(const GemmTile& tile) {
return std::hash<size_t>{}(
(static_cast<size_t>(tile.m) << 32) +
(static_cast<size_t>(tile.n) << 16) + (static_cast<size_t>(tile.k)));
}

size_t hash(const MatMulTileOptions& opts) {
return (hash(opts.instruction_tile) << 0) ^ (hash(opts.warp_tile) << 1) ^
(hash(opts.cta_tile) << 2);
}

} // namespace nvfuser
20 changes: 15 additions & 5 deletions csrc/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ struct GemmTile {
int m, n, k;
GemmTile(int m_, int n_, int k_) : m(m_), n(n_), k(k_) {}

bool operator==(const GemmTile& other) {
bool operator==(const GemmTile& other) const {
return m == other.m && n == other.n && k == other.k;
}

GemmTile operator/(const GemmTile& other) {
GemmTile operator/(const GemmTile& other) const {
return GemmTile(m / other.m, n / other.n, k / other.k);
}

std::vector<int> toVector() {
std::vector<int> toVector() const {
return {m, n, k};
}
};
Expand Down Expand Up @@ -186,9 +186,19 @@ int getOutputRegisterSize(MmaOptions::MacroType macro);
int getInputARegisterSize(MmaOptions::MacroType macro);
int getInputBRegisterSize(MmaOptions::MacroType macro);

// Unpack MMA op shape
GemmTile getMmaOpShape(MmaOptions::MacroType macro);

// MMA stringify utils
std::string toString(MmaOptions::MacroType macro);
std::string toString(MmaOptions::MmaInputLayout input_layout);
std::string toString(MmaOptions::MacroType mt);

std::string toString(const GemmTile& tile);
std::string toString(const MatMulTileOptions& opts);
std::string toString(MmaOptions::MacroType macro, bool);

// MMA hash utils
size_t hash(MmaOptions::MacroType macro);
size_t hash(MmaOptions::MmaInputLayout input_layout);
size_t hash(const GemmTile& tile);
size_t hash(const MatMulTileOptions& opts);
} // namespace nvfuser
4 changes: 3 additions & 1 deletion csrc/scheduler/all_schedulers.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on
#pragma once
#include <scheduler/matmul.h>
#include <scheduler/normalization.h>
#include <scheduler/pointwise.h>
#include <scheduler/reduction.h>
Expand All @@ -19,7 +20,8 @@ enum class TORCH_CUDA_CU_API ScheduleHeuristic {
PointWise,
Reduction,
Persistent,
Transpose
Transpose,
Matmul
};

} // namespace nvfuser
Loading