Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ vNext
-
-
-
-
-
-
- Some Module arguments can now be passed either as dataclass attribute or
as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/design_notes/arguments.html)
- `use_running_average` and `deterministic` no longer have a default. They should be passed explicitly
- `broadcast_dims` is now a attribute to `Dropout` instead of a `__call__` argument.
-
-
-
Expand Down Expand Up @@ -95,4 +96,3 @@ v0.1
`rhs_dilation` -> `kernel_dilation`
- Change default layer names from numbers '0', '1', etc. to
include the Module class name, e.g. 'Dense_0', 'LayerNorm_1'.

91 changes: 91 additions & 0 deletions docs/design_notes/arguments.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Dealing with Module Arguments

## Introduction

In Linen we can define `Module` arguments either as dataclass attributes or as arguments to methods (usually `__call__`).
Typically the distinction is clear:
* Completely fixed properties, such as the choice of kernel initializer or number of output features, are hyperparameters and should be defined as dataclass attributes. Typically two Module instances with different hyperparamaters cannot share in a meaningful way.
* Dynamic properties, such as input data and top-level "mode switches" like `train=True/False` that should be passed as arguments to `__call__` or another method.

Some cases are however ambigious. Take for example the `Dropout` module.
We have a number of clear hyperparameters:
1. The dropout rate
2. The axes for which a dropout mask is generated

And some clear call time arguments:
1. The input that should be masked using dropout
2. The (optional) rng used to sample the random mask

There is however one property that is ambigious -- the `deterministic` property in a Dropout module.

If `deterministic` is `True` no dropout mask is sampled. This is typically used during model evaluation.
However, if we pass `eval=True` or `train=False` to a top-level Module. The `deterministic` argument needs
to be applied everywhere and the boolean argument needs to be passed down to all the layers that might use `Dropout`.
If instead `deterministic` is a dataclass attribute, we might do the following:

```python
from functools import partial
from flax import linen as nn

class ResidualModel(nn.Module):
drop_rate: float

@nn.compact
def __call__(self, x, *, train):
dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
for i in range(10):
x += ResidualBlock(dropout=dropout, ...)(x)
```

It makes sense to pass `determinstic` to the constructor here because this way we can pass the dropout template to the sub-modules.
Now the sub-module no longer needs to take care of train vs eval mode and can simply use the `dropout` argument.
Note that because the dropout layer can only be constructed in the sub-module we can only partially apply `deterministic` to the constructor but not to `__call__`.

However, if `deterministic` is a dataclass attribute we run into trouble when using the setup pattern. We would **want** to write our module code like this:

```python
class SomeModule(nn.Module):
drop_rate: float

def setup(self):
self.dropout = nn.Dropout(rate=self.drop_rate)

@nn.compact
def __call__(self, x, *, train):
# ...
x = self.dropout(x, deterministic=not train)
# ...
```

But as defined above `deterministic` would be an attribute, so this doesn't work.
Here it makes sense to pass `deterministic` during `__call__` because it depends on the `train` argument.

## Solution

We can support both use cases described before by allowing certain properties to be passed
as dataclass attributes or as method argument (but not both!).
This can be implemented as follows:
```python
class MyDropout(nn.Module):
drop_rate: float
deterministic: Optional[bool] = None

@nn.compact
def __call__(self, x, deterministic=None):
deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
# ...
```

In this example `nn.merge_param` will ensure that either `self.deterministic` or `deterministic` is set but not both.
An error is raised if both values are `None` or both values are not `None`.
This avoids confusing behavior where 2 different parts of the code set the same parameter and one is overruled by the other.
It also avoids a default value which would problably cause either the train step or eval step of a training procedure to be broken by default.



## Functional Core and flax.nn

The old NN api and functional core define functions rather than classes.
Therefore, there is no clear distinction between hyper parameters and call time arguments.
The only way to pre-determine the hyper parameters is by using `partial`.
On the upside, there are no ambiguous cases where method arguments could also be attributes.
2 changes: 1 addition & 1 deletion examples/pixelcnn/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_conv_transpose_down_right(self):
def test_pcnn_shape(self):
x = random.normal(self.rng, (2, 4, 4, 3))
model = pixelcnn.PixelCNNPP(depth=0, features=2, dropout_p=0)
out, _ = model.init_with_output(self.rng, x)
out, _ = model.init_with_output(self.rng, x, train=False)
self.assertEqual(out.shape, (2, 4, 4, 100))


Expand Down
14 changes: 8 additions & 6 deletions examples/pixelcnn/pixelcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ class PixelCNNPP(nn.Module):
dropout_p: float = 0.5

@nn.compact
def __call__(self, images):
def __call__(self, images, *, train):
# Special convolutional and resnet blocks which allow information flow
# downwards and to the right.
conv_down = partial(ConvDown, features=self.features)
conv_down_right = partial(ConvDownRight, features=self.features)

res_down = partial(ResDown, dropout_p=self.dropout_p)
res_down_right = partial(ResDownRight, dropout_p=self.dropout_p)
dropout = partial(
nn.Dropout, rate=self.dropout_p, deterministic=not train)

res_down = partial(ResDown, dropout=dropout)
res_down_right = partial(ResDownRight, dropout=dropout)

# Conv Modules which halve or double the spatial dimensions
halve_down = partial(conv_down, strides=(2, 2))
Expand Down Expand Up @@ -301,8 +304,8 @@ def __call__(self, inputs):
# Resnet modules
class GatedResnet(nn.Module):
conv_module: Callable[..., Any]
dropout: Callable[..., Any]
nonlinearity: Callable[..., Any] = concat_elu
dropout_p: float = 0.

@nn.compact
def __call__(self, inputs, aux=None):
Expand All @@ -311,8 +314,7 @@ def __call__(self, inputs, aux=None):
if aux is not None:
y = self.nonlinearity(y + ConvOneByOne(c)(self.nonlinearity(aux)))

if self.dropout_p > 0:
y = nn.Dropout(rate=self.dropout_p)(y)
y = self.dropout()(y)

# Set init_scale=0.1 so that the res block is close to the identity at
# initialization.
Expand Down
7 changes: 4 additions & 3 deletions examples/pixelcnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def loss_fn(params):
config,
dropout_p=config.dropout_rate).apply({'params': params},
batch['image'],
rngs={'dropout': dropout_rng})
rngs={'dropout': dropout_rng},
train=True)
return neg_log_likelihood_loss(pcnn_out, batch['image'])

lr = learning_rate_fn(optimizer.state.step)
Expand All @@ -105,7 +106,7 @@ def loss_fn(params):

def eval_step(config, params, batch):
images = batch['image']
pcnn_out = model(config, dropout_p=0.).apply({'params': params}, images)
pcnn_out = model(config).apply({'params': params}, images, train=False)
return {
'loss': jax.lax.pmean(neg_log_likelihood_loss(pcnn_out, images), 'batch')
}
Expand Down Expand Up @@ -173,7 +174,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
{
'params': init_rng,
'dropout': dropout_rng
}, init_batch)['params']
}, init_batch, train=False)['params']
optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995)
optimizer = optimizer_def.create(initial_variables)

Expand Down
24 changes: 16 additions & 8 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flax.linen.linear import default_kernel_init
from flax.linen.linear import DenseGeneral
from flax.linen.module import Module, compact
from flax.linen.module import Module, compact, merge_param
from flax.linen.initializers import zeros

