Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _logits_cumulative(self, inputs):
shape = tf.shape(inputs)
inputs = tf.reshape(inputs, (-1, 1, self.batch_shape.num_elements()))
inputs = tf.transpose(inputs, (2, 1, 0))
logits = inputs
logits = tf.cast(inputs,dtype=tf.keras.backend.floatx())
for i in range(len(self.num_filters) + 1):
matrix = tf.nn.softplus(self._matrices[i])
logits = tf.linalg.matmul(matrix, logits)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,11 @@ def _quantize_no_offset(self, inputs):

@tf.custom_gradient
def _quantize_offset(self, inputs, offset):

return tf.round(inputs - offset) + offset, lambda x: (x, None)

def _quantize(self, inputs, offset=None):
inputs = tf.cast(inputs,dtype=tf.keras.backend.floatx())
if offset is None:
outputs = self._quantize_no_offset(inputs)
else:
Expand Down
18 changes: 9 additions & 9 deletions tensorflow_compression/python/layers/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def alpha_parameter(self, value):
if isinstance(value, dict):
value = tf.keras.utils.deserialize_keras_object(value)
if value is not None and not callable(value):
value = tf.convert_to_tensor(value, dtype=self.dtype)
value = tf.cast(value, dtype=self.dtype)
self._alpha_parameter = value

@property
Expand All @@ -224,7 +224,7 @@ def beta_parameter(self, value):
if isinstance(value, dict):
value = tf.keras.utils.deserialize_keras_object(value)
if value is not None and not callable(value):
value = tf.convert_to_tensor(value, dtype=self.dtype)
value = tf.cast(value, dtype=self.dtype)
self._beta_parameter = value

@property
Expand All @@ -238,7 +238,7 @@ def gamma_parameter(self, value):
if isinstance(value, dict):
value = tf.keras.utils.deserialize_keras_object(value)
if value is not None and not callable(value):
value = tf.convert_to_tensor(value, dtype=self.dtype)
value = tf.cast(value, dtype=self.dtype)
self._gamma_parameter = value

@property
Expand All @@ -252,7 +252,7 @@ def epsilon_parameter(self, value):
if isinstance(value, dict):
value = tf.keras.utils.deserialize_keras_object(value)
if value is not None and not callable(value):
value = tf.convert_to_tensor(value, dtype=self.dtype)
value = tf.cast(value, dtype=self.dtype)
self._epsilon_parameter = value

@property
Expand Down Expand Up @@ -296,31 +296,31 @@ def alpha(self) -> tf.Tensor:
if self.alpha_parameter is None:
raise RuntimeError("alpha is not initialized yet. Call build().")
if callable(self.alpha_parameter):
return tf.convert_to_tensor(self.alpha_parameter(), dtype=self.dtype)
return tf.cast(self.alpha_parameter(), dtype=self.dtype)
return self.alpha_parameter

@property
def beta(self) -> tf.Tensor:
if self.beta_parameter is None:
raise RuntimeError("beta is not initialized yet. Call build().")
if callable(self.beta_parameter):
return tf.convert_to_tensor(self.beta_parameter(), dtype=self.dtype)
return tf.cast(self.beta_parameter(), dtype=self.dtype)
return self.beta_parameter

@property
def gamma(self) -> tf.Tensor:
if self.gamma_parameter is None:
raise RuntimeError("gamma is not initialized yet. Call build().")
if callable(self.gamma_parameter):
return tf.convert_to_tensor(self.gamma_parameter(), dtype=self.dtype)
return tf.cast(self.gamma_parameter(), dtype=self.dtype)
return self.gamma_parameter

@property
def epsilon(self) -> tf.Tensor:
if self.epsilon_parameter is None:
raise RuntimeError("epsilon is not initialized yet. Call build().")
if callable(self.epsilon_parameter):
return tf.convert_to_tensor(self.epsilon_parameter(), dtype=self.dtype)
return tf.cast(self.epsilon_parameter(), dtype=self.dtype)
return self.epsilon_parameter

@property
Expand Down Expand Up @@ -363,7 +363,7 @@ def build(self, input_shape):
super().build(input_shape)

def call(self, inputs) -> tf.Tensor:
inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
inputs = tf.cast(inputs, dtype=self.dtype)
rank = inputs.shape.rank
if rank is None or rank < 2:
raise ValueError(f"Input tensor must have at least rank 2, received "
Expand Down