diff --git a/src/diffpy/srreal/tests/testpairquantity.py b/src/diffpy/srreal/tests/testpairquantity.py index 3ff606fc..c54c27e9 100644 --- a/src/diffpy/srreal/tests/testpairquantity.py +++ b/src/diffpy/srreal/tests/testpairquantity.py @@ -5,6 +5,7 @@ import unittest import pickle +import numpy from diffpy.srreal.pairquantity import PairQuantity from diffpy.srreal.srreal_ext import BasePairQuantity @@ -76,6 +77,22 @@ def test_setStructure(self): return + def test_setPairMask_args(self): + """check argument type handling in setPairMask + """ + spm = self.pq.setPairMask + gpm = self.pq.getPairMask + self.assertRaises(TypeError, spm, 0.0, 0, False) + self.assertRaises(TypeError, spm, numpy.complex(0.5), 0, False) + self.assertTrue(gpm(0, 0)) + spm(numpy.int32(1), 0, True, others=False) + self.assertTrue(gpm(0, 1)) + self.assertTrue(gpm(1, 0)) + self.assertFalse(gpm(0, 0)) + self.assertFalse(gpm(2, 7)) + return + + def test_getStructure(self): """check PairQuantity.getStructure() """ diff --git a/src/extensions/srreal_converters.cpp b/src/extensions/srreal_converters.cpp index 49dc183e..4fefe4a7 100644 --- a/src/extensions/srreal_converters.cpp +++ b/src/extensions/srreal_converters.cpp @@ -31,6 +31,7 @@ #include #include "srreal_converters.hpp" +#include "srreal_validators.hpp" #include "srreal_numpy_symbol.hpp" // numpy/arrayobject.h needs to be included after srreal_numpy_symbol.hpp, @@ -73,19 +74,6 @@ boost::python::object newNumPyArray(int dim, const int* sz, int typenum) return rv; } - -bool isiterable(boost::python::object obj) -{ - using namespace boost::python; -#if PY_MAJOR_VERSION >= 3 - object Iterable = import("collections.abc").attr("Iterable"); -#else - object Iterable = import("collections").attr("Iterable"); -#endif - bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr())); - return rv; -} - } // namespace namespace srrealmodule { @@ -304,6 +292,7 @@ int extractint(boost::python::object obj) if (PyArray_CheckScalar(pobj)) { int rv = PyArray_PyIntAsInt(pobj); + if (rv == -1 && PyErr_Occurred()) python::throw_error_already_set(); return rv; } // nothing worked, call geti which will raise an exception diff --git a/src/extensions/srreal_validators.cpp b/src/extensions/srreal_validators.cpp index 50742d06..4a3cb217 100644 --- a/src/extensions/srreal_validators.cpp +++ b/src/extensions/srreal_validators.cpp @@ -17,6 +17,7 @@ *****************************************************************************/ #include +#include #include "srreal_validators.hpp" @@ -43,6 +44,19 @@ void ensure_non_negative(int value) } } + +bool isiterable(boost::python::object obj) +{ + using boost::python::import; +#if PY_MAJOR_VERSION >= 3 + object Iterable = import("collections.abc").attr("Iterable"); +#else + object Iterable = import("collections").attr("Iterable"); +#endif + bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr())); + return rv; +} + } // namespace srrealmodule // End of file diff --git a/src/extensions/srreal_validators.hpp b/src/extensions/srreal_validators.hpp index ef703e60..f123498f 100644 --- a/src/extensions/srreal_validators.hpp +++ b/src/extensions/srreal_validators.hpp @@ -23,6 +23,7 @@ namespace srrealmodule { void ensure_index_bounds(int idx, int lo, int hi); void ensure_non_negative(int value); +bool isiterable(boost::python::object obj); } // namespace srrealmodule diff --git a/src/extensions/wrap_PairQuantity.cpp b/src/extensions/wrap_PairQuantity.cpp index d9fda7f7..d3403ce5 100644 --- a/src/extensions/wrap_PairQuantity.cpp +++ b/src/extensions/wrap_PairQuantity.cpp @@ -39,6 +39,7 @@ #include "srreal_converters.hpp" #include "srreal_pickling.hpp" +#include "srreal_validators.hpp" #include @@ -488,13 +489,13 @@ void set_pair_mask(PairQuantity& obj, python::object others) { if (!others.is_none()) mask_all_pairs(obj, others); - python::extract geti(i); - python::extract getj(j); bool mask = msk; - // short circuit for normal call - if (geti.check() && getj.check()) + // short circuit for normal call with scalar values + if (!isiterable(i) && !isiterable(j)) { - obj.setPairMask(geti(), getj(), mask); + const int i1 = extractint(i); + const int j1 = extractint(j); + obj.setPairMask(i1, j1, mask); return; } std::vector iindices = parsepairindex(i);