diff --git a/CHANGELOG.md b/CHANGELOG.md index 311189757..2abd6ed91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,7 +36,7 @@ straightforward as possible. ### Changed -- +- ENH: Function Reverse Arithmetic Priority [#488](https://github.com/RocketPy-Team/RocketPy/pull/488) ### Fixed diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index e52fbb2fb..8bdad9d2b 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -24,6 +24,9 @@ class Function: extrapolation, plotting and algebra. """ + # Arithmetic priority + __array_ufunc__ = None + def __init__( self, source, @@ -1837,7 +1840,9 @@ def __add__(self, other): return Function(lambda x: (self.get_value(x) + other(x))) # If other is Float except... except AttributeError: - if isinstance(other, (float, int, complex)): + if isinstance( + other, (float, int, complex, np.ndarray, np.integer, np.floating) + ): # Check if Function object source is array or callable if isinstance(self.source, np.ndarray): # Operate on grid values @@ -1967,7 +1972,9 @@ def __mul__(self, other): return Function(lambda x: (self.get_value(x) * other(x))) # If other is Float except... except AttributeError: - if isinstance(other, (float, int, complex)): + if isinstance( + other, (float, int, complex, np.ndarray, np.integer, np.floating) + ): # Check if Function object source is array or callable if isinstance(self.source, np.ndarray): # Operate on grid values @@ -2056,7 +2063,9 @@ def __truediv__(self, other): return Function(lambda x: (self.get_value_opt(x) / other(x))) # If other is Float except... except AttributeError: - if isinstance(other, (float, int, complex)): + if isinstance( + other, (float, int, complex, np.ndarray, np.integer, np.floating) + ): # Check if Function object source is array or callable if isinstance(self.source, np.ndarray): # Operate on grid values @@ -2095,7 +2104,9 @@ def __rtruediv__(self, other): A Function object which gives the result of other(x)/self(x). """ # Check if Function object source is array and other is float - if isinstance(other, (float, int, complex)): + if isinstance( + other, (float, int, complex, np.ndarray, np.integer, np.floating) + ): if isinstance(self.source, np.ndarray): # Operate on grid values ys = other / self.y_array @@ -2163,7 +2174,9 @@ def __pow__(self, other): return Function(lambda x: (self.get_value_opt(x) ** other(x))) # If other is Float except... except AttributeError: - if isinstance(other, (float, int, complex)): + if isinstance( + other, (float, int, complex, np.ndarray, np.integer, np.floating) + ): # Check if Function object source is array or callable if isinstance(self.source, np.ndarray): # Operate on grid values @@ -2202,7 +2215,9 @@ def __rpow__(self, other): A Function object which gives the result of other(x)**self(x). """ # Check if Function object source is array and other is float - if isinstance(other, (float, int, complex)): + if isinstance( + other, (float, int, complex, np.ndarray, np.integer, np.floating) + ): if isinstance(self.source, np.ndarray): # Operate on grid values ys = other**self.y_array diff --git a/tests/test_function.py b/tests/test_function.py index c67a21b30..5362b0486 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -390,3 +390,83 @@ def test_shepard_interpolation(x, y, z_expected): func = Function(source=source, inputs=["x", "y"], outputs=["z"]) z = func(x, y) assert np.isclose(z, z_expected, atol=1e-8).all() + + +@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])]) +def test_sum_arithmetic_priority(other): + """Test the arithmetic priority of the add operation of the Function class, + specially comparing to the numpy array operations. + """ + func_lambda = Function(lambda x: x**2) + func_array = Function([(0, 0), (1, 1), (2, 4)]) + + assert isinstance(func_lambda + func_array, Function) + assert isinstance(func_array + func_lambda, Function) + assert isinstance(func_lambda + other, Function) + assert isinstance(other + func_lambda, Function) + assert isinstance(func_array + other, Function) + assert isinstance(other + func_array, Function) + + +@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])]) +def test_sub_arithmetic_priority(other): + """Test the arithmetic priority of the sub operation of the Function class, + specially comparing to the numpy array operations. + """ + func_lambda = Function(lambda x: x**2) + func_array = Function([(0, 0), (1, 1), (2, 4)]) + + assert isinstance(func_lambda - func_array, Function) + assert isinstance(func_array - func_lambda, Function) + assert isinstance(func_lambda - other, Function) + assert isinstance(other - func_lambda, Function) + assert isinstance(func_array - other, Function) + assert isinstance(other - func_array, Function) + + +@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])]) +def test_mul_arithmetic_priority(other): + """Test the arithmetic priority of the mul operation of the Function class, + specially comparing to the numpy array operations. + """ + func_lambda = Function(lambda x: x**2) + func_array = Function([(0, 0), (1, 1), (2, 4)]) + + assert isinstance(func_lambda * func_array, Function) + assert isinstance(func_array * func_lambda, Function) + assert isinstance(func_lambda * other, Function) + assert isinstance(other * func_lambda, Function) + assert isinstance(func_array * other, Function) + assert isinstance(other * func_array, Function) + + +@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])]) +def test_truediv_arithmetic_priority(other): + """Test the arithmetic priority of the truediv operation of the Function class, + specially comparing to the numpy array operations. + """ + func_lambda = Function(lambda x: x**2) + func_array = Function([(1, 1), (2, 4)]) + + assert isinstance(func_lambda / func_array, Function) + assert isinstance(func_array / func_lambda, Function) + assert isinstance(func_lambda / other, Function) + assert isinstance(other / func_lambda, Function) + assert isinstance(func_array / other, Function) + assert isinstance(other / func_array, Function) + + +@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])]) +def test_pow_arithmetic_priority(other): + """Test the arithmetic priority of the pow operation of the Function class, + specially comparing to the numpy array operations. + """ + func_lambda = Function(lambda x: x**2) + func_array = Function([(0, 0), (1, 1), (2, 4)]) + + assert isinstance(func_lambda**func_array, Function) + assert isinstance(func_array**func_lambda, Function) + assert isinstance(func_lambda**other, Function) + assert isinstance(other**func_lambda, Function) + assert isinstance(func_array**other, Function) + assert isinstance(other**func_array, Function)