Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 74 additions & 8 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
tile,
transpose,
where,
repeat,
expand_dims,
full_like
)


Expand Down Expand Up @@ -198,6 +201,7 @@ def clip_grad(orig, grad):

@register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad):
"""Returns the gradient of max_pool2d."""
attrs = orig.attrs
pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
Expand All @@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):

@register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad):
"""Returns the gradient of avg_pool2d."""
attrs = orig.attrs
pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
Expand All @@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
return [pool_grad]


@register_gradient("nn.global_avg_pool2d")
def global_avg_pool2d_grad(orig, grad):
"""Returns the gradient of global_avg_pool2d."""
data = orig.args[0]
shape = data.checked_type.shape
layout = orig.attrs.layout

# we assume NCHW or NHWC layout for now, but easy to add more
assert layout in ["NCHW", "NHWC"]
if layout == "NCHW":
pool_size = shape[2], shape[3]
elif layout == "NHWC":
pool_size = shape[1], shape[2]

pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size,
strides=(1, 1), padding=(0, 0),
layout=layout)
return [pool_grad]


# not implemented, this is only for testing.
@register_gradient("concatenate")
def concatenate_grad(orig, grad):
Expand Down Expand Up @@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
return [backward_data, backward_weight]


def _get_reduce_axis(call):
"""Helper function that returns the reduce axis of the call as plain python ints."""
x, axis = call.args[0], call.attrs.axis
shape = x.checked_type.concrete_shape

# should never exclude when axis is None
assert not (axis is None and call.attrs.exclude)

if axis is None:
return None

# convert to nonnegative integers and sort
axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)])
if call.attrs.exclude:
axis = [ax for ax in range(len(shape)) if ax not in axis]
return axis


def _unreduce_expand(x, axis):
"""Helper function that returns x expanded on the reduced dimensions in axis."""
# assume axis is sorted nonnegative ints
for ax in axis:
x = expand_dims(x, ax)
return x


@register_gradient("max")
def max_grad(orig, grad):
"""Returns the gradient of max"""
# Only support axis=0, since broadcasting orig to x behaves incorrectly
x, axis = orig.args[0], orig.attrs.axis
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
orig = broadcast_to_like(orig, x)
grad = broadcast_to_like(grad, x)
indicators = cast_like(equal(orig, x), grad)
return [indicators * grad]
x, axis = orig.args[0], _get_reduce_axis(orig)
shape = x.checked_type.concrete_shape

repeated = orig
if axis is None:
repeated = full_like(x, repeated)
else:
# expand dims (if necessary) and repeat along each axis
if not orig.attrs.keepdims:
repeated = _unreduce_expand(repeated, axis)
grad = _unreduce_expand(grad, axis)
for ax in axis:
repeated = repeat(repeated, shape[ax], ax)

indicators = cast_like(equal(repeated, x), grad)
num_selected = _sum(indicators, axis, keepdims=True)
# spread error across all max weights
return [indicators * grad / num_selected]


@register_gradient("nn.softmax")
Expand Down Expand Up @@ -372,7 +434,11 @@ def negative_grad(orig, grad):
@register_gradient("sum")
def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
data = orig.args[0]
data, axis = orig.args[0], _get_reduce_axis(orig)
if not orig.attrs.keepdims:
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
grad = _unreduce_expand(grad, axis)
return [broadcast_to_like(grad, data)]


Expand Down
29 changes: 26 additions & 3 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):


def test_max_pool2d_grad():
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)


Expand All @@ -75,14 +74,37 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


def test_avg_pool2d_grad():
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False, count_include_pad=True)
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
ceil_mode=False, count_include_pad=False)


def verify_global_avg_pool2d_grad(x_shape):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.global_avg_pool2d(x)

fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

data = np.random.rand(*x_shape).astype("float32")
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]),
strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg',
ceil_mode=False)

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)

def test_global_avg_pool2d_grad():
verify_global_avg_pool2d_grad((1, 4, 16, 16))
verify_global_avg_pool2d_grad((1, 8, 8, 24))

def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
try:
import torch
Expand Down Expand Up @@ -155,6 +177,7 @@ def test_batch_flatten_grad():
if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
test_global_avg_pool2d_grad()
test_conv2d_grad()
test_dense_grad()
test_batch_flatten_grad()
19 changes: 11 additions & 8 deletions tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ def test_sum_grad():
verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
verify_sum_grad((4, 2, 1), axis=1)


def test_max_grad():
s = (10, 10)
t = relay.TensorType(s)
x = relay.var("x", t)
axis = 0
z = relay.max(x, axis)

fwd_func = relay.Function([x], z)
def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude))
check_grad(fwd_func, scale=1e-3)


def test_max_grad():
verify_max_grad((10, 10), axis=None)
verify_max_grad((10, 10), axis=-1)
verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True)
verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)


if __name__ == "__main__":
pytest.main()