Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,21 +1282,21 @@ def to(self, device: Union[str, torch.device]):
raise TypeError(
"The 'device' parameter must be a string or torch.device object."
)

if (not device == 'cpu') and (not device.startswith('cuda')) and (
not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)


if isinstance(device, str):
device = torch.device(device)
if (not device == 'cpu') and (not device.startswith('cuda')) and (
not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)

if (not device.type == 'cpu') and (
not device.type.startswith('cuda')) and (not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)
elif isinstance(device, torch.device):
if (not device.type == 'cpu') and (
not device.type.startswith('cuda')) and (not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)
device = device.type

if hasattr(self, "device_"):
self.device_ = device
Expand Down
64 changes: 50 additions & 14 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,39 @@ def get_ordered_cuda_devices():
ordered_cuda_devices = get_ordered_cuda_devices() if torch.cuda.is_available(
) else []

def test_fit_after_moving_to_device():
expected_device = 'cpu'
expected_type = type(expected_device)

X = np.random.uniform(0, 1, (10, 5))
cebra_model = cebra_sklearn_cebra.CEBRA(model_architecture="offset1-model",
max_iterations=5,
device=expected_device)

assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device

cebra_model.partial_fit(X)
assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device
if hasattr(cebra_model, 'device_'):
assert type(cebra_model.device_) == expected_type
assert cebra_model.device_ == expected_device

# Move the model to device using the to() method
cebra_model.to('cpu')
assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device
if hasattr(cebra_model, 'device_'):
assert type(cebra_model.device_) == expected_type
assert cebra_model.device_ == expected_device

cebra_model.partial_fit(X)
assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device
if hasattr(cebra_model, 'device_'):
assert type(cebra_model.device_) == expected_type
assert cebra_model.device_ == expected_device

@pytest.mark.parametrize("device", ['cpu'] + ordered_cuda_devices)
def test_move_cpu_to_cuda_device(device):
Expand All @@ -875,9 +908,12 @@ def test_move_cpu_to_cuda_device(device):
new_device = 'cpu' if device.startswith('cuda') else 'cuda:0'
cebra_model.to(new_device)

assert cebra_model.device == torch.device(new_device)
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
new_device)
assert cebra_model.device == new_device
device_model = next(cebra_model.solver_.model.parameters()).device
device_str = str(device_model)
if device_model.type == 'cuda':
device_str = f'cuda:{device_model.index}'
assert device_str == new_device

with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
cebra_model.save(savefile.name)
Expand All @@ -903,9 +939,10 @@ def test_move_cpu_to_mps_device(device):
new_device = 'cpu' if device == 'mps' else 'mps'
cebra_model.to(new_device)

assert cebra_model.device == torch.device(new_device)
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
new_device)
assert cebra_model.device == new_device

device_model = next(cebra_model.solver_.model.parameters()).device
assert device_model.type == new_device

with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
cebra_model.save(savefile.name)
Expand Down Expand Up @@ -939,9 +976,12 @@ def test_move_mps_to_cuda_device(device):
new_device = 'mps' if device.startswith('cuda') else 'cuda:0'
cebra_model.to(new_device)

assert cebra_model.device == torch.device(new_device)
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
new_device)
assert cebra_model.device == new_device
device_model = next(cebra_model.solver_.model.parameters()).device
device_str = str(device_model)
if device_model.type == 'cuda':
device_str = f'cuda:{device_model.index}'
assert device_str == new_device

with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
cebra_model.save(savefile.name)
Expand All @@ -963,11 +1003,7 @@ def test_mps():

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
torch.backends.mps.is_available = lambda: False
with pytest.raises(ValueError):
cebra_model.fit(X)

torch.backends.mps.is_available = lambda: True
torch.backends.mps.is_built = lambda: False

with pytest.raises(ValueError):
cebra_model.fit(X)

Expand Down