Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,11 @@ TVM_DLL Pass AnnotateUsedMemory();
*/
TVM_DLL Pass CapturePostDfsIndexInSpans();

/*!
* \brief Calls device dependent memory scope analysis pass, collects mapping of desirable
* expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope
*/
TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config);
} // namespace transform

/*!
Expand Down
19 changes: 14 additions & 5 deletions python/tvm/topi/adreno/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
add_pad,
bind_data_copy,
get_default_conv2d_config,
get_texture_storage,
)


Expand Down Expand Up @@ -214,8 +215,11 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
5d tensors
4. pad should be scheduled separately to create independent opencl kernel. If pad is
inlined into convolution, this gives 1.5x performance drop
5. We are using cache_read to produce texture and guarantee the best performance
on the next stage.
5. We are using cache_read for intermediate tensors to produce texture and guarantee
the best performance on the next stage.
The weights are managed through static texture planning mechanism and guarantied come
in texture memory scope.
Thus way we are calling cache_read only for data tensor
6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
for textures
For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
Expand Down Expand Up @@ -288,10 +292,15 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
s[output].compute_inline()

# create cache stage
AT = s.cache_read(pad_data, "global.texture", [conv])
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, "global.texture-weight", [conv])
bind_data_copy(s[WT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, fc, y, x, fb = s[latest_blocked].op.axis
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/topi/adreno/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ def schedule_conv2d_NHWC(cfg, s, output):
5d tensors
4. pad should be scheduled separately to create independent opencl kernel. If pad is
inlined into convolution, this gives 1.5x performance drop
5. We are using cache_read to produce texture and guarantee the best performance
on the next stage.
5. We are using cache_read for intermediate tensors to produce texture and guarantee
the best performance on the next stage.
The weights are managed through static texture planning mechanism and guarantied come
in texture memory scope.
Thus way we are calling cache_read only for data tensor
6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
for textures
For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
Expand Down Expand Up @@ -287,8 +290,13 @@ def schedule_conv2d_NHWC(cfg, s, output):
# create cache stage
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, y, x, fc, fb = s[latest_blocked].op.axis
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class RelayBuildModule : public runtime::ModuleNode {
relay_module = transform::Inline()(relay_module);
relay_module = transform::InferType()(relay_module);
relay_module = transform::LabelOps()(relay_module);
relay_module = transform::AnnotateMemoryScope(config_)(relay_module);

ICHECK(relay_module.defined());

Expand Down
Loading