Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 87 additions & 31 deletions python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def visit_call(self, call) -> relay.expr.Expr:
if (
not isinstance(arg, (Var, Constant))
and isinstance(arg, tvm.relay.TupleGetItem)
and isinstance(arg.tuple_value.op, tvm.ir.op.Op)
and arg.tuple_value.op.name == "nn.batch_norm"
and (not isinstance(arg.tuple_value.args[0], (Var, Constant)))
and arg.tuple_value.args[0].op.name == "nn.conv2d"
Expand Down Expand Up @@ -260,7 +261,8 @@ def conv_pattern():
)
)
pattern = pattern.optional(is_op("nn.relu"))
pattern = pattern.optional(is_op("clip"))
# Fusion pattern to support with relu6 layer.
pattern = pattern.optional(is_op("clip").has_attr({"a_min": 0.0, "a_max": 6.0}))
return pattern

def conv_transpose_pattern():
Expand All @@ -276,7 +278,8 @@ def conv_transpose_pattern():
)
)
pattern = pattern.optional(is_op("nn.relu"))
pattern = pattern.optional(is_op("clip"))
# Fusion pattern to support with relu6 layer.
pattern = pattern.optional(is_op("clip").has_attr({"a_min": 0.0, "a_max": 6.0}))
return pattern

def pad_conv_pattern():
Expand All @@ -293,7 +296,8 @@ def pad_conv_pattern():
)
)
pattern = pattern.optional(is_op("nn.relu"))
pattern = pattern.optional(is_op("clip"))
# Fusion pattern to support with relu6 layer.
pattern = pattern.optional(is_op("clip").has_attr({"a_min": 0.0, "a_max": 6.0}))
return pattern

def batch_norm_pattern():
Expand Down Expand Up @@ -359,6 +363,9 @@ def check_conv(extract):
if attrs.data_layout != "NCHW":
return False

if call.checked_type.shape[0] > 1:
return False

