diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index 1b4fad762568..3181efd7daa6 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -16,6 +16,7 @@ # under the License. """Batch normalization.""" import typing +from functools import reduce from tvm import te from tvm import topi @@ -31,6 +32,8 @@ def batch_norm( epsilon: typing.Optional[float] = None, center: typing.Optional[bool] = None, scale: typing.Optional[bool] = None, + training: typing.Optional[bool] = None, + momentum: typing.Optional[float] = None, ) -> typing.List[te.Tensor]: """Batch normalization layer (Ioffe and Szegedy, 2014). @@ -69,6 +72,13 @@ def batch_norm( If True, scale normalized tensor by gamma. If False, gamma is ignored. + training : bool, optional, defualt=False + Indicating whether it is in training mode. If True, update + moving_mean and moving_var. + + momentum : float, optional, default=0.1 + The value used for the moving_mean and moving_var update. + Returns ------- output : list of tvm.te.Tensor @@ -92,19 +102,47 @@ def batch_norm( if scale is None: scale = True + if training is None: + training = False + + if momentum is None: + momentum = 0.1 + shape = [1] * len(data.shape) shape[axis] = data.shape[axis] - moving_mean_rs = topi.reshape(moving_mean, shape) - moving_var_rs = topi.reshape(moving_var, shape) - - out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) + if training: + reduce_axes = list(range(len(data.shape))) + reduce_axes.remove(axis) + shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) + data_mean = topi.sum(data, axis=reduce_axes) / shape_prod + data_mean_rs = topi.reshape(data_mean, shape) + data_var = ( + topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod + ) + data_var_rs = topi.reshape(data_var, shape) + out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon) + else: + moving_mean_rs = topi.reshape(moving_mean, shape) + moving_var_rs = topi.reshape(moving_var, shape) + out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) if scale: out = out * topi.reshape(gamma, shape) if center: out = out + topi.reshape(beta, shape) + if training: + assert 0 <= momentum <= 1, "the valid momentum range is [0, 1]." + data_var = ( + topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod + ) + return [ + out, + (1 - momentum) * moving_mean + momentum * data_mean, + (1 - momentum) * moving_var + momentum * data_var, + ] + # Moving mean and var aren't updated during test. To avoid # placeholder reuse, we multiply by 1 and return them. return [out, moving_mean * 1, moving_var * 1] diff --git a/python/tvm/topi/testing/batch_norm.py b/python/tvm/topi/testing/batch_norm.py index 0a79b6849d4e..d86249018aa0 100644 --- a/python/tvm/topi/testing/batch_norm.py +++ b/python/tvm/topi/testing/batch_norm.py @@ -28,6 +28,8 @@ def batch_norm( epsilon: float, center: bool, scale: bool, + training: bool, + momentum: float, ): """Batch Normalization operator implemented in Numpy. @@ -62,6 +64,13 @@ def batch_norm( If True, scale normalized tensor by gamma. If False, gamma is ignored. + training : bool + Indicating whether it is in training mode. If True, update + moving_mean and moving_var. + + momentum : float + The value used for the moving_mean and moving_var update + Returns ------- output : np.ndarray @@ -76,14 +85,30 @@ def batch_norm( shape = [1] * len(x.shape) shape[axis] = x.shape[axis] - moving_mean_rs = moving_mean.reshape(shape) - moving_var_rs = moving_var.reshape(shape) - - out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon) + if training: + reduce_axes = list(range(len(x.shape))) + reduce_axes.remove(axis) + reduce_axes = tuple(reduce_axes) + data_mean = np.mean(x, axis=reduce_axes) + data_var = np.var(x, axis=reduce_axes) + data_mean_rs = np.reshape(data_mean, shape) + data_var_rs = np.reshape(data_var, shape) + out = (x - data_mean_rs) / np.sqrt(data_var_rs + epsilon) + else: + moving_mean_rs = moving_mean.reshape(shape) + moving_var_rs = moving_var.reshape(shape) + out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon) if scale: out = out * gamma.reshape(shape) if center: out = out + beta.reshape(shape) + if training: + return [ + out, + (1 - momentum) * moving_mean + momentum * data_mean, + (1 - momentum) * moving_var + momentum * data_var, + ] + return [out, moving_mean, moving_var] diff --git a/tests/python/topi/python/test_topi_batch_norm.py b/tests/python/topi/python/test_topi_batch_norm.py index f6c37ce6de3d..c7feb5d7c860 100644 --- a/tests/python/topi/python/test_topi_batch_norm.py +++ b/tests/python/topi/python/test_topi_batch_norm.py @@ -33,19 +33,19 @@ @pytest.mark.parametrize( - "shape, axis, epsilon, center, scale", + "shape, axis, epsilon, center, scale, training, momentum", [ - ((1,), 0, 0.1, True, True), - ((2, 3), 0, 0.1, True, True), - ((1, 2, 4), 0, 0.1, True, True), - ((1, 2, 3, 4), 0, 0.001, False, False), - ((2, 3, 4, 1), 1, 0.01, False, True), - ((3, 4, 1, 2), 2, 0.1, True, False), - ((4, 1, 2, 3), 3, 1.0, True, True), - ((1, 2, 4, 4, 5), 0, 0.1, True, True), + ((1,), 0, 0.1, True, True, False, 0.1), + ((2, 3), 0, 0.1, True, True, False, 0.1), + ((1, 2, 4), 0, 0.1, True, True, False, 0.1), + ((1, 2, 3, 4), 0, 0.001, False, False, False, 0.1), + ((2, 3, 4, 1), 1, 0.01, False, True, False, 0.1), + ((3, 4, 1, 2), 2, 0.1, True, False, True, 0.1), + ((4, 1, 2, 3), 3, 1.0, True, True, True, 0.2), + ((1, 2, 4, 4, 5), 0, 0.1, True, True, True, 0.3), ], ) -def test_batch_norm(shape, axis, epsilon, center, scale): +def test_batch_norm(shape, axis, epsilon, center, scale, training, momentum): x_np = np.random.random(shape).astype("float32") gamma_np = np.random.random(shape[axis]).astype("float32") beta_np = np.random.random(shape[axis]).astype("float32") @@ -53,7 +53,17 @@ def test_batch_norm(shape, axis, epsilon, center, scale): moving_var_np = np.random.random(shape[axis]).astype("float32") out_x_np, out_moving_mean_np, out_moving_var_np = tvm.topi.testing.batch_norm( - x_np, gamma_np, beta_np, moving_mean_np, moving_var_np, axis, epsilon, center, scale + x_np, + gamma_np, + beta_np, + moving_mean_np, + moving_var_np, + axis, + epsilon, + center, + scale, + training, + momentum, ) x_te = te.placeholder(shape, name="x", dtype="float32") @@ -65,7 +75,17 @@ def test_batch_norm(shape, axis, epsilon, center, scale): with tvm.target.Target(_DEVICE): fcompute, fschedule = tvm.topi.testing.dispatch(_DEVICE, _BATCH_NORM_IMPLEMENT) out_x, out_moving_mean, out_moving_var = fcompute( - x_te, gamma_te, beta_te, moving_mean_te, moving_var_te, axis, epsilon, center, scale + x_te, + gamma_te, + beta_te, + moving_mean_te, + moving_var_te, + axis, + epsilon, + center, + scale, + training, + momentum, ) s = fschedule([out_x, out_moving_mean, out_moving_var]) @@ -113,4 +133,4 @@ def test_batch_norm(shape, axis, epsilon, center, scale): if __name__ == "__main__": - test_batch_norm() + tvm.testing.main()