diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 2ed4960d..3ffcb6f4 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -53,6 +53,7 @@ jobs: python -m pip install -r develop.txt python -m pip install -r docs/requirements.txt python -m pip install astropy scikit-image scikit-learn + python -m pip install tensorflow>=2.4.1 python -m pip install twine python -m pip install . diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 26e1e4ea..873a4506 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -9,12 +9,14 @@ """ from builtins import range -from unittest import TestCase +from unittest import TestCase, skipIf import numpy as np import numpy.testing as npt from modopt.base import np_adjust, transform, types +from modopt.base.backend import (LIBRARIES, change_backend, get_array_module, + get_backend) class NPAdjustTestCase(TestCase): @@ -275,3 +277,53 @@ def test_check_npndarray(self): self.data3, dtype=np.integer, ) + + +class TestBackend(TestCase): + """Test the backend codes.""" + + def setUp(self): + """Set test parameter values.""" + self.input = np.array([10, 10]) + + @skipIf(LIBRARIES['tensorflow'] is None, 'tensorflow library not installed') + def test_tf_backend(self): + """Test tensorflow backend.""" + xp, backend = get_backend('tensorflow') + if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']: + raise AssertionError('tensorflow get_backend fails!') + tf_input = change_backend(self.input, 'tensorflow') + if ( + get_array_module(LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow'] + or get_array_module(tf_input) != LIBRARIES['tensorflow'] + ): + raise AssertionError('tensorflow backend fails!') + + @skipIf(LIBRARIES['cupy'] is None, 'cupy library not installed') + def test_cp_backend(self): + """Test cupy backend.""" + xp, backend = get_backend('cupy') + if backend != 'cupy' or xp != LIBRARIES['cupy']: + raise AssertionError('cupy get_backend fails!') + cp_input = change_backend(self.input, 'cupy') + if ( + get_array_module(LIBRARIES['cupy'].ones(1)) != LIBRARIES['cupy'] + or get_array_module(cp_input) != LIBRARIES['cupy'] + ): + raise AssertionError('cupy backend fails!') + + def test_np_backend(self): + """Test numpy backend.""" + xp, backend = get_backend('numpy') + if backend != 'numpy' or xp != LIBRARIES['numpy']: + raise AssertionError('numpy get_backend fails!') + np_input = change_backend(self.input, 'numpy') + if ( + get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy'] + or get_array_module(np_input) != LIBRARIES['numpy'] + ): + raise AssertionError('numpy backend fails!') + + def tearDown(self): + """Tear Down of objects.""" + self.input = None diff --git a/setup.cfg b/setup.cfg index 74fb3f79..d2f544f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,8 @@ per-file-ignores = modopt/signal/wavelet.py: S404,S603 #Todo: Clean up tests modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604 + #Todo: Import has bad parenthesis + modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301 #WPS Settings max-arguments = 25 max-attributes = 40