Skip to content

Added casts to GDN, DeepFactorized, and ContinuousBase for better mixed precision compatibility#105

Closed
MahmoudAshraf97 wants to merge 3 commits intotensorflow:masterfrom
MahmoudAshraf97:master
Closed

Added casts to GDN, DeepFactorized, and ContinuousBase for better mixed precision compatibility#105
MahmoudAshraf97 wants to merge 3 commits intotensorflow:masterfrom
MahmoudAshraf97:master

Conversation

@MahmoudAshraf97
Copy link
Contributor

Hello
These changes to support mixed precision training because explicitly instantiating the layers and objects in float16 throws errors with operations involving other layers' variables and weights

changed tf.convert_to_tensor to tf.cast  for better mixed precision support
added cast to floatx for logits for better mixed precision support
added cast to floatx for quantization inputs for better mixed precision support
@jonaballe
Copy link
Collaborator

Hi Mahmoud, thanks for the PR, but I'm not sure this is the best way to deal with mixed precision. Can you explain which parts of the model you want to have in which precision?

All the layer and entropy model classes have a dtype argument; isn't it enough to set these arguments when you instantiate them, and then ensure that any tensors you are passing into them have the right dtype?

@MahmoudAshraf97
Copy link
Contributor Author

MahmoudAshraf97 commented Feb 1, 2022

All the layer and entropy model classes have a dtype argument; isn't it enough to set these arguments when you instantiate them, and then ensure that any tensors you are passing into them have the right dtype?

this argument works perfectly in single precision mode whether it's float16 or float32, but with mixed precision different types are used for forward pass and back propagation and calculation and storage of the weights which is handled automatically by tensorflow, so hard coding the dtype into the layer or object using dtype argument will not work because the layer must be flexible to work with multiple dtypes in the same model at different training stages.

Also I used a version of TFC2.2 with this commit for a while and training a model using fp32 and mixed precision converged to the same result so I think it's safe to assume that it yields correct results

@jonaballe
Copy link
Collaborator

I see, thanks for the explanation. We weren't aware of the Keras developments regarding mixed precision. I'll look into this some more, will keep you posted.

@jonaballe
Copy link
Collaborator

We added full mixed precision support in commit c20abdb. Can you please check if this solves your issue?

@MahmoudAshraf97
Copy link
Contributor Author

MahmoudAshraf97 commented Feb 10, 2022

@jonycgn unfortunately, the mentioned commit does not resolve the issue, using the code https://github.com/MahmoudAshraf97/AutoencoderCompression with TF2.8 and TFC2.8 still throws the error ValueError: Tensor conversion requested dtype float16 for Tensor with dtype float32 at any usage of the tf.convert_to_tensor function.

Although the documentation of the function states that it accepts tensor objects, it throws the same error if a tensor is passed with a different dtype than the requested dype, check this code snippet for a quick example:

test = tf.zeros(5,dtype=tf.float32)
tf.convert_to_tensor(test,dtype=tf.uint8)

passing any dtype value other than tf.float32 gives an error

Nevertheless, the commit handled the mixed precision training in a good way but the problem here is with tf.convert_to_tensor strange behavior, my proposal to solve this issue is to handle tf tensors using tf.cast and use tf.convert_to_tensor for other objects

@jonaballe
Copy link
Collaborator

Could you clarify which specific instances of tf.convert_to_tensor are causing your problems?

@jonaballe
Copy link
Collaborator

FWIW, I can't see your code in https://github.com/MahmoudAshraf97/AutoencoderCompression. Is it private?

@MahmoudAshraf97
Copy link
Contributor Author

/tensorflow_compression/python/entropy_models/continuous_batched.py", line 304, in __call__ bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
Please note that all instances of tf.convert_to_tensor cause the issue, this is confirmed by replacing an instance with tf.cast and the issue appears somewhere else
the repo is public but to save time the error occurs in the line defining the bottleneck tensor
"modules.py", line 178, in call y_hat, bits = entropy_model(y, training=training)

@jonaballe
Copy link
Collaborator

If you hit this problem in line 178, this would indicate that the analysis transform did not output a tensor that has the same dtype as bottleneck_dtype. I don't think this should happen if everything is set up correctly (i.e. normally, both should be float32, and if you use mixed precision, both should be float16). So that would indicate that the problem is somewhere else. Did you set the tf.keras.mixed_precision.global_policy?

@MahmoudAshraf97
Copy link
Contributor Author

tldr: bottleneck_dtype should default to tf.keras.mixed_precision.global_policy().variable_dtype

