From a4ba56f4aaa812947b21f8917be20a20645a9b8f Mon Sep 17 00:00:00 2001 From: Supavit Dumrongprechachan Date: Wed, 25 Jul 2018 14:18:32 +1000 Subject: [PATCH 1/3] Add bisection and brent's method for root finding I added two jitted robust root finding methods: Bisection and Brentq which are based on Scipy's version which is written in C. They both follow the previous jitted root finding procedures by returning a namedtuple with relevant information. I also added a basic test for each method. --- quantecon/optimize/__init__.py | 2 +- quantecon/optimize/root_finding.py | 246 +++++++++++++++++- quantecon/optimize/tests/test_root_finding.py | 30 ++- 3 files changed, 266 insertions(+), 12 deletions(-) diff --git a/quantecon/optimize/__init__.py b/quantecon/optimize/__init__.py index ad3267faf..de079da8b 100644 --- a/quantecon/optimize/__init__.py +++ b/quantecon/optimize/__init__.py @@ -3,4 +3,4 @@ """ from .scalar_maximization import brent_max -from .root_finding import newton, newton_halley, newton_secant +from .root_finding import * diff --git a/quantecon/optimize/root_finding.py b/quantecon/optimize/root_finding.py index e88e7de51..cb5c7b11c 100644 --- a/quantecon/optimize/root_finding.py +++ b/quantecon/optimize/root_finding.py @@ -2,13 +2,17 @@ from numba import jit, njit from collections import namedtuple -__all__ = ['newton', 'newton_halley', 'newton_secant'] +__all__ = ['newton', 'newton_halley', 'newton_secant', 'bisect', 'brentq'] _ECONVERGED = 0 _ECONVERR = -1 -results = namedtuple('results', - ('root function_calls iterations converged')) +_iter = 100 +_xtol = 2e-12 +_rtol = 4*np.finfo(float).eps + +results = namedtuple('results', 'root function_calls iterations converged') + @njit def _results(r): @@ -16,12 +20,13 @@ def _results(r): x, funcalls, iterations, flag = r return results(x, funcalls, iterations, flag == 0) + @njit def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50, disp=True): """ - Find a zero from the Newton-Raphson method using the jitted version of - Scipy's newton for scalars. Note that this does not provide an alternative + Find a zero from the Newton-Raphson method using the jitted version of + Scipy's newton for scalars. Note that this does not provide an alternative method such as secant. Thus, it is important that `fprime` can be provided. Note that `func` and `fprime` must be jitted via Numba. @@ -85,18 +90,19 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50, break newton_step = fval / fder # Newton step - p = p0 - newton_step + p = p0 - newton_step if abs(p - p0) < tol: status = _ECONVERGED break p0 = p - + if disp and status == _ECONVERR: msg = "Failed to converge" raise RuntimeError(msg) return _results((p, funcalls, itr + 1, status)) + @njit def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8, maxiter=50, disp=True): @@ -179,6 +185,7 @@ def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8, return _results((p, funcalls, itr + 1, status)) + @njit def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, disp=True): @@ -254,4 +261,227 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, msg = "Failed to converge" raise RuntimeError(msg) - return _results((p, funcalls, itr + 1, status)) \ No newline at end of file + return _results((p, funcalls, itr + 1, status)) + + +@njit +def _bisect_interval(a, b, fa, fb): + """Conditional checks for intervals in methods involving bisection""" + if fa*fb > 0: + raise ValueError("f(a) and f(b) must have different signs") + root = 0.0 + status = _ECONVERR + + # Root found at either end of [a,b] + if fa == 0: + root = a + status = _ECONVERGED + if fb == 0: + root = b + status = _ECONVERGED + + return root, status + + +@njit +def bisect(f, a, b, args=(), xtol=_xtol, + rtol=_rtol, maxiter=_iter, disp=True): + """ + Find root of a function within an interval adapted from Scipy's bisect. + + Basic bisection routine to find a zero of the function `f` between the + arguments `a` and `b`. `f(a)` and `f(b)` cannot have the same signs. + + `f` must be jitted via numba. + + Parameters + ---------- + f : jitted and callable + Python function returning a number. `f` must be continuous. + a : number + One end of the bracketing interval [a,b]. + b : number + The other end of the bracketing interval [a,b]. + args : tuple, optional + Extra arguments to be used in the function call. + xtol : number, optional + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The + parameter must be nonnegative. + rtol : number, optional + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. + maxiter : number, optional + Maximum number of iterations. + disp : bool, optional + If True, raise a RuntimeError if the algorithm didn't converge. + + Returns + ------- + results : namedtuple + + """ + + if xtol <= 0: + raise ValueError("xtol is too small (<= 0)") + + if maxiter < 1: + raise ValueError("maxiter must be greater than 0") + + # Convert to float + xa = a * 1.0 + xb = b * 1.0 + + fa = f(xa, *args) + fb = f(xb, *args) + funcalls = 2 + root, status = _bisect_interval(xa, xb, fa, fb) + + # Check for sign error and early termination + if status == _ECONVERGED: + itr = 0 + else: + # Perform bisection + dm = xb - xa + for itr in range(maxiter): + dm *= 0.5 + xm = xa + dm + fm = f(xm, *args) + funcalls += 1 + + if fm * fa >= 0: + xa = xm + + if fm == 0 or abs(dm) < xtol + rtol * abs(xm): + root = xm + status = _ECONVERGED + itr += 1 + break + + if disp and status == _ECONVERR: + raise RuntimeError("Failed to converge") + + return _results((root, funcalls, itr, status)) + + +@njit +def brentq(f, a, b, args=(), xtol=_xtol, + rtol=_rtol, maxiter=_iter, disp=True): + """ + Find a root of a function in a bracketing interval using Brent's method + adapted from Scipy's brentq. + + Uses the classic Brent's method to find a zero of the function `f` on + the sign changing interval [a , b]. + + `f` must be jitted via numba. + + Parameters + ---------- + f : jitted and callable + Python function returning a number. `f` must be continuous. + a : number + One end of the bracketing interval [a,b]. + b : number + The other end of the bracketing interval [a,b]. + args : tuple, optional + Extra arguments to be used in the function call. + xtol : number, optional + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The + parameter must be nonnegative. + rtol : number, optional + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. + maxiter : number, optional + Maximum number of iterations. + disp : bool, optional + If True, raise a RuntimeError if the algorithm didn't converge. + + Returns + ------- + results : namedtuple + + """ + if xtol <= 0: + raise ValueError("xtol is too small (<= 0)") + if maxiter < 1: + raise ValueError("maxiter must be greater than 0") + + # Convert to float + xpre = a * 1.0 + xcur = b * 1.0 + + fpre = f(xpre, *args) + fcur = f(xcur, *args) + funcalls = 2 + + root, status = _bisect_interval(xpre, xcur, fpre, fcur) + + # Check for sign error and early termination + if status == _ECONVERGED: + itr = 0 + else: + # Perform Brent's method + for itr in range(maxiter): + + if fpre * fcur < 0: + xblk = xpre + fblk = fpre + spre = scur = xcur - xpre + if abs(fblk) < abs(fcur): + xpre = xcur + xcur = xblk + xblk = xpre + + fpre = fcur + fcur = fblk + fblk = fpre + + delta = (xtol + rtol * abs(xcur)) / 2 + sbis = (xblk - xcur) / 2 + + # Root found + if fcur == 0 or abs(sbis) < delta: + status = _ECONVERGED + root = xcur + itr += 1 + break + + if abs(spre) > delta and abs(fcur) < abs(fpre): + if xpre == xblk: + # interpolate + stry = -fcur * (xcur - xpre) / (fcur - fpre) + else: + # extrapolate + dpre = (fpre - fcur) / (xpre - xcur) + dblk = (fblk - fcur) / (xblk - xcur) + stry = -fcur * (fblk * dblk - fpre * dpre) / \ + (dblk * dpre * (fblk - fpre)) + + if (2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta)): + # good short step + spre = scur + scur = stry + else: + # bisect + spre = sbis + scur = sbis + else: + # bisect + spre = sbis + scur = sbis + + xpre = xcur + fpre = fcur + if (abs(scur) > delta): + xcur += scur + else: + xcur += (delta if sbis > 0 else -delta) + fcur = f(xcur, *args) + funcalls += 1 + + if disp and status == _ECONVERR: + raise RuntimeError("Failed to converge") + + return _results((root, funcalls, itr, status)) diff --git a/quantecon/optimize/tests/test_root_finding.py b/quantecon/optimize/tests/test_root_finding.py index ad0015ce4..ede1b718e 100644 --- a/quantecon/optimize/tests/test_root_finding.py +++ b/quantecon/optimize/tests/test_root_finding.py @@ -2,7 +2,8 @@ from numpy.testing import assert_almost_equal, assert_allclose from numba import njit -from quantecon.optimize import newton, newton_halley, newton_secant +from quantecon.optimize import * + @njit def func(x): @@ -19,6 +20,7 @@ def func_prime(x): """ return (3*x**2) + @njit def func_prime2(x): """ @@ -26,6 +28,7 @@ def func_prime2(x): """ return 6*x + @njit def func_two(x): """ @@ -41,6 +44,7 @@ def func_two_prime(x): """ return 4*np.cos(4*(x - 1/4)) + 20*x**19 + 1 + @njit def func_two_prime2(x): """ @@ -67,8 +71,8 @@ def test_newton_basic_two(): true_fval = 1.0 fval = newton(func, 5, func_prime) assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0) - - + + def test_newton_hard(): """ Harder test for convergence. @@ -76,6 +80,7 @@ def test_newton_hard(): true_fval = 0.408 fval = newton(func_two, 0.4, func_two_prime) assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) + def test_halley_basic(): """ @@ -85,6 +90,7 @@ def test_halley_basic(): fval = newton_halley(func, 5, func_prime, func_prime2) assert_almost_equal(true_fval, fval.root, decimal=4) + def test_halley_hard(): """ Harder test for halley method @@ -93,6 +99,7 @@ def test_halley_hard(): fval = newton_halley(func_two, 0.4, func_two_prime, func_two_prime2) assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) + def test_secant_basic(): """ Basic test for secant option. @@ -111,8 +118,25 @@ def test_secant_hard(): assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) +def run_check(method, name): + a = -1 + b = np.sqrt(3) + true_fval = 0.408 + r = method(func_two, a, b) + assert_allclose(true_fval, r.root, atol=0.01, rtol=1e-5, + err_msg='method %s' % name) + + +def test_bisect_basic(): + run_check(bisect, 'bisect') + + +def test_brentq_basic(): + run_check(brentq, 'brentq') + # executing testcases. + if __name__ == '__main__': import sys import nose From 1b179405245bc7d7d6157528bd64e2b399491090 Mon Sep 17 00:00:00 2001 From: Supavit Dumrongprechachan Date: Wed, 25 Jul 2018 15:46:59 +1000 Subject: [PATCH 2/3] Fix import to list items --- quantecon/optimize/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantecon/optimize/__init__.py b/quantecon/optimize/__init__.py index de079da8b..0704bb4f8 100644 --- a/quantecon/optimize/__init__.py +++ b/quantecon/optimize/__init__.py @@ -3,4 +3,4 @@ """ from .scalar_maximization import brent_max -from .root_finding import * +from .root_finding import newton, newton_halley, newton_secant, bisect, brentq From 20b6437912ce31544db2d3d5671516f6dca3e18b Mon Sep 17 00:00:00 2001 From: Supavit Dumrongprechachan Date: Mon, 30 Jul 2018 14:09:33 +1000 Subject: [PATCH 3/3] Add default args in docstring and Format code I have included optional default arguments in the docstring of all functions. All trailing whitespaces are now fixed conforming to PEP8. The check is done via pycodestyle and http://pep8online.com/. --- quantecon/optimize/root_finding.py | 62 +++++++++++++++--------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/quantecon/optimize/root_finding.py b/quantecon/optimize/root_finding.py index cb5c7b11c..2734dbe08 100644 --- a/quantecon/optimize/root_finding.py +++ b/quantecon/optimize/root_finding.py @@ -1,5 +1,5 @@ import numpy as np -from numba import jit, njit +from numba import njit from collections import namedtuple __all__ = ['newton', 'newton_halley', 'newton_secant', 'bisect', 'brentq'] @@ -43,13 +43,13 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50, actual zero. fprime : callable and jitted The derivative of the function (when available and convenient). - args : tuple, optional + args : tuple, optional(default=()) Extra arguments to be used in the function call. - tol : float, optional + tol : float, optional(default=1.48e-8) The allowable error of the zero value. - maxiter : int, optional + maxiter : int, optional(default=50) Maximum number of iterations. - disp : bool, optional + disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge Returns @@ -90,7 +90,7 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50, break newton_step = fval / fder # Newton step - p = p0 - newton_step + p = p0 - newton_step if abs(p - p0) < tol: status = _ECONVERGED break @@ -125,13 +125,13 @@ def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8, The derivative of the function (when available and convenient). fprime2 : callable and jitted The second order derivative of the function - args : tuple, optional + args : tuple, optional(default=()) Extra arguments to be used in the function call. - tol : float, optional + tol : float, optional(default=1.48e-8) The allowable error of the zero value. - maxiter : int, optional + maxiter : int, optional(default=50) Maximum number of iterations. - disp : bool, optional + disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge Returns @@ -190,11 +190,11 @@ def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8, def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, disp=True): """ - Find a zero from the secant method using the jitted version of + Find a zero from the secant method using the jitted version of Scipy's secant method. - + Note that `func` must be jitted via Numba. - + Parameters ---------- func : callable and jitted @@ -204,13 +204,13 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, x0 : float An initial estimate of the zero that should be somewhere near the actual zero. - args : tuple, optional + args : tuple, optional(default=()) Extra arguments to be used in the function call. - tol : float, optional + tol : float, optional(default=1.48e-8) The allowable error of the zero value. - maxiter : int, optional + maxiter : int, optional(default=50) Maximum number of iterations. - disp : bool, optional + disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge. Returns @@ -221,17 +221,17 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, iterations - Number of iterations needed to find the root. converged - True if the routine converged """ - + if tol <= 0: raise ValueError("tol is too small <= 0") if maxiter < 1: raise ValueError("maxiter must be greater than 0") - + # Convert to float (don't use float(x0); this works also for complex x0) p0 = 1.0 * x0 funcalls = 0 status = _ECONVERR - + # Secant method if x0 >= 0: p1 = x0 * (1 + 1e-4) + 1e-4 @@ -256,7 +256,7 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, p1 = p q1 = func(p1, *args) funcalls += 1 - + if disp and status == _ECONVERR: msg = "Failed to converge" raise RuntimeError(msg) @@ -302,18 +302,18 @@ def bisect(f, a, b, args=(), xtol=_xtol, One end of the bracketing interval [a,b]. b : number The other end of the bracketing interval [a,b]. - args : tuple, optional + args : tuple, optional(default=()) Extra arguments to be used in the function call. - xtol : number, optional + xtol : number, optional(default=2e-12) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The parameter must be nonnegative. - rtol : number, optional + rtol : number, optional(default=4*np.finfo(float).eps) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. - maxiter : number, optional + maxiter : number, optional(default=100) Maximum number of iterations. - disp : bool, optional + disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge. Returns @@ -384,18 +384,18 @@ def brentq(f, a, b, args=(), xtol=_xtol, One end of the bracketing interval [a,b]. b : number The other end of the bracketing interval [a,b]. - args : tuple, optional + args : tuple, optional(default=()) Extra arguments to be used in the function call. - xtol : number, optional + xtol : number, optional(default=2e-12) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The parameter must be nonnegative. - rtol : number, optional + rtol : number, optional(default=4*np.finfo(float).eps) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. - maxiter : number, optional + maxiter : number, optional(default=100) Maximum number of iterations. - disp : bool, optional + disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge. Returns