Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ straightforward as possible.

### Changed

-
- ENH: Function Reverse Arithmetic Priority [#488](https://github.com/RocketPy-Team/RocketPy/pull/488)

### Fixed

Expand Down
27 changes: 21 additions & 6 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class Function:
extrapolation, plotting and algebra.
"""

# Arithmetic priority
__array_ufunc__ = None

def __init__(
self,
source,
Expand Down Expand Up @@ -1837,7 +1840,9 @@ def __add__(self, other):
return Function(lambda x: (self.get_value(x) + other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -1967,7 +1972,9 @@ def __mul__(self, other):
return Function(lambda x: (self.get_value(x) * other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -2056,7 +2063,9 @@ def __truediv__(self, other):
return Function(lambda x: (self.get_value_opt(x) / other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -2095,7 +2104,9 @@ def __rtruediv__(self, other):
A Function object which gives the result of other(x)/self(x).
"""
# Check if Function object source is array and other is float
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
if isinstance(self.source, np.ndarray):
# Operate on grid values
ys = other / self.y_array
Expand Down Expand Up @@ -2163,7 +2174,9 @@ def __pow__(self, other):
return Function(lambda x: (self.get_value_opt(x) ** other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -2202,7 +2215,9 @@ def __rpow__(self, other):
A Function object which gives the result of other(x)**self(x).
"""
# Check if Function object source is array and other is float
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
if isinstance(self.source, np.ndarray):
# Operate on grid values
ys = other**self.y_array
Expand Down
80 changes: 80 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,83 @@ def test_shepard_interpolation(x, y, z_expected):
func = Function(source=source, inputs=["x", "y"], outputs=["z"])
z = func(x, y)
assert np.isclose(z, z_expected, atol=1e-8).all()


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_sum_arithmetic_priority(other):
"""Test the arithmetic priority of the add operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda + func_array, Function)
assert isinstance(func_array + func_lambda, Function)
assert isinstance(func_lambda + other, Function)
assert isinstance(other + func_lambda, Function)
assert isinstance(func_array + other, Function)
assert isinstance(other + func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_sub_arithmetic_priority(other):
"""Test the arithmetic priority of the sub operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda - func_array, Function)
assert isinstance(func_array - func_lambda, Function)
assert isinstance(func_lambda - other, Function)
assert isinstance(other - func_lambda, Function)
assert isinstance(func_array - other, Function)
assert isinstance(other - func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_mul_arithmetic_priority(other):
"""Test the arithmetic priority of the mul operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda * func_array, Function)
assert isinstance(func_array * func_lambda, Function)
assert isinstance(func_lambda * other, Function)
assert isinstance(other * func_lambda, Function)
assert isinstance(func_array * other, Function)
assert isinstance(other * func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_truediv_arithmetic_priority(other):
"""Test the arithmetic priority of the truediv operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(1, 1), (2, 4)])

assert isinstance(func_lambda / func_array, Function)
assert isinstance(func_array / func_lambda, Function)
assert isinstance(func_lambda / other, Function)
assert isinstance(other / func_lambda, Function)
assert isinstance(func_array / other, Function)
assert isinstance(other / func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_pow_arithmetic_priority(other):
"""Test the arithmetic priority of the pow operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda**func_array, Function)
assert isinstance(func_array**func_lambda, Function)
assert isinstance(func_lambda**other, Function)
assert isinstance(other**func_lambda, Function)
assert isinstance(func_array**other, Function)
assert isinstance(other**func_array, Function)