From 655d74a05f8a8b85d72c6db34c96c19f55c102a6 Mon Sep 17 00:00:00 2001 From: jheek Date: Fri, 27 Nov 2020 09:35:57 +0000 Subject: [PATCH 1/3] Add call override for use_running_average --- CHANGELOG.md | 3 +- docs/design_notes/arguments.md | 91 +++++++++++++++++++++++++++++ examples/pixelcnn/model_test.py | 2 +- examples/pixelcnn/pixelcnn.py | 14 +++-- examples/pixelcnn/train.py | 7 ++- flax/linen/attention.py | 24 +++++--- flax/linen/module.py | 36 ++++++++++++ flax/linen/normalization.py | 12 ++-- flax/linen/stochastic.py | 15 ++++- tests/linen/linen_attention_test.py | 7 ++- tests/linen/linen_test.py | 2 +- 11 files changed, 185 insertions(+), 28 deletions(-) create mode 100644 docs/design_notes/arguments.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 99613d3ef..6e3430710 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ vNext - - - - - + - Some Module arguments should now be passed either as dataclass attribute or + as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/design_notes/arguments.html) - - - diff --git a/docs/design_notes/arguments.md b/docs/design_notes/arguments.md new file mode 100644 index 000000000..66944f5e8 --- /dev/null +++ b/docs/design_notes/arguments.md @@ -0,0 +1,91 @@ +# Dealing with Module Arguments + +## Introduction + +In Linen we can define `Module` arguments either as dataclass attributes or as arguments to methods (usually `__call__`). +Typically the distinction is clear: +* Completely fixed properties, such as the choice of kernel initializer or number of output features, are hyperparameters and should be defined as dataclass attributes. Typically two Module instances with different hyperparamaters cannot share in a meaningfull way. +* Dynamic properties, such as input data and top-level "mode switches" like `train=True/False` that should be passed as arguments to `__call__` or another method. + +Some cases are however ambigious. Take for example the `Dropout` module. +We have a number of clear hyperparameters: +1. The dropout rate +2. The axes for which a dropout mask is generated + +And some clear call time arguments: +1. The input that should be masked using dropout +2. The (optional) rng used to sample the random mask + +There is however one property that is ambigious -- the `deterministic` property in a Dropout module. + +If `deterministic` is `True` no dropout mask is sampled. This is typically used during model evaluation. +However, if we pass `eval=True` or `train=False` to a top-level Module. The `deterministic` argument needs +to be applied everywhere and the boolean argument needs to be passed down to all the layers that might use `Dropout`. +If instead `deterministic` is a dataclass attribute, we might do the following: + +```python +from functools import partial +from flax import linen as nn + +class ResidualModel(nn.Module): + drop_rate: float + + @nn.compact + def __call__(self, x, *, train): + dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train) + for i in range(10): + x += ResidualBlock(dropout=dropout, ...)(x) +``` + +It makes sense to pass `determinstic` to the constructor here because this way we can pass the dropout template to the sub-modules. +Now the sub-module no longer needs to take care of train vs eval mode and can simply use the `dropout` argument. +Note that because the dropout layer can only be constructed in the sub-module we can only partially `deterministic` to the constructor but not to `__call__`. + +However, if `deterministic` is a dataclass attribute we run into trouble when using the setup pattern. We would **want** to write our module code like this: + +```python +class SomeModule(nn.Module): + drop_rate: float + + def setup(self): + self.dropout = nn.Dropout(rate=self.drop_rate) + + @nn.compact + def __call__(self, x, *, train): + # ... + x = self.dropout(x, deterministic=not train) + # ... +``` + +But as defined above `deterministic` would be an attribute, so this doesn't work. +Here it makes sense to pass `deterministic` during `__call__` because it depends on the `train` argument. + +## Solution + +We can support both use cases described before by allowing certain properties to be passed +as dataclass attributes or as method argument (but not both!). +This can be implemented as follows: +```python +class MyDropout(nn.Module): + drop_rate: float + deterministic: Optional[bool] = None + + @nn.compact + def __call__(self, x, deterministic=None): + deterministic = nn.merge_param('deterministic', self.deterministic, deterministic) + # ... +``` + +In this example `nn.merge_param` will ensure that either `self.deterministic` or `deterministic` is set but not both. +An error is raised if both values are `None` or both values are not `None`. +This avoids confusing behavior where 2 different parts of the code set the same parameter and one is overruled by the other. +It also avoids a default value which would problably cause either the train step or eval step of a training procedure to be broken by default. + + + +## Functional Core and flax.nn + +The old NN api and functional core define functions rather than classes. +Therefore, there is no clear distinction between hyper parameters and call time arguments. +The only way to pre-determine the hyper parameters is by using `partial`. +On the upside, there are no ambiguous cases where method arguments could also be attributes. diff --git a/examples/pixelcnn/model_test.py b/examples/pixelcnn/model_test.py index ec3b7e9a6..be59e00fe 100644 --- a/examples/pixelcnn/model_test.py +++ b/examples/pixelcnn/model_test.py @@ -126,7 +126,7 @@ def test_conv_transpose_down_right(self): def test_pcnn_shape(self): x = random.normal(self.rng, (2, 4, 4, 3)) model = pixelcnn.PixelCNNPP(depth=0, features=2, dropout_p=0) - out, _ = model.init_with_output(self.rng, x) + out, _ = model.init_with_output(self.rng, x, train=False) self.assertEqual(out.shape, (2, 4, 4, 100)) diff --git a/examples/pixelcnn/pixelcnn.py b/examples/pixelcnn/pixelcnn.py index 2750166a9..4dcc803c2 100644 --- a/examples/pixelcnn/pixelcnn.py +++ b/examples/pixelcnn/pixelcnn.py @@ -57,14 +57,17 @@ class PixelCNNPP(nn.Module): dropout_p: float = 0.5 @nn.compact - def __call__(self, images): + def __call__(self, images, *, train): # Special convolutional and resnet blocks which allow information flow # downwards and to the right. conv_down = partial(ConvDown, features=self.features) conv_down_right = partial(ConvDownRight, features=self.features) - res_down = partial(ResDown, dropout_p=self.dropout_p) - res_down_right = partial(ResDownRight, dropout_p=self.dropout_p) + dropout = partial( + nn.Dropout, rate=self.dropout_p, deterministic=not train) + + res_down = partial(ResDown, dropout=dropout) + res_down_right = partial(ResDownRight, dropout=dropout) # Conv Modules which halve or double the spatial dimensions halve_down = partial(conv_down, strides=(2, 2)) @@ -301,8 +304,8 @@ def __call__(self, inputs): # Resnet modules class GatedResnet(nn.Module): conv_module: Callable[..., Any] + dropout: Callable[..., Any] nonlinearity: Callable[..., Any] = concat_elu - dropout_p: float = 0. @nn.compact def __call__(self, inputs, aux=None): @@ -311,8 +314,7 @@ def __call__(self, inputs, aux=None): if aux is not None: y = self.nonlinearity(y + ConvOneByOne(c)(self.nonlinearity(aux))) - if self.dropout_p > 0: - y = nn.Dropout(rate=self.dropout_p)(y) + y = self.dropout()(y) # Set init_scale=0.1 so that the res block is close to the identity at # initialization. diff --git a/examples/pixelcnn/train.py b/examples/pixelcnn/train.py index 39cfd89a4..b506521cf 100644 --- a/examples/pixelcnn/train.py +++ b/examples/pixelcnn/train.py @@ -85,7 +85,8 @@ def loss_fn(params): config, dropout_p=config.dropout_rate).apply({'params': params}, batch['image'], - rngs={'dropout': dropout_rng}) + rngs={'dropout': dropout_rng}, + train=True) return neg_log_likelihood_loss(pcnn_out, batch['image']) lr = learning_rate_fn(optimizer.state.step) @@ -105,7 +106,7 @@ def loss_fn(params): def eval_step(config, params, batch): images = batch['image'] - pcnn_out = model(config, dropout_p=0.).apply({'params': params}, images) + pcnn_out = model(config).apply({'params': params}, images, train=False) return { 'loss': jax.lax.pmean(neg_log_likelihood_loss(pcnn_out, images), 'batch') } @@ -173,7 +174,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): { 'params': init_rng, 'dropout': dropout_rng - }, init_batch)['params'] + }, init_batch, train=False)['params'] optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(initial_variables) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 05be68230..2ac7384d7 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -25,7 +25,7 @@ from flax.linen.linear import default_kernel_init from flax.linen.linear import DenseGeneral -from flax.linen.module import Module, compact +from flax.linen.module import Module, compact, merge_param from flax.linen.initializers import zeros PRNGKey = Any @@ -125,7 +125,9 @@ class MultiHeadDotProductAttention(Module): out_features: dimension of the last projection broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) + deterministic: if false, the attention weight is masked randomly + using dropout, whereas if true, the attention weights + are deterministic. precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. @@ -142,7 +144,7 @@ class MultiHeadDotProductAttention(Module): out_features: Optional[int] = None broadcast_dropout: bool = True dropout_rate: float = 0. - deterministic: bool = False + deterministic: Optional[bool] = None precision: Any = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros @@ -154,7 +156,8 @@ class MultiHeadDotProductAttention(Module): def __call__(self, inputs_q: Array, inputs_kv: Array, - mask: Optional[Array] = None): + mask: Optional[Array] = None, + deterministic: Optional[bool] = None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -167,10 +170,14 @@ def __call__(self, `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. + deterministic: if false, the attention weight is masked randomly + using dropout, whereas if true, the attention weights + are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ + deterministic = merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( @@ -238,7 +245,7 @@ def __call__(self, attention_bias = None dropout_rng = None - if not self.deterministic and self.dropout_rate > 0.: + if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention @@ -250,7 +257,7 @@ def __call__(self, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, - deterministic=self.deterministic, + deterministic=deterministic, dtype=self.dtype, precision=self.precision) @@ -270,8 +277,9 @@ class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention.""" @compact - def __call__(self, inputs_q: Array, mask: Optional[Array] = None): - return super().__call__(inputs_q, inputs_q, mask) + def __call__(self, inputs_q: Array, mask: Optional[Array] = None, + deterministic: Optional[bool] = None): + return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic) # mask-making utility functions diff --git a/flax/linen/module.py b/flax/linen/module.py index e98ced6b2..b30f7c7af 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -814,6 +814,42 @@ def variables(self) -> VariableDict: return self.scope.variables() +T = TypeVar('T') +def merge_param(name: str, a: Optional[T], b: Optional[T]) -> T: + """Merges construction and call time argument. + + This is a utility for supporting the pattern where a Module hyper parameter + can be passed to `__init__` or `__call__`. + + Example:: + + class Foo(nn.Module): + train: Optional[bool] = None + + def __call__(self, train: Optional[bool] = None): + train = nn.merge_param('train', self.train, train) + + An error is thrown when both arguments are `None` or both values are not `None`. + + Args: + name: the name of the parameter. Used for error messages. + a: option a + b: option b + Returns: + a or b whichever is not `None`. + + """ + if a is None and b is None: + raise ValueError(f'Parameter "{name}" must be passed to the constructor or at call time.') + if a is not None and b is not None: + raise ValueError(f'Parameter "{name}" was passed to the constructor and at call time.' + ' Should be passed just once.') + if a is None: + return b + else: + return a + + # @contextmanager # def mutate(self, mutable=True, **updates): # cloned = self.clone(**updates) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 158058db1..e0b068616 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -20,7 +20,7 @@ from jax.nn import initializers import jax.numpy as jnp -from flax.linen.module import Module, compact +from flax.linen.module import Module, compact, merge_param PRNGKey = Any @@ -60,7 +60,7 @@ class BatchNorm(Module): the examples on the first two and last two devices. See `jax.lax.psum` for more details. """ - use_running_average: bool = False + use_running_average: Optional[bool] = None axis: int = -1 momentum: float = 0.99 epsilon: float = 1e-5 @@ -73,15 +73,19 @@ class BatchNorm(Module): axis_index_groups: Any = None @compact - def __call__(self, x): + def __call__(self, x, use_running_average: Optional[bool] = None): """Normalizes the input using batch statistics. Args: x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). """ + use_running_average = merge_param( + 'use_running_average', self.use_running_average, use_running_average) x = jnp.asarray(x, jnp.float32) axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) axis = _absolute_dims(x.ndim, axis) @@ -99,7 +103,7 @@ def __call__(self, x): lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) - if self.use_running_average: + if use_running_average: mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index 8c8c9fbf9..8d7a18797 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -15,11 +15,13 @@ """Stochastic modules. """ +from typing import Iterable, Optional + from jax import lax from jax import random import jax.numpy as jnp -from flax.linen.module import Module, compact +from flax.linen.module import Module, compact, merge_param class Dropout(Module): @@ -27,11 +29,17 @@ class Dropout(Module): Attributes: rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by `1 / (1 - rate)` and + masked, whereas if true, no mask is applied and the inputs are returned as + is. """ rate: float + broadcast_dims: Iterable[int] = () + deterministic: Optional[bool] = None @compact - def __call__(self, inputs, deterministic=False, rng=None, broadcast_dims=()): + def __call__(self, inputs, deterministic: Optional[bool] = None, rng=None): """Applies a random dropout mask to the input. Args: @@ -47,6 +55,7 @@ def __call__(self, inputs, deterministic=False, rng=None, broadcast_dims=()): Returns: The masked inputs reweighted to preserve mean. """ + deterministic = merge_param('deterministic', self.deterministic, deterministic) if self.rate == 0.: return inputs keep_prob = 1. - self.rate @@ -56,7 +65,7 @@ def __call__(self, inputs, deterministic=False, rng=None, broadcast_dims=()): if rng is None: rng = self.make_rng('dropout') broadcast_shape = list(inputs.shape) - for dim in broadcast_dims: + for dim in self.broadcast_dims: broadcast_shape[dim] = 1 mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 3e23b707f..2a6e70465 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -44,6 +44,7 @@ def test_multihead_self_attention(self): qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, + deterministic=False, ) y, _ = sa_module.init_with_output(rng, x) self.assertEqual(y.shape, x.shape) @@ -57,6 +58,7 @@ def test_multihead_encoder_decoder_attention(self): qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, + deterministic=False, ) y, _ = sa_module.init_with_output(rng, q, kv) self.assertEqual(y.shape, q.shape) @@ -70,6 +72,7 @@ def test_multihead_self_attention_w_dropout(self): kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.1, + deterministic=False, ) rng1, rng2 = random.split(rng) rngs = {'params': rng1, 'dropout': rng2} @@ -98,6 +101,7 @@ def test_decoding(self, spatial_shape, attn_dims): num_heads=num_heads, qkv_features=num_heads * num_features, precision=lax.Precision.HIGHEST, + deterministic=False, decode=False) decode_module = module.clone(decode=True) @@ -130,7 +134,8 @@ def test_autoregresive_receptive_field_1d(self): module = nn.MultiHeadDotProductAttention( num_heads=num_heads, - kernel_init=jax.nn.initializers.ones) + kernel_init=jax.nn.initializers.ones, + deterministic=False) initial_vars = module.init(rng1, inputs, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 0ff1d183c..cbac1975c 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -91,7 +91,7 @@ def test_batch_norm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (4, 3, 2)) - model_cls = nn.BatchNorm(momentum=0.9) + model_cls = nn.BatchNorm(momentum=0.9, use_running_average=False) y, initial_params = model_cls.init_with_output(key2, x) mean = y.mean((0, 1)) From 06dbc5cedef4039cab0688ba36057dbe4e867ed8 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Thu, 4 Feb 2021 15:22:24 +0000 Subject: [PATCH 2/3] Add breaking changes to CHANGELOG --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e3430710..1516c634f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,8 +17,8 @@ vNext - - Some Module arguments should now be passed either as dataclass attribute or as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/design_notes/arguments.html) - - - - + - `use_running_average` and `deterministic` no longer have a default. They should be passed explicitly + - `broadcast_dims` is now a attribute to `Dropout` instead of a `__call__` argument. - - - From 289880a51956abdf26d438d6fbc125494970e544 Mon Sep 17 00:00:00 2001 From: jheek Date: Thu, 4 Feb 2021 16:24:18 +0100 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Avital Oliver --- CHANGELOG.md | 3 +-- docs/design_notes/arguments.md | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1516c634f..ea571e7e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ vNext - - - - - Some Module arguments should now be passed either as dataclass attribute or + - Some Module arguments can now be passed either as dataclass attribute or as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/design_notes/arguments.html) - `use_running_average` and `deterministic` no longer have a default. They should be passed explicitly - `broadcast_dims` is now a attribute to `Dropout` instead of a `__call__` argument. @@ -96,4 +96,3 @@ v0.1 `rhs_dilation` -> `kernel_dilation` - Change default layer names from numbers '0', '1', etc. to include the Module class name, e.g. 'Dense_0', 'LayerNorm_1'. - diff --git a/docs/design_notes/arguments.md b/docs/design_notes/arguments.md index 66944f5e8..037575f74 100644 --- a/docs/design_notes/arguments.md +++ b/docs/design_notes/arguments.md @@ -4,7 +4,7 @@ In Linen we can define `Module` arguments either as dataclass attributes or as arguments to methods (usually `__call__`). Typically the distinction is clear: -* Completely fixed properties, such as the choice of kernel initializer or number of output features, are hyperparameters and should be defined as dataclass attributes. Typically two Module instances with different hyperparamaters cannot share in a meaningfull way. +* Completely fixed properties, such as the choice of kernel initializer or number of output features, are hyperparameters and should be defined as dataclass attributes. Typically two Module instances with different hyperparamaters cannot share in a meaningful way. * Dynamic properties, such as input data and top-level "mode switches" like `train=True/False` that should be passed as arguments to `__call__` or another method. Some cases are however ambigious. Take for example the `Dropout` module. @@ -39,7 +39,7 @@ class ResidualModel(nn.Module): It makes sense to pass `determinstic` to the constructor here because this way we can pass the dropout template to the sub-modules. Now the sub-module no longer needs to take care of train vs eval mode and can simply use the `dropout` argument. -Note that because the dropout layer can only be constructed in the sub-module we can only partially `deterministic` to the constructor but not to `__call__`. +Note that because the dropout layer can only be constructed in the sub-module we can only partially apply `deterministic` to the constructor but not to `__call__`. However, if `deterministic` is a dataclass attribute we run into trouble when using the setup pattern. We would **want** to write our module code like this: