Skip to content
5 changes: 3 additions & 2 deletions python/tvm/topi/adreno/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,9 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt:
if not autotvm.GLOBAL_SCOPE.in_tuning:
bind_data_copy(s[kernel])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
if kernel.shape[2] == 1 and kernel.shape[3] == 1:
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/adreno/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ def schedule_conv2d_NHWC(cfg, s, output):
if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt:
if not autotvm.GLOBAL_SCOPE.in_tuning:
bind_data_copy(s[kernel])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
if kernel.shape[0] == 1 and kernel.shape[1] == 1:
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/adreno/conv2d_winograd_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ def schedule_conv2d_winograd(cfg, s, output, pre_computed):
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
and kernel.shape[2] == 1
and kernel.shape[3] == 1
):
BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL])
bind_data_copy(s[BB])
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/adreno/depthwise_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output):
# create cache stage for tuning only or in case of 4d case
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
if kernel.shape[2] == 1 and kernel.shape[3] == 1:
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, fc, y, x, fb = s[latest_blocked].op.axis
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ def schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, output):
# create cache stage for tuning only or in case of 4d case
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
if kernel.shape[0] == 1 and kernel.shape[1] == 1:
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])

# tile and bind spatial axes
n, y, x, fc, fb = s[latest_blocked].op.axis
Expand Down
78 changes: 62 additions & 16 deletions src/relay/transforms/annotate_texture_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

#include <memory>
#include <unordered_map>
#include <unordered_set>

