diff --git a/tensorflow_compression/python/distributions/deep_factorized.py b/tensorflow_compression/python/distributions/deep_factorized.py index a5faeae..af88663 100644 --- a/tensorflow_compression/python/distributions/deep_factorized.py +++ b/tensorflow_compression/python/distributions/deep_factorized.py @@ -178,7 +178,7 @@ def _logits_cumulative(self, inputs): shape = tf.shape(inputs) inputs = tf.reshape(inputs, (-1, 1, self.batch_shape.num_elements())) inputs = tf.transpose(inputs, (2, 1, 0)) - logits = inputs + logits = tf.cast(inputs,dtype=tf.keras.backend.floatx()) for i in range(len(self.num_filters) + 1): matrix = tf.nn.softplus(self._matrices[i]) logits = tf.linalg.matmul(matrix, logits) diff --git a/tensorflow_compression/python/entropy_models/continuous_base.py b/tensorflow_compression/python/entropy_models/continuous_base.py index e30f5f2..c11249a 100644 --- a/tensorflow_compression/python/entropy_models/continuous_base.py +++ b/tensorflow_compression/python/entropy_models/continuous_base.py @@ -266,9 +266,11 @@ def _quantize_no_offset(self, inputs): @tf.custom_gradient def _quantize_offset(self, inputs, offset): + return tf.round(inputs - offset) + offset, lambda x: (x, None) def _quantize(self, inputs, offset=None): + inputs = tf.cast(inputs,dtype=tf.keras.backend.floatx()) if offset is None: outputs = self._quantize_no_offset(inputs) else: diff --git a/tensorflow_compression/python/layers/gdn.py b/tensorflow_compression/python/layers/gdn.py index 5dcc45a..6036054 100644 --- a/tensorflow_compression/python/layers/gdn.py +++ b/tensorflow_compression/python/layers/gdn.py @@ -210,7 +210,7 @@ def alpha_parameter(self, value): if isinstance(value, dict): value = tf.keras.utils.deserialize_keras_object(value) if value is not None and not callable(value): - value = tf.convert_to_tensor(value, dtype=self.dtype) + value = tf.cast(value, dtype=self.dtype) self._alpha_parameter = value @property @@ -224,7 +224,7 @@ def beta_parameter(self, value): if isinstance(value, dict): value = tf.keras.utils.deserialize_keras_object(value) if value is not None and not callable(value): - value = tf.convert_to_tensor(value, dtype=self.dtype) + value = tf.cast(value, dtype=self.dtype) self._beta_parameter = value @property @@ -238,7 +238,7 @@ def gamma_parameter(self, value): if isinstance(value, dict): value = tf.keras.utils.deserialize_keras_object(value) if value is not None and not callable(value): - value = tf.convert_to_tensor(value, dtype=self.dtype) + value = tf.cast(value, dtype=self.dtype) self._gamma_parameter = value @property @@ -252,7 +252,7 @@ def epsilon_parameter(self, value): if isinstance(value, dict): value = tf.keras.utils.deserialize_keras_object(value) if value is not None and not callable(value): - value = tf.convert_to_tensor(value, dtype=self.dtype) + value = tf.cast(value, dtype=self.dtype) self._epsilon_parameter = value @property @@ -296,7 +296,7 @@ def alpha(self) -> tf.Tensor: if self.alpha_parameter is None: raise RuntimeError("alpha is not initialized yet. Call build().") if callable(self.alpha_parameter): - return tf.convert_to_tensor(self.alpha_parameter(), dtype=self.dtype) + return tf.cast(self.alpha_parameter(), dtype=self.dtype) return self.alpha_parameter @property @@ -304,7 +304,7 @@ def beta(self) -> tf.Tensor: if self.beta_parameter is None: raise RuntimeError("beta is not initialized yet. Call build().") if callable(self.beta_parameter): - return tf.convert_to_tensor(self.beta_parameter(), dtype=self.dtype) + return tf.cast(self.beta_parameter(), dtype=self.dtype) return self.beta_parameter @property @@ -312,7 +312,7 @@ def gamma(self) -> tf.Tensor: if self.gamma_parameter is None: raise RuntimeError("gamma is not initialized yet. Call build().") if callable(self.gamma_parameter): - return tf.convert_to_tensor(self.gamma_parameter(), dtype=self.dtype) + return tf.cast(self.gamma_parameter(), dtype=self.dtype) return self.gamma_parameter @property @@ -320,7 +320,7 @@ def epsilon(self) -> tf.Tensor: if self.epsilon_parameter is None: raise RuntimeError("epsilon is not initialized yet. Call build().") if callable(self.epsilon_parameter): - return tf.convert_to_tensor(self.epsilon_parameter(), dtype=self.dtype) + return tf.cast(self.epsilon_parameter(), dtype=self.dtype) return self.epsilon_parameter @property @@ -363,7 +363,7 @@ def build(self, input_shape): super().build(input_shape) def call(self, inputs) -> tf.Tensor: - inputs = tf.convert_to_tensor(inputs, dtype=self.dtype) + inputs = tf.cast(inputs, dtype=self.dtype) rank = inputs.shape.rank if rank is None or rank < 2: raise ValueError(f"Input tensor must have at least rank 2, received "