From 121b99fd7612d24db4a834998103b22fb0a4a12d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Gr=C3=A9us?= Date: Thu, 20 Jun 2024 21:34:09 +0200 Subject: [PATCH 1/5] feat: add `ImageDataset.split` --- .../data/image/containers/_image_list.py | 4 +- .../data/labeled/containers/_image_dataset.py | 110 +++++++++++++----- .../labeled/containers/test_image_dataset.py | 98 ++++++++++++++++ 3 files changed, 182 insertions(+), 30 deletions(-) diff --git a/src/safeds/data/image/containers/_image_list.py b/src/safeds/data/image/containers/_image_list.py index 5e648d209..7a86d03ee 100644 --- a/src/safeds/data/image/containers/_image_list.py +++ b/src/safeds/data/image/containers/_image_list.py @@ -404,7 +404,7 @@ def __contains__(self, item: object) -> bool: Returns ------- has_item: - Weather the given item is in this image list + Whether the given item is in this image list """ return isinstance(item, Image) and self.has_image(item) @@ -524,7 +524,7 @@ def has_image(self, image: Image) -> bool: Returns ------- has_image: - Weather the given image is in this image list + Whether the given image is in this image list """ # ------------------------------------------------------------------------------------------------------------------ diff --git a/src/safeds/data/labeled/containers/_image_dataset.py b/src/safeds/data/labeled/containers/_image_dataset.py index ec50d9835..5138ce74d 100644 --- a/src/safeds/data/labeled/containers/_image_dataset.py +++ b/src/safeds/data/labeled/containers/_image_dataset.py @@ -43,7 +43,7 @@ class ImageDataset(Dataset[ImageList, Out_co]): batch_size: the batch size used for training shuffle: - weather the data should be shuffled after each epoch of training + whether the data should be shuffled after each epoch of training """ def __init__(self, input_data: ImageList, output_data: Out_co, batch_size: int = 1, shuffle: bool = False) -> None: @@ -108,13 +108,13 @@ def __iter__(self) -> ImageDataset: return im_ds def __next__(self) -> tuple[Tensor, Tensor]: - if self._next_batch_index * self._batch_size >= len(self._input): + if self._next_batch_index * self._batch_size >= len(self._shuffle_tensor_indices): raise StopIteration self._next_batch_index += 1 return self._get_batch(self._next_batch_index - 1) def __len__(self) -> int: - return self._input.image_count + return len(self._shuffle_tensor_indices) def __eq__(self, other: object) -> bool: """ @@ -138,6 +138,7 @@ def __eq__(self, other: object) -> bool: and isinstance(other._output, type(self._output)) and (self._input == other._input) and (self._output == other._output) + and (self._shuffle_tensor_indices.tolist() == other._shuffle_tensor_indices.tolist()) ) def __hash__(self) -> int: @@ -149,7 +150,7 @@ def __hash__(self) -> int: hash: the hash value """ - return _structural_hash(self._input, self._output, self._shuffle_after_epoch, self._batch_size) + return _structural_hash(self._input, self._output, self._shuffle_after_epoch, self._batch_size, self._shuffle_tensor_indices.tolist()) def __sizeof__(self) -> int: """ @@ -205,7 +206,7 @@ def get_input(self) -> ImageList: input: the input data of this dataset """ - return self._sort_image_list_with_shuffle_tensor_indices(self._input) + return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._input) def get_output(self) -> Out_co: """ @@ -222,19 +223,22 @@ def get_output(self) -> Out_co: elif isinstance(output, _ColumnAsTensor): return output._to_column(self._shuffle_tensor_indices) # type: ignore[return-value] else: - return self._sort_image_list_with_shuffle_tensor_indices(self._output) # type: ignore[return-value] + return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._output) # type: ignore[return-value] - def _sort_image_list_with_shuffle_tensor_indices(self, image_list: _SingleSizeImageList) -> _SingleSizeImageList: + def _sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self, image_list: _SingleSizeImageList) -> _SingleSizeImageList: shuffled_image_list = _SingleSizeImageList() - shuffled_image_list._tensor = image_list._tensor - shuffled_image_list._indices_to_tensor_positions = { - index: self._shuffle_tensor_indices[tensor_position].item() - for index, tensor_position in image_list._indices_to_tensor_positions.items() + tensor_pos = [ + image_list._indices_to_tensor_positions[shuffled_index] + for shuffled_index in sorted(self._shuffle_tensor_indices.tolist()) + ] + temp_pos = { + shuffled_index: new_index for new_index, shuffled_index in enumerate(self._shuffle_tensor_indices.tolist()) } + shuffled_image_list._tensor = image_list._tensor[tensor_pos] shuffled_image_list._tensor_positions_to_indices = [ - index - for index, _ in sorted(shuffled_image_list._indices_to_tensor_positions.items(), key=lambda item: item[1]) + new_index for _, new_index in sorted(temp_pos.items(), key=lambda item: item[0]) ] + shuffled_image_list._indices_to_tensor_positions = shuffled_image_list._calc_new_indices_to_tensor_positions() return shuffled_image_list def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[Tensor, Tensor]: @@ -247,18 +251,16 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[ _check_bounds("batch_size", batch_size, lower_bound=_ClosedBound(1)) - if batch_number < 0 or batch_size * batch_number >= len(self._input): + if batch_number < 0 or batch_size * batch_number >= len(self._shuffle_tensor_indices): raise IndexOutOfBoundsError(batch_size * batch_number) max_index = ( - batch_size * (batch_number + 1) if batch_size * (batch_number + 1) < len(self._input) else len(self._input) + batch_size * (batch_number + 1) if batch_size * (batch_number + 1) < len(self._shuffle_tensor_indices) else len(self._shuffle_tensor_indices) ) input_tensor = ( self._input._tensor[ - self._shuffle_tensor_indices[ - [ - self._input._indices_to_tensor_positions[index] - for index in range(batch_size * batch_number, max_index) - ] + [ + self._input._indices_to_tensor_positions[index] + for index in self._shuffle_tensor_indices[batch_size * batch_number: max_index].tolist() ] ].to(torch.float32) / 255 @@ -267,24 +269,22 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[ if isinstance(self._output, _SingleSizeImageList): output_tensor = ( self._output._tensor[ - self._shuffle_tensor_indices[ - [ - self._output._indices_to_tensor_positions[index] - for index in range(batch_size * batch_number, max_index) - ] + [ + self._input._indices_to_tensor_positions[index] + for index in self._shuffle_tensor_indices[batch_size * batch_number: max_index].tolist() ] ].to(torch.float32) / 255 ) else: # _output is instance of _TableAsTensor or _ColumnAsTensor - output_tensor = self._output._tensor[self._shuffle_tensor_indices[batch_size * batch_number : max_index]] + output_tensor = self._output._tensor[self._shuffle_tensor_indices[batch_size * batch_number: max_index]] return input_tensor, output_tensor def shuffle(self) -> ImageDataset[Out_co]: """ Return a new `ImageDataset` with shuffled data. - The original dataset list is not modified. + The original dataset is not modified. Returns ------- @@ -296,10 +296,64 @@ def shuffle(self) -> ImageDataset[Out_co]: _init_default_device() im_dataset: ImageDataset[Out_co] = copy.copy(self) - im_dataset._shuffle_tensor_indices = torch.randperm(len(self)) + im_dataset._shuffle_tensor_indices = self._shuffle_tensor_indices[torch.randperm(len(self._shuffle_tensor_indices))] im_dataset._next_batch_index = 0 return im_dataset + def split(self, percentage_in_first: float, *, shuffle: bool = True) -> tuple[ImageDataset[Out_co], ImageDataset[Out_co]]: + """ + Create two image datasets by splitting the data of the current dataset. + + The first dataset contains a percentage of the data specified by `percentage_in_first`, and the second dataset + contains the remaining data. + + The original dataset is not modified. + By default, the data is shuffled before splitting. You can disable this by setting `shuffle` to False. + + Parameters + ---------- + percentage_in_first: + The percentage of data to include in the first dataset. Must be between 0 and 1. + shuffle: + Whether to shuffle the data before splitting. + + Returns + ------- + first_dataset: + The first dataset. + second_dataset: + The second dataset. + + Raises + ------ + OutOfBoundsError + If `percentage_in_first` is not between 0 and 1. + """ + import torch + + _check_bounds( + "percentage_in_first", + percentage_in_first, + lower_bound=_ClosedBound(0), + upper_bound=_ClosedBound(1), + ) + + first_dataset: ImageDataset[Out_co] = copy.copy(self) + second_dataset: ImageDataset[Out_co] = copy.copy(self) + + if shuffle: + shuffled_indices = torch.randperm(len(self._shuffle_tensor_indices)) + else: + shuffled_indices = torch.arange(len(self._shuffle_tensor_indices)) + + first_dataset._shuffle_tensor_indices, second_dataset._shuffle_tensor_indices = shuffled_indices.split( + [ + round(percentage_in_first * len(self)), + len(self) - round(percentage_in_first * len(self)), + ] + ) + return first_dataset, second_dataset + class _TableAsTensor: def __init__(self, table: Table) -> None: diff --git a/tests/safeds/data/labeled/containers/test_image_dataset.py b/tests/safeds/data/labeled/containers/test_image_dataset.py index 0c487d3b3..0dfc5a030 100644 --- a/tests/safeds/data/labeled/containers/test_image_dataset.py +++ b/tests/safeds/data/labeled/containers/test_image_dataset.py @@ -381,6 +381,104 @@ def test_get_batch_device(self, device: Device) -> None: assert batch[1].device == _get_device() +@pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids()) +@pytest.mark.parametrize( + "shuffle", + [ + True, + False + ] +) +class TestSplit: + + @pytest.mark.parametrize( + "output", + [ + Column("images", images_all()[:4] + images_all()[5:]), + Table({"0": [1, 0, 0, 0, 0, 0], "1": [0, 1, 0, 0, 0, 0], "2": [0, 0, 1, 0, 0, 0], "3": [0, 0, 0, 1, 0, 0], "4": [0, 0, 0, 0, 1, 0], "5": [0, 0, 0, 0, 0, 1]}), + _EmptyImageList(), + ], + ids=["Column", "Table", "ImageList"] + ) + def test_should_split(self, device: Device, shuffle: bool, output: Column | Table | ImageList) -> None: + configure_test_with_device(device) + image_list = ImageList.from_files(resolve_resource_path(images_all())).remove_duplicate_images().resize(10, 10) + if isinstance(output, _EmptyImageList): + output = image_list + image_dataset = ImageDataset(image_list, output) + image_dataset1, image_dataset2 = image_dataset.split(0.4, shuffle=shuffle) + offset = len(image_dataset1) + assert len(image_dataset1) == round(0.4 * len(image_dataset)) + assert len(image_dataset2) == len(image_dataset) - offset + assert len(image_dataset1.get_input()) == round(0.4 * len(image_dataset)) + assert len(image_dataset2.get_input()) == len(image_dataset) - offset + if isinstance(image_dataset.get_output(), Table): + assert image_dataset1.get_output().row_count == round(0.4 * len(image_dataset)) + assert image_dataset2.get_output().row_count == len(image_dataset) - offset + else: + assert len(image_dataset1.get_output()) == round(0.4 * len(image_dataset)) + assert len(image_dataset2.get_output()) == len(image_dataset) - offset + + assert image_dataset != image_dataset1 + assert image_dataset != image_dataset2 + assert image_dataset1 != image_dataset2 + + for i, image in enumerate(image_dataset1.get_input().to_images()): + index = image_list.index(image)[0] + if not shuffle: + assert index == i + if isinstance(image_dataset1.get_output(), ImageList): + assert image_list.index(image_dataset1.get_output().get_image(i))[0] == index + elif isinstance(image_dataset1.get_output(), Column): + assert output.to_list().index(image_dataset1.get_output().to_list()[i]) == index + elif isinstance(image_dataset1.get_output(), Table): + assert output.get_column(str(index)).to_list()[index] == 1 + + for i, image in enumerate(image_dataset2.get_input().to_images()): + index = image_list.index(image)[0] + if not shuffle: + assert index == i + offset + if isinstance(image_dataset2.get_output(), ImageList): + assert image_list.index(image_dataset2.get_output().get_image(i))[0] == index + elif isinstance(image_dataset2.get_output(), Column): + assert output.to_list().index(image_dataset2.get_output().to_list()[i]) == index + elif isinstance(image_dataset2.get_output(), Table): + assert output.get_column(str(index)).to_list()[index] == 1 + + image_dataset._batch_size = len(image_dataset) + image_dataset1._batch_size = 1 + image_dataset2._batch_size = 1 + image_dataset_batch = next(iter(image_dataset)) + + for i, b in enumerate(iter(image_dataset1)): + assert b[0] in image_dataset_batch[0] + index = (b[0] == image_dataset_batch[0]).all(dim=[1, 2, 3]).nonzero()[0][0] + if not shuffle: + assert index == i + assert torch.all(torch.eq(b[0], image_dataset_batch[0][index])) + assert torch.all(torch.eq(b[1], image_dataset_batch[1][index])) + + for i, b in enumerate(iter(image_dataset2)): + assert b[0] in image_dataset_batch[0] + index = (b[0] == image_dataset_batch[0]).all(dim=[1, 2, 3]).nonzero()[0][0] + if not shuffle: + assert index == i + offset + assert torch.all(torch.eq(b[0], image_dataset_batch[0][index])) + assert torch.all(torch.eq(b[1], image_dataset_batch[1][index])) + + @pytest.mark.parametrize( + "percentage", + [-1, -0.1, 1.1, 2] + ) + def test_should_raise(self, device: Device, shuffle: bool, percentage: float): + configure_test_with_device(device) + image_list = ImageList.from_files(resolve_resource_path(images_all())).resize(10, 10) + image_dataset = ImageDataset(image_list, Column("images", images_all())) + with pytest.raises(OutOfBoundsError): + image_dataset.split(percentage, shuffle=shuffle) + + + @pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids()) class TestTableAsTensor: def test_should_raise_if_not_one_hot_encoded(self, device: Device) -> None: From 4d7cd67e365e95bd07960652e0ff68731d4fce1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Gr=C3=A9us?= Date: Thu, 20 Jun 2024 21:51:44 +0200 Subject: [PATCH 2/5] refactor: linter --- .../labeled/containers/test_image_dataset.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/safeds/data/labeled/containers/test_image_dataset.py b/tests/safeds/data/labeled/containers/test_image_dataset.py index 0dfc5a030..fed6ef483 100644 --- a/tests/safeds/data/labeled/containers/test_image_dataset.py +++ b/tests/safeds/data/labeled/containers/test_image_dataset.py @@ -405,18 +405,20 @@ def test_should_split(self, device: Device, shuffle: bool, output: Column | Tabl image_list = ImageList.from_files(resolve_resource_path(images_all())).remove_duplicate_images().resize(10, 10) if isinstance(output, _EmptyImageList): output = image_list - image_dataset = ImageDataset(image_list, output) + image_dataset = ImageDataset(image_list, output) # type: ignore[type-var] image_dataset1, image_dataset2 = image_dataset.split(0.4, shuffle=shuffle) offset = len(image_dataset1) assert len(image_dataset1) == round(0.4 * len(image_dataset)) assert len(image_dataset2) == len(image_dataset) - offset assert len(image_dataset1.get_input()) == round(0.4 * len(image_dataset)) assert len(image_dataset2.get_input()) == len(image_dataset) - offset - if isinstance(image_dataset.get_output(), Table): + if isinstance(image_dataset1.get_output(), Table): assert image_dataset1.get_output().row_count == round(0.4 * len(image_dataset)) - assert image_dataset2.get_output().row_count == len(image_dataset) - offset else: assert len(image_dataset1.get_output()) == round(0.4 * len(image_dataset)) + if isinstance(image_dataset2.get_output(), Table): + assert image_dataset2.get_output().row_count == len(image_dataset) - offset + else: assert len(image_dataset2.get_output()) == len(image_dataset) - offset assert image_dataset != image_dataset1 @@ -427,22 +429,24 @@ def test_should_split(self, device: Device, shuffle: bool, output: Column | Tabl index = image_list.index(image)[0] if not shuffle: assert index == i - if isinstance(image_dataset1.get_output(), ImageList): - assert image_list.index(image_dataset1.get_output().get_image(i))[0] == index - elif isinstance(image_dataset1.get_output(), Column): - assert output.to_list().index(image_dataset1.get_output().to_list()[i]) == index - elif isinstance(image_dataset1.get_output(), Table): + out = image_dataset1.get_output() + if isinstance(out, ImageList): + assert image_list.index(out.get_image(i))[0] == index + elif isinstance(out, Column): + assert output.to_list().index(out.to_list()[i]) == index + elif isinstance(out, Table): assert output.get_column(str(index)).to_list()[index] == 1 for i, image in enumerate(image_dataset2.get_input().to_images()): index = image_list.index(image)[0] if not shuffle: assert index == i + offset - if isinstance(image_dataset2.get_output(), ImageList): - assert image_list.index(image_dataset2.get_output().get_image(i))[0] == index - elif isinstance(image_dataset2.get_output(), Column): - assert output.to_list().index(image_dataset2.get_output().to_list()[i]) == index - elif isinstance(image_dataset2.get_output(), Table): + out = image_dataset2.get_output() + if isinstance(out, ImageList): + assert image_list.index(out.get_image(i))[0] == index + elif isinstance(out, Column): + assert output.to_list().index(out.to_list()[i]) == index + elif isinstance(out, Table): assert output.get_column(str(index)).to_list()[index] == 1 image_dataset._batch_size = len(image_dataset) @@ -470,7 +474,7 @@ def test_should_split(self, device: Device, shuffle: bool, output: Column | Tabl "percentage", [-1, -0.1, 1.1, 2] ) - def test_should_raise(self, device: Device, shuffle: bool, percentage: float): + def test_should_raise(self, device: Device, shuffle: bool, percentage: float) -> None: configure_test_with_device(device) image_list = ImageList.from_files(resolve_resource_path(images_all())).resize(10, 10) image_dataset = ImageDataset(image_list, Column("images", images_all())) From 96cbfacb0c7b16ac4bdf918e49cfd4d54169a53b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Gr=C3=A9us?= Date: Thu, 20 Jun 2024 22:01:56 +0200 Subject: [PATCH 3/5] refactor: linter --- .../labeled/containers/test_image_dataset.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/safeds/data/labeled/containers/test_image_dataset.py b/tests/safeds/data/labeled/containers/test_image_dataset.py index fed6ef483..5ab03ccb2 100644 --- a/tests/safeds/data/labeled/containers/test_image_dataset.py +++ b/tests/safeds/data/labeled/containers/test_image_dataset.py @@ -412,14 +412,16 @@ def test_should_split(self, device: Device, shuffle: bool, output: Column | Tabl assert len(image_dataset2) == len(image_dataset) - offset assert len(image_dataset1.get_input()) == round(0.4 * len(image_dataset)) assert len(image_dataset2.get_input()) == len(image_dataset) - offset - if isinstance(image_dataset1.get_output(), Table): - assert image_dataset1.get_output().row_count == round(0.4 * len(image_dataset)) + im1_output = image_dataset1.get_output() + im2_output = image_dataset2.get_output() + if isinstance(im1_output, Table): + assert im1_output.row_count == round(0.4 * len(image_dataset)) else: - assert len(image_dataset1.get_output()) == round(0.4 * len(image_dataset)) - if isinstance(image_dataset2.get_output(), Table): - assert image_dataset2.get_output().row_count == len(image_dataset) - offset + assert len(im1_output) == round(0.4 * len(image_dataset)) + if isinstance(im2_output, Table): + assert im2_output.row_count == len(image_dataset) - offset else: - assert len(image_dataset2.get_output()) == len(image_dataset) - offset + assert len(im2_output) == len(image_dataset) - offset assert image_dataset != image_dataset1 assert image_dataset != image_dataset2 @@ -432,9 +434,9 @@ def test_should_split(self, device: Device, shuffle: bool, output: Column | Tabl out = image_dataset1.get_output() if isinstance(out, ImageList): assert image_list.index(out.get_image(i))[0] == index - elif isinstance(out, Column): + elif isinstance(out, Column) and isinstance(output, Column): assert output.to_list().index(out.to_list()[i]) == index - elif isinstance(out, Table): + elif isinstance(out, Table) and isinstance(output, Table): assert output.get_column(str(index)).to_list()[index] == 1 for i, image in enumerate(image_dataset2.get_input().to_images()): @@ -444,9 +446,9 @@ def test_should_split(self, device: Device, shuffle: bool, output: Column | Tabl out = image_dataset2.get_output() if isinstance(out, ImageList): assert image_list.index(out.get_image(i))[0] == index - elif isinstance(out, Column): + elif isinstance(out, Column) and isinstance(output, Column): assert output.to_list().index(out.to_list()[i]) == index - elif isinstance(out, Table): + elif isinstance(out, Table) and isinstance(output, Table): assert output.get_column(str(index)).to_list()[index] == 1 image_dataset._batch_size = len(image_dataset) From 18b54d347252723ad3567d6e45ad5b42fa51abf3 Mon Sep 17 00:00:00 2001 From: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Date: Thu, 20 Jun 2024 20:03:29 +0000 Subject: [PATCH 4/5] style: apply automated linter fixes --- .../data/labeled/containers/_image_dataset.py | 32 +++++++++++++------ .../labeled/containers/test_image_dataset.py | 27 ++++++++-------- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/safeds/data/labeled/containers/_image_dataset.py b/src/safeds/data/labeled/containers/_image_dataset.py index 5138ce74d..6ae1f7121 100644 --- a/src/safeds/data/labeled/containers/_image_dataset.py +++ b/src/safeds/data/labeled/containers/_image_dataset.py @@ -150,7 +150,13 @@ def __hash__(self) -> int: hash: the hash value """ - return _structural_hash(self._input, self._output, self._shuffle_after_epoch, self._batch_size, self._shuffle_tensor_indices.tolist()) + return _structural_hash( + self._input, + self._output, + self._shuffle_after_epoch, + self._batch_size, + self._shuffle_tensor_indices.tolist(), + ) def __sizeof__(self) -> int: """ @@ -225,7 +231,9 @@ def get_output(self) -> Out_co: else: return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._output) # type: ignore[return-value] - def _sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self, image_list: _SingleSizeImageList) -> _SingleSizeImageList: + def _sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary( + self, image_list: _SingleSizeImageList, + ) -> _SingleSizeImageList: shuffled_image_list = _SingleSizeImageList() tensor_pos = [ image_list._indices_to_tensor_positions[shuffled_index] @@ -254,13 +262,15 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[ if batch_number < 0 or batch_size * batch_number >= len(self._shuffle_tensor_indices): raise IndexOutOfBoundsError(batch_size * batch_number) max_index = ( - batch_size * (batch_number + 1) if batch_size * (batch_number + 1) < len(self._shuffle_tensor_indices) else len(self._shuffle_tensor_indices) + batch_size * (batch_number + 1) + if batch_size * (batch_number + 1) < len(self._shuffle_tensor_indices) + else len(self._shuffle_tensor_indices) ) input_tensor = ( self._input._tensor[ [ self._input._indices_to_tensor_positions[index] - for index in self._shuffle_tensor_indices[batch_size * batch_number: max_index].tolist() + for index in self._shuffle_tensor_indices[batch_size * batch_number : max_index].tolist() ] ].to(torch.float32) / 255 @@ -271,13 +281,13 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[ self._output._tensor[ [ self._input._indices_to_tensor_positions[index] - for index in self._shuffle_tensor_indices[batch_size * batch_number: max_index].tolist() + for index in self._shuffle_tensor_indices[batch_size * batch_number : max_index].tolist() ] ].to(torch.float32) / 255 ) else: # _output is instance of _TableAsTensor or _ColumnAsTensor - output_tensor = self._output._tensor[self._shuffle_tensor_indices[batch_size * batch_number: max_index]] + output_tensor = self._output._tensor[self._shuffle_tensor_indices[batch_size * batch_number : max_index]] return input_tensor, output_tensor def shuffle(self) -> ImageDataset[Out_co]: @@ -296,11 +306,15 @@ def shuffle(self) -> ImageDataset[Out_co]: _init_default_device() im_dataset: ImageDataset[Out_co] = copy.copy(self) - im_dataset._shuffle_tensor_indices = self._shuffle_tensor_indices[torch.randperm(len(self._shuffle_tensor_indices))] + im_dataset._shuffle_tensor_indices = self._shuffle_tensor_indices[ + torch.randperm(len(self._shuffle_tensor_indices)) + ] im_dataset._next_batch_index = 0 return im_dataset - def split(self, percentage_in_first: float, *, shuffle: bool = True) -> tuple[ImageDataset[Out_co], ImageDataset[Out_co]]: + def split( + self, percentage_in_first: float, *, shuffle: bool = True, + ) -> tuple[ImageDataset[Out_co], ImageDataset[Out_co]]: """ Create two image datasets by splitting the data of the current dataset. @@ -350,7 +364,7 @@ def split(self, percentage_in_first: float, *, shuffle: bool = True) -> tuple[Im [ round(percentage_in_first * len(self)), len(self) - round(percentage_in_first * len(self)), - ] + ], ) return first_dataset, second_dataset diff --git a/tests/safeds/data/labeled/containers/test_image_dataset.py b/tests/safeds/data/labeled/containers/test_image_dataset.py index 5ab03ccb2..bd02f32ce 100644 --- a/tests/safeds/data/labeled/containers/test_image_dataset.py +++ b/tests/safeds/data/labeled/containers/test_image_dataset.py @@ -382,23 +382,26 @@ def test_get_batch_device(self, device: Device) -> None: @pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids()) -@pytest.mark.parametrize( - "shuffle", - [ - True, - False - ] -) +@pytest.mark.parametrize("shuffle", [True, False]) class TestSplit: @pytest.mark.parametrize( "output", [ Column("images", images_all()[:4] + images_all()[5:]), - Table({"0": [1, 0, 0, 0, 0, 0], "1": [0, 1, 0, 0, 0, 0], "2": [0, 0, 1, 0, 0, 0], "3": [0, 0, 0, 1, 0, 0], "4": [0, 0, 0, 0, 1, 0], "5": [0, 0, 0, 0, 0, 1]}), + Table( + { + "0": [1, 0, 0, 0, 0, 0], + "1": [0, 1, 0, 0, 0, 0], + "2": [0, 0, 1, 0, 0, 0], + "3": [0, 0, 0, 1, 0, 0], + "4": [0, 0, 0, 0, 1, 0], + "5": [0, 0, 0, 0, 0, 1], + }, + ), _EmptyImageList(), ], - ids=["Column", "Table", "ImageList"] + ids=["Column", "Table", "ImageList"], ) def test_should_split(self, device: Device, shuffle: bool, output: Column | Table | ImageList) -> None: configure_test_with_device(device) @@ -472,10 +475,7 @@ def test_should_split(self, device: Device, shuffle: bool, output: Column | Tabl assert torch.all(torch.eq(b[0], image_dataset_batch[0][index])) assert torch.all(torch.eq(b[1], image_dataset_batch[1][index])) - @pytest.mark.parametrize( - "percentage", - [-1, -0.1, 1.1, 2] - ) + @pytest.mark.parametrize("percentage", [-1, -0.1, 1.1, 2]) def test_should_raise(self, device: Device, shuffle: bool, percentage: float) -> None: configure_test_with_device(device) image_list = ImageList.from_files(resolve_resource_path(images_all())).resize(10, 10) @@ -484,7 +484,6 @@ def test_should_raise(self, device: Device, shuffle: bool, percentage: float) -> image_dataset.split(percentage, shuffle=shuffle) - @pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids()) class TestTableAsTensor: def test_should_raise_if_not_one_hot_encoded(self, device: Device) -> None: From e969b0887fec1387d4c2045655e52ff3916d98fa Mon Sep 17 00:00:00 2001 From: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Date: Thu, 20 Jun 2024 20:05:00 +0000 Subject: [PATCH 5/5] style: apply automated linter fixes --- src/safeds/data/labeled/containers/_image_dataset.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/safeds/data/labeled/containers/_image_dataset.py b/src/safeds/data/labeled/containers/_image_dataset.py index 6ae1f7121..c852c11a8 100644 --- a/src/safeds/data/labeled/containers/_image_dataset.py +++ b/src/safeds/data/labeled/containers/_image_dataset.py @@ -232,7 +232,8 @@ def get_output(self) -> Out_co: return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._output) # type: ignore[return-value] def _sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary( - self, image_list: _SingleSizeImageList, + self, + image_list: _SingleSizeImageList, ) -> _SingleSizeImageList: shuffled_image_list = _SingleSizeImageList() tensor_pos = [ @@ -313,7 +314,10 @@ def shuffle(self) -> ImageDataset[Out_co]: return im_dataset def split( - self, percentage_in_first: float, *, shuffle: bool = True, + self, + percentage_in_first: float, + *, + shuffle: bool = True, ) -> tuple[ImageDataset[Out_co], ImageDataset[Out_co]]: """ Create two image datasets by splitting the data of the current dataset.