In my code I didn't explicitly set the dtype for any layer or tensor, so they're all handled by keras and work perfectly in single precision whether float16 or float32. As I stated in an earlier comment, tf.convert_to_tensor only accepts tf tensor objects that their dtype is the same as the requested dtype, which is always the case in single precision training, otherwise it throws an error if the input tensor object has a different dtype than the requested dtype, this doesn't happen with other objects such as float of numpy array regardless of their dtype.
Why does the problem occur in mixed precision training? the input to the entropy model is a tf variable which has the variable_dtype which is by default tf.float32, thus the output tensor should also be variable_dtype not compute_dtype which is the current case with the bottleneck tensor. A suggested workaround is to explicitly set the bottleneck_dtype to tf.float32 but this will make the code inflexible imho, the main advantage of mixed precision is the ability to run old code without any modification that might affect it in other training precisions.

this is a colab notebook that presents a minimal example to try with
notice that the code works fine in two cases only:

  • without mixed precision
  • with mixed precision enabled and bottleneck_dtype set to tf.keras.mixed_precision.global_policy().variable_dtype

if mixed precision is enabled without setting bottleneck_dtype the error occurs

@jonaballe
Copy link
Collaborator

jonaballe commented Feb 10, 2022

Thanks for your explanation.

However, I don't see why the input to the entropy model should be a variable. Certainly that could be the case for some models, but typically, it will not be a variable, but instead, it would be the output of some encoder-side neural network. These outputs should have compute_dtype, not variable_dtype.

Edit: Also, in the code that you pointed to in your repository, y is the output of the analysis transform, so what you said shouldn't apply there. In addition, note that you can always set bottleneck_dtypeto tf.keras.mixed_precision.global_policy().variable_dtype manually, if your model requires it.

@MahmoudAshraf97
Copy link
Contributor Author

I might have used the wrong term when I said that the input will be a variable, the actual case is that keras uses tf.float32 for everything other than computations in mixed precision, this includes all tensors as per here. So naturally the input to the entropy model or any layer is variable_dtype which defaults to tf.float32

as for my code, I don't think the problem occurs with my code only as my experiments with the example notebook suggest otherwise but anyway this is a full traceback from the error in my code:

Traceback (most recent call last):
  File "train.py", line 49, in <module>
    model = AutoencoderModel(1)
  File "/mnt/e/Final Project/refactor/modules.py", line 170, in __init__
    self.build((None, None, None, 3))
  File "/home/mahmoud/.local/lib/python3.8/site-packages/keras/engine/training.py", line 440, in build
    self.call(x, **kwargs)
  File "/mnt/e/Final Project/refactor/modules.py", line 178, in call
    y_hat, bits = entropy_model(y, training=training)
  File "/home/mahmoud/.local/lib/python3.8/site-packages/tensorflow/python/module/module.py", line 311, in method_with_name_scope
    return method(self, *args, **kwargs)
  File "/home/mahmoud/.local/lib/python3.8/site-packages/tensorflow_compression/python/entropy_models/continuous_batched.py", line 304, in __call__
    bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
  File "/home/mahmoud/.local/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/mahmoud/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1662, in convert_to_tensor
    raise ValueError(
ValueError: Tensor conversion requested dtype float16 for Tensor with dtype float32: <tf.Tensor 'Encoder/conv2d_58/leaky_re_lu_18/LeakyRelu:0' shape=(None, None, None, 32) dtype=float32>

@jonaballe
Copy link
Collaborator

I'm not sure this is the right way to interpret mixed precision training. If what you are saying is correct, this would mean that every layer would need to do its computations in float16, but then cast the output back to float32. The next layer would then cast it back to float16, and so on… That seems like an unnecessary number of casts.

We have a unit test here (and analogous ones for the other classes), that ensures that if the input tensor is float16, then the output of the layer will be float16 as well. Rather than trying to cast everything back and forth between 16 and 32 bits, have you tried just ensuring that the input image's dtype is compute_dtype, and then feeding it to the model?

@jonaballe
Copy link
Collaborator

jonaballe commented Feb 16, 2022

Hey, we added support for mixed precision to the models under models/, and it seems to work fine (see commit 963aa2d).

If the current code doesn't work for you, could you try to find out what is different in your codebase and let me know? I'm reluctant to switch all calls from tf.convert_to_tensor to tf.cast, since that would mean that if you are not using the right datatype, it would silently succeed without letting you know that there was a problem.

@jonaballe
Copy link
Collaborator

Closing this due to inactivity. Please reopen in case you experience further issues with mixed precision training.

@jonaballe jonaballe closed this Mar 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants