diff --git a/flask_babel/__init__.py b/flask_babel/__init__.py index bddb77b..58ebcee 100644 --- a/flask_babel/__init__.py +++ b/flask_babel/__init__.py @@ -251,10 +251,13 @@ def get_locale(): locale = getattr(ctx, 'babel_locale', None) if locale is None: babel = current_app.extensions['babel'] - if babel.locale_selector_func is None: + locale_selector_func = getattr(ctx, 'locale_selector_func', None) + if locale_selector_func is None: + locale_selector_func = babel.locale_selector_func + if locale_selector_func is None: locale = babel.default_locale else: - rv = babel.locale_selector_func() + rv = locale_selector_func() if rv is None: locale = babel.default_locale else: @@ -325,20 +328,18 @@ def force_locale(locale): yield return - babel = current_app.extensions['babel'] - - orig_locale_selector_func = babel.locale_selector_func + orig_locale_selector_func = getattr(ctx, 'locale_selector_func', None) orig_attrs = {} for key in ('babel_translations', 'babel_locale'): orig_attrs[key] = getattr(ctx, key, None) try: - babel.locale_selector_func = lambda: locale + ctx.locale_selector_func = lambda: locale for key in orig_attrs: setattr(ctx, key, None) yield finally: - babel.locale_selector_func = orig_locale_selector_func + ctx.locale_selector_func = orig_locale_selector_func for key, value in orig_attrs.items(): setattr(ctx, key, value) diff --git a/tests/tests.py b/tests/tests.py index a987f05..2008ea3 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -204,6 +204,29 @@ def select_locale(): assert str(babel.get_locale()) == 'en_US' assert str(babel.get_locale()) == 'de_DE' + def test_force_locale_with_two_concurrent_requests(self): + app = flask.Flask(__name__) + b = babel.Babel(app) + + @b.localeselector + def select_locale(): + return 'de_DE' + + def make_request(forced_locale): + with app.test_request_context(): + assert str(babel.get_locale()) == 'de_DE' + with babel.force_locale(forced_locale): + yield + + request1 = make_request('en_US') + next(request1) + + request2 = make_request('en_US') + next(request2) + + next(request2, None) + next(request1, None) + class NumberFormattingTestCase(unittest.TestCase):