Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions python/tvm/topi/nn/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Batch normalization."""
import typing
from functools import reduce

from tvm import te
from tvm import topi
Expand All @@ -31,6 +32,8 @@ def batch_norm(
epsilon: typing.Optional[float] = None,
center: typing.Optional[bool] = None,
scale: typing.Optional[bool] = None,
training: typing.Optional[bool] = None,
momentum: typing.Optional[float] = None,
) -> typing.List[te.Tensor]:
"""Batch normalization layer (Ioffe and Szegedy, 2014).

Expand Down Expand Up @@ -69,6 +72,13 @@ def batch_norm(
If True, scale normalized tensor by gamma. If False, gamma
is ignored.

training : bool, optional, defualt=False
Indicating whether it is in training mode. If True, update
moving_mean and moving_var.

momentum : float, optional, default=0.1
The value used for the moving_mean and moving_var update.

Returns
-------
output : list of tvm.te.Tensor
Expand All @@ -92,19 +102,47 @@ def batch_norm(
if scale is None:
scale = True

if training is None:
training = False

if momentum is None:
momentum = 0.1

shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]

moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)

out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
if training:
reduce_axes = list(range(len(data.shape)))
reduce_axes.remove(axis)
shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1)
data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
data_mean_rs = topi.reshape(data_mean, shape)
data_var = (
topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod
)
data_var_rs = topi.reshape(data_var, shape)
out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
else:
moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)
out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)

if scale:
out = out * topi.reshape(gamma, shape)
if center:
out = out + topi.reshape(beta, shape)

if training:
assert 0 <= momentum <= 1, "the valid momentum range is [0, 1]."
data_var = (
topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod
)
return [
out,
(1 - momentum) * moving_mean + momentum * data_mean,
(1 - momentum) * moving_var + momentum * data_var,
]

# Moving mean and var aren't updated during test. To avoid
# placeholder reuse, we multiply by 1 and return them.
return [out, moving_mean * 1, moving_var * 1]
33 changes: 29 additions & 4 deletions python/tvm/topi/testing/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def batch_norm(
epsilon: float,
center: bool,
scale: bool,
training: bool,
momentum: float,
):
"""Batch Normalization operator implemented in Numpy.

Expand Down Expand Up @@ -62,6 +64,13 @@ def batch_norm(
If True, scale normalized tensor by gamma. If False, gamma
is ignored.

training : bool
Indicating whether it is in training mode. If True, update
moving_mean and moving_var.

momentum : float
The value used for the moving_mean and moving_var update

Returns
-------
output : np.ndarray
Expand All @@ -76,14 +85,30 @@ def batch_norm(
shape = [1] * len(x.shape)
shape[axis] = x.shape[axis]

moving_mean_rs = moving_mean.reshape(shape)
moving_var_rs = moving_var.reshape(shape)

out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon)
if training:
reduce_axes = list(range(len(x.shape)))
reduce_axes.remove(axis)
reduce_axes = tuple(reduce_axes)
data_mean = np.mean(x, axis=reduce_axes)
data_var = np.var(x, axis=reduce_axes)
data_mean_rs = np.reshape(data_mean, shape)
data_var_rs = np.reshape(data_var, shape)
out = (x - data_mean_rs) / np.sqrt(data_var_rs + epsilon)
else:
moving_mean_rs = moving_mean.reshape(shape)
moving_var_rs = moving_var.reshape(shape)
out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon)

if scale:
out = out * gamma.reshape(shape)
if center:
out = out + beta.reshape(shape)

if training:
return [
out,
(1 - momentum) * moving_mean + momentum * data_mean,
(1 - momentum) * moving_var + momentum * data_var,
]

return [out, moving_mean, moving_var]
46 changes: 33 additions & 13 deletions tests/python/topi/python/test_topi_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,37 @@


@pytest.mark.parametrize(
"shape, axis, epsilon, center, scale",
"shape, axis, epsilon, center, scale, training, momentum",
[
((1,), 0, 0.1, True, True),
((2, 3), 0, 0.1, True, True),
((1, 2, 4), 0, 0.1, True, True),
((1, 2, 3, 4), 0, 0.001, False, False),
((2, 3, 4, 1), 1, 0.01, False, True),
((3, 4, 1, 2), 2, 0.1, True, False),
((4, 1, 2, 3), 3, 1.0, True, True),
((1, 2, 4, 4, 5), 0, 0.1, True, True),
((1,), 0, 0.1, True, True, False, 0.1),
((2, 3), 0, 0.1, True, True, False, 0.1),
((1, 2, 4), 0, 0.1, True, True, False, 0.1),
((1, 2, 3, 4), 0, 0.001, False, False, False, 0.1),
((2, 3, 4, 1), 1, 0.01, False, True, False, 0.1),
((3, 4, 1, 2), 2, 0.1, True, False, True, 0.1),
((4, 1, 2, 3), 3, 1.0, True, True, True, 0.2),
((1, 2, 4, 4, 5), 0, 0.1, True, True, True, 0.3),
],
)
def test_batch_norm(shape, axis, epsilon, center, scale):
def test_batch_norm(shape, axis, epsilon, center, scale, training, momentum):
x_np = np.random.random(shape).astype("float32")
gamma_np = np.random.random(shape[axis]).astype("float32")
beta_np = np.random.random(shape[axis]).astype("float32")
moving_mean_np = np.random.random(shape[axis]).astype("float32")
moving_var_np = np.random.random(shape[axis]).astype("float32")

out_x_np, out_moving_mean_np, out_moving_var_np = tvm.topi.testing.batch_norm(
x_np, gamma_np, beta_np, moving_mean_np, moving_var_np, axis, epsilon, center, scale
x_np,
gamma_np,
beta_np,
moving_mean_np,
moving_var_np,
axis,
epsilon,
center,
scale,
training,
momentum,
)

x_te = te.placeholder(shape, name="x", dtype="float32")
Expand All @@ -65,7 +75,17 @@ def test_batch_norm(shape, axis, epsilon, center, scale):
with tvm.target.Target(_DEVICE):
fcompute, fschedule = tvm.topi.testing.dispatch(_DEVICE, _BATCH_NORM_IMPLEMENT)
out_x, out_moving_mean, out_moving_var = fcompute(
x_te, gamma_te, beta_te, moving_mean_te, moving_var_te, axis, epsilon, center, scale
x_te,
gamma_te,
beta_te,
moving_mean_te,
moving_var_te,
axis,
epsilon,
center,
scale,
training,
momentum,
)
s = fschedule([out_x, out_moving_mean, out_moving_var])

Expand Down Expand Up @@ -113,4 +133,4 @@ def test_batch_norm(shape, axis, epsilon, center, scale):


if __name__ == "__main__":
test_batch_norm()
tvm.testing.main()