#include "../op/memory/device_copy.h"
#include "../op/memory/memory.h"
Expand Down Expand Up @@ -90,22 +91,28 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
for (const auto& a : storage_info.args_to_vars_) {
if (storage_map.count(a.first)) {
for (const auto& v : a.second) {
storage_map.Set(v, storage_map[a.first]);
if (storage_map[a.first][Expr()][0] == "global" &&
storage_info.accept_textures_.count(v)) {
if (storage_info.buffers_params.find(v) != storage_info.buffers_params.end()) {
Map<Expr, Array<String>> ent;
ent.Set(Expr(), storage_info.accept_textures_[v][Expr()]);
ent.Set(Expr(), Array<String>{"global"});
storage_map.Set(v, ent);
for (const auto& calls : storage_info.accept_textures_[v]) {
if (calls.first != Expr()) {
if (storage_map.count(a.first)) {
Map<Expr, Array<String>> ent_call = storage_map[a.first];
ent_call.Set(calls.first, calls.second);
storage_map.Set(a.first, ent_call);
} else {
Map<Expr, Array<String>> ent_call;
ent_call.Set(calls.first, calls.second);
storage_map.Set(a.first, ent_call);
} else {
storage_map.Set(v, storage_map[a.first]);
if (storage_map[a.first][Expr()][0] == "global" &&
storage_info.accept_textures_.count(v)) {
Map<Expr, Array<String>> ent;
ent.Set(Expr(), storage_info.accept_textures_[v][Expr()]);
storage_map.Set(v, ent);
for (const auto& calls : storage_info.accept_textures_[v]) {
if (calls.first != Expr()) {
if (storage_map.count(a.first)) {
Map<Expr, Array<String>> ent_call = storage_map[a.first];
ent_call.Set(calls.first, calls.second);
storage_map.Set(a.first, ent_call);
} else {
Map<Expr, Array<String>> ent_call;
ent_call.Set(calls.first, calls.second);
storage_map.Set(a.first, ent_call);
}
}
}
}
Expand Down Expand Up @@ -160,11 +167,20 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
storage_scope_[call].push_back("global.texture");
}
}
const int weights_pos = 1;
for (size_t i = 0; i < fn->params.size(); i++) {
args_to_vars_[call->args[i]].push_back(fn->params[i]);
// adding info about arguments if they can be converted to texture
for (const auto& ttype : FlattenTupleType(fn->params[i]->checked_type())) {
std::string scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
if (expr_attrib.as<Conv2DAttrs>() || expr_attrib.as<Conv2DWinogradAttrs>()) {
if ((i == weights_pos) && !ttype->dtype.is_float16() &&
CanUseBuffers(call->args[i], ttype->shape, fn->attrs)) {
buffers_params.insert(fn->params[i]);
buffers_args.insert(call->args[i]);
scope = "global";
}
}
if (scope.find("global.texture") != std::string::npos) {
if (accept_textures_.count(fn->params[i])) {
Map<Expr, Array<String>> ent = accept_textures_[fn->params[i]];
Expand Down Expand Up @@ -193,13 +209,15 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
}
}
}

if (!primitive_supports_texture_) {
expr_attrib = call->attrs;
primitive_supports_texture_ = SupportsTextureStorage(call);
}

for (auto& arg : call->args) {
Visit(arg);
if (buffers_args.find(arg) == buffers_args.end()) {
Visit(arg);
}
}
// We have all callees filled into storage_scope_ if they support textures
// We need to verify if this call expects texture and if it does not, remove from
Expand Down Expand Up @@ -398,6 +416,28 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
return supports_texture_storage;
}

bool CanUseBuffers(const Expr param, const Array<PrimExpr> shape,
const tvm::DictAttrs param_attrs) const {
bool use_buffer = false;
if (param.as<ConstantNode>() && shape.size() == 5) {
auto kernel_layout = param_attrs.GetAttr<String>("kernel_layout");
if (kernel_layout == "HWOI4o" || kernel_layout == "HWIO4o") {
int a0 = shape[0].as<IntImmNode>()->value;
int a1 = shape[1].as<IntImmNode>()->value;
if (a0 != 1 && a1 != 1) {
use_buffer = true;
}
} else if (kernel_layout == "OIHW4o") {
int a2 = shape[2].as<IntImmNode>()->value;
int a3 = shape[3].as<IntImmNode>()->value;
if (a2 != 1 && a3 != 1) {
use_buffer = true;
}
}
}
return use_buffer;
}

/*! \brief Temporary state for marking whether a visited function
* primitive supports texture storage scope */
bool primitive_supports_texture_ = false;
Expand All @@ -409,6 +449,12 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
std::unordered_map<Expr, std::vector<Var>, ObjectPtrHash, ObjectPtrEqual> args_to_vars_;
/*! \brief mapping of arguments that can be converted to texture*/
Map<Expr, Map<Expr, Array<String>>> accept_textures_;
/*! \brief main attribute for expression*/
tvm::Attrs expr_attrib;
/*! \brief parameters that filter out from storage_map to use buffers*/
std::unordered_set<Expr, ObjectPtrHash> buffers_params;
/*! \brief arguments in expression that will use buffers*/
std::unordered_set<Expr, ObjectPtrHash> buffers_args;
};

} // namespace
Expand Down
20 changes: 10 additions & 10 deletions tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,12 @@ def test_residual_block(remote, target, dtype):
static_memory_scope = [
"",
"global.texture",
"global.texture-weight",
"global",
"global.texture-weight",
"global.texture",
"global.texture-weight",
"global.texture",
"global.texture-weight",
"global",
"",
"",
]
Expand Down Expand Up @@ -834,11 +834,11 @@ def test_pooling_branching_texture_params(remote, target, dtype):
"global.texture-weight",
"global.texture",
"global.texture",
"global.texture-weight",
"global",
"global.texture-weight",
"global.texture-weight",
"global.texture",
"global.texture-weight",
"global",
"global.texture",
"",
"",
Expand Down Expand Up @@ -960,11 +960,11 @@ def test_branching_texture_params(remote, target, dtype):
"global.texture",
"global.texture-weight",
"global.texture",
"global.texture-weight",
"global",
"global.texture-weight",
"global.texture-weight",
"global.texture",
"global.texture-weight",
"global",
"global.texture",
"",
"",
Expand Down Expand Up @@ -1179,9 +1179,9 @@ def test_injective_nwo_inputs1(remote, target, dtype):
static_memory_scope = [
"",
"global.texture",
"global.texture-nhwc",
"global",
"global.texture",
"global.texture-nhwc",
"global",
"global.texture",
"global",
"global",
Expand Down Expand Up @@ -1277,10 +1277,10 @@ def test_injective_nwo_inputs2(remote, target, dtype):
static_memory_scope = [
"",
"global.texture",
"global.texture-nhwc",
"global",
"global.texture",
"global",
"global.texture-nhwc",
"global",
"global.texture",
"global",
]
Expand Down