Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from . import memory
from . import nn

# Operator gradient functions
# Register operator gradient functions
from . import _op_gradient


Expand Down
35 changes: 16 additions & 19 deletions python/tvm/relax/op/_op_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def _get_dtype(expr: Expr) -> str:
return dtype


def _fit_shape(bb: BlockBuilder, expr: Expr, target: Expr) -> Expr:
def _fit_shape(bb: BlockBuilder, input_grad: Expr, input: Expr) -> Expr:
"""When expr and target has the same shape, return expr;
otherwise return `collapse_sum_to(expr, target.struct_info.shape)`.

Will use BlockBuilder to normalize expr first.
"""
target_shape = _get_shape(target)
expr_sinfo = _get_shape(bb.normalize(expr)).struct_info
target_shape = _get_shape(input)
expr_sinfo = _get_shape(bb.normalize(input_grad)).struct_info
target_sinfo = target_shape.struct_info
assert isinstance(expr_sinfo, ShapeStructInfo)
assert isinstance(target_sinfo, ShapeStructInfo)
Expand All @@ -109,9 +109,9 @@ def _check_shape_equal(lhs: ShapeStructInfo, rhs: ShapeStructInfo):
return True

return (
expr
input_grad
if _check_shape_equal(expr_sinfo, target_sinfo)
else collapse_sum_to(expr, target_shape)
else collapse_sum_to(input_grad, target_shape)
)


Expand Down Expand Up @@ -250,15 +250,14 @@ def maximum_grad(
`z = relax.maximum(x, y)`

Backward:
Returns `[z_grad * (where(x < y, 0, 1)), z_grad * (where(x >= y, 0, 1))]`.
Returns `[where(x < y, 0, z_grad), where(x >= y, 0, z_grad)]`.
"""
x = orig_call.args[0]
y = orig_call.args[1]
one = relax.const(1, _get_dtype(x))
zero = relax.const(0, _get_dtype(x))
return [
where(less(x, y), zero, one) * output_grad,
where(greater_equal(x, y), zero, one) * output_grad,
where(less(x, y), zero, output_grad),
where(greater_equal(x, y), zero, output_grad),
]


Expand All @@ -275,15 +274,14 @@ def minimum_grad(
`z = relax.minimum(x, y)`

Backward:
Returns `[z_grad * (where(x >= y, 0, 1)), z_grad * (where(x < y, 0, 1))]`.
Returns `[where(x >= y, 0, z_grad), where(x < y, 0, z_grad)]`.
"""
x = orig_call.args[0]
y = orig_call.args[1]
one = relax.const(1, _get_dtype(x))
zero = relax.const(0, _get_dtype(x))
return [
where(greater_equal(x, y), zero, one) * output_grad,
where(less(x, y), zero, one) * output_grad,
where(greater_equal(x, y), zero, output_grad),
where(less(x, y), zero, output_grad),
]


Expand Down Expand Up @@ -1030,12 +1028,11 @@ def relu_grad(
`y = relax.relu(x)`

Backward:
Returns `[y_grad * (where(x < 0, 0, 1))]`.
Returns `[where(x < 0, 0, y_grad)]`.
"""
x = orig_call.args[0]
one = relax.const(1, _get_dtype(x))
zero = relax.const(0, _get_dtype(x))
return [where(less(x, zero), zero, one) * output_grad]
return [where(less(x, zero), zero, output_grad)]


@register_gradient("relax.nn.silu")
Expand Down Expand Up @@ -1090,10 +1087,10 @@ def log_softmax_grad(
`y = relax.log_softmax(x, axis)`

Backward:
Returns `[y_grad - sum(y_output_grad, axis, keepdims=True) * softmax(x)]`
Returns `[y_grad - sum(y_grad, axis, keepdims=True) * exp(y)]`
"""
x_softmax = exp(orig_var)
return [(output_grad - sum(output_grad, orig_call.attrs.axis, True) * x_softmax)]
y_exp = exp(orig_var)
return [(output_grad - sum(output_grad, orig_call.attrs.axis, True) * y_exp)]


@register_gradient("relax.nn.cross_entropy_with_logits")
Expand Down
53 changes: 53 additions & 0 deletions python/tvm/relax/op/grad/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,59 @@ def no_grad(input: Expr) -> Expr:
return _ffi_api.no_grad(input) # type: ignore


def start_checkpoint(input: Expr) -> Expr:
"""Mark the start of the checkpoint stage. The computation between start_checkpoint and
end_checkpoint will be marked as the checkpoint stage.

Rather than storing all intermediate activations of the entire computation graph for
computing backward, the checkpointed stage does not save intermediate activations, and instead
recomputes them in backward process.

For instance,
```
a = relax.Var("a", relax.TensorStructInfo((2, 2), "float32"))
b = relax.Var("b", relax.TensorStructInfo((2, 2), "float32"))
c = a * 2
d = b * 2
c_cp = start_checkpoint(c)
d_cp = start_checkpoint(d)
e = c_cp + d_cp
e_out = end_checkpoint(e)
```
Then `e` will be recomputed in the backward stage.

See tvm.relax.transform.Gradient, tvm.relax.testing.nn.checkpoint,
tvm.relax.op.grad.end_checkpoint for more information.

Parameters
----------
input : relax.Expr
The tensor marking the input of the checkpoint stage.

Returns
-------
result : relax.Expr
The same tensor as the input.
"""
return _ffi_api.start_checkpoint(input) # type: ignore


def end_checkpoint(input: Expr) -> Expr:
"""Mark the end of checkpoint stage. See tvm.relax.op.grad.start_checkpoint.

Parameters
----------
input : relax.Expr
The output of the checkpoint stage.

Returns
-------
result : relax.Expr
The same tensor as the input.
"""
return _ffi_api.end_checkpoint(input) # type: ignore


def nll_loss_backward(
output_grad: Expr,
predictions: Expr,
Expand Down
123 changes: 122 additions & 1 deletion python/tvm/relax/testing/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin, invalid-name
"""PyTorch-like nn.Module API for constructing workloads."""


Expand All @@ -24,6 +24,7 @@

import tvm
from tvm import relax, topi, tir
from tvm.relax.op.grad.grad import end_checkpoint, start_checkpoint


def emit(expr: relax.Expr, name_hint: str = "") -> relax.Var:
Expand All @@ -34,6 +35,126 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var:
return relax.BlockBuilder.current().emit_te(func, *args, **kwargs)


def checkpoint(
func: Callable, *args: Any, **kwargs: Any
) -> Union[relax.Var, List[relax.Var], List[Any]]:
"""Mark function(*args, **kwargs) should be computed in a checkpointed manner during backward.

To be specific, args and kwargs will be checkpointed, and func(*args, **kwargs) will be
recomputed in the backward stage.
"""
args = [start_checkpoint(v) if isinstance(v, relax.Expr) else v for v in args]
kwargs = {k: start_checkpoint(v) if isinstance(v, relax.Expr) else v for k, v in kwargs.items()}
result = func(*args, **kwargs)
if isinstance(result, (list, tuple)):
result = [end_checkpoint(v) if isinstance(v, relax.Expr) else v for v in result]
else:
assert isinstance(result, relax.Expr)
result = end_checkpoint(result)
return result


def emit_checkpoint(
func: Callable, *args: Any, **kwargs: Any
) -> Union[relax.Var, List[relax.Var], List[Any]]:
"""Mark function(*args, **kwargs) should be computed in a checkpointed manner during backward.

To be specific, args and kwargs will be checkpointed, and func(*args, **kwargs) will be
recomputed in the backward stage.

This interface will additionally emit the exprs marked with start_checkpoint() and
end_checkpoint() with suffix "_scp" and "_ecp" respectively, for easily understanding the
result tvmscript.
"""
bb = relax.BlockBuilder.current()
args = [
bb.emit(start_checkpoint(v), v.name_hint + "_scp") if isinstance(v, relax.Var) else v
for v in args
]
kwargs = {
k: bb.emit(start_checkpoint(v), v.name_hint + "_scp") if isinstance(v, relax.Var) else v
for k, v in kwargs.items()
}
result = func(*args, **kwargs)
if isinstance(result, (list, tuple)):
result = list(result)
for i, v in enumerate(result):
if isinstance(v, relax.Expr):
if not isinstance(v, relax.Var):
v = bb.emit(v)
result[i] = bb.emit(end_checkpoint(v), v.name_hint + "_ecp")
else:
assert isinstance(result, relax.Expr)
result_emit = bb.emit(result)
result = bb.emit(end_checkpoint(result_emit), result_emit.name_hint + "_ecp")

return result


def emit_checkpoint_sequential(
functions: List[Callable],
segments: Union[int, List[int]],
input: relax.Var,
checkpoint_last: bool = False,
) -> Union[relax.Var, List[relax.Var], List[Any]]:
"""A helper function for checkpointing sequential models. This interface has similar purpose
as torch.utils.checkpoint.checkpoint_sequential.

Sequential models execute a list of modules/functions in order (sequentially). Therefore, we
can divide such a model in various segments and checkpoint each segment. By default, we will
checkpoint all segments except the last, meaning their inputs will be saved from the forward
stage and they will be recomputed in the backward stage.

Parameters
----------
functions : List[Callable]
The list of functions to be executed sequentially.

segments : Union[int, List[int]]
The segments. If segments is int `n`, functions will be evenly divided into `n` segments;
if segments is a list of ints, it marks the start of every segment.

input : relax.Var
The input of the first function.

checkpoint_last : bool
Whether the last segment will be checkpointed. Default: False

Returns
-------
output : Union[relax.Var, List[relax.Var], List[Any]]
The emited output of the last function.
"""
bb = relax.BlockBuilder.current()

def run_function(start, end, functions):
def forward(input):
for j in range(start, end):
input = functions[j](input)
return input

return forward

n = len(functions)
if not isinstance(segments, list):
segments = list(range(0, n, n // segments)) + [n]
if segments[-1] != n:
segments = segments + [n]

assert len(segments) >= 2

for i in range(len(segments) - 1):
if i == len(segments) - 2 and not checkpoint_last:
input = run_function(segments[i], segments[i + 1], functions)(input)
else:
input = emit_checkpoint(run_function(segments[i], segments[i + 1], functions), input)

assert isinstance(input, relax.Expr)
if not isinstance(input, relax.Var):
input = bb.emit(input)
return input


def _try_unique_name(name: str):
"""Attempt to uniquify the name

Expand Down
14 changes: 7 additions & 7 deletions python/tvm/relax/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, unused-argument
"""Utility functions for relax training."""

from typing import Optional, Callable
Expand All @@ -24,7 +24,7 @@
from tvm._ffi.registry import register_func
from tvm.relax.block_builder import BlockBuilder

from ..expr import Function
from ..expr import Function, Var, Call
from . import _ffi_api


Expand Down Expand Up @@ -189,13 +189,13 @@ def register(func: Callable):
# It will return the emitted var.

def handler(
builder: BlockBuilder, output_grad_var: relax.Var, call_tir: relax.Call
orig_var: Var, call_tir_with_grad: Call, output_grad: Var, ctx: BlockBuilder
) -> relax.Expr:
return builder.emit_te(
return ctx.emit_te(
func,
output_grad_var,
*call_tir.args[1],
**call_tir.attrs.te_grad_kwargs,
output_grad,
*call_tir_with_grad.args[1],
**call_tir_with_grad.attrs.te_grad_kwargs,
primfunc_name_hint=te_grad_name,
)

Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relax/transform/legalize_ops/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def _no_grad(bb: BlockBuilder, call: Call) -> Expr:
return call.args[0]


@register_legalize("relax.grad.start_checkpoint")
def _start_checkpoint(bb: BlockBuilder, call: Call) -> Expr:
return call.args[0]


@register_legalize("relax.grad.end_checkpoint")
def _end_checkpoint(bb: BlockBuilder, call: Call) -> Expr:
return call.args[0]


@register_legalize("relax.grad.nll_loss_backward")
def _grad_nll_loss_backward(bb: BlockBuilder, call: Call) -> Expr:
# topi.sum don't support zero-dim x
Expand All @@ -51,9 +61,8 @@ def te_nll_loss_backward(output_grad, predictions, targets, weights, reduction,
if reduction == "sum":
output_grad = topi.broadcast_to(output_grad, targets.shape)
elif reduction == "mean":
output_grad = topi.divide(
topi.broadcast_to(output_grad, targets.shape), topi_sum_extend(all_weights)
)
weight_sum = topi_sum_extend(all_weights)
output_grad = topi.divide(topi.broadcast_to(output_grad, targets.shape), weight_sum)

# handle no batch
if predictions.ndim == 1:
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def main_adjoint(original_parameters):
R.output(original_outputs, grad_1, grad_2, ...)
return (original_return_value, (grad_1, grad_2, ...))

This AD pass also supports checkpointing as described in
"Training deep nets with sublinear memory cost." - Chen, Tianqi, et al. (2016).
See tvm.relax.testing.nn.checkpoint for more details.

Parameters
----------
func_name : str
Expand Down
Loading