diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 8666e3985eca..0c875045032f 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -230,6 +230,7 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("cast", False, elemwise_shape_func) +register_shape_func("cast_like", False, elemwise_shape_func) register_shape_func("zeros", False, no_data_full_shape_func) register_shape_func("zeros_like", False, elemwise_shape_func) register_shape_func("ones", False, no_data_full_shape_func) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 85168a561399..9f4f20c9000c 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -20,8 +20,11 @@ from tvm.topi.nn.util import get_pad_tuple from tvm.topi.util import get_const_tuple +from tvm.error import OpError -from ..expr import Tuple, TupleGetItem, const +from ..expr import Tuple, TupleGetItem, const, Var +from ..ty import TensorType +from ..loops import while_loop from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum @@ -40,6 +43,7 @@ equal, shape_of, log, + concatenate, ) from .transform import ( broadcast_to_like, @@ -55,6 +59,10 @@ repeat, expand_dims, full_like, + split, + squeeze, + strided_set, + arange, ) @@ -665,3 +673,134 @@ def cross_entropy_with_logits_grad(orig, grad): batch_size = take(shape, const(0, dtype="int32"), axis=0) grad = grad / batch_size.astype(x.checked_type.dtype) return [-grad * y, -grad * x] + + +@register_gradient("take") +def take_grad(orig, grad): + """ + Returns the gradient of take. + """ + + def make_scalar_tensor(v): + if isinstance(v, int): + v = const(v, dtype="int32") + return reshape(v, (1,)) + + # TODO(@altanh): we currently assume indices are in range + data, indices = orig.args + axis = orig.attrs.axis + zero, one = map(make_scalar_tensor, [0, 1]) + data_grad = zeros_like(data) + try: + data_shape = data.checked_type.concrete_shape + except TypeError as ty_err: + raise OpError("currently take_grad only supports data with concrete shape") from ty_err + if axis is None: + axis = 0 + data_grad = reshape(data_grad, (-1,)) + data_shape = 1 + for dim in data.checked_type.concrete_shape: + data_shape *= dim + data_shape = (data_shape,) + else: + axis = int(axis) + strides = [1] * len(data_shape) + + if len(indices.checked_type.shape) == 0: + # axis on grad has been squeezed in this case + num_indices = one + indices = reshape(indices, (1,)) + grad = expand_dims(grad, int(axis)) + elif len(indices.checked_type.shape) == 1: + num_indices = take(shape_of(indices), zero, axis=0) + else: + raise OpError("take_grad only supports scalar or 1D indices") + + def loop_cond(data_grad, i): + return squeeze(less(i, num_indices)) + + def loop_body(data_grad, i): + index = take(indices, i, axis=0) + grad_slice = take(grad, i, axis=axis) + begin, end = [], [] + for ax, size in enumerate(data_shape): + size = make_scalar_tensor(size) + begin.append(zero if ax != axis else index) + end.append(size if ax != axis else index + one) + begin, end = concatenate(begin, axis=0), concatenate(end, axis=0) + # data_grad[:,...,index at axis,...,:] += grad_slice + update = strided_slice(data_grad, begin, end, strides=strides) + update = update + grad_slice # no need to expand grad_slice since i has shape (1,) + next_data_grad = strided_set(data_grad, update, begin, end, strides=strides) + return (next_data_grad, i + one) + + loop_vars = [ + Var("data_grad", type_annotation=TensorType(data_shape, data.checked_type.dtype)), + Var("i", type_annotation=TensorType((1,), "int32")), + ] + + loop = while_loop(loop_cond, loop_vars, loop_body) + result = loop(data_grad, zero) + data_grad = TupleGetItem(result, 0) + + if orig.attrs.axis is None: + data_grad = reshape_like(data_grad, data) + + return [data_grad, zeros_like(orig.args[1])] + + +@register_gradient("contrib_reverse_reshape") +def reverse_reshape_grad(orig, grad): + """ + Returns the gradient of reverse_reshape (same as reshape). + """ + return [reshape_like(grad, orig.args[0])] + + +@register_gradient("stack") +def stack_grad(orig, grad): + """ + Returns grad split across stacked inputs. + """ + stack_axis = int(orig.attrs.axis) + sections = len(orig.args[0].checked_type.fields) + splits = split(grad, sections, stack_axis) + splits = Tuple([squeeze(x, axis=[stack_axis]) for x in splits]) + return [splits] + + +@register_gradient("squeeze") +def squeeze_grad(orig, grad): + """ + Returns grad expanded to input size. + """ + # this should work, can't use expand_dims since we lose + # squeeze information when axis=None + return [reshape_like(grad, orig.args[0])] + + +@register_gradient("expand_dims") +def expand_dims_grad(orig, grad): + """ + Returns grad squeezed on expanded dims. + """ + axis = int(orig.attrs.axis) + for _ in range(orig.attrs.num_newaxis): + grad = squeeze(grad, axis=[axis]) + return [grad] + + +@register_gradient("arange") +def arange_grad(orig, grad): + """ + Returns the gradient of arange. + """ + start, stop, step = orig.args + length = take(shape_of(orig), const(0, dtype="int32"), axis=0) + + grad_start = cast_like(_sum(grad), start) + grad_stop = zeros_like(stop) + grad_step = cast_like(arange(length, dtype="int32"), grad) * grad + grad_step = cast_like(_sum(grad_step), step) + + return [grad_start, grad_stop, grad_step] diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 6eb71b581ab2..9c87f2795e5c 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -64,7 +64,11 @@ def run_infer_type(expr): def _np_randn_from_type(t, scale=1, mean=0): - return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype) + res = mean + (scale * np.random.randn(*(int(d) for d in t.shape))) + # if t.shape == (), then randn returns a scalar so we need to wrap for dtype conversion + if np.isscalar(res): + res = np.array(res) + return res.astype(t.dtype) def check_grad( diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 55f736895018..8f14b557dc54 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -615,7 +615,7 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const Error& err) { - this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << "err"); + this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; } catch (const dmlc::Error& e) { ICHECK(false) << e.what(); diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index c0270eae80d2..cac07c437a42 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -144,5 +144,11 @@ def test_bias_add_grad(): verify_bias_add((4, 8), (8,)) +def test_expand_dims_grad(): + data = relay.var("data", shape=(2, 3), dtype="float64") + fwd_func = relay.Function([data], relay.expand_dims(data, axis=1, num_newaxis=2)) + check_grad(fwd_func) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 462a75255f90..4a6ffb933881 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -67,5 +67,10 @@ def test_batch_matmul_grad(): check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y))) +def test_reverse_reshape_grad(): + x = relay.var("x", shape=(3, 4, 5), dtype="float64") + check_grad(relay.Function([x], relay.op.reverse_reshape(x, (-1, 0)))) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 0b4f8920aa5c..9c27afd87205 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import check_grad, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type, _np_randn_from_type from tvm.relay.transform import gradient import tvm.testing @@ -75,5 +75,47 @@ def test_copy_grad(): check_grad(fwd_func) +def test_take_grad(): + data_dtype = relay.TensorType((3, 4, 5), "float64") + data = relay.var("data", data_dtype) + indices = relay.var("indices", relay.TensorType((relay.Any(),), "int32")) + inputs = [_np_randn_from_type(data_dtype, scale=1e-5), np.array([1, 2], dtype="int32")] + test_inputs = [inputs[0]] + + # take on axis + fwd_func = relay.Function([data, indices], relay.take(data, indices, axis=1)) + check_grad(fwd_func, inputs=inputs, test_inputs=test_inputs) + + # take on flattened + fwd_func = relay.Function([data, indices], relay.take(data, indices, axis=None)) + check_grad(fwd_func, inputs=inputs, test_inputs=test_inputs) + + +def test_stack_grad(): + args = [relay.var(c, shape=(2, 3, 4), dtype="float64") for c in "xyz"] + fwd_func = relay.Function(args, relay.stack(args, axis=0)) + check_grad(fwd_func) + + +def test_squeeze_grad(): + data = relay.var("data", shape=(2, 1, 1, 3, 4, 1), dtype="float64") + fwd_func = relay.Function([data], relay.squeeze(data)) + fwd_func_subset = relay.Function([data], relay.squeeze(data, axis=[1, -1])) + check_grad(fwd_func) + check_grad(fwd_func_subset) + + +def test_arange_grad(): + # TODO: testing arange numerically is strange because two-sided approx can + # produce different output shapes + dtype = "float64" + start = relay.var("start", relay.TensorType((), dtype)) + stop = relay.var("stop", relay.TensorType((), dtype)) + step = relay.var("step", relay.TensorType((), dtype)) + values = [np.array(v, dtype=dtype) for v in [2.5, 9.5, 1.8]] + fwd_func = relay.Function([start, stop, step], relay.arange(start, stop, step, dtype)) + check_grad(fwd_func, inputs=values) + + if __name__ == "__main__": pytest.main()