Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_percentile():

np.testing.assert_almost_equal(p_np, p_t)


def test_normalize_tf():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Expand Down
1 change: 0 additions & 1 deletion tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torchvision import transforms
from skimage.metrics import structural_similarity as ssim


def test_cov():
x = np.random.randn(10, 10)
cov_np = np.cov(x)
Expand Down
4 changes: 2 additions & 2 deletions torchstain/tf/utils/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ def percentile(t: tf.Tensor, q: float) -> Union[int, float]:
:param q: Percentile to compute, which must be between 0 and 100 inclusive.
:return: Resulting value (scalar).
"""
k = 1 + tf.math.round(.01 * tf.cast(q, tf.float32) * (tf.cast(tf.size(t), tf.float32) - 1))
return tf.sort(tf.reshape(t, [-1]))[tf.cast(k, tf.int32)]
k = 1 + tf.math.round(.01 * tf.cast(q, tf.float32) * (tf.cast(tf.math.reduce_prod(tf.size(t)), tf.float32) - 1))
return tf.sort(tf.reshape(t, [-1]))[tf.cast(k - 1, tf.int32)]
2 changes: 1 addition & 1 deletion torchstain/torch/utils/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
def percentile(t: torch.tensor, q: float) -> Union[int, float]:
"""
Return the ``q``-th percentile of the flattened input tensor's data.
Return the ``q``-th percentile of the flattenepip d input tensor's data.

CAUTION:
* Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
Expand Down