From f0b218f471010a2dfbbd3fa66cc9aac019f48d1e Mon Sep 17 00:00:00 2001 From: Kai Striega Date: Fri, 11 Feb 2022 10:42:54 +0800 Subject: [PATCH] ENH: Return close results in rate calculation --- numpy_financial/_financial.py | 18 +++++++++++------- numpy_financial/tests/test_financial.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/numpy_financial/_financial.py b/numpy_financial/_financial.py index 3e3183c..2554775 100644 --- a/numpy_financial/_financial.py +++ b/numpy_financial/_financial.py @@ -658,17 +658,21 @@ def rate(nper, pmt, pv, fv, when='end', guess=None, tol=None, maxiter=100): rn = guess iterator = 0 close = False - while (iterator < maxiter) and not close: + while (iterator < maxiter) and not np.all(close): rnp1 = rn - _g_div_gp(rn, nper, pmt, pv, fv, when) diff = abs(rnp1-rn) - close = np.all(diff < tol) + close = diff < tol iterator += 1 rn = rnp1 - if not close: - # Return nan's in array of the same shape as rn - return default_type(np.nan) + rn - else: - return rn + + if not np.all(close): + if np.isscalar(rn): + return default_type(np.nan) + else: + # Return nan's in array of the same shape as rn + # where the solution is not close to tol. + rn[~close] = np.nan + return rn def _roots(p): diff --git a/numpy_financial/tests/test_financial.py b/numpy_financial/tests/test_financial.py index 9648cb3..8099b0f 100644 --- a/numpy_financial/tests/test_financial.py +++ b/numpy_financial/tests/test_financial.py @@ -123,6 +123,19 @@ def test_rate_decimal(self): Decimal('10000')) assert_equal(Decimal('0.1106908537142689284704528100'), rate) + def test_gh48(self): + """ + Test the correct result is returned with only infeasible solutions + converted to nan. + """ + des = [-0.39920185, -0.02305873, -0.41818459, 0.26513414, numpy.nan] + nper = 2 + pmt = 0 + pv = [-593.06, -4725.38, -662.05, -428.78, -13.65] + fv = [214.07, 4509.97, 224.11, 686.29, -329.67] + actual = npf.rate(nper, pmt, pv, fv) + assert_allclose(actual, des) + class TestNpv: def test_npv(self):