Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,24 @@ TVM_DLL const Op& start_profile_intrinsic();
*/
TVM_DLL const Op& end_profile_intrinsic();

// Intrinsics for the VK_NV_cooperative_matrix Vulkan extension.

/*! \brief The intrinsic corresponding to the OpCooperativeMatrixLoadNV instruction. */
TVM_DLL const Op& cooperative_matrix_load_NV();

/*! \brief The intrinsic corresponding to the OpCooperativeMatrixStoreNV instruction. */
TVM_DLL const Op& cooperative_matrix_store_NV();

/*!
* \brief Create a new cooperative matrix filled with the provided value.
*
* There is no such instruction in the extension, but it is added for convenience.
*/
TVM_DLL const Op& cooperative_matrix_fill_NV();

/*! \brief The intrinsic corresponding to the OpCooperativeMatrixMulAddNV instruction. */
TVM_DLL const Op& cooperative_matrix_mad_NV();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,10 @@ def wrapped(*args, **kwargs):
TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
cooperative_matrix_load_NV = _op_wrapper(_tir_op.cooperative_matrix_load_NV)
cooperative_matrix_store_NV = _op_wrapper(_tir_op.cooperative_matrix_store_NV)
cooperative_matrix_fill_NV = _op_wrapper(_tir_op.cooperative_matrix_fill_NV)
cooperative_matrix_mad_NV = _op_wrapper(_tir_op.cooperative_matrix_mad_NV)


def _dtype_forward(func):
Expand Down Expand Up @@ -2144,4 +2148,8 @@ def wrapped(*args, **kwargs):
"IterVar",
"CommReducer",
"Range",
"cooperative_matrix_load_NV",
"cooperative_matrix_store_NV",
"cooperative_matrix_fill_NV",
"cooperative_matrix_mad_NV",
]
142 changes: 142 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,6 +3037,148 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr):
return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)


# Intrinsics for the VK_NV_cooperative_matrix Vulkan extension.


def cooperative_matrix_load_NV(buffer_mat, offset, src, rows, cols, stride, column_major):
"""The intrinsic corresponding to the OpCooperativeMatrixLoadNV instruction.

Parameters
----------
buffer_mat: Var
The destination buffer with "cooperative_matrix_nv" scope.

offset: IntImm
The element offset for the matrix to be loaded in the destination buffer.

src: Expr
The source pointer expression.

rows: int
The number of rows in the matrix.

cols: int
The number of columns in the matrix.

stride: Expr
The stride of the matrix in the number of elements.

column_major: bool
Whether the matrix elements are stored in the column-major order.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.cooperative_matrix_load_NV",
buffer_mat,
offset,
src,
rows,
cols,
stride,
column_major,
)


def cooperative_matrix_store_NV(dst, buffer_mat, offset, stride, column_major):
"""The intrinsic corresponding to the OpCooperativeMatrixStoreNV instruction.

Parameters
----------
dst: Expr
The destination pointer expression.

buffer_mat: Var
The source buffer with "cooperative_matrix_nv" scope.

offset: IntImm
The element offset for the matrix to be stored in the destination buffer.

stride: Expr
The stride of the matrix in the number of elements.

column_major: bool
Whether the matrix elements are stored in the column-major order.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle", "tir.cooperative_matrix_store_NV", dst, buffer_mat, offset, stride, column_major
)


def cooperative_matrix_fill_NV(buffer_mat, offset, rows, cols, value):
"""Create a new cooperative matrix filled with the provided value.

There is no such instruction in the extension, but it is added for convenience.

Parameters
----------

buffer_mat: Var
The buffer with "cooperative_matrix_nv" scope to be filled.

offset: IntImm
The element offset for the matrix to be filled in `buffer_mat`.

rows: int
The number of rows in the matrix.

cols: int
The number of columns in the matrix.

value: Expr
The value the matrix will be filled with.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle", "tir.cooperative_matrix_fill_NV", buffer_mat, offset, rows, cols, value
)


def cooperative_matrix_mad_NV(A, A_off, B, B_off, C, C_off):
"""The intrinsic corresponding to the OpCooperativeMatrixMulAddNV instruction.

Parameters
----------

A : Var
The buffer with "cooperative_matrix_nv" scope corresponding to the "A" matrix.

A_off : IntImm
The element offset of the "A" matrix in the buffer `A`.

B : Var
The buffer with "cooperative_matrix_nv" scope corresponding to the "B" matrix.

B_off : IntImm
The element offset of the "B" matrix in the buffer `B`.

C : Var
The buffer with "cooperative_matrix_nv" scope corresponding to the "C" matrix.

C_off : IntImm
The element offset of the "C" matrix in the buffer `C`.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.cooperative_matrix_mad_NV", A, A_off, B, B_off, C, C_off)


# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ enum class StorageRank {
kTexture = 7,
/*! \brief global scope amx tmm memory */
kAMXTMM = 8,
/*!
* \brief Scope representing the cooperative matrix in the Vulkan
* VK_NV_cooperative_matrix extension.
*/
kCooperativeMatrixNV = 9,
};

/*!
Expand Down Expand Up @@ -154,6 +159,9 @@ struct StorageScope {
} else if (s.compare(0, 7, "amx.tmm") == 0) {
r.rank = StorageRank::kAMXTMM;
r.tag = s.substr(7, std::string::npos);
} else if (s == "cooperative_matrix_nv") {
r.rank = StorageRank::kCooperativeMatrixNV;
r.tag = "";
} else {
LOG(FATAL) << "unknown storage scope " << s;
}
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance,
!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION");

supports_integer_dot_product = device.HasExtension("VK_KHR_shader_integer_dot_product");
supports_cooperative_matrix_nv = device.HasExtension("VK_NV_cooperative_matrix");

// The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically
// needed, since it will be set so long at least one queue has
Expand Down Expand Up @@ -435,7 +436,8 @@ std::vector<const char*> VulkanDevice::SelectEnabledExtensions() const {
"VK_KHR_get_memory_requirements2",
"VK_KHR_dedicated_allocation",
"VK_KHR_spirv_1_4",
"VK_KHR_shader_integer_dot_product"};
"VK_KHR_shader_integer_dot_product",
"VK_NV_cooperative_matrix"};

uint32_t device_extension_prop_count;
VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr,
Expand Down
1 change: 1 addition & 0 deletions src/runtime/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct VulkanDeviceProperties {
bool supports_push_descriptor{false};
bool supports_dedicated_allocation{false};
bool supports_integer_dot_product{false};
bool supports_cooperative_matrix_nv{false};
uint32_t supported_subgroup_operations{0};
uint32_t max_num_threads{1};
uint32_t thread_warp_size{1};
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property,
*rv = prop.supports_integer_dot_product;
}

if (property == "supports_cooperative_matrix_nv") {
*rv = prop.supports_cooperative_matrix_nv;
}

if (property == "device_name") {
*rv = prop.device_name;
}
Expand Down
Loading