PRNGKey = Any
Expand Down Expand Up @@ -125,7 +125,9 @@ class MultiHeadDotProductAttention(Module):
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
Expand All @@ -142,7 +144,7 @@ class MultiHeadDotProductAttention(Module):
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.
deterministic: bool = False
deterministic: Optional[bool] = None
Comment thread
jheek marked this conversation as resolved.
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
Expand All @@ -154,7 +156,8 @@ class MultiHeadDotProductAttention(Module):
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None):
mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
"""Applies multi-head dot product attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors,
Expand All @@ -167,10 +170,14 @@ def __call__(self,
`[batch_sizes..., length, features]`.
mask: attention mask of shape
`[batch_sizes..., num_heads, query_length, key/value_length]`.
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.

Returns:
output of shape `[batch_sizes..., length, features]`.
"""
deterministic = merge_param('deterministic', self.deterministic, deterministic)
features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
Expand Down Expand Up @@ -238,7 +245,7 @@ def __call__(self,
attention_bias = None

dropout_rng = None
if not self.deterministic and self.dropout_rate > 0.:
if not deterministic and self.dropout_rate > 0.:
dropout_rng = self.make_rng('dropout')

# apply attention
Expand All @@ -250,7 +257,7 @@ def __call__(self,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=self.deterministic,
deterministic=deterministic,
dtype=self.dtype,
precision=self.precision)

Expand All @@ -270,8 +277,9 @@ class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""

@compact
def __call__(self, inputs_q: Array, mask: Optional[Array] = None):
return super().__call__(inputs_q, inputs_q, mask)
def __call__(self, inputs_q: Array, mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic)


# mask-making utility functions
Expand Down
36 changes: 36 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,42 @@ def variables(self) -> VariableDict:
return self.scope.variables()


T = TypeVar('T')
def merge_param(name: str, a: Optional[T], b: Optional[T]) -> T:
"""Merges construction and call time argument.

This is a utility for supporting the pattern where a Module hyper parameter
can be passed to `__init__` or `__call__`.

Example::

class Foo(nn.Module):
train: Optional[bool] = None

def __call__(self, train: Optional[bool] = None):
train = nn.merge_param('train', self.train, train)

An error is thrown when both arguments are `None` or both values are not `None`.

Args:
name: the name of the parameter. Used for error messages.
a: option a
b: option b
Returns:
a or b whichever is not `None`.

"""
if a is None and b is None:
raise ValueError(f'Parameter "{name}" must be passed to the constructor or at call time.')
if a is not None and b is not None:
raise ValueError(f'Parameter "{name}" was passed to the constructor and at call time.'
' Should be passed just once.')
if a is None:
return b
else:
return a


# @contextmanager
# def mutate(self, mutable=True, **updates):
# cloned = self.clone(**updates)
Expand Down
12 changes: 8 additions & 4 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.nn import initializers
import jax.numpy as jnp

from flax.linen.module import Module, compact
from flax.linen.module import Module, compact, merge_param


PRNGKey = Any
Expand Down Expand Up @@ -60,7 +60,7 @@ class BatchNorm(Module):
the examples on the first two and last two devices. See `jax.lax.psum`
for more details.
"""
use_running_average: bool = False
use_running_average: Optional[bool] = None
Comment thread
jheek marked this conversation as resolved.
axis: int = -1
momentum: float = 0.99
epsilon: float = 1e-5
Expand All @@ -73,15 +73,19 @@ class BatchNorm(Module):
axis_index_groups: Any = None

@compact
def __call__(self, x):
def __call__(self, x, use_running_average: Optional[bool] = None):
"""Normalizes the input using batch statistics.

Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats
will be used instead of computing the batch statistics on the input.

Returns:
Normalized inputs (the same shape as inputs).
"""
use_running_average = merge_param(
'use_running_average', self.use_running_average, use_running_average)
x = jnp.asarray(x, jnp.float32)
axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
axis = _absolute_dims(x.ndim, axis)
Expand All @@ -99,7 +103,7 @@ def __call__(self, x):
lambda s: jnp.ones(s, jnp.float32),
reduced_feature_shape)

if self.use_running_average:
if use_running_average:
mean, var = ra_mean.value, ra_var.value
else:
mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
Expand Down
Loading