diff --git a/modopt/base/backend.py b/modopt/base/backend.py index def77151..5fbe912f 100644 --- a/modopt/base/backend.py +++ b/modopt/base/backend.py @@ -95,7 +95,7 @@ def get_array_module(input_data): if LIBRARIES['tensorflow'] is not None: if isinstance(input_data, LIBRARIES['tensorflow'].ndarray): return LIBRARIES['tensorflow'] - elif LIBRARIES['cupy'] is not None: + if LIBRARIES['cupy'] is not None: if isinstance(input_data, LIBRARIES['cupy'].ndarray): return LIBRARIES['cupy'] return np