Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/diffpy/srreal/tests/testpairquantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import unittest
import pickle
import numpy

from diffpy.srreal.pairquantity import PairQuantity
from diffpy.srreal.srreal_ext import BasePairQuantity
Expand Down Expand Up @@ -76,6 +77,22 @@ def test_setStructure(self):
return


def test_setPairMask_args(self):
"""check argument type handling in setPairMask
"""
spm = self.pq.setPairMask
gpm = self.pq.getPairMask
self.assertRaises(TypeError, spm, 0.0, 0, False)
self.assertRaises(TypeError, spm, numpy.complex(0.5), 0, False)
self.assertTrue(gpm(0, 0))
spm(numpy.int32(1), 0, True, others=False)
self.assertTrue(gpm(0, 1))
self.assertTrue(gpm(1, 0))
self.assertFalse(gpm(0, 0))
self.assertFalse(gpm(2, 7))
return


def test_getStructure(self):
"""check PairQuantity.getStructure()
"""
Expand Down
15 changes: 2 additions & 13 deletions src/extensions/srreal_converters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <diffpy/srreal/StructureAdapter.hpp>

#include "srreal_converters.hpp"
#include "srreal_validators.hpp"

#include "srreal_numpy_symbol.hpp"
// numpy/arrayobject.h needs to be included after srreal_numpy_symbol.hpp,
Expand Down Expand Up @@ -73,19 +74,6 @@ boost::python::object newNumPyArray(int dim, const int* sz, int typenum)
return rv;
}


bool isiterable(boost::python::object obj)
{
using namespace boost::python;
#if PY_MAJOR_VERSION >= 3
object Iterable = import("collections.abc").attr("Iterable");
#else
object Iterable = import("collections").attr("Iterable");
#endif
bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr()));
return rv;
}

} // namespace

namespace srrealmodule {
Expand Down Expand Up @@ -304,6 +292,7 @@ int extractint(boost::python::object obj)
if (PyArray_CheckScalar(pobj))
{
int rv = PyArray_PyIntAsInt(pobj);
if (rv == -1 && PyErr_Occurred()) python::throw_error_already_set();
return rv;
}
// nothing worked, call geti which will raise an exception
Expand Down
14 changes: 14 additions & 0 deletions src/extensions/srreal_validators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*****************************************************************************/

#include <boost/python/errors.hpp>
#include <boost/python/import.hpp>

#include "srreal_validators.hpp"

Expand All @@ -43,6 +44,19 @@ void ensure_non_negative(int value)
}
}


bool isiterable(boost::python::object obj)
{
using boost::python::import;
#if PY_MAJOR_VERSION >= 3
object Iterable = import("collections.abc").attr("Iterable");
#else
object Iterable = import("collections").attr("Iterable");
#endif
bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr()));
return rv;
}

} // namespace srrealmodule

// End of file
1 change: 1 addition & 0 deletions src/extensions/srreal_validators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace srrealmodule {

void ensure_index_bounds(int idx, int lo, int hi);
void ensure_non_negative(int value);
bool isiterable(boost::python::object obj);

} // namespace srrealmodule

Expand Down
11 changes: 6 additions & 5 deletions src/extensions/wrap_PairQuantity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#include "srreal_converters.hpp"
#include "srreal_pickling.hpp"
#include "srreal_validators.hpp"

#include <diffpy/srreal/PairQuantity.hpp>

Expand Down Expand Up @@ -488,13 +489,13 @@ void set_pair_mask(PairQuantity& obj,
python::object others)
{
if (!others.is_none()) mask_all_pairs(obj, others);
python::extract<int> geti(i);
python::extract<int> getj(j);
bool mask = msk;
// short circuit for normal call
if (geti.check() && getj.check())
// short circuit for normal call with scalar values
if (!isiterable(i) && !isiterable(j))
{
obj.setPairMask(geti(), getj(), mask);
const int i1 = extractint(i);
const int j1 = extractint(j);
obj.setPairMask(i1, j1, mask);
return;
}
std::vector<int> iindices = parsepairindex(i);
Expand Down