Move TF building to an actual build() method#23760
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
e3068b1 to
aa599e4
Compare
|
This should be ready to review now! Some tests failing, but that looks like Hub connection issues |
sgugger
left a comment
There was a problem hiding this comment.
Thanks for your PR. I think it needs more TF-expert eyes so tagging @amyeroberts
The changes in Longformer and LED are very big so should go in their own PR to make it easier for future blame.
| ) | ||
| out = tf.matmul(x, w, transpose_b=True) | ||
| if b is not None: | ||
| out += b |
There was a problem hiding this comment.
Is the transpose of b unnecessary here?
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| build_context = threading.local() |
There was a problem hiding this comment.
This is not used anywhere here?
There was a problem hiding this comment.
Ah, you're right, sorry! This is leftover from an earlier approach I was trying.
amyeroberts
left a comment
There was a problem hiding this comment.
Nice :) Generally looks like a good clean-up, but I'm not a fan of the current API - it overrides default keras API behaviour in a non-obvious way. I left a more detailed comment in modeling_tf_utils.py.
One question I have is about the change in conditional logic checks in TF modeling code i.e. removing tf.cond(...) - is this necessary with this new build logic or just an update based on the new logic?
| lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)), | ||
| lambda: pixel_values, | ||
| ) | ||
| if shape_list(pixel_values)[1] == num_channels: |
There was a problem hiding this comment.
I don't think the updated check is compatible with graph mode / compilation of models for XLA
There was a problem hiding this comment.
shape_list is actually quite smart (because I wasn't the one who wrote it) - it returns a list of the static shape for each dimension when this is known at compile time, and the dynamic shape when it isn't. In an XLA compilation pass shapes are fully static, and so the static shape will always be fully known. As a result, all shape_list calls will just return the static shape in an XLA context and not introduce data-dependent conditionals.
There was a problem hiding this comment.
Nice! But also don't put yourself down - I thought you'd written it :) I fear we shall never rid ourselves of shape_list
| def build_with_dummies(self, dummies=None): | ||
| if self.built_with_dummies and dummies is None: | ||
| return | ||
| if dummies is None: | ||
| dummies = self.dummy_inputs | ||
| self(dummies, training=False) | ||
| self.built_with_dummies = True |
There was a problem hiding this comment.
I think the API here is a bit confusing:
- Allowing
dummiesto essentially be any input - Rebuilding if
dummiesis None. - Having
input_shapeas an argument forbuildand then not using it
I would rework to have build match the logic of the parent class, and build_with_dummies just use the model's dummy values. This way build remains generic and build_with_x specifically builds with x e.g. something like:
def build(self, input_shape):
if self.built or call_context().in_call:
self.built = True
return
self(input_shape, training=False)
self.built = True
self.built_with_dummies = True
def build_with_dummies(self):
if self.built_with_dummies:
return
self(self.dummy_inputs, training=False)
self.built_with_dummies = True
self.built = TrueThis would:
- Change all the calls from
model.build()tomodel.build_with_dummies(). We can then remove all the comments next to the build calls explaining we're using dummy inputs. - Remove the need to call
super().build(input_shape)when we want the old build logic. - Removes the need to set the input_shape to
Nonein all the current build methods
also - why have the distinction between built and built_with_dummies?
|
Actually, I should explain my reasoning for some of the changes here - you're probably right that I can improve the API, though! Firstly, the removal of Incredibly, though, Removing Secondly, the distinction between |
😑 In this case, let's rid ourselves of this pseudolayer! I'm pro the if/else changes :)
Yep, that's what I would go for. Would it be possible to still have some of the logic to exit early if already built? Or would this be to tricky to handle to be worth it? |
|
I think we could, but it's probably not necessary - the only cases where we build the model with specific inputs are in weird PT-TF crossloading functions, which should always be called during or near model init anyway, so I think it's fine if there's a risk of a little bit of duplicated work there to save on overall code complexity. |
…), which also sets it
4386387 to
c76c308
Compare
|
@amyeroberts Done! |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for iterating - nice cleanup!
Just a small Q about build logic for LED / Longformer
| with tf.name_scope("query_global"): | ||
| self.query_global.build((self.config.hidden_size,)) | ||
| with tf.name_scope("key_global"): | ||
| self.key_global.build((self.config.hidden_size,)) | ||
| with tf.name_scope("value_global"): | ||
| self.value_global.build((self.config.hidden_size,)) |
There was a problem hiding this comment.
Two silly questions
- why do these layers need to be built separately here?
- why no
super().build(input_shape)call in the method?
There was a problem hiding this comment.
Good question! The answer is that it's hard to get dummy inputs that correctly touch all of those layers, so they tend to be left un-built unless we explicitly build them.
As for the super().build(), I just forgot! The base build() method doesn't really do anything, but you're right that I should probably still call it just in case.
|
Also, this PR looks ready but I'm going to let it sit for a couple of days to make sure the CI is working again after my last library-breaking PR, then merge it. |
|
Change of plans: The CI is working except for OOM errors during building for some of the pipelines, and since this cleans up building a bit we're going to merge this one too and see if it helps. If it doesn't, I'll open a new PR to see if I can lower the memory usage in the affected models. |
| if parse(tf.__version__).minor >= 13: | ||
| from keras import backend as K | ||
| from keras.__internal__ import KerasTensor | ||
| from keras.engine.base_layer_utils import call_context |
There was a problem hiding this comment.
This would break since keras 2.13 has moved the import to keras.src.engine
See #23663
There was a problem hiding this comment.
Thanks for the catch, I'll make the fix ASAP!
| lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length), | ||
| lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len), | ||
| ) | ||
| if seq_len > 1: |
There was a problem hiding this comment.
Hi @Rocketknight1, I am unfortunately having an issue with this change.
When I build a functional keras model using the Whisper encoder & decoder layers, I cannot serialize the model because of this change as it raises the error:
Using a symbolic `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Here is a minimal reproducible example to raise the error:
from transformers import TFWhisperModel
import tensorflow as tf
whisper = TFWhisperModel.from_pretrained("openai/whisper-tiny")
inp = tf.keras.Input((80, 3000))
stack = whisper.get_encoder()(inp)
decoder_input_ids = tf.ones((tf.shape(inp)[0], 1), dtype=tf.int32)* whisper.config.decoder_start_token_id
stack = whisper.get_decoder()(input_ids=decoder_input_ids, encoder_hidden_states=stack.last_hidden_state)
model = tf.keras.Model(inp, stack)
model.summary()
model.save("whisper-tiny-custom")What do you think?
I will open an issue for this to be referenced!
* A fun new PR where I break the entire codebase again * A fun new PR where I break the entire codebase again * Handle cross-attention * Move calls to model(model.dummy_inputs) to the new build() method * Seeing what fails with the build context thing * make fix-copies * Let's see what fails with new build methods * Fix the pytorch crossload build calls * Fix the overridden build methods in vision_text_dual_encoder * Make sure all our build methods set self.built or call super().build(), which also sets it * make fix-copies * Remove finished TODO * Tentatively remove unneeded (?) line * Transpose b in deberta correctly and remove unused threading local * Get rid of build_with_dummies and all it stands for * Rollback some changes to TF-PT crossloading * Correctly call super().build()
This has been a longstanding dream of mine: To move all TF model building into a proper
build()method, using symbolic tensors instead of actual dummies. This would allow us to, among other things, stop our very hacky overriding ofsave_spec, as well as allowing us to build our TF models with zero device flops (although the speedup may be system-dependent, as we do have some compile time with this approach). It would make our models much closer to the Keras standard, which would stop Chollet casting curses upon me from afar.In the past, we've run into serious problems with tensor names moving around when we tried this - I think I've figured out why, though, and I have a couple of ideas to resolve that without lots of hacky edge-case code.
This is an extremely draft PR that will break everything until I finish testing it properly!
Update: Using symbolic tensors is much slower - it works in most cases, but increases the time it takes for our tests to run by a factor of ~4, which is probably not acceptable. Instead, I'm going to rework this PR to move to a standard build() method using actual dummies. With some optimizations, I believe we can make this work, while still preserving most of the benefits of this PR, including not repeating the build unnecessarily and adding the ability to override
build()to speed up our slowest models