From 0832b074ec3fb7a077848a213429b978909f7828 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Thu, 25 Apr 2019 10:49:54 -0400 Subject: [PATCH 1/5] TST: exceptions for invalid setPairMask arguments And verify it can accept numpy.int. --- src/diffpy/srreal/tests/testpairquantity.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/diffpy/srreal/tests/testpairquantity.py b/src/diffpy/srreal/tests/testpairquantity.py index 3ff606fc..d5f2470f 100644 --- a/src/diffpy/srreal/tests/testpairquantity.py +++ b/src/diffpy/srreal/tests/testpairquantity.py @@ -5,6 +5,7 @@ import unittest import pickle +import numpy from diffpy.srreal.pairquantity import PairQuantity from diffpy.srreal.srreal_ext import BasePairQuantity @@ -76,6 +77,22 @@ def test_setStructure(self): return + def test_setPairMask_args(self): + """check argument type handling in setPairMask + """ + spm = self.pq.setPairMask + gpm = self.pq.getPairMask + self.assertRaises(TypeError, spm, 0.0, 0, False) + self.assertRaises(TypeError, spm, numpy.float32(0.5), 0, False) + self.assertTrue(gpm(0, 0)) + spm(numpy.int32(1), 0, True, others=False) + self.assertTrue(gpm(0, 1)) + self.assertTrue(gpm(1, 0)) + self.assertFalse(gpm(0, 0)) + self.assertFalse(gpm(2, 7)) + return + + def test_getStructure(self): """check PairQuantity.getStructure() """ From 24cdbd6eef46a782ca75798be1a5e16a7a68a15e Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Thu, 25 Apr 2019 11:06:52 -0400 Subject: [PATCH 2/5] BUG: raise error for invalid numpy int conversion --- src/extensions/srreal_converters.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/extensions/srreal_converters.cpp b/src/extensions/srreal_converters.cpp index 49dc183e..b93fc7a9 100644 --- a/src/extensions/srreal_converters.cpp +++ b/src/extensions/srreal_converters.cpp @@ -304,6 +304,7 @@ int extractint(boost::python::object obj) if (PyArray_CheckScalar(pobj)) { int rv = PyArray_PyIntAsInt(pobj); + if (rv == -1 && PyErr_Occurred()) python::throw_error_already_set(); return rv; } // nothing worked, call geti which will raise an exception From 374ab556bb9d7294c1ac6bd67cf466963a663069 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Thu, 25 Apr 2019 11:13:44 -0400 Subject: [PATCH 3/5] MNT: expose internal helper `isiterable` Move it to srreal_validators.cpp. --- src/extensions/srreal_converters.cpp | 14 +------------- src/extensions/srreal_validators.cpp | 14 ++++++++++++++ src/extensions/srreal_validators.hpp | 1 + 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/extensions/srreal_converters.cpp b/src/extensions/srreal_converters.cpp index b93fc7a9..4fefe4a7 100644 --- a/src/extensions/srreal_converters.cpp +++ b/src/extensions/srreal_converters.cpp @@ -31,6 +31,7 @@ #include #include "srreal_converters.hpp" +#include "srreal_validators.hpp" #include "srreal_numpy_symbol.hpp" // numpy/arrayobject.h needs to be included after srreal_numpy_symbol.hpp, @@ -73,19 +74,6 @@ boost::python::object newNumPyArray(int dim, const int* sz, int typenum) return rv; } - -bool isiterable(boost::python::object obj) -{ - using namespace boost::python; -#if PY_MAJOR_VERSION >= 3 - object Iterable = import("collections.abc").attr("Iterable"); -#else - object Iterable = import("collections").attr("Iterable"); -#endif - bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr())); - return rv; -} - } // namespace namespace srrealmodule { diff --git a/src/extensions/srreal_validators.cpp b/src/extensions/srreal_validators.cpp index 50742d06..4a3cb217 100644 --- a/src/extensions/srreal_validators.cpp +++ b/src/extensions/srreal_validators.cpp @@ -17,6 +17,7 @@ *****************************************************************************/ #include +#include #include "srreal_validators.hpp" @@ -43,6 +44,19 @@ void ensure_non_negative(int value) } } + +bool isiterable(boost::python::object obj) +{ + using boost::python::import; +#if PY_MAJOR_VERSION >= 3 + object Iterable = import("collections.abc").attr("Iterable"); +#else + object Iterable = import("collections").attr("Iterable"); +#endif + bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr())); + return rv; +} + } // namespace srrealmodule // End of file diff --git a/src/extensions/srreal_validators.hpp b/src/extensions/srreal_validators.hpp index ef703e60..f123498f 100644 --- a/src/extensions/srreal_validators.hpp +++ b/src/extensions/srreal_validators.hpp @@ -23,6 +23,7 @@ namespace srrealmodule { void ensure_index_bounds(int idx, int lo, int hi); void ensure_non_negative(int value); +bool isiterable(boost::python::object obj); } // namespace srrealmodule From 677a0be9975ab9449a4e5d2e0ab702ee7f27088b Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Thu, 25 Apr 2019 11:21:14 -0400 Subject: [PATCH 4/5] MNT: improve argument conversion in setPairMask Avoid round-about conversion of numpy.int scalars. --- src/extensions/wrap_PairQuantity.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/extensions/wrap_PairQuantity.cpp b/src/extensions/wrap_PairQuantity.cpp index d9fda7f7..d3403ce5 100644 --- a/src/extensions/wrap_PairQuantity.cpp +++ b/src/extensions/wrap_PairQuantity.cpp @@ -39,6 +39,7 @@ #include "srreal_converters.hpp" #include "srreal_pickling.hpp" +#include "srreal_validators.hpp" #include @@ -488,13 +489,13 @@ void set_pair_mask(PairQuantity& obj, python::object others) { if (!others.is_none()) mask_all_pairs(obj, others); - python::extract geti(i); - python::extract getj(j); bool mask = msk; - // short circuit for normal call - if (geti.check() && getj.check()) + // short circuit for normal call with scalar values + if (!isiterable(i) && !isiterable(j)) { - obj.setPairMask(geti(), getj(), mask); + const int i1 = extractint(i); + const int j1 = extractint(j); + obj.setPairMask(i1, j1, mask); return; } std::vector iindices = parsepairindex(i); From 12379de9461e230c0c48e8d52cb6c89a264c7926 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Thu, 25 Apr 2019 12:51:54 -0400 Subject: [PATCH 5/5] TST: use more invalid numpy scalar Old numpy versions convert numpy.float32 to integer. Make sure test raises TypeError from extractint. --- src/diffpy/srreal/tests/testpairquantity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffpy/srreal/tests/testpairquantity.py b/src/diffpy/srreal/tests/testpairquantity.py index d5f2470f..c54c27e9 100644 --- a/src/diffpy/srreal/tests/testpairquantity.py +++ b/src/diffpy/srreal/tests/testpairquantity.py @@ -83,7 +83,7 @@ def test_setPairMask_args(self): spm = self.pq.setPairMask gpm = self.pq.getPairMask self.assertRaises(TypeError, spm, 0.0, 0, False) - self.assertRaises(TypeError, spm, numpy.float32(0.5), 0, False) + self.assertRaises(TypeError, spm, numpy.complex(0.5), 0, False) self.assertTrue(gpm(0, 0)) spm(numpy.int32(1), 0, True, others=False) self.assertTrue(gpm(0, 1))