diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index dc0ca3f02b9bf6..deb5069321604b 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,6 +16,8 @@ import collections import collections.abc import concurrent.futures +import contextvars +import functools import heapq import itertools import logging @@ -728,7 +730,7 @@ def call_soon_threadsafe(self, callback, *args, context=None): self._write_to_self() return handle - def run_in_executor(self, executor, func, *args): + def run_in_executor(self, executor, func, *args, context=None): self._check_closed() if self._debug: self._check_callback(func, 'run_in_executor') @@ -737,8 +739,15 @@ def run_in_executor(self, executor, func, *args): if executor is None: executor = concurrent.futures.ThreadPoolExecutor() self._default_executor = executor + + if context is None: + context = contextvars.copy_context() + + if args: + func = functools.partial(func, *args) + return futures.wrap_future( - executor.submit(func, *args), loop=self) + executor.submit(context.run, func), loop=self) def set_default_executor(self, executor): self._default_executor = executor diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 11cd950df1cedb..fbd4bcdfac4711 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -2,6 +2,7 @@ import collections.abc import concurrent.futures +import contextvars import functools import io import os @@ -36,6 +37,9 @@ from test.test_asyncio import utils as test_utils from test import support +foo_ctx = contextvars.ContextVar('foo') +foo_ctx.set('bar') + def tearDownModule(): asyncio.set_event_loop_policy(None) @@ -369,6 +373,46 @@ def run(): time.sleep(0.4) self.assertFalse(called) + def test_run_in_executor_hierarchy(self): + def run(): + foo_ctx.set('foo') + res = foo_ctx.get() + self.assertEqual(res, 'foo') + return res + + f = self.loop.run_in_executor(None, run) + res = self.loop.run_until_complete(f) + self.assertEqual(res, 'foo') + + res = foo_ctx.get() + self.assertEqual(res, 'bar') + + def test_run_in_executor_no_context(self): + def run(): + return foo_ctx.get() + + f = self.loop.run_in_executor(None, run) + res = self.loop.run_until_complete(f) + self.assertEqual(res, 'bar') + + def test_run_in_executor_context(self): + def run(): + return foo_ctx.get() + + context = contextvars.copy_context() + f = self.loop.run_in_executor(None, run, context=context) + res = self.loop.run_until_complete(f) + self.assertEqual(res, 'bar') + + def test_run_in_executor_context_args(self): + def run(arg): + return (arg, foo_ctx.get()) + + context = contextvars.copy_context() + f = self.loop.run_in_executor(None, run, 'yo', context=context) + res = self.loop.run_until_complete(f) + self.assertEqual(res, ('yo', 'bar')) + def test_reader_callback(self): r, w = socket.socketpair() r.setblocking(False) diff --git a/Misc/NEWS.d/next/Library/2018-07-01-02-37-05.bpo-34014.RfrJGJ.rst b/Misc/NEWS.d/next/Library/2018-07-01-02-37-05.bpo-34014.RfrJGJ.rst new file mode 100644 index 00000000000000..28c0a9c504da4a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-07-01-02-37-05.bpo-34014.RfrJGJ.rst @@ -0,0 +1 @@ +Added support of contextvars for BaseEventLoop.run_in_executor