From cda6e11eb89828ade70d3c342fff8bb955cb2b69 Mon Sep 17 00:00:00 2001 From: sofiagilardini Date: Sun, 6 Aug 2023 18:43:38 +0100 Subject: [PATCH 1/4] fix type bug in to() method --- cebra/integrations/sklearn/cebra.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 234297c4..56e38678 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1282,21 +1282,21 @@ def to(self, device: Union[str, torch.device]): raise TypeError( "The 'device' parameter must be a string or torch.device object." ) - - if (not device == 'cpu') and (not device.startswith('cuda')) and ( - not device == 'mps'): - raise ValueError( - "The 'device' parameter must be a valid device string or device object." - ) - + if isinstance(device, str): - device = torch.device(device) + if (not device == 'cpu') and (not device.startswith('cuda')) and ( + not device == 'mps'): + raise ValueError( + "The 'device' parameter must be a valid device string or device object." + ) - if (not device.type == 'cpu') and ( - not device.type.startswith('cuda')) and (not device == 'mps'): - raise ValueError( - "The 'device' parameter must be a valid device string or device object." - ) + elif isinstance(device, torch.device): + if (not device.type == 'cpu') and ( + not device.type.startswith('cuda')) and (not device == 'mps'): + raise ValueError( + "The 'device' parameter must be a valid device string or device object." + ) + device = device.type if hasattr(self, "device_"): self.device_ = device From 118c099cc39594890c8311141866e8c7a10b8402 Mon Sep 17 00:00:00 2001 From: sofiagilardini Date: Sun, 13 Aug 2023 22:45:23 +0100 Subject: [PATCH 2/4] added test for to() method --- tests/test_sklearn.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index ee509578..06b5a1aa 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -856,6 +856,39 @@ def get_ordered_cuda_devices(): ordered_cuda_devices = get_ordered_cuda_devices() if torch.cuda.is_available( ) else [] +def test_fit_after_moving_to_device(): + expected_device = 'cpu' + expected_type = type(expected_device) + + X = np.random.uniform(0, 1, (10, 5)) + cebra_model = cebra_sklearn_cebra.CEBRA(model_architecture="offset1-model", + max_iterations=5, + device=expected_device) + + assert type(cebra_model.device) == expected_type + assert cebra_model.device == expected_device + + cebra_model.partial_fit(X) + assert type(cebra_model.device) == expected_type + assert cebra_model.device == expected_device + if hasattr(cebra_model, 'device_'): + assert type(cebra_model.device_) == expected_type + assert cebra_model.device_ == expected_device + + # Move the model to device using the to() method + cebra_model.to('cpu') + assert type(cebra_model.device) == expected_type + assert cebra_model.device == expected_device + if hasattr(cebra_model, 'device_'): + assert type(cebra_model.device_) == expected_type + assert cebra_model.device_ == expected_device + + cebra_model.partial_fit(X) + assert type(cebra_model.device) == expected_type + assert cebra_model.device == expected_device + if hasattr(cebra_model, 'device_'): + assert type(cebra_model.device_) == expected_type + assert cebra_model.device_ == expected_device @pytest.mark.parametrize("device", ['cpu'] + ordered_cuda_devices) def test_move_cpu_to_cuda_device(device): From 0533dcd265cb9c0920943a97b813beff99db0366 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 11 Sep 2023 11:12:58 +0200 Subject: [PATCH 3/4] update tests --- tests/test_sklearn.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 06b5a1aa..02148116 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -908,9 +908,12 @@ def test_move_cpu_to_cuda_device(device): new_device = 'cpu' if device.startswith('cuda') else 'cuda:0' cebra_model.to(new_device) - assert cebra_model.device == torch.device(new_device) - assert next(cebra_model.solver_.model.parameters()).device == torch.device( - new_device) + assert cebra_model.device == new_device + device_model = next(cebra_model.solver_.model.parameters()).device + device_str = str(device_model) + if device_model.type == 'cuda': + device_str = f'cuda:{device_model.index}' + assert device_str == new_device with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: cebra_model.save(savefile.name) @@ -936,9 +939,12 @@ def test_move_cpu_to_mps_device(device): new_device = 'cpu' if device == 'mps' else 'mps' cebra_model.to(new_device) - assert cebra_model.device == torch.device(new_device) - assert next(cebra_model.solver_.model.parameters()).device == torch.device( - new_device) + assert cebra_model.device == new_device + device_model = next(cebra_model.solver_.model.parameters()).device + device_str = str(device_model) + if device_model.type == 'cuda': + device_str = f'cuda:{device_model.index}' + assert device_str == new_device with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: cebra_model.save(savefile.name) @@ -972,9 +978,12 @@ def test_move_mps_to_cuda_device(device): new_device = 'mps' if device.startswith('cuda') else 'cuda:0' cebra_model.to(new_device) - assert cebra_model.device == torch.device(new_device) - assert next(cebra_model.solver_.model.parameters()).device == torch.device( - new_device) + assert cebra_model.device == new_device + device_model = next(cebra_model.solver_.model.parameters()).device + device_str = str(device_model) + if device_model.type == 'cuda': + device_str = f'cuda:{device_model.index}' + assert device_str == new_device with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: cebra_model.save(savefile.name) From 0964ec24bbc11ac8b335d30790f2f70c5fa346be Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Tue, 12 Sep 2023 10:14:55 +0200 Subject: [PATCH 4/4] update tests --- tests/test_sklearn.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 02148116..ff07979a 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -940,11 +940,9 @@ def test_move_cpu_to_mps_device(device): cebra_model.to(new_device) assert cebra_model.device == new_device + device_model = next(cebra_model.solver_.model.parameters()).device - device_str = str(device_model) - if device_model.type == 'cuda': - device_str = f'cuda:{device_model.index}' - assert device_str == new_device + assert device_model.type == new_device with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: cebra_model.save(savefile.name) @@ -1005,11 +1003,7 @@ def test_mps(): if torch.backends.mps.is_available() and torch.backends.mps.is_built(): torch.backends.mps.is_available = lambda: False - with pytest.raises(ValueError): - cebra_model.fit(X) - - torch.backends.mps.is_available = lambda: True - torch.backends.mps.is_built = lambda: False + with pytest.raises(ValueError): cebra_model.fit(X)