Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class HardwareParamsNode : public Object {
// GPU related parameters got from device query API
/*! \brief The max shared memory per block in bytes. */
int max_shared_memory_per_block;
/*! \brief The max number of register per block. */
int max_registers_per_block;
/*! \brief The max local memory per block in bytes. */
int max_local_memory_per_block;
/*! \brief The max number of threads per block. */
int max_threads_per_block;
/*! \brief The max vthread extent. */
Expand All @@ -60,7 +60,7 @@ class HardwareParamsNode : public Object {
v->Visit("vector_unit_bytes", &vector_unit_bytes);
v->Visit("cache_line_bytes", &cache_line_bytes);
v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block);
v->Visit("max_registers_per_block", &max_registers_per_block);
v->Visit("max_local_memory_per_block", &max_local_memory_per_block);
v->Visit("max_threads_per_block", &max_threads_per_block);
v->Visit("max_vthread_extent", &max_vthread_extent);
v->Visit("warp_size", &warp_size);
Expand Down Expand Up @@ -90,13 +90,13 @@ class HardwareParams : public ObjectRef {
* \param vector_unit_bytes The width of vector units in bytes.
* \param cache_line_bytes The size of cache line in bytes.
* \param max_shared_memory_per_block The max amount of shared memory per block for GPU.
* \param max_registers_per_block The max number of registers per block for GPU.
* \param max_local_memory_per_block The max amount of local memory per block for GPU.
* \param max_threads_per_block The max number of threads per block for GPU.
* \param max_vthread_extent The max extent of vthread for GPU.
* \param warp_size The warp size for GPU
*/
HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_shared_memory_per_block, int max_local_memory_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size);

TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode);
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class HardwareParams(Object):
The size of cache line in bytes.
max_shared_memory_per_block : int
The max shared memory per block in bytes.
max_registers_per_block : int
The max number of register per block.
max_local_memory_per_block : int
The max local memory per block in bytes.
max_threads_per_block : int
The max number of threads per block.
max_vthread_extent : int
Expand All @@ -65,7 +65,7 @@ def __init__(
vector_unit_bytes,
cache_line_bytes,
max_shared_memory_per_block,
max_registers_per_block,
max_local_memory_per_block,
max_threads_per_block,
max_vthread_extent,
warp_size,
Expand All @@ -76,7 +76,7 @@ def __init__(
vector_unit_bytes,
cache_line_bytes,
max_shared_memory_per_block,
max_registers_per_block,
max_local_memory_per_block,
max_threads_per_block,
max_vthread_extent,
warp_size,
Expand Down
2 changes: 1 addition & 1 deletion src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
pass_list.push_back(tir::transform::Simplify());
tvm::Map<String, tvm::PrimExpr> gpu_params{
{"max_shared_memory_per_block", task->hardware_params->max_shared_memory_per_block},
{"max_local_memory_per_block", task->hardware_params->max_registers_per_block},
{"max_local_memory_per_block", task->hardware_params->max_local_memory_per_block},
{"max_threads_per_block", task->hardware_params->max_threads_per_block},
{"max_vector_bytes", task->hardware_params->vector_unit_bytes},
{"max_vthread", task->hardware_params->max_vthread_extent},
Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/measure_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct Handler<::tvm::auto_scheduler::HardwareParamsNode> {
writer->WriteArrayItem(data.vector_unit_bytes);
writer->WriteArrayItem(data.cache_line_bytes);
writer->WriteArrayItem(data.max_shared_memory_per_block);
writer->WriteArrayItem(data.max_registers_per_block);
writer->WriteArrayItem(data.max_local_memory_per_block);
writer->WriteArrayItem(data.max_threads_per_block);
writer->WriteArrayItem(data.max_vthread_extent);
writer->WriteArrayItem(data.warp_size);
Expand All @@ -140,7 +140,7 @@ struct Handler<::tvm::auto_scheduler::HardwareParamsNode> {
reader->Read(&data->max_shared_memory_per_block);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->max_registers_per_block);
reader->Read(&data->max_local_memory_per_block);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->max_threads_per_block);
Expand Down
19 changes: 10 additions & 9 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ TVM_REGISTER_NODE_TYPE(HardwareParamsNode);
TVM_REGISTER_NODE_TYPE(SearchTaskNode);

HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_shared_memory_per_block, int max_local_memory_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size) {
auto node = make_object<HardwareParamsNode>();
node->num_cores = num_cores;
node->vector_unit_bytes = vector_unit_bytes;
node->cache_line_bytes = cache_line_bytes;
node->max_shared_memory_per_block = max_shared_memory_per_block;
node->max_registers_per_block = max_registers_per_block;
node->max_local_memory_per_block = max_local_memory_per_block;
node->max_threads_per_block = max_threads_per_block;
node->max_vthread_extent = max_vthread_extent;
node->warp_size = warp_size;
Expand All @@ -64,8 +64,9 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
int max_shared_memory_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret);
int max_registers_per_block = ret;
// There is no explicit local memory limition in CUDA runtime,
// so we can use INT32_MAX to disalbe the check on local_memory.
int max_local_memory_per_block = INT32_MAX;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment as the PR description.


device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
int max_threads_per_block = ret;
Expand All @@ -74,17 +75,17 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
int warp_size = ret;

int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block,
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else if (target->kind->device_type == kDLMetal) {
// Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
// This setting looks working for Metal GPUs later than A10
int max_shared_memory_per_block = 32 * 1024;
int max_registers_per_block = 4 * 1024;
int max_local_memory_per_block = INT32_MAX; // skip the check on local memory
int max_threads_per_block = 1024;
int warp_size = 8;
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block,
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
Expand All @@ -110,10 +111,10 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe

TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
.set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_shared_memory_per_block, int max_local_memory_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size) {
return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes,
max_shared_memory_per_block, max_registers_per_block,
max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
});

Expand Down