33# Copyright (c) IPython Development Team.
44# Distributed under the terms of the Modified BSD License.
55
6+ from __future__ import annotations
7+
68import atexit
79import contextvars
810import io
1517from collections import defaultdict , deque
1618from io import StringIO , TextIOBase
1719from threading import Event , Thread , local
18- from typing import Any , Callable , Optional
20+ from typing import Any , Callable
1921
2022import zmq
2123from anyio import create_task_group , run , sleep , to_thread
2527# Globals
2628# -----------------------------------------------------------------------------
2729
28- MASTER = 0
29- CHILD = 1
30+ _PARENT = 0
31+ _CHILD = 1
3032
3133PIPE_BUFFER_SIZE = 1000
3234
@@ -87,9 +89,16 @@ def __init__(self, socket, pipe=False):
8789 Whether this process should listen for IOPub messages
8890 piped from subprocesses.
8991 """
90- self .socket = socket
92+ # ensure all of our sockets as sync zmq.Sockets
93+ # don't create async wrappers until we are within the appropriate coroutines
94+ self .socket : zmq .Socket [bytes ] | None = zmq .Socket (socket )
95+ if self .socket .context is None :
96+ # bug in pyzmq, shadow socket doesn't always inherit context attribute
97+ self .socket .context = socket .context # type:ignore[unreachable]
98+ self ._context = socket .context
99+
91100 self .background_socket = BackgroundSocket (self )
92- self ._master_pid = os .getpid ()
101+ self ._main_pid = os .getpid ()
93102 self ._pipe_flag = pipe
94103 if pipe :
95104 self ._setup_pipe_in ()
@@ -106,8 +115,7 @@ def __init__(self, socket, pipe=False):
106115
107116 def _setup_event_pipe (self ):
108117 """Create the PULL socket listening for events that should fire in this thread."""
109- ctx = self .socket .context
110- self ._pipe_in0 = ctx .socket (zmq .PULL )
118+ self ._pipe_in0 = self ._context .socket (zmq .PULL , socket_class = zmq .Socket )
111119 self ._pipe_in0 .linger = 0
112120
113121 _uuid = b2a_hex (os .urandom (16 )).decode ("ascii" )
@@ -141,8 +149,8 @@ def _event_pipe(self):
141149 event_pipe = self ._local .event_pipe
142150 except AttributeError :
143151 # new thread, new event pipe
144- ctx = zmq . Context ( self . socket . context )
145- event_pipe = ctx . socket (zmq .PUSH )
152+ # create sync base socket
153+ event_pipe = self . _context . socket (zmq .PUSH , socket_class = zmq . Socket )
146154 event_pipe .linger = 0
147155 event_pipe .connect (self ._event_interface )
148156 self ._local .event_pipe = event_pipe
@@ -161,9 +169,11 @@ async def _handle_event(self):
161169 Whenever *an* event arrives on the event stream,
162170 *all* waiting events are processed in order.
163171 """
172+ # create async wrapper within coroutine
173+ pipe_in = zmq .asyncio .Socket (self ._pipe_in0 )
164174 try :
165175 while True :
166- await self . _pipe_in0 .recv ()
176+ await pipe_in .recv ()
167177 # freeze event count so new writes don't extend the queue
168178 # while we are processing
169179 n_events = len (self ._events )
@@ -177,12 +187,12 @@ async def _handle_event(self):
177187
178188 def _setup_pipe_in (self ):
179189 """setup listening pipe for IOPub from forked subprocesses"""
180- ctx = self .socket . context
190+ ctx = self ._context
181191
182192 # use UUID to authenticate pipe messages
183193 self ._pipe_uuid = os .urandom (16 )
184194
185- self ._pipe_in1 = ctx .socket (zmq .PULL )
195+ self ._pipe_in1 = ctx .socket (zmq .PULL , socket_class = zmq . Socket )
186196 self ._pipe_in1 .linger = 0
187197
188198 try :
@@ -199,6 +209,8 @@ def _setup_pipe_in(self):
199209
200210 async def _handle_pipe_msgs (self ):
201211 """handle pipe messages from a subprocess"""
212+ # create async wrapper within coroutine
213+ self ._async_pipe_in1 = zmq .asyncio .Socket (self ._pipe_in1 )
202214 try :
203215 while True :
204216 await self ._handle_pipe_msg ()
@@ -209,8 +221,8 @@ async def _handle_pipe_msgs(self):
209221
210222 async def _handle_pipe_msg (self , msg = None ):
211223 """handle a pipe message from a subprocess"""
212- msg = msg or await self ._pipe_in1 .recv_multipart ()
213- if not self ._pipe_flag or not self ._is_master_process ():
224+ msg = msg or await self ._async_pipe_in1 .recv_multipart ()
225+ if not self ._pipe_flag or not self ._is_main_process ():
214226 return
215227 if msg [0 ] != self ._pipe_uuid :
216228 print ("Bad pipe message: %s" , msg , file = sys .__stderr__ )
@@ -225,14 +237,14 @@ def _setup_pipe_out(self):
225237 pipe_out .connect ("tcp://127.0.0.1:%i" % self ._pipe_port )
226238 return ctx , pipe_out
227239
228- def _is_master_process (self ):
229- return os .getpid () == self ._master_pid
240+ def _is_main_process (self ):
241+ return os .getpid () == self ._main_pid
230242
231243 def _check_mp_mode (self ):
232244 """check for forks, and switch to zmq pipeline if necessary"""
233- if not self ._pipe_flag or self ._is_master_process ():
234- return MASTER
235- return CHILD
245+ if not self ._pipe_flag or self ._is_main_process ():
246+ return _PARENT
247+ return _CHILD
236248
237249 def start (self ):
238250 """Start the IOPub thread"""
@@ -265,7 +277,8 @@ def close(self):
265277 self ._pipe_in0 .close ()
266278 if self ._pipe_flag :
267279 self ._pipe_in1 .close ()
268- self .socket .close ()
280+ if self .socket is not None :
281+ self .socket .close ()
269282 self .socket = None
270283
271284 @property
@@ -301,12 +314,12 @@ def _really_send(self, msg, *args, **kwargs):
301314 return
302315
303316 mp_mode = self ._check_mp_mode ()
304-
305- if mp_mode != CHILD :
306- # we are master, do a regular send
317+ if mp_mode != _CHILD :
318+ # we are the main parent process, do a regular send
319+ assert self . socket is not None
307320 self .socket .send_multipart (msg , * args , ** kwargs )
308321 else :
309- # we are a child, pipe to master
322+ # we are a child, pipe to parent process
310323 # new context/socket for every pipe-out
311324 # since forks don't teardown politely, use ctx.term to ensure send has completed
312325 ctx , pipe_out = self ._setup_pipe_out ()
@@ -379,7 +392,7 @@ class OutStream(TextIOBase):
379392 flush_interval = 0.2
380393 topic = None
381394 encoding = "UTF-8"
382- _exc : Optional [ Any ] = None
395+ _exc : Any = None
383396
384397 def fileno (self ):
385398 """
@@ -477,7 +490,7 @@ def __init__(
477490 self ._thread_to_parent = {}
478491 self ._thread_to_parent_header = {}
479492 self ._parent_header_global = {}
480- self ._master_pid = os .getpid ()
493+ self ._main_pid = os .getpid ()
481494 self ._flush_pending = False
482495 self ._subprocess_flush_pending = False
483496 self ._buffer_lock = threading .RLock ()
@@ -569,8 +582,8 @@ def _setup_stream_redirects(self, name):
569582 self .watch_fd_thread .daemon = True
570583 self .watch_fd_thread .start ()
571584
572- def _is_master_process (self ):
573- return os .getpid () == self ._master_pid
585+ def _is_main_process (self ):
586+ return os .getpid () == self ._main_pid
574587
575588 def set_parent (self , parent ):
576589 """Set the parent header."""
@@ -674,7 +687,7 @@ def _flush(self):
674687 ident = self .topic ,
675688 )
676689
677- def write (self , string : str ) -> Optional [ int ]: # type:ignore[override]
690+ def write (self , string : str ) -> int :
678691 """Write to current stream after encoding if necessary
679692
680693 Returns
@@ -700,15 +713,15 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
700713 msg = "I/O operation on closed file"
701714 raise ValueError (msg )
702715
703- is_child = not self ._is_master_process ()
716+ is_child = not self ._is_main_process ()
704717 # only touch the buffer in the IO thread to avoid races
705718 with self ._buffer_lock :
706719 self ._buffers [frozenset (parent .items ())].write (string )
707720 if is_child :
708721 # mp.Pool cannot be trusted to flush promptly (or ever),
709722 # and this helps.
710723 if self ._subprocess_flush_pending :
711- return None
724+ return 0
712725 self ._subprocess_flush_pending = True
713726 # We can not rely on self._io_loop.call_later from a subprocess
714727 self .pub_thread .schedule (self ._flush )
0 commit comments