diff --git a/python/geom.py b/python/geom.py index a258333ad..ce8d759cd 100644 --- a/python/geom.py +++ b/python/geom.py @@ -17,9 +17,9 @@ def check_nonnegative(prop, val): class Vector3(object): def __init__(self, x=0.0, y=0.0, z=0.0): - self.x = float(x) - self.y = float(y) - self.z = float(z) + self.x = float(x) if type(x) is int else x + self.y = float(y) if type(y) is int else y + self.z = float(z) if type(z) is int else z def __eq__(self, other): return self.x == other.x and self.y == other.y and self.z == other.z @@ -64,12 +64,37 @@ def scale(self, s): def dot(self, v): return self.x * v.x + self.y * v.y + self.z * v.z + def cdot(self, v): + conj_vec = Vector3(self.x.conjugate(), + self.y.conjugate(), + self.z.conjugate()) + return conj_vec.dot(v) + + def cross(self, v): + x = self.y * v.z - self.z * v.y + y = self.z * v.x - self.x * v.z + z = self.x * v.y - self.y * v.x + + return Vector3(x, y, z) + def norm(self): return math.sqrt(abs(self.dot(self))) def unit(self): return self.scale(1 / self.norm()) + def close(self, v, tol=1.0e-7): + return (abs(self.x - v.x) <= tol and + abs(self.y - v.y) <= tol and + abs(self.z - v.z) <= tol) + + def rotate(self, axis, theta): + u = axis.unit() + vpar = u.scale(u.dot(self)) + vcross = u.cross(self) + vperp = self - vpar + return vpar + (vperp.scale(math.cos(theta)) + vcross.scale(math.sin(theta))) + class Medium(object): diff --git a/python/tests/geom.py b/python/tests/geom.py index 95ec34190..7a5a86b59 100644 --- a/python/tests/geom.py +++ b/python/tests/geom.py @@ -1,4 +1,5 @@ import unittest +from math import pi import numpy as np import meep as mp import meep.geom as gm @@ -253,5 +254,39 @@ def test_use_as_numpy_array(self): self.assertTrue(type(res) is np.ndarray) np.testing.assert_array_equal(np.array([20, 20, 20]), res) + def test_cross(self): + v1 = mp.Vector3(x=1) + v2 = mp.Vector3(z=1) + self.assertEqual(v1.cross(v2), mp.Vector3(y=-1)) + + v1 = mp.Vector3(1, 1) + v2 = mp.Vector3(0, 1, 1) + self.assertEqual(v1.cross(v2), mp.Vector3(1, -1, 1)) + + def test_cdot(self): + complex_vec1 = mp.Vector3(complex(1, 1), complex(1, 1), complex(1, 1)) + complex_vec2 = mp.Vector3(complex(2, 2), complex(2, 2), complex(2, 2)) + + self.assertEqual(complex_vec1.cdot(complex_vec2), 12 + 0j) + + def test_rotate(self): + axis = mp.Vector3(z=1) + v = mp.Vector3(x=1) + res = v.rotate(axis, pi) + self.assertTrue(res.close(mp.Vector3(x=-1))) + + def test_close(self): + v1 = mp.Vector3(1e-7) + v2 = mp.Vector3(1e-8) + self.assertTrue(v1.close(v2)) + + v1 = mp.Vector3(1e-6) + v2 = mp.Vector3(1e-7) + self.assertFalse(v1.close(v2)) + + v1 = mp.Vector3(1e-10) + v2 = mp.Vector3(1e-11) + self.assertTrue(v1.close(v2, tol=1e-10)) + if __name__ == '__main__': unittest.main()