diff --git a/tfscripts/compat/v1/pooling.py b/tfscripts/compat/v1/pooling.py index 0b77b6a..b7a9403 100644 --- a/tfscripts/compat/v1/pooling.py +++ b/tfscripts/compat/v1/pooling.py @@ -34,6 +34,14 @@ def pool3d(layer, ksize, strides, padding, pooling_type): The pooled output tensor. """ + # tensorflow's pooling operations do not support float64, so + # use workaround with casting to float32 and then back again + if layer.dtype == tf.float64: + layer = tf.cast(layer, tf.float32) + was_float64 = True + else: + was_float64 = False + # pool over depth, if necessary: if ksize[-1] != 1 or strides[-1] != 1: layer = pool_over_depth(layer, @@ -70,6 +78,9 @@ def pool3d(layer, ksize, strides, padding, pooling_type): padding=padding) layer = (layer_avg + layer_max) / 2. + if was_float64: + layer = tf.cast(layer, tf.float64) + return layer @@ -97,6 +108,14 @@ def pool(layer, ksize, strides, padding, pooling_type): The pooled output tensor. """ + # tensorflow's pooling operations do not support float64, so + # use workaround with casting to float32 and then back again + if layer.dtype == tf.float64: + layer = tf.cast(layer, tf.float32) + was_float64 = True + else: + was_float64 = False + # pool over depth, if necessary: if ksize[-1] != 1 or strides[-1] != 1: layer = pool_over_depth(layer, @@ -138,6 +157,10 @@ def pool(layer, ksize, strides, padding, pooling_type): padding=padding, ) layer = (layer_avg + layer_max) / 2. + + if was_float64: + layer = tf.cast(layer, tf.float64) + return layer diff --git a/tfscripts/pooling.py b/tfscripts/pooling.py index 06ec647..aa2ad17 100644 --- a/tfscripts/pooling.py +++ b/tfscripts/pooling.py @@ -34,6 +34,14 @@ def pool3d(layer, ksize, strides, padding, pooling_type): The pooled output tensor. """ + # tensorflow's pooling operations do not support float64, so + # use workaround with casting to float32 and then back again + if layer.dtype == tf.float64: + layer = tf.cast(layer, tf.float32) + was_float64 = True + else: + was_float64 = False + # pool over depth, if necessary: if ksize[-1] != 1 or strides[-1] != 1: layer = pool_over_depth(layer, @@ -70,6 +78,9 @@ def pool3d(layer, ksize, strides, padding, pooling_type): padding=padding) layer = (layer_avg + layer_max) / 2. + if was_float64: + layer = tf.cast(layer, tf.float64) + return layer @@ -97,6 +108,14 @@ def pool2d(layer, ksize, strides, padding, pooling_type): The pooled output tensor. """ + # tensorflow's pooling operations do not support float64, so + # use workaround with casting to float32 and then back again + if layer.dtype == tf.float64: + layer = tf.cast(layer, tf.float32) + was_float64 = True + else: + was_float64 = False + # pool over depth, if necessary: if ksize[-1] != 1 or strides[-1] != 1: layer = pool_over_depth(layer, @@ -130,6 +149,10 @@ def pool2d(layer, ksize, strides, padding, pooling_type): strides=strides, padding=padding) layer = (layer_avg + layer_max) / 2. + + if was_float64: + layer = tf.cast(layer, tf.float64) + return layer