if (
(not clip_found)
and (attrs.kernel_size[0] == 3)
Expand Down Expand Up @@ -411,19 +418,13 @@ def check_binary_op(extract):
# Scalars are not supported
if len(call.args[1].checked_type.shape) == 0:
return False
if call.args[0] == call.args[1]:
return False

if tuple(call.args[0].checked_type.shape) != tuple(call.args[1].checked_type.shape):
return False

for arg in call.args:
# Avoid any operators with dtype Int64
if arg.checked_type.dtype == "int64":
return False
# No support for batch> 1
if arg.checked_type.shape[0] > 1:
return False

return True
return check_default_op(call)

def check_pad_op(extract):
call = extract
Expand All @@ -433,75 +434,134 @@ def check_pad_op(extract):
# Pad layers before any convolution are not guarenteed to be NCHW.
if isinstance(call.args[0], tvm.relay.expr.Var):
return False
return True
return check_default_op(call)

def check_softmax_op(extract):
call = extract
# supports 2D and 4D tensors
# supports 2D and 4D tensors.
if len(call.args[0].checked_type.shape) not in [2, 4]:
return False
return True
return check_default_op(call)

def check_upsampling_op(extract):
call = extract
if call.attrs["method"] != "bilinear":
return False
return True
return check_default_op(call)

def check_concat_op(extract):
call = extract
if call.attrs["axis"] != 1:
return False
return True
return check_default_op(call)

def check_default_op(extract):
call = extract

if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value
call_shape = call.checked_type.fields[0].shape
call_dtype = call.checked_type.fields[0].dtype
else:
call_shape = call.checked_type.shape
call_dtype = call.checked_type.dtype

# int64, int32 dtypes are not Supported in CLML
if call_dtype in ["int64", "int32"]:
return False

# Avoid any operators with dtype Int64
for arg in call.args:
if arg.checked_type.dtype == "int64":
# Supports only upto 4 dim shapes
if len(call_shape) > 4:
return False
# Only support batch dim = 1
if isinstance(call_shape[0], tvm.tir.expr.Any) or call_shape[0] > 1:
return False
# Checking buffer indexing limit
for shape in call_shape:
if shape > 32768:
return False
# Avoid any operators with dtype Int64 and upsupported shape
for _arg in call.args:
t_arg = _arg if isinstance(_arg, tvm.relay.Tuple) else [_arg]
for arg in t_arg:
checked_type = (
arg.tuple_value.checked_type.fields[arg.index]
if isinstance(arg, tvm.relay.TupleGetItem)
else arg.checked_type
)
if checked_type.dtype in ["int64", "int32"]:
return False
# Supports only 4 dim shapes
if len(checked_type.shape) > 4:
return False
# Only support batch dim = 1
if len(checked_type.shape) > 0 and checked_type.shape[0] > 1:
return False
for shape in checked_type.shape:
if shape > 32768:
return False
return True

def check_batch_matmul_op(extract):
call = extract
# Only support single Matmul
# Only support single Matmul.
if call.args[0].checked_type.shape[0] > 1:
return False
if call.args[1].checked_type.shape[0] > 1:
return False
return True
return check_default_op(call)

def check_dense1d_op(extract):
call = extract
# Only support single Matmul
# Only support single Matmul.
if call.args[0].checked_type.shape[0] > 1:
return False
if not (call.op.name in ["nn.bias_add", "add"] and call.args[0].op.name == "nn.dense"):
return False
return check_default_op(call)

def check_dense2d_op(extract):
call = extract
# Only support 2D Matmul without bias
if call.op.name in ["nn.bias_add", "add"] and call.args[0].op.name == "nn.dense":
return False
# Avoid any operators with dtype Int64 and upsupported shape
for _arg in call.args:
t_arg = _arg if isinstance(_arg, tvm.relay.Tuple) else [_arg]
for arg in t_arg:
checked_type = (
arg.tuple_value.checked_type.fields[arg.index]
if isinstance(arg, tvm.relay.TupleGetItem)
else arg.checked_type
)
if len(checked_type.shape) != 2:
return False
return True

def check_reshape(extract):
def check_depth_to_space(extract):
call = extract
call_shape = call.checked_type.shape
arg_shape = call.args[0].checked_type.shape
# Supports only upto 4 dim shapes
if len(call_shape) > 4 or len(arg_shape) > 4:
return False
# Only support batch dim = 1
if call_shape[0] > 1:
return False
# Checking buffer indexing limit
for shape in call_shape:
if shape > 32768:
return False
if call.attrs["layout"] != "NCHW" or call.attrs["mode"] != "DCR":
return False
return True

return [
("clml.pad_conv2d", pad_conv_pattern(), check_conv),
("clml.conv2d", conv_pattern(), check_conv),
("clml.conv2d_transpose", conv_transpose_pattern(), check_conv_transpose),
("clml.dense1d", dense1d_pattern(), check_dense1d_op),
("clml.dense2d", dense2d_pattern(), check_default_op),
("clml.dense2d", dense2d_pattern(), check_dense2d_op),
("clml.pad", pad_pattern(), check_pad_op),
("clml.concat", concat_pattern(), check_concat_op),
("clml.batch_norm", batch_norm_pattern(), check_default_op),
Expand All @@ -512,15 +572,15 @@ def check_reshape(extract):
("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op),
("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op),
("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op),
("clml.reshape", is_op("reshape")(wildcard()), check_reshape),
("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op),
("clml.max_pool2d", is_op("nn.max_pool2d")(wildcard()), check_default_op),
("clml.global_avg_pool2d", is_op("nn.global_avg_pool2d")(wildcard()), check_default_op),
("clml.global_max_pool2d", is_op("nn.global_max_pool2d")(wildcard()), check_default_op),
("clml.relu", is_op("nn.relu")(wildcard()), check_default_op),
("clml.clip", is_op("clip")(wildcard()), check_default_op),
("clml.batch_flatten", is_op("nn.batch_flatten")(wildcard()), check_default_op),
("clml.depth_to_space", is_op("nn.depth_to_space")(wildcard()), check_default_op),
("clml.depth_to_space", is_op("nn.depth_to_space")(wildcard()), check_depth_to_space),
("clml.upsampling", is_op("nn.upsampling")(wildcard()), check_upsampling_op),
(
"clml.batch_matmul",
Expand All @@ -538,10 +598,6 @@ def _func_wrapper(expr):
return _func_wrapper


_register_external_op_helper("minimum")
_register_external_op_helper("maximum")


class OpAttrContext(object):
"""Temporarily changes the attr of an op."""

Expand Down
30 changes: 22 additions & 8 deletions tests/python/contrib/test_clml/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,13 +809,16 @@ def _verify(out, params, inputs):


@pytest.mark.parametrize("dtype", ["float32", "float16"])
@pytest.mark.parametrize("input_shape", [(1, 64, 8, 8), (1, 64, 8, 8), (1, 512, 8, 8)])
@pytest.mark.parametrize("block_size", [4, 8])
@pytest.mark.parametrize("mode", ["DCR", "CRD"])
@tvm.testing.requires_openclml
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_depth_to_space(remote, dtype, target, executor_type):
def _get_model(a_shape, block_size):
def test_depth_to_space(remote, dtype, target, executor_type, input_shape, block_size, mode):
def _get_model(a_shape, block_size, mode):
np.random.seed(0)
a = relay.var("a", shape=(a_shape), dtype=dtype)
out = relay.nn.depth_to_space(a, block_size)
out = relay.nn.depth_to_space(a, block_size, mode=mode)
inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))}
params = {}
return out, params, inputs
Expand All @@ -841,7 +844,7 @@ def _verify(out, params, inputs):
"attrs": {
"block_size": [[str(int(out.attrs.block_size))]],
"layout": [["NCHW"]],
"mode": [["DCR"]],
"mode": [[out.attrs.mode]],
"dtype": [[dtype]],
"num_inputs": "1",
"num_outputs": "1",
Expand All @@ -852,11 +855,22 @@ def _verify(out, params, inputs):
"op": "kernel",
},
]
verify_codegen(remote, mod, params, exp_codegen, target)
num_clml_modules = 1
tvm_ops = 0
if out.attrs.mode != "DCR":
num_clml_modules = 0
tvm_ops = 1
verify_codegen(
remote,
mod,
params,
exp_codegen,
target,
num_clml_modules=num_clml_modules,
tvm_ops=tvm_ops,
)

_verify(*(_get_model((1, 64, 8, 8), 4)))
_verify(*(_get_model((1, 64, 8, 8), 8)))
_verify(*(_get_model((1, 512, 8, 8), 8)))
_verify(*(_get_model(input_shape, block_size, mode)))


@pytest.mark.parametrize("dtype", ["float32", "float16"])
Expand Down