Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions monai/apps/pathology/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple, Union, cast
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -113,14 +113,16 @@ def __init__(
):
self.tile_count = tile_count
self.tile_size = tile_size
self.step = step
self.random_offset = random_offset
self.pad_full = pad_full
self.background_val = background_val
self.filter_mode = filter_mode

if self.step is None:
self.step = self.tile_size # non-overlapping grid
if step is None:
# non-overlapping grid
self.step = self.tile_size
else:
self.step = step

self.offset = (0, 0)
self.random_idxs = np.array((0,))
Expand All @@ -131,7 +133,6 @@ def __init__(
def randomize(self, img_size: Sequence[int]) -> None:

c, h, w = img_size
tile_step = cast(int, self.step)

self.offset = (0, 0)
if self.random_offset:
Expand All @@ -147,8 +148,8 @@ def randomize(self, img_size: Sequence[int]) -> None:
h = h + pad_h
w = w + pad_w

h_n = (h - self.tile_size + tile_step) // tile_step
w_n = (w - self.tile_size + tile_step) // tile_step
h_n = (h - self.tile_size + self.step) // self.step
w_n = (w - self.tile_size + self.step) // self.step
tile_total = h_n * w_n

if self.tile_count is not None and tile_total > self.tile_count:
Expand All @@ -160,7 +161,6 @@ def __call__(self, image: np.ndarray) -> np.ndarray:

# add random offset
self.randomize(img_size=image.shape)
tile_step = cast(int, self.step)

if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
image = image[:, self.offset[0] :, self.offset[1] :]
Expand All @@ -177,7 +177,7 @@ def __call__(self, image: np.ndarray) -> np.ndarray:
)

# extact tiles
xstep, ystep = tile_step, tile_step
xstep, ystep = self.step, self.step
xsize, ysize = self.tile_size, self.tile_size
clen, xlen, ylen = image.shape
cstride, xstride, ystride = image.strides
Expand Down