Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ include version
recursive-include spams_wrap *.h
recursive-include spams_wrap *.cpp

include spams/version

recursive-include spams/tests *.py
recursive-include spams/data *.png
recursive-include tests *.py
recursive-include data *.png
recursive-include doc *
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,26 @@ Manipulated objects are imported from numpy and scipy. Matrices should be stored

### Testing the interface

- From the command line (to be called from the project root directory):
```bash
python tests/test_spams.py -h # to get help
python tests/test_spams.py # will run all the tests
python tests/test_spams.py -h # print the man page
python tests/test_spams.py # run all the tests
```

- From Python (assuming `spams` package is installed):
```python
from spams.tests import test_spams

test_spams('-h') # print the man page
test_spams() # run all tests
test_spams(['sort', 'calcAAt']) # run specific tests
test_spams(python_exec='python3') # specify the python exec
```

- From the command line (assuming `spams` package is installed):
```bash
# c.f. previous point for the different options
python -c "from spams.tests import test_spams; test_spams()"
```

---
Expand Down
16 changes: 5 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# import platform
import sys

from setuptools import setup, Extension, find_packages
from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext

from distutils.sysconfig import get_python_inc
Expand Down Expand Up @@ -193,18 +193,12 @@ def get_extension():
python_requires='>=3',
install_requires=['Cython>=0.29', 'numpy>=1.12',
'Pillow>=6.0', 'scipy>=1.0', 'six>=1.12'],
packages=find_packages(),
packages=['myscipy_rand', 'spams_wrap', 'spams', 'spams.tests'],
cmdclass={'build_ext': CustomBuildExtCommand},
ext_modules=get_extension(),
data_files=[('tests', ['tests/test_spams.py',
'tests/test_decomp.py',
'tests/test_dictLearn.py',
'tests/test_linalg.py',
'tests/test_prox.py',
'tests/test_utils.py']),
('doc', ['doc/doc_spams.pdf']),
('tests', ['data/boat.png', 'data/lena.png'])],
include_package_data=True,
package_data={
"spams": ["data/*.png", "version"]
},
zip_safe=True
)

Expand Down
1 change: 1 addition & 0 deletions spams/data
1 change: 1 addition & 0 deletions spams/tests
1 change: 1 addition & 0 deletions spams/version
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .run import test_spams
59 changes: 59 additions & 0 deletions tests/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import subprocess
import os

def test_spams(args=None, python_exec = 'python'):
"""Run spams-python tests from python

To run the tests, from the command line, please check the README.md file

Input:
args (string|string list): argument to pass to the the main test script,
use `args='-h'` for more details. If None (default), all tests are
run.
python_exec (string): python executable. By default, it is `'python'`.
Python 3+ is required. On some system, you may have to set
`python_exec = 'python3'`.
"""

# check spams availability
try:
import spams
except:
raise ModuleNotFoundError("No module named 'spams'")

# main test file
test_file = os.path.join("tests", "test_spams.py")
if os.path.isfile(test_file):
test_file = os.path.abspath(test_file)
else:
test_file = os.path.join(
os.path.dirname(os.path.abspath(spams.__file__)), test_file
)

# test dir
test_dir = os.path.dirname(test_file)

# check python executable
try:
python_version = subprocess.run([python_exec, "-V"], capture_output=True)
except:
raise SystemError(f"{python_exec} is not a valid python executable")
# check python version
if not "Python 3." in python_version.stdout.decode('UTF-8'):
raise SystemError("Python 3+ is required, try using python_exec = 'python3'")

# check args
if args is None:
args = []
elif not (isinstance(args, str) or \
(isinstance(args, list) and \
all(isinstance(arg, str) for arg in args))):
raise TypeError("'args' input should be a string or a list of strings.")

if isinstance(args, str):
args = [args]

# run
subprocess.run([python_exec, 'test_spams.py'] + args, cwd = test_dir)


38 changes: 33 additions & 5 deletions tests/test_dictLearn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import sys
import os
import numpy as np
import scipy
import scipy.sparse as ssp
Expand All @@ -15,6 +16,33 @@
ssprand = myscipy_rand.rand
else:
ssprand = ssp.rand


def get_img_file_path(img):
"""Return path to an image file

Arguments:
img (string): image filename without path among 'boat.png' or
'lena.png'.

Output:
img_file (string): normalized path to image input filename.

"""
# check input
if not img in ["boat.png", "lena.png"]:
raise ValueError("bad input, `img` should be 'boat.png' or 'lena.png'")
# try local file
img_file = os.path.join("data", img)
if os.path.isfile(img_file):
img_file = os.path.abspath(img_file)
else:
# file from install
img_file = os.path.join(
os.path.dirname(os.path.abspath(spams.__file__)), img_file
)
# output
return img_file


def _extract_lasso_param(f_param):
Expand Down Expand Up @@ -46,8 +74,8 @@ def _objective(X, D, param, imgname=None):


def test_trainDL():
img_file = 'data/boat.png'
try:
img_file = get_img_file_path("boat.png")
img = Image.open(img_file)
except:
print("Cannot load image %s : skipping test" % img_file)
Expand Down Expand Up @@ -140,8 +168,8 @@ def test_trainDL():


def test_trainDL_Memory():
img_file = 'data/lena.png'
try:
img_file = get_img_file_path("lena.png")
img = Image.open(img_file)
except:
print("Cannot load image %s : skipping test" % img_file)
Expand Down Expand Up @@ -202,8 +230,8 @@ def test_trainDL_Memory():


def test_structTrainDL():
img_file = 'data/lena.png'
try:
img_file = get_img_file_path("lena.png")
img = Image.open(img_file)
except Exception as e:
print("Cannot load image %s (%s) : skipping test" % (img_file, e))
Expand Down Expand Up @@ -376,8 +404,8 @@ def test_structTrainDL():


def test_nmf():
img_file = 'data/boat.png'
try:
img_file = get_img_file_path("boat.png")
img = Image.open(img_file)
except:
print("Cannot load image %s : skipping test" % img_file)
Expand Down Expand Up @@ -413,8 +441,8 @@ def test_nmf():

# Archetypal Analysis, run first steps with FISTA and run last steps with activeSet,
def test_archetypalAnalysis():
img_file = 'data/lena.png'
try:
img_file = get_img_file_path("lena.png")
img = Image.open(img_file)
except Exception as e:
print("Cannot load image %s (%s) : skipping test" % (img_file, e))
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.6.5.1
2.6.5.2