From 0198ba75481834801888b00a737846a69f022d88 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Mar 2024 20:45:52 +0000 Subject: [PATCH] [SLM] Implement pattern replacement Prior to this commit, if an optimized `nn.Module` implementation were defined, either the model definition needed to be rewritten to include the optimized implementation, or the user needed to write their own `nn.Mutator` implementation. This commit implements `nn.replace_implementation`, which provides a `nn.Mutator` that replaces all instances of a `nn.Module` with an optimized implementation. This allows a user to write optimized implementations as subclasses, such as shown below. ```python class BaseImplementation(nn.Module): ... class OptimizedImplementation(BaseImplementation): ... ``` After defining the optimized implementation, `nn.replace_implementation(OptimizedImplementation)` returns a `nn.Mutator` that can inject the optimized module into an existing end-to-end SLM model. In addition, the SLM-to-SLM transformation can be converted into a Relax-to-Relax transformation, allowing an easy path for migrating optimized kernels into a Relax optimation pipeline. --- python/tvm/relax/frontend/nn/__init__.py | 2 + python/tvm/relax/frontend/nn/core.py | 36 +-- python/tvm/relax/frontend/nn/exporter.py | 6 +- .../frontend/nn/replace_implementation.py | 234 ++++++++++++++++++ python/tvm/relax/frontend/nn/spec.py | 4 +- ...test_frontend_nn_replace_implementation.py | 156 ++++++++++++ 6 files changed, 419 insertions(+), 19 deletions(-) create mode 100644 python/tvm/relax/frontend/nn/replace_implementation.py create mode 100644 tests/python/relax/test_frontend_nn_replace_implementation.py diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index a8200d8dd627..77c3dee4969b 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -26,6 +26,7 @@ ConvTranspose1D, Embedding, GroupNorm, + Identity, IOEffect, KVCache, LayerNorm, @@ -37,3 +38,4 @@ from .op import * from .subroutine import SubroutineMixin from .visitor import Mutator +from .replace_implementation import replace_implementation diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 21118b1cb8af..86b7803aeb78 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -107,11 +107,11 @@ def __init__(self, *, _expr: rx.Expr) -> None: def _check_tensor(expr: rx.Expr) -> None: assert expr.struct_info_ is not None assert isinstance(expr.struct_info, TensorStructInfo) - assert expr.struct_info.ndim != -1 - assert expr.struct_info.shape is not None - assert expr.struct_info.shape.struct_info_ is not None - assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo) - assert expr.struct_info.shape.struct_info.values is not None + if expr.struct_info.ndim != -1: + assert expr.struct_info.shape is not None + assert expr.struct_info.shape.struct_info_ is not None + assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo) + assert expr.struct_info.shape.struct_info.values is not None _check_tensor(_expr) self._expr = _expr @@ -148,26 +148,28 @@ def placeholder( If shape is a string `name`, we create a symbolic shape `tvm.tir.Var(name, "int64")`. """ - new_shape = [] - for expr in shape: + + def _normalize_dim(expr): if isinstance(expr, (int, tir.IntImm)): expr = int(expr) assert expr >= 0 - new_shape.append(expr) - continue - if isinstance(expr, str): - expr = tir.Var(expr, "int64") - new_shape.append(expr) - continue - if not isinstance(expr, tir.PrimExpr): + return expr + elif isinstance(expr, str): + return tir.Var(expr, "int64") + elif isinstance(expr, tir.PrimExpr): + assert expr.dtype == "int64" + return expr + else: raise TypeError(f"Invalid shape: {shape}") - assert expr.dtype == "int64" - new_shape.append(expr) + + if shape is not None: + shape = [_normalize_dim(dim) for dim in shape] + return Tensor( _expr=rx.Var( name_hint=name, struct_info=TensorStructInfo( - shape=new_shape, # type: ignore[arg-type] + shape=shape, # type: ignore[arg-type] dtype=dtype, ), ) diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 1a7dcd6a648b..8fdb54911f45 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -304,8 +304,12 @@ def _convert_input(arg_name, arg_spec): if isinstance(arg_spec, _spec.Int): arg = _get_var(arg_name) elif isinstance(arg_spec, _spec.Tensor): + shape = arg_spec.shape + if shape is not None: + shape = [_get_var(x) if isinstance(x, str) else x for x in shape] + arg = core.Tensor.placeholder( # pylint: disable=protected-access - shape=[_get_var(x) if isinstance(x, str) else x for x in arg_spec.shape], + shape=shape, dtype=arg_spec.dtype, name=arg_name, ) diff --git a/python/tvm/relax/frontend/nn/replace_implementation.py b/python/tvm/relax/frontend/nn/replace_implementation.py new file mode 100644 index 000000000000..3b7001a7c4fe --- /dev/null +++ b/python/tvm/relax/frontend/nn/replace_implementation.py @@ -0,0 +1,234 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Pattern matching in SLM""" + + +import inspect +from typing import Dict, Tuple, List + +import tvm +from tvm.relax.frontend import nn + +from tvm.relax import dpl as relax_pattern + + +def _relax_function_to_pattern( + func: "tvm.relax.Function", +) -> Tuple[List[relax_pattern.WildcardPattern], relax_pattern.DFPattern]: + """Convert a relax function into a pattern to be matched + + TODO(Lunderberg): Replace `tvm.relax.dpl` with function objects. + Pattern-matching and replacement can be done using a function + object as the pattern. + """ + + params: List[relax_pattern.WildcardPattern] = [] + patterns: Dict[tvm.relax.Var, relax_pattern.DFPattern] = {} + + for param in func.params: + wildcard = relax_pattern.WildcardPattern().has_struct_info(param.struct_info) + params.append(wildcard) + patterns[param] = wildcard + + def _make_pattern(expr: tvm.relax.Expr) -> relax_pattern.DFPattern: + if isinstance(expr, tvm.relax.Var): + return patterns[expr] + elif isinstance(expr, tvm.relax.Call): + op = _make_pattern(expr.op) + args = [_make_pattern(arg) for arg in expr.args] + return op(*args) + elif isinstance(expr, tvm.relax.Tuple): + fields = [_make_pattern(field) for field in expr.fields] + return relax_pattern.TuplePattern(fields) + elif isinstance(expr, tvm.ir.Op): + return relax_pattern.ExprPattern(expr) + else: + raise TypeError( + f"Cannot convert relax expression {expr} of type {type(expr)}, " + f"which has struct info {expr.struct_info_}, " + f"into DFPattern." + ) + + seq_expr = func.body + for block in seq_expr.blocks: + for binding in block.bindings: + patterns[binding.var] = _make_pattern(binding.value) + + top_pattern = _make_pattern(seq_expr.body) + + return params, top_pattern + + +def _relax_function_to_rewriter( + param_patterns: List[relax_pattern.WildcardPattern], + replacement_func: "tvm.relax.Function", +) -> Tuple[List[relax_pattern.WildcardPattern], relax_pattern.DFPattern]: + """Generate a rewriter from a relax.Function""" + + assert len(replacement_func.params) == len(param_patterns) + + def rewriter(expr, matches): + match_results = [matches[param_pattern] for param_pattern in param_patterns] + func = tvm.relax.utils.copy_with_new_vars(replacement_func) + + input_bindings = [ + tvm.relax.VarBinding(param, match_result) + for param, match_result in zip(func.params, match_results) + ] + output_expr = tvm.relax.SeqExpr([tvm.relax.DataflowBlock(input_bindings)], func.body) + + output_var = tvm.relax.Var("match_result", expr.struct_info) + output_binding = tvm.relax.VarBinding(output_var, output_expr) + + return tvm.relax.SeqExpr([tvm.relax.DataflowBlock([output_binding])], output_var) + + return rewriter + + +def _relax_transform_by_rewrite_call(pattern, rewriter): + @tvm.ir.transform.module_pass(name="relax.PatternReplacement", opt_level=0) + def transform(mod, _context): + updates = {} + for gvar, func in mod.functions.items(): + if isinstance(func, tvm.relax.Function): + new_func = relax_pattern.rewrite_call(pattern, rewriter, func) + if not func.same_as(new_func): + updates[gvar] = new_func + + if updates: + mod = mod.clone() + mod.update(updates) + + return mod + + return transform + + +def _no_op_init__(self): # pylint: ignore=unused-argument + pass + + +class ReplaceWithSubclass(nn.Mutator): + """A SLM mutator to inject an optimized subclass + + Parameters + ---------- + optimized_subclass: type + + A optimized subclass of a `nn.Module` subclass. + """ + + def __init__(self, optimized_subclass: type): + base_class = optimized_subclass.__base__ + + assert issubclass( + optimized_subclass, nn.Module + ), "The optimized implementation must inherit from a subclass of nn.Module" + assert ( + base_class is not nn.Module + ), "The optimized implementation must not be a direct subclass of nn.Module" + + self.base_class = base_class + self.optimized_subclass = optimized_subclass + + def visit_module(self, name: str, node: nn.Module) -> nn.Module: + """Replace a nn.Module subclass with an optimized version""" + + node = super().visit_module(name, node) + if isinstance(node, self.base_class): + # We want to replace the nn.Module without needing to + # construct a new instance. + node.__class__ = self.optimized_subclass + + cached_init = self.base_class.__init__ + self.base_class.__init__ = _no_op_init__ + try: + node.__init__() + finally: + self.base_class.__init__ = cached_init + + return node + + def as_relax_transform(self) -> tvm.ir.transform.Pass: + """Produce a Relax-to-Relax transform""" + init_sig = inspect.signature(self.base_class.__init__) + + init_kwargs = {} + for name, param in init_sig.parameters.items(): + if name == "self": + pass + elif issubclass(int, param.annotation): + # The annotation is either `int` on its own, or a + # Union that includes `int`. + init_kwargs[name] = tvm.tir.Var(name, "int64") + else: + raise TypeError( + f"Cannot determine argument type for __init__ argument {name}, " + f"with type annotation {param.annotation}" + ) + + forward_sig = inspect.signature(self.base_class.forward) + forward_spec = {} + for name, param in forward_sig.parameters.items(): + if name == "self": + pass + elif param.annotation is nn.Tensor: + forward_spec[name] = nn.spec.Tensor(None, "void") + else: + raise TypeError( + f"Cannot determine argument type for __init__ argument {name}, " + f"with type annotation {param.annotation}" + ) + + spec = {"forward": forward_spec} + + base_impl = self.base_class(**init_kwargs) + optimized_impl = self.optimized_subclass(**init_kwargs) + + base_tvm, _ = base_impl.export_tvm(spec) + optimized_tvm, _ = optimized_impl.export_tvm(spec) + + base_tvm = base_tvm["forward"] + optimized_tvm = optimized_tvm["forward"] + + param_patterns, match_pattern = _relax_function_to_pattern(base_tvm) + match_rewriter = _relax_function_to_rewriter(param_patterns, optimized_tvm) + + return _relax_transform_by_rewrite_call(match_pattern, match_rewriter) + + +def replace_implementation(optimized_subclass: type): + """Produce a mutator to replace an existing nn.Module + + This utility allows users to write an optimized implementation of + an existing `nn.Module`, and to substitute it into an existing + end-to-end model. + + Parameters + ---------- + optimized_subclass: type + + A optimized subclass of a `nn.Module` subclass. + + Returns + ------- + mutator: nn.Mutator + + A mutator that replaces `optimized_subclass.__base__` with + `optimized_subclass`. + """ + return ReplaceWithSubclass(optimized_subclass) diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py index 54928ce07b80..db933171b484 100644 --- a/python/tvm/relax/frontend/nn/spec.py +++ b/python/tvm/relax/frontend/nn/spec.py @@ -44,7 +44,9 @@ class Tensor: # pylint: disable=too-few-public-methods dtype: str def __init__(self, shape: typing.Sequence[typing.Union[int, str]], dtype: str) -> None: - self.shape = list(shape) + if shape is not None: + shape = list(shape) + self.shape = shape self.dtype = dtype def __repr__(self) -> str: diff --git a/tests/python/relax/test_frontend_nn_replace_implementation.py b/tests/python/relax/test_frontend_nn_replace_implementation.py new file mode 100644 index 000000000000..135bf53f57b2 --- /dev/null +++ b/tests/python/relax/test_frontend_nn_replace_implementation.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.frontend import nn + +when_to_apply_replacement = tvm.testing.parameter( + "relax", + "slm", +) + + +def test_replace_implementation(when_to_apply_replacement): + class FeedForward(nn.Module): + """The base implementation to be replaced""" + + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear( + in_features=hidden_size, + out_features=intermediate_size, + bias=False, + ) + self.up_proj = nn.Linear( + in_features=hidden_size, + out_features=intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, hidden_states: nn.Tensor): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + return self.down_proj(nn.op.silu(gate) * up) + + class OptimizedFeedForward(FeedForward): + """The optimized implementation""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def gate_up_proj(self): + hidden_size = self.gate_proj.in_features + intermediate_size = self.gate_proj.out_features + + gate_up_proj = nn.Linear( + in_features=hidden_size, + out_features=2 * intermediate_size, + bias=False, + ) + gate_up_proj.weight = nn.op.concat([self.gate_proj.weight, self.up_proj.weight], dim=0) + return gate_up_proj + + def forward(self, hidden_states: nn.Tensor): + concat_x1_x2 = self.gate_up_proj(hidden_states) + gate, up = nn.op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(nn.op.silu(gate) * up) + + class DecoderLayer(nn.Module): + """A Module that internally uses the base implementation""" + + def __init__(self, hidden_side: int, intermediate_size: int): + self.self_attn = nn.Identity() # For sake of testing + self.mlp = FeedForward(hidden_size, intermediate_size) + self.input_layernorm = nn.RMSNorm(hidden_size, axes=-1, bias=False) + self.post_attention_layernorm = nn.RMSNorm(hidden_size, axes=-1, bias=False) + + def forward(self, hidden_states: nn.Tensor): + hidden_states += self.self_attn(self.input_layernorm(hidden_states)) + hidden_states += self.mlp(self.post_attention_layernorm(hidden_states)) + return hidden_states + + class ExpectedDecoderLayer(nn.Module): + """A Module that internally uses the optimized implementation + + This class is for testing purposes. After injecting the + optimized implementation, we should produce equivalent + end-to-end SLM/Relax models as the hand-written version. + """ + + def __init__(self, hidden_side: int, intermediate_size: int): + self.self_attn = nn.Identity() # For sake of testing + self.mlp = OptimizedFeedForward(hidden_size, intermediate_size) + self.input_layernorm = nn.RMSNorm(hidden_size, axes=-1, bias=False) + self.post_attention_layernorm = nn.RMSNorm(hidden_size, axes=-1, bias=False) + + def forward(self, hidden_states: nn.Tensor): + hidden_states += self.self_attn(self.input_layernorm(hidden_states)) + hidden_states += self.mlp(self.post_attention_layernorm(hidden_states)) + return hidden_states + + batch_size = 16 + hidden_size = 4096 + intermediate_size = 11008 + dtype = "float32" + + slm_model = DecoderLayer(intermediate_size, hidden_size) + + mutator = nn.replace_implementation(OptimizedFeedForward) + + if when_to_apply_replacement == "slm": + slm_model = mutator.visit("", slm_model) + + model_expected = ExpectedDecoderLayer(intermediate_size, hidden_size) + + spec = { + "forward": {"hidden_states": nn.spec.Tensor([batch_size, hidden_size], dtype)}, + } + + relax_expected = model_expected.export_tvm(spec)[0] + assert tvm.relax.analysis.well_formed(relax_expected) + + relax_model = slm_model.export_tvm(spec)[0] + assert tvm.relax.analysis.well_formed(relax_model) + + if when_to_apply_replacement == "relax": + transform = mutator.as_relax_transform() + relax_model = transform(relax_model) + + normalize = tvm.ir.transform.Sequential( + [ + # Normalize the IRModule by apply a topological sort within each + # dataflow block. Otherwise, equivalent replacements performed at + # a different step of optimization can result in a different order + # of intermediates. + tvm.relax.transform.TopologicalSort(order="depth-first", direction="from-outputs"), + # The SLM exporter produces a trivial `var = dataflow_var` + # binding for the output, which should be removed before + # validating the output. + tvm.relax.transform.CanonicalizeBindings(), + ] + ) + relax_model = normalize(relax_model) + relax_expected = normalize(relax_expected) + + tvm.ir.assert_structural_equal(relax_model, relax_expected) + + +if __name__ == "__main__": + tvm.testing.main()