diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 90120d64c2ac..5836aebce393 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -238,14 +238,28 @@ def divide_grad(orig, grad): @register_gradient("zeros") def zeros_grad(orig, grad): - """Returns [shape]""" - return [orig.args[0]] + """Returns []""" + return [] + + +@register_gradient("dyn.zeros") +def dyn_zeros_grad(orig, grad): + """Returns the gradient of dyn.zeros which is just zero.""" + assert len(orig.args) == 1 + return [zeros_like(orig.args[0])] @register_gradient("ones") def ones_grad(orig, grad): - """Returns [shape]""" - return [orig.args[0]] + """Returns []""" + return [] + + +@register_gradient("dyn.ones") +def dyn_ones_grad(orig, grad): + """Returns the gradient of dyn.ones which is just zero.""" + assert len(orig.args) == 1 + return [zeros_like(orig.args[0])] @register_gradient("zeros_like") diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 0c89aa7d2e9a..904576a181f6 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import check_grad, run_infer_type, _np_randn_from_type +from tvm.relay.testing import check_grad, run_infer_type, run_opt_pass, _np_randn_from_type from tvm.relay.transform import gradient import tvm.testing @@ -133,5 +133,52 @@ def test_reshape_like_grad(): check_grad(fwd_func) +def test_zeros_ones_grad_const_ints(): + # when shape is static (i.e. not an input), there is no gradient at all + static_ty = relay.TensorType([2, 3, 4], dtype="float32") + expected_ty = relay.TupleType([static_ty, relay.TupleType([])]) + + for op in [relay.zeros, relay.ones]: + fwd_func = relay.Function([], op(static_ty.concrete_shape, static_ty.dtype)) + bwd_func = run_infer_type(gradient(run_infer_type(fwd_func))) + tvm.ir.assert_structural_equal(bwd_func.ret_type, expected_ty) + + +def test_zeros_ones_grad_const_expr(): + # when shape is static (i.e. not an input), there is no gradient at all + shape_const = relay.const(np.array([2, 3, 4]), dtype="int32") + static_ty = relay.TensorType([2, 3, 4], dtype="float32") + dyn_ty = relay.TensorType([relay.Any(), relay.Any(), relay.Any()], dtype="float32") + expected_ty_static = relay.TupleType([static_ty, relay.TupleType([])]) + expected_ty_dyn = relay.TupleType([dyn_ty, relay.TupleType([])]) + + for op in [relay.zeros, relay.ones]: + # with DynamicToStatic, the shape should be concretized + fwd_func = relay.Function([], op(shape_const, static_ty.dtype)) + fwd_func = run_opt_pass(fwd_func, relay.transform.DynamicToStatic()) + bwd_func = run_infer_type(gradient(run_infer_type(fwd_func))) + tvm.ir.assert_structural_equal(bwd_func.ret_type, expected_ty_static) + + fwd_func = relay.Function([], op(shape_const, static_ty.dtype)) + bwd_func = run_infer_type(gradient(run_infer_type(fwd_func))) + tvm.ir.assert_structural_equal(bwd_func.ret_type, expected_ty_dyn) + + +def test_zeros_ones_grad_dynamic(): + rank = np.random.randint(low=1, high=5, dtype="int32") + dyn_shape = np.random.randint(low=1, high=4, size=(rank,), dtype="int32") + shape_data = relay.var("shape_data", shape=(rank,), dtype="int32") + + for op, op_ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: + fwd_func = relay.Function([shape_data], op(shape_data, dtype="float32")) + bwd_func = run_infer_type(gradient(run_infer_type(fwd_func))) + + for target, ctx in tvm.testing.enabled_targets(): + intrp = relay.create_executor(ctx=ctx, target=target) + res, (grad,) = intrp.evaluate(bwd_func)(dyn_shape) + tvm.testing.assert_allclose(res.asnumpy(), op_ref(dyn_shape, dtype="float32")) + tvm.testing.assert_allclose(grad.asnumpy(), np.zeros((rank,), dtype="int32")) + + if __name__ == "__main__": pytest.main()