Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,224 @@ def __call__(self, *args, **kwargs):
pass


class BinaryElementwiseRewriter(DFPatternCallback):
"""Convert ethosu binary elementwise composite functions to
ethosu_binary_elementwise operators"""

def __init__(
self,
params_class: Type,
pattern: CallPattern,
):
super().__init__(require_type=True)
self.params_class = params_class
self.pattern = pattern

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[1] if params.reversed_operands else post.args[0]
params.ifm2.tensor = post.args[0] if params.reversed_operands else post.args[1]
channels_map = {
"NHWC": 3,
}
if str(params.ofm.layout) not in channels_map.keys():
raise UnsupportedLayout(str(params.ofm.layout))

activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

# We don't yet support activation functions that need to get legalized to LUTs.
lut = relay.const([], dtype="int8")

return ethosu_ops.ethosu_binary_elementwise(
ifm=params.ifm.tensor,
ifm2=params.ifm2.tensor,
lut=lut,
operator_type=params.operator_type,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ifm2_scale=float(params.ifm2.q_params.scale_f32),
ifm2_zero_point=int(params.ifm2.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=params.ifm.shape[3],
ifm2_channels=params.ifm2.shape[3],
reversed_operands=params.reversed_operands,
ofm_dtype=params.ofm.dtype,
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
ifm_layout=str(params.ifm.layout),
ifm2_layout=str(params.ifm2.layout),
ofm_layout=str(params.ofm.layout),
)


class AddRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.AddParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AddParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeAdd:
"""This is the pass that wraps the AddRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(AddRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class SubRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.SubParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.SubParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeSub:
"""This is the pass that wraps the SubRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(SubRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class MulRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.MulParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MulParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeMul:
"""This is the pass that wraps the MulRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MulRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class MinRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.MinParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MinParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeMin:
"""This is the pass that wraps the MinRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MinRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class MaxRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.MaxParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MaxParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeMax:
"""This is the pass that wraps the MaxRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MaxRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class ShlRewriter(BinaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.ShlParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.ShlParams.composite_name}))(
wildcard(), wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeShl:
"""This is the pass that wraps the ShlRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(ShlRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -423,11 +641,20 @@ class LegalizeEthosU:
def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
"""This is the method that replaces the operations with hardware/codegen supported
operations.
"""
mod = LegalizeSplit()(mod)
mod = LegalizeConv2D()(mod)
mod = LegalizeDepthwiseConv2D()(mod)
mod = LegalizeMaxPooling()(mod)
mod = LegalizeAvgPooling()(mod)
mod = LegalizeAdd()(mod)
mod = LegalizeSub()(mod)
mod = LegalizeMul()(mod)
mod = LegalizeMin()(mod)
mod = LegalizeMax()(mod)
mod = LegalizeShl()(mod)
return mod

def __call__(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .convolution import ethosu_conv2d
from .depthwise import ethosu_depthwise_conv2d
from .pooling import ethosu_pooling
from .binary_elementwise import ethosu_binary_elementwise
Loading