diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index b8b5b5f22eb0..6533da45885e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -38,7 +38,7 @@ from . import memory from . import nn -# Operator gradient functions +# Register operator gradient functions from . import _op_gradient diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index d57bcc8621fa..17bafe0a37bf 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -87,14 +87,14 @@ def _get_dtype(expr: Expr) -> str: return dtype -def _fit_shape(bb: BlockBuilder, expr: Expr, target: Expr) -> Expr: +def _fit_shape(bb: BlockBuilder, input_grad: Expr, input: Expr) -> Expr: """When expr and target has the same shape, return expr; otherwise return `collapse_sum_to(expr, target.struct_info.shape)`. Will use BlockBuilder to normalize expr first. """ - target_shape = _get_shape(target) - expr_sinfo = _get_shape(bb.normalize(expr)).struct_info + target_shape = _get_shape(input) + expr_sinfo = _get_shape(bb.normalize(input_grad)).struct_info target_sinfo = target_shape.struct_info assert isinstance(expr_sinfo, ShapeStructInfo) assert isinstance(target_sinfo, ShapeStructInfo) @@ -109,9 +109,9 @@ def _check_shape_equal(lhs: ShapeStructInfo, rhs: ShapeStructInfo): return True return ( - expr + input_grad if _check_shape_equal(expr_sinfo, target_sinfo) - else collapse_sum_to(expr, target_shape) + else collapse_sum_to(input_grad, target_shape) ) @@ -250,15 +250,14 @@ def maximum_grad( `z = relax.maximum(x, y)` Backward: - Returns `[z_grad * (where(x < y, 0, 1)), z_grad * (where(x >= y, 0, 1))]`. + Returns `[where(x < y, 0, z_grad), where(x >= y, 0, z_grad)]`. """ x = orig_call.args[0] y = orig_call.args[1] - one = relax.const(1, _get_dtype(x)) zero = relax.const(0, _get_dtype(x)) return [ - where(less(x, y), zero, one) * output_grad, - where(greater_equal(x, y), zero, one) * output_grad, + where(less(x, y), zero, output_grad), + where(greater_equal(x, y), zero, output_grad), ] @@ -275,15 +274,14 @@ def minimum_grad( `z = relax.minimum(x, y)` Backward: - Returns `[z_grad * (where(x >= y, 0, 1)), z_grad * (where(x < y, 0, 1))]`. + Returns `[where(x >= y, 0, z_grad), where(x < y, 0, z_grad)]`. """ x = orig_call.args[0] y = orig_call.args[1] - one = relax.const(1, _get_dtype(x)) zero = relax.const(0, _get_dtype(x)) return [ - where(greater_equal(x, y), zero, one) * output_grad, - where(less(x, y), zero, one) * output_grad, + where(greater_equal(x, y), zero, output_grad), + where(less(x, y), zero, output_grad), ] @@ -1030,12 +1028,11 @@ def relu_grad( `y = relax.relu(x)` Backward: - Returns `[y_grad * (where(x < 0, 0, 1))]`. + Returns `[where(x < 0, 0, y_grad)]`. """ x = orig_call.args[0] - one = relax.const(1, _get_dtype(x)) zero = relax.const(0, _get_dtype(x)) - return [where(less(x, zero), zero, one) * output_grad] + return [where(less(x, zero), zero, output_grad)] @register_gradient("relax.nn.silu") @@ -1090,10 +1087,10 @@ def log_softmax_grad( `y = relax.log_softmax(x, axis)` Backward: - Returns `[y_grad - sum(y_output_grad, axis, keepdims=True) * softmax(x)]` + Returns `[y_grad - sum(y_grad, axis, keepdims=True) * exp(y)]` """ - x_softmax = exp(orig_var) - return [(output_grad - sum(output_grad, orig_call.attrs.axis, True) * x_softmax)] + y_exp = exp(orig_var) + return [(output_grad - sum(output_grad, orig_call.attrs.axis, True) * y_exp)] @register_gradient("relax.nn.cross_entropy_with_logits") diff --git a/python/tvm/relax/op/grad/grad.py b/python/tvm/relax/op/grad/grad.py index e1f15918766e..2218db223208 100644 --- a/python/tvm/relax/op/grad/grad.py +++ b/python/tvm/relax/op/grad/grad.py @@ -43,6 +43,59 @@ def no_grad(input: Expr) -> Expr: return _ffi_api.no_grad(input) # type: ignore +def start_checkpoint(input: Expr) -> Expr: + """Mark the start of the checkpoint stage. The computation between start_checkpoint and + end_checkpoint will be marked as the checkpoint stage. + + Rather than storing all intermediate activations of the entire computation graph for + computing backward, the checkpointed stage does not save intermediate activations, and instead + recomputes them in backward process. + + For instance, + ``` + a = relax.Var("a", relax.TensorStructInfo((2, 2), "float32")) + b = relax.Var("b", relax.TensorStructInfo((2, 2), "float32")) + c = a * 2 + d = b * 2 + c_cp = start_checkpoint(c) + d_cp = start_checkpoint(d) + e = c_cp + d_cp + e_out = end_checkpoint(e) + ``` + Then `e` will be recomputed in the backward stage. + + See tvm.relax.transform.Gradient, tvm.relax.testing.nn.checkpoint, + tvm.relax.op.grad.end_checkpoint for more information. + + Parameters + ---------- + input : relax.Expr + The tensor marking the input of the checkpoint stage. + + Returns + ------- + result : relax.Expr + The same tensor as the input. + """ + return _ffi_api.start_checkpoint(input) # type: ignore + + +def end_checkpoint(input: Expr) -> Expr: + """Mark the end of checkpoint stage. See tvm.relax.op.grad.start_checkpoint. + + Parameters + ---------- + input : relax.Expr + The output of the checkpoint stage. + + Returns + ------- + result : relax.Expr + The same tensor as the input. + """ + return _ffi_api.end_checkpoint(input) # type: ignore + + def nll_loss_backward( output_grad: Expr, predictions: Expr, diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py index a43dfab56e38..ba7ca2653fa8 100644 --- a/python/tvm/relax/testing/nn.py +++ b/python/tvm/relax/testing/nn.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin, invalid-name """PyTorch-like nn.Module API for constructing workloads.""" @@ -24,6 +24,7 @@ import tvm from tvm import relax, topi, tir +from tvm.relax.op.grad.grad import end_checkpoint, start_checkpoint def emit(expr: relax.Expr, name_hint: str = "") -> relax.Var: @@ -34,6 +35,126 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var: return relax.BlockBuilder.current().emit_te(func, *args, **kwargs) +def checkpoint( + func: Callable, *args: Any, **kwargs: Any +) -> Union[relax.Var, List[relax.Var], List[Any]]: + """Mark function(*args, **kwargs) should be computed in a checkpointed manner during backward. + + To be specific, args and kwargs will be checkpointed, and func(*args, **kwargs) will be + recomputed in the backward stage. + """ + args = [start_checkpoint(v) if isinstance(v, relax.Expr) else v for v in args] + kwargs = {k: start_checkpoint(v) if isinstance(v, relax.Expr) else v for k, v in kwargs.items()} + result = func(*args, **kwargs) + if isinstance(result, (list, tuple)): + result = [end_checkpoint(v) if isinstance(v, relax.Expr) else v for v in result] + else: + assert isinstance(result, relax.Expr) + result = end_checkpoint(result) + return result + + +def emit_checkpoint( + func: Callable, *args: Any, **kwargs: Any +) -> Union[relax.Var, List[relax.Var], List[Any]]: + """Mark function(*args, **kwargs) should be computed in a checkpointed manner during backward. + + To be specific, args and kwargs will be checkpointed, and func(*args, **kwargs) will be + recomputed in the backward stage. + + This interface will additionally emit the exprs marked with start_checkpoint() and + end_checkpoint() with suffix "_scp" and "_ecp" respectively, for easily understanding the + result tvmscript. + """ + bb = relax.BlockBuilder.current() + args = [ + bb.emit(start_checkpoint(v), v.name_hint + "_scp") if isinstance(v, relax.Var) else v + for v in args + ] + kwargs = { + k: bb.emit(start_checkpoint(v), v.name_hint + "_scp") if isinstance(v, relax.Var) else v + for k, v in kwargs.items() + } + result = func(*args, **kwargs) + if isinstance(result, (list, tuple)): + result = list(result) + for i, v in enumerate(result): + if isinstance(v, relax.Expr): + if not isinstance(v, relax.Var): + v = bb.emit(v) + result[i] = bb.emit(end_checkpoint(v), v.name_hint + "_ecp") + else: + assert isinstance(result, relax.Expr) + result_emit = bb.emit(result) + result = bb.emit(end_checkpoint(result_emit), result_emit.name_hint + "_ecp") + + return result + + +def emit_checkpoint_sequential( + functions: List[Callable], + segments: Union[int, List[int]], + input: relax.Var, + checkpoint_last: bool = False, +) -> Union[relax.Var, List[relax.Var], List[Any]]: + """A helper function for checkpointing sequential models. This interface has similar purpose + as torch.utils.checkpoint.checkpoint_sequential. + + Sequential models execute a list of modules/functions in order (sequentially). Therefore, we + can divide such a model in various segments and checkpoint each segment. By default, we will + checkpoint all segments except the last, meaning their inputs will be saved from the forward + stage and they will be recomputed in the backward stage. + + Parameters + ---------- + functions : List[Callable] + The list of functions to be executed sequentially. + + segments : Union[int, List[int]] + The segments. If segments is int `n`, functions will be evenly divided into `n` segments; + if segments is a list of ints, it marks the start of every segment. + + input : relax.Var + The input of the first function. + + checkpoint_last : bool + Whether the last segment will be checkpointed. Default: False + + Returns + ------- + output : Union[relax.Var, List[relax.Var], List[Any]] + The emited output of the last function. + """ + bb = relax.BlockBuilder.current() + + def run_function(start, end, functions): + def forward(input): + for j in range(start, end): + input = functions[j](input) + return input + + return forward + + n = len(functions) + if not isinstance(segments, list): + segments = list(range(0, n, n // segments)) + [n] + if segments[-1] != n: + segments = segments + [n] + + assert len(segments) >= 2 + + for i in range(len(segments) - 1): + if i == len(segments) - 2 and not checkpoint_last: + input = run_function(segments[i], segments[i + 1], functions)(input) + else: + input = emit_checkpoint(run_function(segments[i], segments[i + 1], functions), input) + + assert isinstance(input, relax.Expr) + if not isinstance(input, relax.Var): + input = bb.emit(input) + return input + + def _try_unique_name(name: str): """Attempt to uniquify the name diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py index 69fa85043e4b..bf9e937457b6 100644 --- a/python/tvm/relax/training/utils.py +++ b/python/tvm/relax/training/utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, unused-argument """Utility functions for relax training.""" from typing import Optional, Callable @@ -24,7 +24,7 @@ from tvm._ffi.registry import register_func from tvm.relax.block_builder import BlockBuilder -from ..expr import Function +from ..expr import Function, Var, Call from . import _ffi_api @@ -189,13 +189,13 @@ def register(func: Callable): # It will return the emitted var. def handler( - builder: BlockBuilder, output_grad_var: relax.Var, call_tir: relax.Call + orig_var: Var, call_tir_with_grad: Call, output_grad: Var, ctx: BlockBuilder ) -> relax.Expr: - return builder.emit_te( + return ctx.emit_te( func, - output_grad_var, - *call_tir.args[1], - **call_tir.attrs.te_grad_kwargs, + output_grad, + *call_tir_with_grad.args[1], + **call_tir_with_grad.attrs.te_grad_kwargs, primfunc_name_hint=te_grad_name, ) diff --git a/python/tvm/relax/transform/legalize_ops/grad.py b/python/tvm/relax/transform/legalize_ops/grad.py index f5c295afc683..1d527bea6ae6 100644 --- a/python/tvm/relax/transform/legalize_ops/grad.py +++ b/python/tvm/relax/transform/legalize_ops/grad.py @@ -29,6 +29,16 @@ def _no_grad(bb: BlockBuilder, call: Call) -> Expr: return call.args[0] +@register_legalize("relax.grad.start_checkpoint") +def _start_checkpoint(bb: BlockBuilder, call: Call) -> Expr: + return call.args[0] + + +@register_legalize("relax.grad.end_checkpoint") +def _end_checkpoint(bb: BlockBuilder, call: Call) -> Expr: + return call.args[0] + + @register_legalize("relax.grad.nll_loss_backward") def _grad_nll_loss_backward(bb: BlockBuilder, call: Call) -> Expr: # topi.sum don't support zero-dim x @@ -51,9 +61,8 @@ def te_nll_loss_backward(output_grad, predictions, targets, weights, reduction, if reduction == "sum": output_grad = topi.broadcast_to(output_grad, targets.shape) elif reduction == "mean": - output_grad = topi.divide( - topi.broadcast_to(output_grad, targets.shape), topi_sum_extend(all_weights) - ) + weight_sum = topi_sum_extend(all_weights) + output_grad = topi.divide(topi.broadcast_to(output_grad, targets.shape), weight_sum) # handle no batch if predictions.ndim == 1: diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index d8d11a50d824..d1d337a8c026 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -75,6 +75,10 @@ def main_adjoint(original_parameters): R.output(original_outputs, grad_1, grad_2, ...) return (original_return_value, (grad_1, grad_2, ...)) + This AD pass also supports checkpointing as described in + "Training deep nets with sublinear memory cost." - Chen, Tianqi, et al. (2016). + See tvm.relax.testing.nn.checkpoint for more details. + Parameters ---------- func_name : str diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index 2fef2d09b9ec..6f3068446030 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -42,10 +42,55 @@ StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { } TVM_REGISTER_OP("relax.grad.no_grad") - .set_num_inputs(0) + .set_num_inputs(1) + .add_argument("x", "Expr", "The corresponding input tensor.") .set_attr("FInferStructInfo", InferStructInfoNoGrad) .set_attr("FPurity", Bool(true)); +/* relax.grad.start_checkpoint */ +Expr start_checkpoint(Expr input) { + static const Op& op = Op::Get("relax.grad.start_checkpoint"); + return Call(op, {std::move(input)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.grad.start_checkpoint").set_body_typed(start_checkpoint); + +StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& ctx) { + if (!call->args[0].as()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The argument of relax.op.grad.start_checkpoint should be a Var."); + } + return GetStructInfo(call->args[0]); +} + +TVM_REGISTER_OP("relax.grad.start_checkpoint") + .set_num_inputs(1) + .add_argument("x", "Expr", "The tensor marking the input of the checkpoint stage.") + .set_attr("FInferStructInfo", InferStructInfoStartCheckpoint) + .set_attr("FPurity", Bool(true)); + +/* relax.grad.end_checkpoint */ +Expr end_checkpoint(Expr input) { + static const Op& op = Op::Get("relax.grad.end_checkpoint"); + return Call(op, {std::move(input)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.grad.end_checkpoint").set_body_typed(end_checkpoint); + +StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ctx) { + if (!call->args[0].as()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The argument of relax.op.grad.end_checkpoint should be a Var."); + } + return GetStructInfo(call->args[0]); +} + +TVM_REGISTER_OP("relax.grad.end_checkpoint") + .set_num_inputs(1) + .add_argument("x", "Expr", "The output of the checkpoint stage.") + .set_attr("FInferStructInfo", InferStructInfoEndCheckpoint) + .set_attr("FPurity", Bool(true)); + /* relax.grad.nll_loss_backward */ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optional weights, String reduction, int ignore_index) { diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 9a271a3ebe7b..4fd18386703a 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -25,6 +25,7 @@ * with respect to the only return value of the function, which needs to be scalar. */ +#include #include #include #include @@ -40,10 +41,235 @@ namespace tvm { namespace relax { +// We will use NestedMsg to handle adjoint updates involving tuple handling using AdjointMsg = NestedMsg; +using VarIdSet = std::unordered_set; -// A tool class for GradientMutator -// Visit the forward bindings and generate the backward bindings +// Used in CallTIRWithGradCollector. call_tir -> call_tir_with_grad +using CallTIRWithGradInfo = std::unordered_map; + +/*! + * \brief Collect all call_tir_with_grad nodes, transform them into call_tir nodes, and collect the + * te_grad_name and te_grad_kwargs information. + */ +class CallTIRWithGradEliminator : private ExprMutator { + public: + /*! + * \brief Collect all variables that needs to be checkpointed, and remove the start_checkpoint + * and the end_checkpoint bindings. + * + * \param func The original function + * \return The function with all start_checkpoint and end_checkpoint bindings removed, and a + * VarIdSet containing all checkpointed vars. + */ + static Function Transform(const Function& func) { + return Downcast(CallTIRWithGradEliminator().VisitExpr(func)); + } + + private: + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) final { + if (call_node->op != Op::Get("relax.call_tir_with_grad")) { + return ExprMutator::VisitExpr_(call_node); + } + return Call(Op::Get("relax.call_tir"), call_node->args, {}, call_node->sinfo_args, + call_node->span); + } +}; + +/*! + * \brief Collect all variables that needs to be checkpointed, and remove the start_checkpoint + * and the end_checkpoint bindings. + * + * Here we have some principles to determine which var should be checkpointed: + * 1. Input of the function is checkpointed + * 2. For var x marked with start_checkpoint() (wrapped by start_checkpoint), it means x is an input + * to some checkpoint function. So var x is checkpointed + * 3. For other var x , find its predecessor path. + * a. If every predecessor path is marked with end_checkpoint(), x is checkpointed + * b. Else, there must exists a predecessor path marked with start_checkpoint(). So x is not + * checkpointed + */ +class CheckpointCollector : private ExprMutator { + public: + /*! + * \brief Collect all variables that needs to be checkpointed, and remove the start_checkpoint + * and the end_checkpoint bindings. + * + * \param func The original function + * \return The function with all start_checkpoint and end_checkpoint bindings removed, and a + * VarIdSet containing all checkpointed vars. + */ + static std::pair Collect(const Function& func) { + auto collector = CheckpointCollector(); + return std::make_pair(Downcast(collector.VisitExpr(func)), collector.checkpoints_); + } + + private: + Expr VisitExpr_(const FunctionNode* func) final { + for (auto var : func->params) { + checkpoints_.insert(var->vid); + } + + return ExprMutator::VisitExpr_(func); + } + + void VisitBinding(const Binding& binding) { + static const auto s_cp = Op::Get("relax.grad.start_checkpoint"); + static const auto e_cp = Op::Get("relax.grad.end_checkpoint"); + + // If every variable that the variable of binding relys on is either + // 1) the output of end_checkpoint; 2) checkpointed + // then the variable of binding will be checkpointed + auto var_binding = binding.as(); + ICHECK(var_binding); + + auto value_call = var_binding->value.as(); + if (!value_call || (value_call->op != s_cp && value_call->op != e_cp)) { + bool all_inner_var_checkpointed = true; + PostOrderVisit(var_binding->value, [this, &all_inner_var_checkpointed](const Expr& expr) { + if (auto var = expr.as()) { + all_inner_var_checkpointed &= + (checkpoints_.count(var->vid) != 0 || e_vars_.count(var->vid) != 0); + } + }); + + if (all_inner_var_checkpointed) { + checkpoints_.insert(var_binding->var->vid); + } + } + + ExprMutator::VisitBinding(binding); + } + + // mark vars to be checkpointed, and eliminate bindings with checkpoint calls + void VisitBinding_(const VarBindingNode* binding, const CallNode* value) final { + static const auto s_cp = Op::Get("relax.grad.start_checkpoint"); + static const auto e_cp = Op::Get("relax.grad.end_checkpoint"); + + if (value->op == s_cp || value->op == e_cp) { + // Eliminate the binding + auto var = value->args[0].as(); + ICHECK(var) << "The first argument of relax.grad.start_checkpoint and " + "relax.grad.end_checkpoint should be a Var"; + // var might already be remapped. Find the original var + auto orig_var = Downcast(ExprMutator::VisitExpr(GetRef(var))); + // Add remapping from binding->var to new_var + if (!binding->var.as() && var->IsInstance()) { + // For output binding, emit a dummy binding + this->var_remap_[binding->var->vid] = builder_->EmitOutput(orig_var, orig_var->name_hint()); + } else { + this->var_remap_[binding->var->vid] = orig_var; + } + + if (value->op == s_cp) { + // mark the original var to be checkpointed + checkpoints_.insert(orig_var->vid); + } else if (value->op == e_cp) { + e_vars_.insert(binding->var->vid); + } + } else { + ExprMutator::VisitBinding_(binding, value); + } + } + + VarIdSet checkpoints_; + VarIdSet e_vars_; +}; + +/*! + * \brief A tool class for BackwardBindingGenerator + * Generate the checkpoint bindings. To be specific, in the backward process, we need to use vars + * computed in the forward process. Those vars contained in the given checkpoints array, and the + * inputs of the function, will be used as is; other vars will be computed again (this will + * generate bindings) using the checkpoint vars. + */ +class CheckpointGenerator : private ExprMutator { + public: + /*! + * \brief Generate the checkpoint bindings for BackwardBindingGenerator + * + * \param builder The BlockBuilder of BackwardBindingGenerator, used to generate bindings + * \param orig_params The parameters of the forward function + * \param forward_block The forward DataflowBlock + * \param checkpoints The checkpointed vars. checkpoints being empty means all Vars are + * checkpointed + */ + CheckpointGenerator(const BlockBuilder& builder, const Array& orig_params, + const DataflowBlock& forward_block, const VarIdSet& checkpoints) + : builder_(builder) { + // func params will always be checkpointed + for (auto var : orig_params) { + checkpoint_map_.Set(var, var); + } + + for (auto binding : forward_block->bindings) { + auto* var_binding = binding.as(); + CHECK(var_binding) << "Now only support VarBindingNode"; + auto var = var_binding->var; + binding_map_.Set(var, var_binding->value); + if (checkpoints.count(var->vid)) { + checkpoint_map_.Set(var, var); + } + } + } + + // Receives the forward binding var and value, returns the checkpointed binding var and value. + std::pair UpdateBinding(const Var& var, const Expr& value) { + Expr new_value = VisitExpr(value); + auto it = checkpoint_map_.find(var); + if (it != checkpoint_map_.end()) { + return std::make_pair((*it).second, new_value); + } + auto new_var = builder_->Emit(new_value, var->name_hint() + "_cp"); + checkpoint_map_.Set(var, new_var); + return std::make_pair(new_var, new_value); + } + + private: + using ExprMutator::VisitExpr_; + + // Visit the use-site of a defined Var + Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef(op)); } + + // Visit the use-site of a defined DataflowVar + Expr VisitExpr_(const DataflowVarNode* op) final { return VisitVar(GetRef(op)); } + + Expr VisitVar(const Var& var) { + auto it = checkpoint_map_.find(var); + if (it != checkpoint_map_.end()) { + return (*it).second; + } + Var new_var = builder_->Emit(VisitExpr(binding_map_[var]), var->name_hint() + "_cp"); + checkpoint_map_.Set(var, new_var); + return new_var; + } + + // The only purpose of this function is create a new expr for Call node + // to pass the structual equal check + Expr VisitExpr_(const CallNode* call_node) final { + Expr new_op = this->VisitExpr(call_node->op); + + tvm::Array call_args; + for (Expr arg : call_node->args) { + Expr new_arg = this->VisitExpr(arg); + call_args.push_back(new_arg); + } + return Call(new_op, call_args, call_node->attrs, call_node->sinfo_args); + } + + BlockBuilder builder_; + // The mapping from the forward vars to the checkpoint vars. + Map checkpoint_map_; + // The mapping from the forward vars to their bindings, used to generate checkpoint bindings + Map binding_map_; +}; + +/*! + * \brief A tool class for GradientMutator + * Visit the forward bindings and generate the backward bindings + */ class BackwardBindingGenerator : private ExprVisitor { public: /*! @@ -52,23 +278,26 @@ class BackwardBindingGenerator : private ExprVisitor { * \param builder The BlockBuilder of GradientMutator, used to generate bindings * \param forward_block The forward DataflowBlock * \param require_grads The Var list to differentiate w.r.t. + * \param orig_params The params of the forward function. Used for checkpointing * \param target_var The target Var to differentiate * \param orig_return_value The original return value of the function. The new return value is a - * 2-tuple, containing the original return value, and a tuple of the adjoints of parameters. + * 2-tuple, containing the original return value, and a tuple of the adjoints of parameters + * \param checkpoints The checkpointed vars. checkpoints being empty means all Vars are + * checkpointed * \return The return expr of new adjoint function. */ static Expr Generate(const BlockBuilder& builder, const DataflowBlock& forward_block, const Array& require_grads, const Var& target_var, - const Expr& orig_return_value) { - BackwardBindingGenerator generator(builder); + const Array& orig_params, const Expr& orig_return_value, + const VarIdSet& checkpoints) { + CheckpointGenerator checkpoint_generator(builder, orig_params, forward_block, checkpoints); + BackwardBindingGenerator generator(builder, checkpoint_generator); - // Initialize the adjoint of target_var as ones op. We have already check the target. + // Initialize the adjoint of target_var as ones op. We have already checked the target. auto* target_sinfo = GetStructInfoAs(target_var); - const Expr& target_adjoint = ones(target_sinfo->shape.value(), target_sinfo->dtype); - UpdateStructInfo(target_adjoint, GetRef(target_sinfo)); - generator.adjoint_msg_map_.Set(target_var, AdjointMsg(target_adjoint)); + generator.UpdateAdjoint(target_var, ones(target_sinfo->shape.value(), target_sinfo->dtype)); - // We do reverse-mode ad, so visit bindings backwards + // Do reverse-mode ad, so visit bindings backwards for (auto it = forward_block->bindings.rbegin(); it != forward_block->bindings.rend(); ++it) { generator.VisitBinding(*it); } @@ -77,29 +306,26 @@ class BackwardBindingGenerator : private ExprVisitor { } private: - explicit BackwardBindingGenerator(const BlockBuilder& builder) : builder_(builder) {} + explicit BackwardBindingGenerator(const BlockBuilder& builder, + const CheckpointGenerator& checkpoint_generator) + : builder_(builder), checkpoint_generator_(checkpoint_generator) {} void VisitBinding(const Binding& binding) final { // TODO(chaofan, yixin): support other types of bindings - CHECK(binding->IsInstance()) << "now only support VarBindingNode"; + CHECK(binding->IsInstance()) << "Now only support VarBindingNode"; auto* var_binding = binding.as(); - auto it = adjoint_msg_map_.find(var_binding->var); - if (it == adjoint_msg_map_.end()) { - // This var is not used in the following bindings + if (adjoint_var_map_.count(var_binding->var) == 0) { + // Optimization: this var is not used in the following bindings return; } - // Meet the definition of binding->var - // Create the adjoint var and bind the adjoint value to it - EmitAdjoint(var_binding->var, (*it).second, true); - Expr value = var_binding->value; // TODO(chaofan, yixin): support other types of binding values CHECK(value->IsInstance() || value->IsInstance() || value->IsInstance() || value->IsInstance() || value->IsInstance()) - << "now does not support the type of binding value: " << value; + << "Now does not support the type of binding value: " << value; ExprVisitor::VisitBinding_(var_binding); } @@ -114,44 +340,66 @@ class BackwardBindingGenerator : private ExprVisitor { static const OpAttrMap& gradient_op_map = Op::GetAttrMap("FPrimalGradient"); + static const constexpr char* te_grad_func_prefix = "tvm.relax.te_grad._register."; Var adjoint_var = adjoint_var_map_[binding->var]; const Op& call_op = Downcast(call->op); + // Support for checkpointing + auto [checkpoint_var, checkpoint_call] = + checkpoint_generator_.UpdateBinding(binding->var, GetRef(call)); + if (call_op == Op::Get("relax.call_tir")) { - LOG(FATAL) << "Differentiation of call_tir op is not supported yet."; + LOG(FATAL) << "Differentiation of call_tir op without registering corresponding gradient " + "function is not supported yet."; + } else if (call_op == Op::Get("relax.call_tir_with_grad")) { + // tir gradient registering + auto te_grad_name = call->attrs.as()->te_grad_name; + auto* grad_func = tvm::runtime::Registry::Get(te_grad_func_prefix + te_grad_name); + CHECK(grad_func) << "TIR gradient function " << te_grad_name << " is not registered"; + Var partials = + (*grad_func)(checkpoint_var, Downcast(checkpoint_call), adjoint_var, builder_); + Tuple args = Downcast(call->args[1]); + auto* tuple_sinfo = GetStructInfoAs(partials); + if (!tuple_sinfo) { + // result_var is a tensor + ICHECK(args->fields.size() == 1); + UpdateAdjoint(args->fields[0], partials); + } else { + ICHECK(args->fields.size() == tuple_sinfo->fields.size()); + for (int i = 0; i < static_cast(args->fields.size()); ++i) { + UpdateAdjoint(args->fields[i], TupleGetItem(partials, i)); + } + } } else { - Array partials = - gradient_op_map[call_op](binding->var, GetRef(call), adjoint_var, builder_); + const Array& partials = gradient_op_map[call_op]( + checkpoint_var, Downcast(checkpoint_call), adjoint_var, builder_); ICHECK(partials.size() == call->args.size()) << "partials number != inputs number"; for (size_t i = 0; i < partials.size(); ++i) { - if (IsCallNoGrad(partials[i])) { // no grad: don't update + Expr partial = partials[i]; + if (IsCallNoGrad(partial)) { // no grad: don't update continue; } - if (!partials[i]->struct_info_.defined()) { - UpdateStructInfo(partials[i], GetStructInfo(call->args[i])); - } - UpdateAdjoint(call->args[i], partials[i]); + UpdateAdjoint(call->args[i], partial); } } } // For Tuple nodes, we would iterate over the input tuple and update adjoint exprs for each input // e.g. - // a = (b, c) + // a = (b, c) --> // b_adjoint += a_adjoint_var[0], c_adjoint += a_adjoint_var[1] - // a = ((b, c), d) + // + // a = ((b, c), d) --> // b_adjoint += a_adjoint_var[0][0], c_adjoint += a_adjoint_var[0][1], // d_adjoint += a_adjoint_var[1] - // - // Here we use adjoint_var to simplify calculation void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final { UpdateAdjoint(GetRef(tuple), adjoint_var_map_[binding->var]); } // For TupleGetItem nodes, we do a partial update // e.g. - // b = a[0] + // b = a[0] --> // a_adjoint[0] += b_adjoint_var // If a_adjoint does not exist, we would create a zeros tuple as a_adjoint first, and then add void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* tuple_get_item) final { @@ -161,14 +409,16 @@ class BackwardBindingGenerator : private ExprVisitor { ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a TupleStructInfo"; const Var& tuple_var = Downcast(tuple_get_item->tuple); - if (adjoint_msg_map_.count(tuple_var) == 0) { - const AdjointMsg& init = InitZerosAdjointNested(GetRef(tuple_sinfo)); - adjoint_msg_map_.Set(tuple_var, init); + if (adjoint_var_map_.count(tuple_var) == 0) { + auto nested_zeros = Downcast(NestedZeros(GetRef(tuple_sinfo))); + auto tuple_fields = nested_zeros->fields; + tuple_fields.Set(tuple_get_item->index, adjoint_var_map_[binding->var]); + EmitAdjoint(tuple_var, Tuple(tuple_fields), false); + } else { + Expr updated_adjoint = AddInTuple(adjoint_var_map_[tuple_var], tuple_get_item->index, + adjoint_var_map_[binding->var]); + EmitAdjoint(tuple_var, updated_adjoint, false); } - - adjoint_msg_map_.Set(tuple_var, - AddInAdjointMsg(adjoint_msg_map_[tuple_var], tuple_get_item->index, - ExprToAdjointMsg(adjoint_var_map_[binding->var]))); } // For assign nodes, we add the adjoint of output to the adjoint of input @@ -183,16 +433,24 @@ class BackwardBindingGenerator : private ExprVisitor { // For constant nodes, we do not have to handle it because it does not contribute to the adjoint void VisitBinding_(const VarBindingNode* binding, const ConstantNode* var) final { return; } - // Add partial (Expr type) to the adjoint of expr + // Add partial to the adjoint of expr + // expr may be a argument of a func call / tuple definition. Its type can be + // 1) var 2) constant (in this case, the adjoint will not be updated) + // 3) (maybe nested) tuple of vars / constant + // + // We use NestedMsg to simplify handling (nested) tuples. That requires converting partial from + // expr to NestedMsg or backwards. void UpdateAdjoint(const Expr& expr, const Expr& partial) { - DecomposeNestedMsg(expr, ExprToAdjointMsg(partial), [&](Expr leaf, AdjointMsg msg) { + AdjointMsg partial_msg = ExprToAdjointMsg(builder_->Normalize(partial)); + DecomposeNestedMsg(expr, partial_msg, [&](Expr leaf, AdjointMsg msg) { if (leaf->IsInstance()) { const Var& v = Downcast(leaf); - if (adjoint_msg_map_.count(v) == 0) { - adjoint_msg_map_.Set(v, msg); - } else { - adjoint_msg_map_.Set(v, TupleAwareAdd(adjoint_msg_map_[v], msg)); + Expr updated_adjoint_expr = builder_->Normalize(AdjointMsgToExpr(msg)); + auto it = adjoint_var_map_.find(v); + if (it != adjoint_var_map_.end()) { + updated_adjoint_expr = TupleAwareAdd((*it).second, updated_adjoint_expr); } + EmitAdjoint(v, updated_adjoint_expr, false); } else if (leaf->IsInstance()) { // nothing to do } else if (leaf->IsInstance()) { @@ -205,21 +463,6 @@ class BackwardBindingGenerator : private ExprVisitor { }); } - // Transform the adjoint expressed as NestedMsg into adjoint Expr, and then emit it - // If the adjoint is assigned to a DataflowVar (the adjoint corresponds to a non-output binding), - // it would be stored in adjoint_var_map_ for future lookup - Var EmitAdjoint(const Var& source_var, const AdjointMsg& adjoint, bool is_dataflow_var) { - Var adjoint_var; - if (is_dataflow_var) { - adjoint_var = builder_->Emit(AdjointMsgToExpr(adjoint), source_var->name_hint() + "_adjoint"); - adjoint_var_map_.Set(source_var, adjoint_var); - } else { - adjoint_var = - builder_->EmitOutput(AdjointMsgToExpr(adjoint), source_var->name_hint() + "_adjoint"); - } - return adjoint_var; - } - // Handle the return value of the AD function. // Returns the new return value, which would be like: // Tuple(original_return_value, @@ -229,19 +472,29 @@ class BackwardBindingGenerator : private ExprVisitor { Array out_adjoints; for (Var var : require_grads) { - // If the var don't have adjoint msg, it do not contribute to the target - // so its adjoint is zeros - AdjointMsg adjoint = - adjoint_msg_map_.Get(var).value_or(InitZerosAdjointNested(GetStructInfo(var))); - Var adjoint_var = EmitAdjoint(var, adjoint, false); - out_adjoints.push_back(adjoint_var); + // If the var don't have adjoint var, it do not contribute to the target. So its adjoint is + // zeros + auto it = adjoint_var_map_.find(var); + if (it == adjoint_var_map_.end()) { + UpdateAdjoint(var, NestedZeros(GetStructInfo(var))); + } + Var adjoint_output_var = EmitAdjoint(var, adjoint_var_map_[var], true); + out_adjoints.push_back(adjoint_output_var); } return Tuple({orig_return_value, Tuple(out_adjoints)}); } - static bool IsCallZeros(const Expr& expr) { - return expr->IsInstance() && Downcast(expr)->op == Op::Get("relax.zeros"); + // Emit the adjoint expr as the name `original_var_name` + "_adjoint" + Var EmitAdjoint(const Var& source_var, const Expr& adjoint, bool is_output) { + Var adjoint_var; + if (is_output) { + adjoint_var = builder_->EmitOutput(adjoint, source_var->name_hint() + "_adjoint_out"); + } else { + adjoint_var = builder_->Emit(adjoint, source_var->name_hint() + "_adjoint"); + adjoint_var_map_.Set(source_var, adjoint_var); + } + return adjoint_var; } static bool IsCallNoGrad(const Expr& expr) { @@ -266,99 +519,139 @@ class BackwardBindingGenerator : private ExprVisitor { }); } - // Create a zeros AdjointMsg with specified struct info - // When sinfo is TupleStructInfo, we would create a nested zeros Tuple - static AdjointMsg InitZerosAdjointNested(const StructInfo& sinfo) { - return MapToNestedMsg(sinfo, [](StructInfo sinfo) { + // Create a zeros Expr with specified struct info + // When sinfo is TupleStructInfo, we would create a (nested) Tuple containing zeros + static Expr NestedZeros(const StructInfo& sinfo) { + AdjointMsg msg = MapToNestedMsg(sinfo, [](StructInfo sinfo) { auto* tensor_sinfo = sinfo.as(); ICHECK(tensor_sinfo) << "The leaf of adjoint should be a Tensor."; - ICHECK(tensor_sinfo->shape.defined()) << "Error: missing shape when building zeros tuple."; + ICHECK(tensor_sinfo->shape.defined()) << "Missing shape when building zeros tuple."; const Expr& init = zeros(tensor_sinfo->shape.value(), tensor_sinfo->dtype); - UpdateStructInfo(init, sinfo); return init; }); + return AdjointMsgToExpr(msg); } - // Return base + increment. A tuple-aware addition. - static AdjointMsg TupleAwareAdd(const AdjointMsg& base, const AdjointMsg& increment) { - return CombineNestedMsg(base, increment, [&](Expr lhs, Expr rhs) { - // a small optimization: a+0=a, 0+a=a. - if (IsCallZeros(lhs)) { - return rhs; - } else if (IsCallZeros(rhs)) { - return lhs; - } - auto* sinfo = GetStructInfoAs(lhs); - ICHECK(sinfo) << "The leaf of adjoint should have StructInfo and be a Tensor."; - ICHECK(GetStructInfoAs(rhs)) - << "The leaf of adjoint should have StructInfo and be a Tensor."; - Expr res = add(lhs, rhs); - UpdateStructInfo(res, GetRef(sinfo)); - return res; - }); + // Return lhs + rhs. Requires lhs and rhs has the same StructInfo. + // Use NestedMsg to handle cases when lhs and rhs are tuples. + static Expr TupleAwareAdd(const Expr& lhs, const Expr& rhs) { + AdjointMsg res = CombineNestedMsg( + ExprToAdjointMsg(lhs), ExprToAdjointMsg(rhs), [](Expr l_leaf, Expr r_leaf) { + auto* sinfo = GetStructInfoAs(l_leaf); + ICHECK(sinfo) << "The leaf of adjoint should have StructInfo and be a Tensor."; + ICHECK(GetStructInfoAs(r_leaf)) + << "The leaf of adjoint should have StructInfo and be a Tensor."; + Expr res = add(l_leaf, r_leaf); + UpdateStructInfo(res, GetRef(sinfo)); + return res; + }); + return AdjointMsgToExpr(res); } // Perform an addition in a specified position of tuple. - // e.g. tuple=(a, b, c), index=1, increment=d, then return (a, b+d, c) - static AdjointMsg AddInAdjointMsg(const AdjointMsg& adjoint, int index, - const AdjointMsg& increment) { - ICHECK(adjoint.IsNested()) << "The adjoint should be nested."; - Array arr = adjoint.NestedArray(); - ICHECK(index >= 0 && index < static_cast(arr.size())); - arr.Set(index, TupleAwareAdd(arr[index], increment)); - return AdjointMsg(arr); + // tuple[index] += increment + // Impl: + // Step 1) t1 = tuple[0], t2 = tuple[1], t3 = tuple[2] + // Step 2)t2_new = t2 + increment (TupleAwareAdd) + // Step 3) tuple_new = (t1, t2_new, t3) + static Expr AddInTuple(const Expr& tuple, int index, const Expr& increment) { + auto* sinfo = GetStructInfoAs(tuple); + ICHECK(sinfo) << "The first argument of AddInTuple should have tuple struct info."; + ICHECK(index >= 0 && index < static_cast(sinfo->fields.size())); + Array res; + for (size_t i = 0; i < sinfo->fields.size(); ++i) { + Expr field; + if (const auto* expr_tuple = tuple.as()) { + field = expr_tuple->fields[i]; + } else { + field = TupleGetItem(tuple, i); + } + if (static_cast(i) == index) { + field = TupleAwareAdd(field, increment); + } + res.push_back(field); + } + return Tuple(res); } // The block builder of the corresponding GradientMutator, to emit bindings BlockBuilder builder_; // Forward Var to its adjoint Var Map adjoint_var_map_; - // Forward Var to its adjoint NestedMsg - // We use NestedMsg to save the adjoint information (equivalent to adjoint Expr) - // When emitting, adjoint information will be transformed into adjoint Expr - Map adjoint_msg_map_; + // The generator for checkpoint bindings + CheckpointGenerator checkpoint_generator_; }; class GradientMutator : private ExprMutator { public: static IRModule Transform(IRModule mod, String func_name, Optional> require_grads, int target_index) { - auto* old_func_ptr = mod->Lookup(func_name).as(); - CHECK(old_func_ptr) << func_name << "is not a Relax Function"; - auto old_func = GetRef(old_func_ptr); - - // when require_grads is not specified, it would be set to all params of the function - auto require_grads_value = require_grads.value_or(old_func->params); - - CheckRequireGrads(require_grads_value, old_func->params, func_name); - - Function new_func = CopyWithNewVars(old_func); - // map the parameter list into new params - for (size_t i = 0; i < require_grads_value.size(); ++i) { - int idx = - std::find(old_func->params.begin(), old_func->params.end(), require_grads_value[i]) - - old_func->params.begin(); - require_grads_value.Set(i, new_func->params[idx]); + // Step 1. Copy function + auto* old_func = mod->Lookup(func_name).as(); + CHECK(old_func) << func_name << "is not a Relax Function"; + auto new_func = CopyWithNewVars(GetRef(old_func)); + + // Step 2. Handle the checkpoints and eliminate start_checkpoint and end_checkpoint ops + auto checkpoint_collected = CheckpointCollector::Collect(new_func); + new_func = checkpoint_collected.first; + auto checkpoints = checkpoint_collected.second; + + // Step 3. Collect call_tir_with_grad information + auto tir_grad_collected = CallTIRWithGradEliminator::Transform(new_func); + + // Step 4. Handle require_grads + // When require_grads is not specified, it would be set to all params of the function + if (require_grads) { + CheckRequireGrads(require_grads.value(), old_func->params, func_name); } + // then map the parameter list into new params + auto require_grads_value = require_grads.value_or(old_func->params).Map([&](const Var& v) { + return new_func->params[std::find(old_func->params.begin(), old_func->params.end(), v) - + old_func->params.begin()]; + }); - GradientMutator mutator(mod, require_grads_value, target_index); - - // make the adjoint public - auto new_name = func_name + "_adjoint"; - Function new_func_transformed = WithAttr(Downcast(mutator.VisitExpr(new_func)), - tvm::attr::kGlobalSymbol, new_name); - - mutator.builder_->AddFunction(new_func_transformed, new_name); - return mutator.builder_->GetContextIRModule(); + // Step 5. Generate the adjoint function, use RemoveAllUnused to simplify it, and then return + // the IRModule with the adjoint function + return GradientMutator(mod, require_grads_value, target_index, checkpoints) + .AddAdjointFunction(new_func, func_name, true); } private: - GradientMutator(const IRModule& module, const Array& require_grads, int target_index) - : ExprMutator(module), require_grads_(require_grads), target_index_(target_index) {} + GradientMutator(const IRModule& module, const Array& require_grads, int target_index, + const VarIdSet& checkpoints) + : ExprMutator(module), + require_grads_(require_grads), + checkpoints_(checkpoints), + target_index_(target_index) {} + + // Add the adjoint function of func to the IRModule using BlockBuilder + IRModule AddAdjointFunction(const Function& func, const String& func_name, + bool remove_all_unused = true) { + // Step 5.1 forward -> forward + backward + auto new_func = Downcast(VisitExpr(func)); + + // Step 5.2 Convert call_tir_with_grad nodes into call_tir nodes + // because call_tir_with_grad nodes is not actually implemented + new_func = CallTIRWithGradEliminator::Transform(new_func); + + if (remove_all_unused) { + new_func = RemoveAllUnused(new_func); + } + + // Step 5.3 mark the transformed function as public + // because the original function may be public, and have gsymbol attribute as func_name + auto new_func_name = func_name + "_adjoint"; + auto new_func_with_gsymbol = WithAttr(new_func, tvm::attr::kGlobalSymbol, new_func_name); + + // Step 5.4 Add the transformed function to IRModule + builder_->AddFunction(new_func_with_gsymbol, new_func_name); + return builder_->GetContextIRModule(); + } Expr VisitExpr_(const FunctionNode* func) final { CHECK(func->body->IsInstance()) << "The body of the function must be SeqExpr."; + orig_params_ = func->params; Expr new_body = this->VisitExpr(func->body); return Function(func->params, new_body, NullOpt, func->is_pure, func->attrs); @@ -376,7 +669,7 @@ class GradientMutator : private ExprMutator { CheckAndSetTarget(seq_expr->body, target_index_); BindingBlock new_block = this->VisitBindingBlock(seq_expr->blocks[0]); - return SeqExpr({new_block}, this->return_expr_); + return SeqExpr({new_block}, return_expr_); } BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { @@ -387,9 +680,9 @@ class GradientMutator : private ExprMutator { } // generate backward bindings and the return value - return_expr_ = BackwardBindingGenerator::Generate(this->builder_, GetRef(block), - this->require_grads_, this->target_var_, - orig_return_expr_); + return_expr_ = BackwardBindingGenerator::Generate(builder_, GetRef(block), + require_grads_, target_var_, orig_params_, + orig_return_expr_, checkpoints_); return builder_->EndBlock(); } @@ -435,13 +728,14 @@ class GradientMutator : private ExprMutator { // 3. the type of the input var should be Tensor of floating point dtype, or Tuple of that static void CheckRequireGrads(const Array& require_grads, const Array& func_params, const String& func_name) { - std::unordered_set var_set; + VarIdSet var_set; for (const auto& var : require_grads) { CHECK(std::find(func_params.begin(), func_params.end(), var) != func_params.end()) << "There is no Var named " << var->name_hint() << " in the parameters of the function " << func_name; - CHECK_EQ(var_set.count(var), 0) << "Var " << var->name_hint() << " appears more than once"; - var_set.emplace(var); + CHECK_EQ(var_set.count(var->vid), 0) + << "Var " << var->name_hint() << " appears more than once"; + var_set.emplace(var->vid); CHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo)) << "Only Tensors of floating point dtype or Tuples of float " @@ -452,10 +746,13 @@ class GradientMutator : private ExprMutator { // differentiation sources Array require_grads_; + // checkpoint + VarIdSet checkpoints_; // the differentiation target int target_index_; Var target_var_; // the return value of the original function and the differentiated function + Array orig_params_; Expr orig_return_expr_; Expr return_expr_; }; diff --git a/tests/python/relax/test_op_grad.py b/tests/python/relax/test_op_grad.py index 01c4226d96c0..d27bae78e212 100644 --- a/tests/python/relax/test_op_grad.py +++ b/tests/python/relax/test_op_grad.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm +from tvm._ffi.base import TVMError import tvm.testing from tvm import relax from tvm.ir import Op @@ -43,6 +45,11 @@ def test_op_correctness(): "relax.grad.take_backward" ) assert relax.op.grad.no_grad(x).op == Op.get("relax.grad.no_grad") + assert relax.op.grad.no_grad(x).args[0] == x + assert relax.op.grad.start_checkpoint(x).op == Op.get("relax.grad.start_checkpoint") + assert relax.op.grad.start_checkpoint(x).args[0] == x + assert relax.op.grad.end_checkpoint(x).op == Op.get("relax.grad.end_checkpoint") + assert relax.op.grad.end_checkpoint(x).args[0] == x def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -50,6 +57,44 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) +def test_start_checkpoint_input_not_var(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 4), "float32")) + y = relax.Var("y", R.Tensor((3, 4), "float32")) + + # ok because x + y will be normalized into a relax Var + with bb.function("main", [x, y]): + gv = bb.emit(relax.op.grad.start_checkpoint(x + y)) + bb.emit_func_output(gv) + + # wrong: tuple will not be normalized + with pytest.raises(TVMError): + bb.normalize(relax.op.grad.start_checkpoint((x, y))) + + # wrong: const will not be normalized + with pytest.raises(TVMError): + bb.normalize(relax.op.grad.start_checkpoint(relax.const(1, "float32"))) + + +def test_end_checkpoint_input_not_var(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 4), "float32")) + y = relax.Var("y", R.Tensor((3, 4), "float32")) + + # ok because x + y will be normalized into a relax Var + with bb.function("main", [x, y]): + gv = bb.emit(relax.op.grad.end_checkpoint(x + y)) + bb.emit_func_output(gv) + + # wrong: tuple will not be normalized + with pytest.raises(TVMError): + bb.normalize(relax.op.grad.end_checkpoint((x, y))) + + # wrong: const will not be normalized + with pytest.raises(TVMError): + bb.normalize(relax.op.grad.end_checkpoint(relax.const(1, "float32"))) + + def test_nll_loss_backward_infer_struct_info(): bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py index 0a7fa3672e52..4b4c5cabc416 100644 --- a/tests/python/relax/test_op_gradient_numeric.py +++ b/tests/python/relax/test_op_gradient_numeric.py @@ -747,7 +747,7 @@ def test_nll_loss_no_batch(target, dev, nll_reduction1, nll_weighted1, nll_ignor ) -(c2d_shape1, c2d_shape2, c2d_kwargs,) = tvm.testing.parameters( +(c2d_shape1, c2d_shape2, c2d_kwargs) = tvm.testing.parameters( ( (3, 2, 10, 10), (3, 2, 3, 3), @@ -797,7 +797,7 @@ def test_conv2d(target, dev, c2d_shape1, c2d_shape2, c2d_kwargs): ) -(pool_size, pool_kwargs,) = tvm.testing.parameters( +(pool_size, pool_kwargs) = tvm.testing.parameters( ( (3, 3), {}, diff --git a/tests/python/relax/test_training_setup_trainer.py b/tests/python/relax/test_training_setup_trainer.py index 401ffdf76ccf..0197aa9745a8 100644 --- a/tests/python/relax/test_training_setup_trainer.py +++ b/tests/python/relax/test_training_setup_trainer.py @@ -68,13 +68,14 @@ def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False) gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64") lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2])) + lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) - lv1_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) - lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.add(lv_1, lv1_1) - x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint + lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1) + x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1 y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint - R.output(gv, y_adjoint) - return (gv, (y_adjoint,)) + y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint + R.output(gv, y_adjoint_out) + return (gv, (y_adjoint_out,)) @R.function def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"))): @@ -142,13 +143,14 @@ def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False) gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64") lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2])) + lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) - lv1_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv) - lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.add(lv_1, lv1_1) - x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint + lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1) + x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1 y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint - R.output(z1, gv, y_adjoint) - return ((gv, z1), (y_adjoint,)) + y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint + R.output(z1, gv, y_adjoint_out) + return ((gv, z1), (y_adjoint_out,)) @R.function def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))): @@ -163,9 +165,10 @@ def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.T lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_v_new) y_new: R.Tensor((2, 2), dtype="float64") = R.subtract(y, lv1) params_new: R.Tuple(R.Tensor((2, 2), dtype="float64")) = (y_new,) - optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64")) = (num_steps_new, y_v_new) + optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64")) = num_steps_new, y_v_new R.output(params_new, optim_states_new) return (params_new, optim_states_new) + # fmt: on sinfo = relax.TensorStructInfo((2, 2), "float64") diff --git a/tests/python/relax/test_transform_gradient.py b/tests/python/relax/test_transform_gradient.py index 50063fe385bb..b96932f8c5d5 100644 --- a/tests/python/relax/test_transform_gradient.py +++ b/tests/python/relax/test_transform_gradient.py @@ -14,14 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm +import numpy as np import pytest + +import tvm import tvm.testing from tvm import relax +from tvm._ffi.base import TVMError from tvm.ir.base import assert_structural_equal from tvm.script.parser import relax as R, tir as T, ir as I -from tvm._ffi.base import TVMError -import numpy as np def test_simple(): @@ -38,20 +39,21 @@ def main(x: R.Tensor((3, 3), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor(None, "float32", ndim=0): + def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - gv: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) - R.output(gv) - return gv + gv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + x_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor(None, "float32", ndim=0),R.Tuple(R.Tensor(None, "float32", ndim=2)),): + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - gv: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - x_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) + gv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -74,27 +76,27 @@ def main(x: R.Tensor((3, 3), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): - # block 0 + def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = x - lv2: R.Tensor((3, 3), "float32") = lv1 - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tensor((3, 3), dtype="float32") = x + lv2: R.Tensor((3, 3), dtype="float32") = lv1 + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = x - lv2: R.Tensor((3, 3), "float32") = lv1 - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv2_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint - x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) + lv1: R.Tensor((3, 3), dtype="float32") = x + lv2: R.Tensor((3, 3), dtype="float32") = lv1 + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -117,27 +119,29 @@ def main(x: R.Tensor((3, 3), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, x) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1, x) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint + x_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(x_adjoint, lv1_adjoint) + x_adjoint2: R.Tensor((3, 3), dtype="float32") = R.add(x_adjoint1, lv1_adjoint) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint2 + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, x) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv2_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint - lv: R.Tensor((3, 3), "float32") = R.add(lv2_adjoint, lv1_adjoint) - x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv, lv1_adjoint) - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1, x) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -149,7 +153,7 @@ def test_unused(): @I.ir_module class Before: @R.function - def main(x: R.Tensor((3, 3), "float32")): + def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")): with R.dataflow(): lv1 = R.add(x, x) lv2 = R.add(lv1, x) @@ -160,24 +164,25 @@ def main(x: R.Tensor((3, 3), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, x) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x) - gv: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) - R.output(gv) - return gv + gv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + x_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + y_adjoint: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, x) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x) - gv: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - x_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1, x) + gv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -189,11 +194,7 @@ def test_default_require_grads(): @I.ir_module class Before: @R.function - def main( - x: R.Tensor((3, 3), "float32"), - y: R.Tensor((3, 3), "float32"), - z: R.Tensor((3, 3), "float32"), - ): + def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")): with R.dataflow(): lv1 = R.add(x, y) lv2 = R.add(lv1, z) @@ -204,33 +205,31 @@ def main( @I.ir_module class Expected1: @R.function - def main( - x: R.Tensor((3, 3), "float32"), - y: R.Tensor((3, 3), "float32"), - z: R.Tensor((3, 3), "float32"), - ) -> R.Tensor((), "float32"): - # block 0 + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, y) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1, z) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint + z_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint + y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out, z_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out, z_adjoint_out)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, y) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv2_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint - x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint - y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint - z_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint - R.output(gv, x_adjoint, y_adjoint, z_adjoint) - return (gv, (x_adjoint, y_adjoint, z_adjoint)) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1, z) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After1 = relax.transform.Gradient("main")(Before) @@ -240,28 +239,27 @@ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), @I.ir_module class Expected2: @R.function - def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): - # block 0 + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, y) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1, z) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): - # block 0 + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, y) - lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv2_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint - x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1, z) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After2 = relax.transform.Gradient("main", require_grads=Before["main"].params[0])(Before) @@ -284,25 +282,27 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((), "float32"), R.Tensor((), "float32")): + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = x - lv2: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) - lv3: R.Tensor((), "float32") = R.sum(y, axis=None, keepdims=False) - R.output(lv1, lv2, lv3) - return (lv1, lv2, lv3) + lv1: R.Tensor((3, 3), dtype="float32") = x + lv2: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + lv3_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + y_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv3_adjoint, R.shape([3, 3])) + x_adjoint: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + R.output(lv1, lv2, lv3, x_adjoint_out, y_adjoint_out) + return ((lv1, lv2, lv3), (x_adjoint_out, y_adjoint_out)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((), "float32"), R.Tensor((), "float32")), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = x - lv2: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) - lv3: R.Tensor((), "float32") = R.sum(y, axis=None, keepdims=False) - lv3_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - x_adjoint: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - y_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(lv3_adjoint, (3, 3)) - R.output(lv1, lv2, lv3, x_adjoint, y_adjoint) - return ((lv1, lv2, lv3), (x_adjoint, y_adjoint)) + lv1: R.Tensor((3, 3), dtype="float32") = x + lv2: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + R.output(lv1, lv2, lv3) + return (lv1, lv2, lv3) # fmt: on After = relax.transform.Gradient("main", target_index=2)(Before) @@ -328,39 +328,43 @@ def main( R.output(gv) return gv + @I.ir_module class Expected: @R.function - def main(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tensor(None, "float32", ndim=0): + def main_adjoint(x: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (y, z) - lv2: R.Tensor((3, 3), "float32") = x[0] - lv3: R.Tensor((3, 3), "float32") = lv1[0] - lv4: R.Tensor((3, 3), "float32") = R.add(lv2, lv3) - gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (y, z) + lv2: R.Tensor((3, 3), dtype="float32") = x[0] + lv3: R.Tensor((3, 3), dtype="float32") = lv1[0] + lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = lv4_adjoint + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = lv4_adjoint + lv: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv1_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv3_adjoint, lv) + lv1_1: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + x_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv2_adjoint, lv1_1) + y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint[0] + z_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint[1] + x_adjoint_out: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = x_adjoint + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out, z_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out, z_adjoint_out)) @R.function - def main_adjoint(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): + def main(x: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (y, z) - lv2: R.Tensor((3, 3), "float32") = x[0] - lv3: R.Tensor((3, 3), "float32") = lv1[0] - lv4: R.Tensor((3, 3), "float32") = R.add(lv2, lv3) - gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv4_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv3_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint - lv2_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint - lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv1_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv3_adjoint, lv) - lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - x_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv2_adjoint, lv11) - y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[0] - z_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1] - R.output(gv, x_adjoint, y_adjoint, z_adjoint) - return (gv, (x_adjoint, y_adjoint, z_adjoint)) + lv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (y, z) + lv2: R.Tensor((3, 3), dtype="float32") = x[0] + lv3: R.Tensor((3, 3), dtype="float32") = lv1[0] + lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -375,57 +379,60 @@ class Before: def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")): with R.dataflow(): lv1 = (x, y) - lv4 = lv1[0] - lv7 = R.add(lv4, x) - lv2 = lv1 - lv3 = lv2[0] - lv5 = R.add(lv3, lv7) - gv = R.sum(lv5) + lv2 = lv1[0] + lv3 = R.add(lv2, x) + lv4 = lv1 + lv5 = lv4[0] + lv6 = R.add(lv5, lv3) + gv = R.sum(lv6) R.output(gv) return gv @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) - lv4: R.Tensor((3, 3), "float32") = lv1[0] - lv7: R.Tensor((3, 3), "float32") = R.add(lv4, x) - lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv1 - lv3: R.Tensor((3, 3), "float32") = lv2[0] - lv5: R.Tensor((3, 3), "float32") = R.add(lv3, lv7) - gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (x, y) + lv2: R.Tensor((3, 3), dtype="float32") = lv1[0] + lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, x) + lv4: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv1 + lv5: R.Tensor((3, 3), dtype="float32") = lv4[0] + lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv6, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv6_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv5_adjoint: R.Tensor((3, 3), dtype="float32") = lv6_adjoint + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = lv6_adjoint + lv: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv4_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv5_adjoint, lv) + lv1_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv4_adjoint + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = lv3_adjoint + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv3_adjoint + lv1_1: R.Tensor((3, 3), dtype="float32") = lv1_adjoint[0] + lv2_1: R.Tensor((3, 3), dtype="float32") = R.add(lv1_1, lv2_adjoint) + lv3_1: R.Tensor((3, 3), dtype="float32") = lv1_adjoint[1] + lv1_adjoint1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv2_1, lv3_1) + lv4_1: R.Tensor((3, 3), dtype="float32") = lv1_adjoint1[0] + x_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(x_adjoint, lv4_1) + y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint1[1] + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1 + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): - # block 0 + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) - lv4: R.Tensor((3, 3), "float32") = lv1[0] - lv7: R.Tensor((3, 3), "float32") = R.add(lv4, x) - lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv1 - lv3: R.Tensor((3, 3), "float32") = lv2[0] - lv5: R.Tensor((3, 3), "float32") = R.add(lv3, lv7) - gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv5_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv3_adjoint: R.Tensor((3, 3), "float32") = lv5_adjoint - lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv3_adjoint, lv) - lv7_adjoint: R.Tensor((3, 3), "float32") = lv5_adjoint - lv4_adjoint: R.Tensor((3, 3), "float32") = lv7_adjoint - lv11: R.Tensor((3, 3), "float32") = lv2_adjoint[0] - lv21: R.Tensor((3, 3), "float32") = R.add(lv11, lv4_adjoint) - lv31: R.Tensor((3, 3), "float32") = lv2_adjoint[1] - lv1_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv21, lv31) - lv41: R.Tensor((3, 3), "float32") = lv1_adjoint[0] - x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv7_adjoint, lv41) - y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1] - R.output(gv, x_adjoint, y_adjoint) - return (gv, (x_adjoint, y_adjoint)) + lv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (x, y) + lv2: R.Tensor((3, 3), dtype="float32") = lv1[0] + lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, x) + lv4: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv1 + lv5: R.Tensor((3, 3), dtype="float32") = lv4[0] + lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv6, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -459,51 +466,66 @@ def main( @I.ir_module class Expected: @R.function - def main(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u) - lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = x[0] - lv3: R.Tensor((3, 3), "float32") = lv2[0] - lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv1[0] - lv5: R.Tensor((3, 3), "float32") = lv4[1] - lv6: R.Tensor((3, 3), "float32") = R.add(lv3, lv5) - lv7: R.Tensor((3, 3), "float32") = x[1] - lv8: R.Tensor((3, 3), "float32") = R.add(lv6, lv7) - gv: R.Tensor((), "float32") = R.sum(lv8, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")) = ((y, z), u) + lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = x[0] + lv3: R.Tensor((3, 3), dtype="float32") = lv2[0] + lv4: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv1[0] + lv5: R.Tensor((3, 3), dtype="float32") = lv4[1] + lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv3, lv5) + lv7: R.Tensor((3, 3), dtype="float32") = x[1] + lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv6, lv7) + gv: R.Tensor((), dtype="float32") = R.sum(lv8, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv8_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv6_adjoint: R.Tensor((3, 3), dtype="float32") = lv8_adjoint + lv7_adjoint: R.Tensor((3, 3), dtype="float32") = lv8_adjoint + lv: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv1_1: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + x_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")) = ((lv, lv1_1), lv7_adjoint) + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = lv6_adjoint + lv5_adjoint: R.Tensor((3, 3), dtype="float32") = lv6_adjoint + lv2_1: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv4_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv2_1, lv5_adjoint) + lv3_1: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv1_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")) = (lv4_adjoint, lv3_1) + lv4_1: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv2_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv3_adjoint, lv4_1) + lv5_1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = x_adjoint[0] + lv6_1: R.Tensor((3, 3), dtype="float32") = lv5_1[0] + lv7_1: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[0] + lv8_1: R.Tensor((3, 3), dtype="float32") = R.add(lv6_1, lv7_1) + lv9: R.Tensor((3, 3), dtype="float32") = lv5_1[1] + lv10: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[1] + lv11: R.Tensor((3, 3), dtype="float32") = R.add(lv9, lv10) + lv12: R.Tensor((3, 3), dtype="float32") = x_adjoint[1] + x_adjoint1: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")) = ((lv8_1, lv11), lv12) + lv13: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv1_adjoint[0] + y_adjoint: R.Tensor((3, 3), dtype="float32") = lv13[0] + z_adjoint: R.Tensor((3, 3), dtype="float32") = lv13[1] + u_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint[1] + x_adjoint_out: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")) = x_adjoint1 + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint + u_adjoint_out: R.Tensor((3, 3), dtype="float32") = u_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out, z_adjoint_out, u_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out, z_adjoint_out, u_adjoint_out)) @R.function - def main_adjoint(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): + def main(x: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u) - lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = x[0] - lv3: R.Tensor((3, 3), "float32") = lv2[0] - lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv1[0] - lv5: R.Tensor((3, 3), "float32") = lv4[1] - lv6: R.Tensor((3, 3), "float32") = R.add(lv3, lv5) - lv7: R.Tensor((3, 3), "float32") = x[1] - lv8: R.Tensor((3, 3), "float32") = R.add(lv6, lv7) - gv: R.Tensor((), "float32") = R.sum(lv8, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv8_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv7_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint - lv6_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint - lv5_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint - lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv4_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv, lv5_adjoint) - lv3_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint - lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv3_adjoint, lv11) - lv21: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv1_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = (lv4_adjoint, lv21) - x_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = (lv2_adjoint, lv7_adjoint) - lv31: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv1_adjoint[0] - y_adjoint: R.Tensor((3, 3), "float32") = lv31[0] - z_adjoint: R.Tensor((3, 3), "float32") = lv31[1] - u_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1] - R.output(gv, x_adjoint, y_adjoint, z_adjoint, u_adjoint) - return (gv, (x_adjoint, y_adjoint, z_adjoint, u_adjoint)) + lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")), R.Tensor((3, 3), dtype="float32")) = ((y, z), u) + lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = x[0] + lv3: R.Tensor((3, 3), dtype="float32") = lv2[0] + lv4: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv1[0] + lv5: R.Tensor((3, 3), dtype="float32") = lv4[1] + lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv3, lv5) + lv7: R.Tensor((3, 3), dtype="float32") = x[1] + lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv6, lv7) + gv: R.Tensor((), dtype="float32") = R.sum(lv8, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -512,7 +534,6 @@ def main_adjoint(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3) def test_tuple_update(): """One tensor `x` is used in and out of tuple many times.""" - # fmt: off @I.ir_module class Before: @@ -536,61 +557,66 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): with R.dataflow(): - lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) - lv1: R.Tensor((3, 3), "float32") = R.add(x, y) - lv2: R.Tensor((3, 3), "float32") = lv0[0] - lv3: R.Tensor((3, 3), "float32") = R.add(lv2, y) - lv4: R.Tensor((3, 3), "float32") = R.add(lv1, lv3) - lv5: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) - lv6: R.Tensor((3, 3), "float32") = lv5[0] - lv7: R.Tensor((3, 3), "float32") = lv0[0] - lv8: R.Tensor((3, 3), "float32") = R.add(lv4, lv6) - lv9: R.Tensor((3, 3), "float32") = R.add(lv8, lv7) - gv: R.Tensor((), "float32") = R.sum(lv9, axis=None, keepdims=False) - R.output(gv) - return gv + lv0: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (x, y) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = lv0[0] + lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, y) + lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv1, lv3) + lv5: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (x, y) + lv6: R.Tensor((3, 3), dtype="float32") = lv5[0] + lv7: R.Tensor((3, 3), dtype="float32") = lv0[0] + lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv4, lv6) + lv9: R.Tensor((3, 3), dtype="float32") = R.add(lv8, lv7) + gv: R.Tensor((), dtype="float32") = R.sum(lv9, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv9_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv8_adjoint: R.Tensor((3, 3), dtype="float32") = lv9_adjoint + lv7_adjoint: R.Tensor((3, 3), dtype="float32") = lv9_adjoint + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = lv8_adjoint + lv6_adjoint: R.Tensor((3, 3), dtype="float32") = lv8_adjoint + lv: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv0_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv7_adjoint, lv) + lv1_1: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv5_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv6_adjoint, lv1_1) + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv5_adjoint[0] + y_adjoint: R.Tensor((3, 3), dtype="float32") = lv5_adjoint[1] + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv4_adjoint + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = lv4_adjoint + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = lv3_adjoint + y_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(y_adjoint, lv3_adjoint) + lv2_1: R.Tensor((3, 3), dtype="float32") = lv0_adjoint[0] + lv3_1: R.Tensor((3, 3), dtype="float32") = R.add(lv2_1, lv2_adjoint) + lv4_1: R.Tensor((3, 3), dtype="float32") = lv0_adjoint[1] + lv0_adjoint1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv3_1, lv4_1) + x_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(x_adjoint, lv1_adjoint) + y_adjoint2: R.Tensor((3, 3), dtype="float32") = R.add(y_adjoint1, lv1_adjoint) + lv5_1: R.Tensor((3, 3), dtype="float32") = lv0_adjoint1[0] + x_adjoint2: R.Tensor((3, 3), dtype="float32") = R.add(x_adjoint1, lv5_1) + lv6_1: R.Tensor((3, 3), dtype="float32") = lv0_adjoint1[1] + y_adjoint3: R.Tensor((3, 3), dtype="float32") = R.add(y_adjoint2, lv6_1) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint2 + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint3 + R.output(gv, x_adjoint_out, y_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out)) @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) - lv1: R.Tensor((3, 3), "float32") = R.add(x, y) - lv2: R.Tensor((3, 3), "float32") = lv0[0] - lv3: R.Tensor((3, 3), "float32") = R.add(lv2, y) - lv4: R.Tensor((3, 3), "float32") = R.add(lv1, lv3) - lv5: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) - lv6: R.Tensor((3, 3), "float32") = lv5[0] - lv7: R.Tensor((3, 3), "float32") = lv0[0] - lv8: R.Tensor((3, 3), "float32") = R.add(lv4, lv6) - lv9: R.Tensor((3, 3), "float32") = R.add(lv8, lv7) - gv: R.Tensor((), "float32") = R.sum(lv9, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv9_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv8_adjoint: R.Tensor((3, 3), "float32") = lv9_adjoint - lv7_adjoint: R.Tensor((3, 3), "float32") = lv9_adjoint - lv6_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint - lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv5_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv6_adjoint, lv) - lv4_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint - lv3_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint - lv2_adjoint: R.Tensor((3, 3), "float32") = lv3_adjoint - lv1_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint - lv11: R.Tensor((3, 3), "float32") = R.add(lv7_adjoint, lv2_adjoint) - lv21: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv0_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv11, lv21) - lv31: R.Tensor((3, 3), "float32") = lv5_adjoint[0] - lv41: R.Tensor((3, 3), "float32") = R.add(lv31, lv1_adjoint) - lv51: R.Tensor((3, 3), "float32") = lv0_adjoint[0] - x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv41, lv51) - lv61: R.Tensor((3, 3), "float32") = lv5_adjoint[1] - lv71: R.Tensor((3, 3), "float32") = R.add(lv61, lv3_adjoint) - lv81: R.Tensor((3, 3), "float32") = R.add(lv71, lv1_adjoint) - lv91: R.Tensor((3, 3), "float32") = lv0_adjoint[1] - y_adjoint: R.Tensor((3, 3), "float32") = R.add(lv81, lv91) - R.output(gv, x_adjoint, y_adjoint) - return (gv, (x_adjoint, y_adjoint)) + lv0: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (x, y) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = lv0[0] + lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, y) + lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv1, lv3) + lv5: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (x, y) + lv6: R.Tensor((3, 3), dtype="float32") = lv5[0] + lv7: R.Tensor((3, 3), dtype="float32") = lv0[0] + lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv4, lv6) + lv9: R.Tensor((3, 3), dtype="float32") = R.add(lv8, lv7) + gv: R.Tensor((), dtype="float32") = R.sum(lv9, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -613,26 +639,27 @@ def main(x: R.Tensor((6,), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((6,), "float32")) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tensor((6,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((6,), dtype="float32"))): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(x, indices_or_sections=2, axis=0) - lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(x, indices_or_sections=2, axis=0) + lv2: R.Tensor((6,), dtype="float32") = R.concat(lv1, axis=0) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv2_adjoint: R.Tensor((6,), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([6])) + lv1_adjoint: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0) + x_adjoint: R.Tensor((6,), dtype="float32") = R.concat(lv1_adjoint, axis=0) + x_adjoint_out: R.Tensor((6,), dtype="float32") = x_adjoint + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) @R.function - def main_adjoint(x: R.Tensor((6,), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((6,), "float32"))): + def main(x: R.Tensor((6,), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(x, indices_or_sections=2, axis=0) - lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0) - gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv2_adjoint: R.Tensor((6,), "float32") = R.broadcast_to(gv_adjoint, (6,)) - lv1_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0) - x_adjoint: R.Tensor((6,), "float32") = R.concat(lv1_adjoint, axis=0) - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) + lv1: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(x, indices_or_sections=2, axis=0) + lv2: R.Tensor((6,), dtype="float32") = R.concat(lv1, axis=0) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -659,46 +686,48 @@ def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3, ), "float32"), R. @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32"))) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tensor((3,), dtype="float32"), y: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32"))) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3,), dtype="float32"), R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")))): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = (x, x) - lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0) - lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0) - lv4: R.Tensor((6,), "float32") = R.concat(y, axis=0) - lv5: R.Tensor((6,), "float32") = R.add(lv2, lv3) - lv6: R.Tensor((6,), "float32") = R.add(lv5, lv4) - gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = (x, x) + lv2: R.Tensor((6,), dtype="float32") = R.concat(lv1, axis=0) + lv3: R.Tensor((6,), dtype="float32") = R.concat((x, x), axis=0) + lv4: R.Tensor((6,), dtype="float32") = R.concat(y, axis=0) + lv5: R.Tensor((6,), dtype="float32") = R.add(lv2, lv3) + lv6: R.Tensor((6,), dtype="float32") = R.add(lv5, lv4) + gv: R.Tensor((), dtype="float32") = R.sum(lv6, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv6_adjoint: R.Tensor((6,), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([6])) + lv5_adjoint: R.Tensor((6,), dtype="float32") = lv6_adjoint + lv4_adjoint: R.Tensor((6,), dtype="float32") = lv6_adjoint + lv2_adjoint: R.Tensor((6,), dtype="float32") = lv5_adjoint + lv3_adjoint: R.Tensor((6,), dtype="float32") = lv5_adjoint + y_adjoint: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(lv4_adjoint, indices_or_sections=[3], axis=0) + lv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0) + x_adjoint: R.Tensor((3,), dtype="float32") = lv[0] + lv1_1: R.Tensor((3,), dtype="float32") = lv[1] + x_adjoint1: R.Tensor((3,), dtype="float32") = R.add(x_adjoint, lv1_1) + lv1_adjoint: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0) + lv2_1: R.Tensor((3,), dtype="float32") = lv1_adjoint[0] + x_adjoint2: R.Tensor((3,), dtype="float32") = R.add(x_adjoint1, lv2_1) + lv3_1: R.Tensor((3,), dtype="float32") = lv1_adjoint[1] + x_adjoint3: R.Tensor((3,), dtype="float32") = R.add(x_adjoint2, lv3_1) + x_adjoint_out: R.Tensor((3,), dtype="float32") = x_adjoint3 + y_adjoint_out: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = y_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out)) @R.function - def main_adjoint(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32"))) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"), R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")))): + def main(x: R.Tensor((3,), dtype="float32"), y: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32"))) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = (x, x) - lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0) - lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0) - lv4: R.Tensor((6,), "float32") = R.concat(y, axis=0) - lv5: R.Tensor((6,), "float32") = R.add(lv2, lv3) - lv6: R.Tensor((6,), "float32") = R.add(lv5, lv4) - gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv6_adjoint: R.Tensor((6,), "float32") = R.broadcast_to(gv_adjoint, (6,)) - lv5_adjoint: R.Tensor((6,), "float32") = lv6_adjoint - lv4_adjoint: R.Tensor((6,), "float32") = lv6_adjoint - lv3_adjoint: R.Tensor((6,), "float32") = lv5_adjoint - lv2_adjoint: R.Tensor((6,), "float32") = lv5_adjoint - lv1_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0) - lv: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0) - lv11: R.Tensor((3,), "float32") = lv[0] - lv21: R.Tensor((3,), "float32") = lv[1] - lv31: R.Tensor((3,), "float32") = R.add(lv11, lv21) - lv41: R.Tensor((3,), "float32") = lv1_adjoint[0] - lv51: R.Tensor((3,), "float32") = R.add(lv31, lv41) - lv61: R.Tensor((3,), "float32") = lv1_adjoint[1] - x_adjoint: R.Tensor((3,), "float32") = R.add(lv51, lv61) - y_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(lv4_adjoint, indices_or_sections=[3], axis=0) - R.output(gv, x_adjoint, y_adjoint) - return (gv, (x_adjoint, y_adjoint)) + lv1: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = (x, x) + lv2: R.Tensor((6,), dtype="float32") = R.concat(lv1, axis=0) + lv3: R.Tensor((6,), dtype="float32") = R.concat((x, x), axis=0) + lv4: R.Tensor((6,), dtype="float32") = R.concat(y, axis=0) + lv5: R.Tensor((6,), dtype="float32") = R.add(lv2, lv3) + lv6: R.Tensor((6,), dtype="float32") = R.add(lv5, lv4) + gv: R.Tensor((), dtype="float32") = R.sum(lv6, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -728,43 +757,41 @@ def main(x: R.Tensor((3,), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3,), "float32")) -> R.Tensor((), "float32"): - # block 0 + def main_adjoint(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3,), dtype="float32"))): with R.dataflow(): - lv1: R.Tensor((6,), "float32") = R.concat((c1, c2), axis=0) - lv2: R.Tensor((6,), "float32") = R.concat((c3, x), axis=0) - lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0) - lv4: R.Tensor((6,), "float32") = R.add(lv1, lv2) - lv5: R.Tensor((6,), "float32") = R.add(lv4, lv3) - gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, keepdims=False) - R.output(gv) - return gv + lv1: R.Tensor((6,), dtype="float32") = R.concat((c1, c2), axis=0) + lv2: R.Tensor((6,), dtype="float32") = R.concat((c3, x), axis=0) + lv3: R.Tensor((6,), dtype="float32") = R.concat((x, x), axis=0) + lv4: R.Tensor((6,), dtype="float32") = R.add(lv1, lv2) + lv5: R.Tensor((6,), dtype="float32") = R.add(lv4, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv5, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv5_adjoint: R.Tensor((6,), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([6])) + lv4_adjoint: R.Tensor((6,), dtype="float32") = lv5_adjoint + lv3_adjoint: R.Tensor((6,), dtype="float32") = lv5_adjoint + lv2_adjoint: R.Tensor((6,), dtype="float32") = lv4_adjoint + lv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0) + x_adjoint: R.Tensor((3,), dtype="float32") = lv[0] + lv1_1: R.Tensor((3,), dtype="float32") = lv[1] + x_adjoint1: R.Tensor((3,), dtype="float32") = R.add(x_adjoint, lv1_1) + lv2_1: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0) + lv3_1: R.Tensor((3,), dtype="float32") = lv2_1[1] + x_adjoint2: R.Tensor((3,), dtype="float32") = R.add(x_adjoint1, lv3_1) + x_adjoint_out: R.Tensor((3,), dtype="float32") = x_adjoint2 + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) @R.function - def main_adjoint(x: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"))): - # block 0 + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tensor((6,), "float32") = R.concat((c1, c2), axis=0) - lv2: R.Tensor((6,), "float32") = R.concat((c3, x), axis=0) - lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0) - lv4: R.Tensor((6,), "float32") = R.add(lv1, lv2) - lv5: R.Tensor((6,), "float32") = R.add(lv4, lv3) - gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv5_adjoint: R.Tensor((6,), "float32") = R.broadcast_to(gv_adjoint, (6,)) - lv4_adjoint: R.Tensor((6,), "float32") = lv5_adjoint - lv3_adjoint: R.Tensor((6,), "float32") = lv5_adjoint - lv2_adjoint: R.Tensor((6,), "float32") = lv4_adjoint - lv1_adjoint: R.Tensor((6,), "float32") = lv4_adjoint - lv: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0) - lv11: R.Tensor((3,), "float32") = lv[0] - lv21: R.Tensor((3,), "float32") = lv[1] - lv31: R.Tensor((3,), "float32") = R.add(lv11, lv21) - lv41: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0) - lv51: R.Tensor((3,), "float32") = lv41[1] - x_adjoint: R.Tensor((3,), "float32") = R.add(lv31, lv51) - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) + lv1: R.Tensor((6,), dtype="float32") = R.concat((c1, c2), axis=0) + lv2: R.Tensor((6,), dtype="float32") = R.concat((c3, x), axis=0) + lv3: R.Tensor((6,), dtype="float32") = R.concat((x, x), axis=0) + lv4: R.Tensor((6,), dtype="float32") = R.add(lv1, lv2) + lv5: R.Tensor((6,), dtype="float32") = R.add(lv4, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv5, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -794,42 +821,85 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")): @I.ir_module class Expected: @R.function - def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, cst) + lv2: R.Tensor((3, 3), dtype="float32") = cst + lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))) = (cst, (cst, lv1)) + lv4: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv3[1] + lv5: R.Tensor((3, 3), dtype="float32") = lv4[1] + lv6: R.Tensor((3, 3), dtype="float32") = R.subtract(lv5, lv2) + gv: R.Tensor((), dtype="float32") = R.sum(lv6, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv6_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv5_adjoint: R.Tensor((3, 3), dtype="float32") = lv6_adjoint + lv: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv4_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = (lv, lv5_adjoint) + lv1_1: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv3_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))) = (lv1_1, lv4_adjoint) + lv2_1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv3_adjoint[1] + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_1[1] + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + y_adjoint: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out)) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, cst) + lv2: R.Tensor((3, 3), dtype="float32") = cst + lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))) = (cst, (cst, lv1)) + lv4: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv3[1] + lv5: R.Tensor((3, 3), dtype="float32") = lv4[1] + lv6: R.Tensor((3, 3), dtype="float32") = R.subtract(lv5, lv2) + gv: R.Tensor((), dtype="float32") = R.sum(lv6, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + After = relax.transform.Gradient("main")(Before) + assert_structural_equal(After, Expected) + + +def test_shape_expr(): + # fmt: off + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((3, 4), "float32")): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, cst) - lv2: R.Tensor((3, 3), "float32") = cst - lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))) = (cst, (cst, lv1)) - lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv3[1] - lv5: R.Tensor((3, 3), "float32") = lv4[1] - lv6: R.Tensor((3, 3), "float32") = R.subtract(lv5, lv2) - gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, keepdims=False) + s = R.shape([3, 2, 2]) + lv = R.reshape(x, s) + gv = R.sum(lv) R.output(gv) return gv + @I.ir_module + class Expected: + @R.function + def main_adjoint(x: R.Tensor((3, 4), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 4), dtype="float32"))): + with R.dataflow(): + s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2]) + lv: R.Tensor((3, 2, 2), dtype="float32") = R.reshape(x, s) + gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv_adjoint: R.Tensor((3, 2, 2), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 2, 2])) + x_adjoint: R.Tensor((3, 4), dtype="float32") = R.reshape(lv_adjoint, R.shape([3, 4])) + x_adjoint_out: R.Tensor((3, 4), dtype="float32") = x_adjoint + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) + @R.function - def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): + def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv1: R.Tensor((3, 3), "float32") = R.add(x, cst) - lv2: R.Tensor((3, 3), "float32") = cst - lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))) = (cst, (cst, lv1)) - lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv3[1] - lv5: R.Tensor((3, 3), "float32") = lv4[1] - lv6: R.Tensor((3, 3), "float32") = R.subtract(lv5, lv2) - gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32") - lv6_adjoint: R.Tensor((3, 3), "float32") = R.broadcast_to(gv_adjoint, (3, 3)) - lv5_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint - lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv4_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (lv, lv5_adjoint) - lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - lv3_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))) = (lv11, lv4_adjoint) - lv2_adjoint: R.Tensor((3, 3), "float32") = R.negative(lv6_adjoint) - lv21: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = lv3_adjoint[1] - lv1_adjoint: R.Tensor((3, 3), "float32") = lv21[1] - x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint - y_adjoint: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32") - R.output(gv, x_adjoint, y_adjoint) - return (gv, (x_adjoint, y_adjoint)) + s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2]) + lv: R.Tensor((3, 2, 2), dtype="float32") = R.reshape(x, s) + gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) + R.output(gv) + return gv # fmt: on After = relax.transform.Gradient("main")(Before) @@ -890,6 +960,29 @@ def main( old_bindings_len = len(old_bindings) new_bindings = After["main_adjoint"].body.blocks[0].bindings[:old_bindings_len] assert_structural_equal(old_bindings, new_bindings, True) + assert relax.analysis.well_formed(After) + + +def test_tir_copy(): + @I.ir_module + class Before: + @R.function + def main( + x0: R.Tensor(("n", "n"), "float32"), + x1: R.Tensor(("n", "n"), "float32"), + x2: R.Tensor(("n", "n"), "float32"), + x3: R.Tensor(("n", "n"), "float32"), + ): + with R.dataflow(): + lv0 = R.add(x0, x1) + lv1 = R.add(x2, x3) + lv2 = R.add(lv0, lv1) + gv = R.sum(lv2) + R.output(gv) + return gv + + After = relax.transform.Gradient("main")(Before) + assert relax.analysis.well_formed(After) def test_report_error(): @@ -1055,47 +1148,6 @@ def main(x: R.Tuple(R.Tensor((3, 3), "int64"), R.Tensor((3, 3), "int64"))): relax.transform.Gradient("main")(IntDtypeTuple) -def test_shape_expr(): - # fmt: off - @I.ir_module - class Before: - @R.function - def main(x: R.Tensor((3, 4), "float32")): - with R.dataflow(): - s = R.shape([3, 2, 2]) - lv = R.reshape(x, s) - gv = R.sum(lv) - R.output(gv) - return gv - - @I.ir_module - class Expected: - @R.function - def main_adjoint(x: R.Tensor((3, 4), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 4), dtype="float32"))): - with R.dataflow(): - s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2]) - lv = R.reshape(x, s) - gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) - gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") - lv_adjoint : R.Tensor([3, 2, 2], "float32") = R.broadcast_to(gv_adjoint, R.shape([3, 2, 2])) - x_adjoint: R.Tensor((3, 4), dtype="float32") = R.reshape(lv_adjoint, R.shape([3, 4])) - R.output(gv, x_adjoint) - return (gv, (x_adjoint,)) - - @R.function - def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): - with R.dataflow(): - s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2]) - lv = R.reshape(x, s) - gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) - R.output(gv) - return gv - # fmt: on - - After = relax.transform.Gradient("main")(Before) - assert_structural_equal(After, Expected) - - def test_mlp_script(): """ An example of single layer multi-layer perceptron. You can add extra layers if you want. @@ -1133,17 +1185,18 @@ def main_adjoint(x: R.Tensor((3, 10), dtype="float32"), w0: R.Tensor((10, 5), dt lv: R.Tensor((), dtype="float32") = R.divide(loss_adjoint, R.const(3, "float32")) lv1: R.Tensor((), dtype="float32") = R.negative(lv) logits_adjoint: R.Tensor((3, 5), dtype="float32") = R.multiply(lv1, label) - lv2: R.Tensor((3, 1), dtype="float32") = R.sum(logits_adjoint, axis=[-1], keepdims=True) - lv3: R.Tensor((3, 5), dtype="float32") = R.exp(logits) - lv4: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv3) - out_adjoint: R.Tensor((3, 5), dtype="float32") = R.subtract(logits_adjoint, lv4) + lv3: R.Tensor((3, 1), dtype="float32") = R.sum(logits_adjoint, axis=[-1], keepdims=True) + lv4: R.Tensor((3, 5), dtype="float32") = R.exp(logits) + lv5: R.Tensor((3, 5), dtype="float32") = R.multiply(lv3, lv4) + out_adjoint: R.Tensor((3, 5), dtype="float32") = R.subtract(logits_adjoint, lv5) lv0_adjoint: R.Tensor((3, 5), dtype="float32") = out_adjoint - lv5: R.Tensor((5, 10), dtype="float32") = R.permute_dims(w0, axes=[1, 0]) - lv6: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x, axes=[1, 0]) - w0_adjoint: R.Tensor((10, 5), dtype="float32") = R.matmul(lv6, lv0_adjoint, out_dtype="void") b0_adjoint: R.Tensor((5,), dtype="float32") = R.collapse_sum_to(out_adjoint, R.shape([5])) - R.output(loss, w0_adjoint, b0_adjoint) - return (loss, (w0_adjoint, b0_adjoint)) + lv7: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x, axes=[1, 0]) + w0_adjoint: R.Tensor((10, 5), dtype="float32") = R.matmul(lv7, lv0_adjoint, out_dtype="void") + w0_adjoint_out: R.Tensor((10, 5), dtype="float32") = w0_adjoint + b0_adjoint_out: R.Tensor((5,), dtype="float32") = b0_adjoint + R.output(loss, w0_adjoint_out, b0_adjoint_out) + return (loss, (w0_adjoint_out, b0_adjoint_out)) @R.function def main(x: R.Tensor((3, 10), dtype="float32"), w0: R.Tensor((10, 5), dtype="float32"), b0: R.Tensor((5,), dtype="float32"), label: R.Tensor((3, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): diff --git a/tests/python/relax/test_transform_gradient_checkpoint.py b/tests/python/relax/test_transform_gradient_checkpoint.py new file mode 100644 index 000000000000..3e94125b77ca --- /dev/null +++ b/tests/python/relax/test_transform_gradient_checkpoint.py @@ -0,0 +1,689 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for gradient with checkpointing.""" +import tvm +from tvm.relax.block_builder import BlockBuilder +from tvm.relax.testing.nn import checkpoint, emit_checkpoint, emit_checkpoint_sequential +import tvm.testing +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script.parser import relax as R, ir as I + + +def test_sequential(): + """Comp. graph is a sequence""" + # fmt: off + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((3, 3), "float32")): + with R.dataflow(): + x_scp = R.grad.start_checkpoint(x) + lv1 = R.power(x_scp, R.const(3, "float32")) + lv1_ecp = R.grad.end_checkpoint(lv1) + lv2 = R.power(lv1_ecp, R.const(3, "float32")) + lv2_scp = R.grad.start_checkpoint(lv2) + lv3 = R.power(lv2_scp, R.const(3, "float32")) + lv4 = R.power(lv3, R.const(3, "float32")) + gv = R.sum(lv4) + gv_ecp = R.grad.end_checkpoint(gv) + R.output(gv_ecp) + return gv_ecp + + @I.ir_module + class Expected: + @R.function + def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, "float32")) + lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1, R.const(3, "float32")) + lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2, R.const(3, "float32")) + lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, R.const(3, "float32")) + gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, keepdims=False) + gv_1: R.Tensor((), dtype="float32") = gv + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + gv_adjoint1: R.Tensor((), dtype="float32") = gv_adjoint + lv3_cp: R.Tensor((3, 3), dtype="float32") = R.power(lv2, R.const(3, "float32")) + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint1, R.shape([3, 3])) + lv: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_adjoint, R.const(3, "float32")) + lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv3_cp, lv1_1) + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv, lv2_1) + lv6: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3_adjoint, R.const(3, "float32")) + lv7: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv8: R.Tensor((3, 3), dtype="float32") = R.power(lv2, lv7) + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, lv8) + lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, "float32")) + lv12: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2_adjoint, R.const(3, "float32")) + lv13: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv14: R.Tensor((3, 3), dtype="float32") = R.power(lv1_cp, lv13) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv12, lv14) + lv18: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_adjoint, R.const(3, "float32")) + lv19: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv20: R.Tensor((3, 3), dtype="float32") = R.power(x, lv19) + x_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv18, lv20) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + R.output(gv_1, x_adjoint_out) + return (gv_1, (x_adjoint_out,)) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + x_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(x) + lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp, R.const(3, "float32")) + lv1_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv1) + lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1_ecp, R.const(3, "float32")) + lv2_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv2) + lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2_scp, R.const(3, "float32")) + lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, R.const(3, "float32")) + gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, keepdims=False) + gv_ecp: R.Tensor((), dtype="float32") = R.grad.end_checkpoint(gv) + R.output(gv_ecp) + return gv_ecp + # fmt: on + + After = relax.transform.Gradient("main")(Before) + assert_structural_equal(After, Expected) + + +def test_sequential_consecutive(): + """Comp. graph is a sequence""" + # fmt: off + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((3, 3), "float32")): + with R.dataflow(): + x_scp = R.grad.start_checkpoint(x) + lv1 = R.power(x_scp, R.const(3, "float32")) + lv2 = R.power(lv1, R.const(3, "float32")) + lv2_ecp = R.grad.end_checkpoint(lv2) + lv2_scp = R.grad.start_checkpoint(lv2_ecp) + lv3 = R.power(lv2_scp, R.const(3, "float32")) + lv4 = R.power(lv3, R.const(3, "float32")) + lv4_ecp = R.grad.end_checkpoint(lv4) + gv = R.sum(lv4_ecp) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, "float32")) + lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1, R.const(3, "float32")) + lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2, R.const(3, "float32")) + lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, R.const(3, "float32")) + gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv3_cp: R.Tensor((3, 3), dtype="float32") = R.power(lv2, R.const(3, "float32")) + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_adjoint, R.const(3, "float32")) + lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv3_cp, lv1_1) + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv, lv2_1) + lv6: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3_adjoint, R.const(3, "float32")) + lv7: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv8: R.Tensor((3, 3), dtype="float32") = R.power(lv2, lv7) + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, lv8) + lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, "float32")) + lv12: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2_adjoint, R.const(3, "float32")) + lv13: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv14: R.Tensor((3, 3), dtype="float32") = R.power(lv1_cp, lv13) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv12, lv14) + lv18: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_adjoint, R.const(3, "float32")) + lv19: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv20: R.Tensor((3, 3), dtype="float32") = R.power(x, lv19) + x_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv18, lv20) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + x_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(x) + lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp, R.const(3, "float32")) + lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1, R.const(3, "float32")) + lv2_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv2) + lv2_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv2_ecp) + lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2_scp, R.const(3, "float32")) + lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, R.const(3, "float32")) + lv4_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv4) + gv: R.Tensor((), dtype="float32") = R.sum(lv4_ecp, axis=None, keepdims=False) + R.output(gv) + return gv + + # fmt: on + + After = relax.transform.Gradient("main")(Before) + assert_structural_equal(After, Expected) + + +def test_tuple(): + """Comp. graph is a sequence""" + # fmt: off + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((3, 3), "float32")): + with R.dataflow(): + x_scp = R.grad.start_checkpoint(x) + lv1 = R.power(x_scp, R.const(3, "float32")) + lv2 = (x, lv1) + lv3 = lv2 + lv4 = R.power(lv3[0], R.const(3, "float32")) + lv4_ecp = R.grad.end_checkpoint(lv4) + gv = R.sum(lv4_ecp) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, "float32")) + lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = x, lv1 + lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv2 + lv4: R.Tensor((3, 3), dtype="float32") = lv3[0] + lv4_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4, R.const(3, "float32")) + gv: R.Tensor((), dtype="float32") = R.sum(lv4_1, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, "float32")) + lv2_cp: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = x, lv1_cp + lv3_cp: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv2_cp + lv4_cp: R.Tensor((3, 3), dtype="float32") = lv3_cp[0] + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_adjoint, R.const(3, "float32")) + lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4_cp, lv1_1) + lv4_adjoint1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv, lv2_1) + lv6: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") + lv3_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv4_adjoint1, lv6 + lv2_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv3_adjoint + x_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[0] + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[1] + lv7: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_adjoint, R.const(3, "float32")) + lv8: R.Tensor((), dtype="float32") = R.subtract(R.const(3, "float32"), R.const(1, "float32")) + lv9: R.Tensor((3, 3), dtype="float32") = R.power(x, lv8) + lv12: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7, lv9) + x_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(x_adjoint, lv12) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1 + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + x_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(x) + lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp, R.const(3, "float32")) + lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = x, lv1 + lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")) = lv2 + lv4: R.Tensor((3, 3), dtype="float32") = lv3[0] + lv4_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4, R.const(3, "float32")) + lv4_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv4_1) + gv: R.Tensor((), dtype="float32") = R.sum(lv4_ecp, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + After = relax.transform.Gradient("main")(Before) + assert_structural_equal(After, Expected) + + +def test_tree(): + """Comp. graph is a output-directed tree""" + # fmt: off + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v: R.Tensor((3, 3), "float32")): + with R.dataflow(): + lv1 = x * y + lv1_scp = R.grad.start_checkpoint(lv1) + z_scp = R.grad.start_checkpoint(z) + lv2 = lv1_scp * z_scp + lv2_ecp = R.grad.end_checkpoint(lv2) + u_scp = R.grad.start_checkpoint(u) + v_scp = R.grad.start_checkpoint(v) + lv3 = u_scp * v_scp + lv3_ecp = R.grad.end_checkpoint(lv3) + lv4 = lv2_ecp * lv3_ecp + gv = R.sum(lv4) + R.output(gv) + return gv + + @I.ir_module + class Expected1: + @R.function + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"))): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z) + lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v) + lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv2_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z) + lv3_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v) + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_adjoint, lv3_cp) + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_adjoint, lv2_cp) + u_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3_adjoint, v) + v_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3_adjoint, u) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2_adjoint, z) + z_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2_adjoint, lv1) + x_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_adjoint, y) + y_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_adjoint, x) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint + y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint + z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint + u_adjoint_out: R.Tensor((3, 3), dtype="float32") = u_adjoint + v_adjoint_out: R.Tensor((3, 3), dtype="float32") = v_adjoint + R.output(gv, x_adjoint_out, y_adjoint_out, z_adjoint_out, u_adjoint_out, v_adjoint_out) + return (gv, (x_adjoint_out, y_adjoint_out, z_adjoint_out, u_adjoint_out, v_adjoint_out)) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv1 = x * y + lv1_scp = R.grad.start_checkpoint(lv1) + z_scp = R.grad.start_checkpoint(z) + lv2 = lv1_scp * z_scp + lv2_ecp = R.grad.end_checkpoint(lv2) + u_scp = R.grad.start_checkpoint(u) + v_scp = R.grad.start_checkpoint(v) + lv3 = u_scp * v_scp + lv3_ecp = R.grad.end_checkpoint(lv3) + lv4 = lv2_ecp * lv3_ecp + gv = R.sum(lv4) + R.output(gv) + return gv + # fmt: on + + After1 = relax.transform.Gradient("main")(Before) + assert_structural_equal(After1, Expected1) + + # fmt: off + @I.ir_module + class Expected2: + @R.function + def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, y) + lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z) + lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v) + lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2, lv3) + gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv3_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v) + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_adjoint, lv3_cp) + z_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2_adjoint, lv1) + z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint + R.output(gv, z_adjoint_out) + return (gv, (z_adjoint_out,)) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv1 = x * y + lv1_scp = R.grad.start_checkpoint(lv1) + z_scp = R.grad.start_checkpoint(z) + lv2 = lv1_scp * z_scp + lv2_ecp = R.grad.end_checkpoint(lv2) + u_scp = R.grad.start_checkpoint(u) + v_scp = R.grad.start_checkpoint(v) + lv3 = u_scp * v_scp + lv3_ecp = R.grad.end_checkpoint(lv3) + lv4 = lv2_ecp * lv3_ecp + gv = R.sum(lv4) + R.output(gv) + return gv + # fmt: on + + After2 = relax.transform.Gradient("main", require_grads=Before["main"].params[2])(Before) + assert_structural_equal(After2, Expected2) + + +def test_dag(): + """Comp. graph is a DAG with only one output. Here we only test the simple case: comp. graph + is a sequence of sub-graphs, and the checkpoints are the intersections of connected + subgraphs.""" + # fmt: off + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((3, 3), "float32")): + with R.dataflow(): + lv = R.grad.start_checkpoint(x) + lv1 = R.multiply(lv, R.const(2, "float32")) + lv2 = R.multiply(lv1, R.const(2, "float32")) + lv3 = R.grad.end_checkpoint(lv2) + lv4 = R.multiply(x, lv3) + lv5 = R.grad.start_checkpoint(lv4) + lv6 = R.multiply(lv5, R.const(2, "float32")) + lv7 = R.multiply(lv6, R.const(2, "float32")) + lv8 = R.grad.end_checkpoint(lv7) + lv9 = R.multiply(lv4, lv8) + lv10 = R.grad.start_checkpoint(lv9) + lv11 = R.multiply(lv10, R.const(2, "float32")) + lv12 = R.multiply(lv11, R.const(2, "float32")) + lv13 = R.grad.end_checkpoint(lv12) + lv14 = R.multiply(lv9, lv13) + gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None, keepdims=False) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))): + with R.dataflow(): + lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, R.const(2, "float32")) + lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, R.const(2, "float32")) + lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(x, lv2) + lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3, R.const(2, "float32")) + lv5: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4, R.const(2, "float32")) + lv6: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3, lv5) + lv7: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, R.const(2, "float32")) + lv8: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7, R.const(2, "float32")) + lv9: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, lv8) + gv: R.Tensor((), dtype="float32") = R.sum(lv9, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv9_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([3, 3])) + lv7_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, R.const(2, "float32")) + lv8_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7_cp, R.const(2, "float32")) + lv6_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv9_adjoint, lv8_cp) + lv8_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv9_adjoint, lv6) + lv7_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv8_adjoint, R.const(2, "float32")) + lv1_1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7_adjoint, R.const(2, "float32")) + lv6_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(lv6_adjoint, lv1_1) + lv4_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3, R.const(2, "float32")) + lv5_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_cp, R.const(2, "float32")) + lv3_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6_adjoint1, lv5_cp) + lv5_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6_adjoint1, lv3) + lv4_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv5_adjoint, R.const(2, "float32")) + lv4_1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_adjoint, R.const(2, "float32")) + lv3_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(lv3_adjoint, lv4_1) + lv1_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(x, R.const(2, "float32")) + lv2_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_cp, R.const(2, "float32")) + x_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3_adjoint1, lv2_cp) + lv2_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3_adjoint1, x) + lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2_adjoint, R.const(2, "float32")) + lv7_1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_adjoint, R.const(2, "float32")) + x_adjoint1: R.Tensor((3, 3), dtype="float32") = R.add(x_adjoint, lv7_1) + x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1 + R.output(gv, x_adjoint_out) + return (gv, (x_adjoint_out,)) + + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv = R.grad.start_checkpoint(x) + lv1 = R.multiply(lv, R.const(2, "float32")) + lv2 = R.multiply(lv1, R.const(2, "float32")) + lv3 = R.grad.end_checkpoint(lv2) + lv4 = R.multiply(x, lv3) + lv5 = R.grad.start_checkpoint(lv4) + lv6 = R.multiply(lv5, R.const(2, "float32")) + lv7 = R.multiply(lv6, R.const(2, "float32")) + lv8 = R.grad.end_checkpoint(lv7) + lv9 = R.multiply(lv4, lv8) + lv10 = R.grad.start_checkpoint(lv9) + lv11 = R.multiply(lv10, R.const(2, "float32")) + lv12 = R.multiply(lv11, R.const(2, "float32")) + lv13 = R.grad.end_checkpoint(lv12) + lv14 = R.multiply(lv9, lv13) + gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + After = relax.transform.Gradient("main")(Before) + assert_structural_equal(After, Expected) + + +def test_checkpoint_api(): + """Test on tvm.relax.testing.nn.checkpoint API""" + + def func1(x): + return relax.op.power(x, relax.const(3, "float32")) + + def func2(x): + y = relax.op.power(relax.op.power(x, relax.const(3, "float32")), relax.const(3, "float32")) + return relax.op.sum(y) + + bb = BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv1 = bb.emit(checkpoint(func1, x)) + lv2 = bb.emit(relax.op.power(lv1, relax.const(3, "float32"))) + lv3 = bb.emit_output(checkpoint(func2, lv2)) + bb.emit_func_output(lv3) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 3), "float32")): + with R.dataflow(): + x_scp = R.grad.start_checkpoint(x) + lv1 = R.power(x_scp, R.const(3, "float32")) + lv1_ecp = R.grad.end_checkpoint(lv1) + lv2 = R.power(lv1_ecp, R.const(3, "float32")) + lv2_scp = R.grad.start_checkpoint(lv2) + lv3 = R.power(lv2_scp, R.const(3, "float32")) + lv4 = R.power(lv3, R.const(3, "float32")) + gv = R.sum(lv4) + gv_ecp = R.grad.end_checkpoint(gv) + R.output(gv_ecp) + return gv_ecp + # fmt: on + + assert_structural_equal(bb.get(), Expected) + + +def test_checkpoint_tree(): + """Comp. graph is a output-directed tree""" + + def func(x, y, z, w): + return x * y, z * w + + bb = BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + y = relax.Var("y", relax.TensorStructInfo((3, 3), "float32")) + z = relax.Var("z", relax.TensorStructInfo((3, 3), "float32")) + u = relax.Var("u", relax.TensorStructInfo((3, 3), "float32")) + v = relax.Var("v", relax.TensorStructInfo((3, 3), "float32")) + with bb.function("main", [x, y, z, u, v]): + with bb.dataflow(): + lv1 = bb.emit(x * y) + cp = checkpoint(func, lv1, z, u, v) + lv2 = bb.emit(cp[0]) + lv3 = bb.emit(cp[1]) + lv4 = bb.emit(lv2 * lv3) + gv = bb.emit_output(relax.op.sum(lv4)) + bb.emit_func_output(gv) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v: R.Tensor((3, 3), "float32")): + with R.dataflow(): + lv1 = x * y + lv1_scp = R.grad.start_checkpoint(lv1) + z_scp = R.grad.start_checkpoint(z) + lv2 = lv1_scp * z_scp + lv2_ecp = R.grad.end_checkpoint(lv2) + u_scp = R.grad.start_checkpoint(u) + v_scp = R.grad.start_checkpoint(v) + lv3 = u_scp * v_scp + lv3_ecp = R.grad.end_checkpoint(lv3) + lv4 = lv2_ecp * lv3_ecp + gv = R.sum(lv4) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(bb.get(), Expected) + + +def test_checkpoint_dag(): + """Comp. graph is a DAG with only one output. Here we only test the simple case: comp. graph + is a sequence of sub-graphs, and the checkpoints are the intersections of connected + subgraphs.""" + + def func(x): + return x * relax.const(2, "float32") * relax.const(2, "float32") + + bb = BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv1 = bb.emit(checkpoint(func, x)) + lv2 = bb.emit(x * lv1) + lv3 = bb.emit(checkpoint(func, lv2)) + lv4 = bb.emit(lv2 * lv3) + lv5 = bb.emit(checkpoint(func, lv4)) + lv6 = bb.emit(lv4 * lv5) + gv = bb.emit_output(relax.op.sum(lv6)) + bb.emit_func_output(gv) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv = R.grad.start_checkpoint(x) + lv1 = R.multiply(lv, R.const(2, "float32")) + lv2 = R.multiply(lv1, R.const(2, "float32")) + lv3 = R.grad.end_checkpoint(lv2) + lv4 = R.multiply(x, lv3) + lv5 = R.grad.start_checkpoint(lv4) + lv6 = R.multiply(lv5, R.const(2, "float32")) + lv7 = R.multiply(lv6, R.const(2, "float32")) + lv8 = R.grad.end_checkpoint(lv7) + lv9 = R.multiply(lv4, lv8) + lv10 = R.grad.start_checkpoint(lv9) + lv11 = R.multiply(lv10, R.const(2, "float32")) + lv12 = R.multiply(lv11, R.const(2, "float32")) + lv13 = R.grad.end_checkpoint(lv12) + lv14 = R.multiply(lv9, lv13) + gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(bb.get(), Expected) + + +def test_checkpoint_sequential(): + def func(x): + return x + x + + bb = BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv1 = emit_checkpoint_sequential([func] * 5, 2, x) + lv2 = emit_checkpoint_sequential([func] * 4, 2, lv1) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"): + with R.dataflow(): + x_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(x) + lv: R.Tensor((3, 3), dtype="float32") = R.add(x_scp, x_scp) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(lv, lv) + lv1_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv1) + lv1_ecp_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv1_ecp) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1_ecp_scp, lv1_ecp_scp) + lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv2) + lv3_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv3) + lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv3_ecp, lv3_ecp) + lv4_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv4) + lv5: R.Tensor((3, 3), dtype="float32") = R.add(lv4_scp, lv4_scp) + lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv5) + lv6_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv6) + lv7: R.Tensor((3, 3), dtype="float32") = R.add(lv6_ecp, lv6_ecp) + lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv7, lv7) + gv: R.Tensor((3, 3), dtype="float32") = lv8 + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(bb.get(), Expected) + + +def test_checkpoint_sequential_checkpoint_last(): + def func(x): + return x + x + + bb = BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv1 = emit_checkpoint_sequential([func] * 5, 2, x, checkpoint_last=True) + lv2 = emit_checkpoint_sequential([func] * 4, 2, lv1, checkpoint_last=True) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + + # fmt: off + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"): + with R.dataflow(): + x_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(x) + lv: R.Tensor((3, 3), dtype="float32") = R.add(x_scp, x_scp) + lv1: R.Tensor((3, 3), dtype="float32") = R.add(lv, lv) + lv1_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv1) + lv1_ecp_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv1_ecp) + lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1_ecp_scp, lv1_ecp_scp) + lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv2) + lv3_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv3) + lv3_ecp_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv3_ecp) + lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv3_ecp_scp, lv3_ecp_scp) + lv4_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv4) + lv4_ecp_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv4_ecp) + lv5: R.Tensor((3, 3), dtype="float32") = R.add(lv4_ecp_scp, lv4_ecp_scp) + lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv5) + lv6_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv6) + lv6_ecp_scp: R.Tensor((3, 3), dtype="float32") = R.grad.start_checkpoint(lv6_ecp) + lv7: R.Tensor((3, 3), dtype="float32") = R.add(lv6_ecp_scp, lv6_ecp_scp) + lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv7, lv7) + lv8_ecp: R.Tensor((3, 3), dtype="float32") = R.grad.end_checkpoint(lv8) + gv: R.Tensor((3, 3), dtype="float32") = lv8_ecp + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(bb.get(), Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_training_register_te_gradient.py b/tests/python/relax/test_transform_gradient_te_register.py similarity index 63% rename from tests/python/relax/test_training_register_te_gradient.py rename to tests/python/relax/test_transform_gradient_te_register.py index 9803e4fab34e..b6b785fe3c49 100644 --- a/tests/python/relax/test_training_register_te_gradient.py +++ b/tests/python/relax/test_transform_gradient_te_register.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Unit tests for relax training utils.""" +"""Unit tests for registering tir gradient functions in the gradient pass.""" import pytest import tvm import tvm.testing -from tvm import relax +from tvm import relax, tir from tvm.ir.base import assert_structural_equal from tvm.script.parser import relax as R, tir as T, ir as I @@ -91,21 +91,23 @@ def f_mul_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T. def main_adjoint(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((5, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32"))): cls = Expected with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") + lv = R.call_tir(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32")) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") lv_adjoint: R.Tensor((5, 5), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([5, 5])) lv_1 = R.call_tir(cls.f_mul_grad, (lv_adjoint, a, b), out_sinfo=[R.Tensor((5, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32")]) a_adjoint: R.Tensor((5, 5), dtype="float32") = lv_1[0] b_adjoint: R.Tensor((5, 5), dtype="float32") = lv_1[1] - R.output(gv, a_adjoint, b_adjoint) - return (gv, (a_adjoint, b_adjoint)) + a_adjoint_out: R.Tensor((5, 5), dtype="float32") = a_adjoint + b_adjoint_out: R.Tensor((5, 5), dtype="float32") = b_adjoint + R.output(gv, a_adjoint_out, b_adjoint_out) + return (gv, (a_adjoint_out, b_adjoint_out)) @R.function def main(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Expected with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") + lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -113,7 +115,6 @@ def main(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype="float3 return Expected -@pytest.mark.skip("gradient will be refactored in later pull requests") def test_emit_te(register_te_grads): # Build the target module using emit_te def f_mul(src1, src2): @@ -128,7 +129,11 @@ def mul(*idx): bb = relax.BlockBuilder() with bb.function("main", [a, b]): with bb.dataflow(): - d = bb.emit_te(f_mul, a, b, primfunc_name_hint="f_mul", te_grad_name="f_mul_grad") + d = bb.emit( + bb.call_te_with_grad( + f_mul, a, b, primfunc_name_hint="f_mul", te_grad_name="f_mul_grad" + ) + ) out = bb.emit_output(R.sum(d)) bb.emit_func_output(out) @@ -137,7 +142,6 @@ def mul(*idx): assert_structural_equal(After, get_expected_1()) -@pytest.mark.skip("gradient will be refactored in later pull requests") def test_call_tir(register_te_grads): # fmt: off @I.ir_module @@ -157,7 +161,7 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64 def main(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Before with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") + lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -197,20 +201,21 @@ def f_mulk_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T def main_adjoint(a: R.Tensor((5, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((5, 5), dtype="float32"))): cls = Expected with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) + lv = R.call_tir(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32")) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") lv_adjoint: R.Tensor((5, 5), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([5, 5])) lv_1 = R.call_tir(cls.f_mulk_grad, (lv_adjoint, a), out_sinfo=R.Tensor((5, 5), dtype="float32")) a_adjoint: R.Tensor((5, 5), dtype="float32") = lv_1 - R.output(gv, a_adjoint) - return (gv, (a_adjoint,)) + a_adjoint_out: R.Tensor((5, 5), dtype="float32") = a_adjoint + R.output(gv, a_adjoint_out) + return (gv, (a_adjoint_out,)) @R.function def main(a: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Expected with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) + lv = R.call_tir_with_grad(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -218,7 +223,6 @@ def main(a: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): return Expected -@pytest.mark.skip("gradient will be refactored in later pull requests") def test_emit_te_kwargs(register_te_grads): # Build the target module using emit_te def f_mul2(src): @@ -229,24 +233,24 @@ def f_mul2(src): bb = relax.BlockBuilder() with bb.function("main", [a]): with bb.dataflow(): - d = bb.emit_te( - f_mul2, - a, - primfunc_name_hint="f_mul", - te_grad_name="f_mulk_grad", - te_grad_kwargs={"k": T.float32(2)}, + d = bb.emit( + bb.call_te_with_grad( + f_mul2, + a, + primfunc_name_hint="f_mul", + te_grad_name="f_mulk_grad", + te_grad_kwargs={"k": T.float32(2)}, + ) ) out = bb.emit_output(R.sum(d)) bb.emit_func_output(out) Before = bb.get() After = Gradient("main")(Before) - After.show(None, False) assert_structural_equal(After, get_expected_2()) -@pytest.mark.skip("gradient will be refactored in later pull requests") def test_call_tir_kwargs(register_te_grads): # fmt: off @I.ir_module @@ -266,7 +270,7 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T. def main(a: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Before with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) + lv = R.call_tir_with_grad(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -276,5 +280,105 @@ def main(a: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): assert_structural_equal(After, get_expected_2()) +def get_expected_3(): + # fmt: off + @I.ir_module + class Expected: + @T.prim_func + def f_mul(var_A: T.handle, var_B: T.handle, var_f_mul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (n, n)) + B = T.match_buffer(var_B, (n, n)) + f_mul_1 = T.match_buffer(var_f_mul, (n, n)) + # with T.block("root"): + for i0, i1 in T.grid(n, n): + with T.block("f_mul"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1], B[v_i0, v_i1]) + T.writes(f_mul_1[v_i0, v_i1]) + f_mul_1[v_i0, v_i1] = A[v_i0, v_i1] * B[v_i0, v_i1] + + @T.prim_func + def f_mul_grad(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_f_mul_grad_1: T.handle, var_f_mul_grad_2: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(var_A, (n, n)) + B = T.match_buffer(var_B, (n, n)) + C = T.match_buffer(var_C, (n, n)) + f_mul_grad_1 = T.match_buffer(var_f_mul_grad_1, (n, n)) + f_mul_grad_2 = T.match_buffer(var_f_mul_grad_2, (n, n)) + # with T.block("root"): + for i0, i1 in T.grid(n, n): + with T.block("f_mul_grad_1"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[v_i0, v_i1], A[v_i0, v_i1]) + T.writes(f_mul_grad_1[v_i0, v_i1]) + f_mul_grad_1[v_i0, v_i1] = C[v_i0, v_i1] * A[v_i0, v_i1] + for i0, i1 in T.grid(n, n): + with T.block("f_mul_grad_2"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(B[v_i0, v_i1], A[v_i0, v_i1]) + T.writes(f_mul_grad_2[v_i0, v_i1]) + f_mul_grad_2[v_i0, v_i1] = B[v_i0, v_i1] * A[v_i0, v_i1] + + @R.function + def main_adjoint(a: R.Tensor(("n", "n"), dtype="float32"), b: R.Tensor(("n", "n"), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor(("n", "n"), dtype="float32"), R.Tensor(("n", "n"), dtype="float32"))): + n = T.int64() + cls = Expected + with R.dataflow(): + lv = R.call_tir(cls.f_mul, (a, b), out_sinfo=R.Tensor((n, n), dtype="float32")) + gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) + gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") + lv_adjoint: R.Tensor((n, n), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([n, n])) + lv_1 = R.call_tir(cls.f_mul_grad, (lv_adjoint, a, b), out_sinfo=[R.Tensor((n, n), dtype="float32"), R.Tensor((n, n), dtype="float32")]) + a_adjoint: R.Tensor((n, n), dtype="float32") = lv_1[0] + b_adjoint: R.Tensor((n, n), dtype="float32") = lv_1[1] + a_adjoint_out: R.Tensor((n, n), dtype="float32") = a_adjoint + b_adjoint_out: R.Tensor((n, n), dtype="float32") = b_adjoint + R.output(gv, a_adjoint_out, b_adjoint_out) + return (gv, (a_adjoint_out, b_adjoint_out)) + + @R.function + def main(a: R.Tensor(("n", "n"), dtype="float32"), b: R.Tensor(("n", "n"), dtype="float32")) -> R.Tensor((), dtype="float32"): + n = T.int64() + cls = Expected + with R.dataflow(): + lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_sinfo=R.Tensor((n, n), dtype="float32"), te_grad_name="f_mul_grad") + gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + return Expected + + +def test_tir_var(register_te_grads): + def f_mul(src1, src2): + def mul(*idx): + return src1[idx] * src2[idx] + + return tvm.te.compute(src1.shape, mul, name="f_mul") + + n = tir.Var("n", "int64") + a = relax.Var("a", relax.TensorStructInfo([n, n], "float32")) + b = relax.Var("b", relax.TensorStructInfo([n, n], "float32")) + + bb = relax.BlockBuilder() + with bb.function("main", [a, b]): + with bb.dataflow(): + d = bb.emit( + bb.call_te_with_grad( + f_mul, a, b, primfunc_name_hint="f_mul", te_grad_name="f_mul_grad" + ) + ) + out = bb.emit_output(R.sum(d)) + bb.emit_func_output(out) + + Before = bb.get() + After = Gradient("main")(Before) + assert_structural_equal(After, get_expected_3()) + assert relax.analysis.well_formed(After) + + if __name__ == "__main__": tvm.testing.main()