55
66import asyncio
77import atexit
8+ import contextvars
89import io
910import os
1011import sys
1112import threading
1213import traceback
1314import warnings
1415from binascii import b2a_hex
15- from collections import deque
16+ from collections import defaultdict , deque
1617from io import StringIO , TextIOBase
1718from threading import local
1819from typing import Any , Callable , Deque , Dict , Optional
@@ -412,7 +413,7 @@ def __init__(
412413 name : str {'stderr', 'stdout'}
413414 the name of the standard stream to replace
414415 pipe : object
415- the pip object
416+ the pipe object
416417 echo : bool
417418 whether to echo output
418419 watchfd : bool (default, True)
@@ -446,13 +447,18 @@ def __init__(
446447 self .pub_thread = pub_thread
447448 self .name = name
448449 self .topic = b"stream." + name .encode ()
449- self .parent_header = {}
450+ self ._parent_header : contextvars .ContextVar [Dict [str , Any ]] = contextvars .ContextVar (
451+ "parent_header"
452+ )
453+ self ._parent_header .set ({})
454+ self ._thread_parents = {}
455+ self ._parent_header_global = {}
450456 self ._master_pid = os .getpid ()
451457 self ._flush_pending = False
452458 self ._subprocess_flush_pending = False
453459 self ._io_loop = pub_thread .io_loop
454460 self ._buffer_lock = threading .RLock ()
455- self ._buffer = StringIO ( )
461+ self ._buffers = defaultdict ( StringIO )
456462 self .echo = None
457463 self ._isatty = bool (isatty )
458464 self ._should_watch = False
@@ -495,6 +501,24 @@ def __init__(
495501 msg = "echo argument must be a file-like object"
496502 raise ValueError (msg )
497503
504+ @property
505+ def parent_header (self ):
506+ try :
507+ # asyncio-specific
508+ return self ._parent_header .get ()
509+ except LookupError :
510+ try :
511+ # thread-specific
512+ return self ._thread_parents [threading .current_thread ().ident ]
513+ except KeyError :
514+ # global (fallback)
515+ return self ._parent_header_global
516+
517+ @parent_header .setter
518+ def parent_header (self , value ):
519+ self ._parent_header_global = value
520+ return self ._parent_header .set (value )
521+
498522 def isatty (self ):
499523 """Return a bool indicating whether this is an 'interactive' stream.
500524
@@ -598,28 +622,28 @@ def _flush(self):
598622 if self .echo is not sys .__stderr__ :
599623 print (f"Flush failed: { e } " , file = sys .__stderr__ )
600624
601- data = self ._flush_buffer ()
602- if data :
603- # FIXME: this disables Session's fork-safe check,
604- # since pub_thread is itself fork-safe.
605- # There should be a better way to do this.
606- self .session .pid = os .getpid ()
607- content = {"name" : self .name , "text" : data }
608- msg = self .session .msg ("stream" , content , parent = self . parent_header )
609-
610- # Each transform either returns a new
611- # message or None. If None is returned,
612- # the message has been 'used' and we return.
613- for hook in self ._hooks :
614- msg = hook (msg )
615- if msg is None :
616- return
617-
618- self .session .send (
619- self .pub_thread ,
620- msg ,
621- ident = self .topic ,
622- )
625+ for parent , data in self ._flush_buffers ():
626+ if data :
627+ # FIXME: this disables Session's fork-safe check,
628+ # since pub_thread is itself fork-safe.
629+ # There should be a better way to do this.
630+ self .session .pid = os .getpid ()
631+ content = {"name" : self .name , "text" : data }
632+ msg = self .session .msg ("stream" , content , parent = parent )
633+
634+ # Each transform either returns a new
635+ # message or None. If None is returned,
636+ # the message has been 'used' and we return.
637+ for hook in self ._hooks :
638+ msg = hook (msg )
639+ if msg is None :
640+ return
641+
642+ self .session .send (
643+ self .pub_thread ,
644+ msg ,
645+ ident = self .topic ,
646+ )
623647
624648 def write (self , string : str ) -> Optional [int ]: # type:ignore[override]
625649 """Write to current stream after encoding if necessary
@@ -630,6 +654,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630654 number of items from input parameter written to stream.
631655
632656 """
657+ parent = self .parent_header
633658
634659 if not isinstance (string , str ):
635660 msg = f"write() argument must be str, not { type (string )} " # type:ignore[unreachable]
@@ -649,7 +674,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649674 is_child = not self ._is_master_process ()
650675 # only touch the buffer in the IO thread to avoid races
651676 with self ._buffer_lock :
652- self ._buffer .write (string )
677+ self ._buffers [ frozenset ( parent . items ())] .write (string )
653678 if is_child :
654679 # mp.Pool cannot be trusted to flush promptly (or ever),
655680 # and this helps.
@@ -675,19 +700,20 @@ def writable(self):
675700 """Test whether the stream is writable."""
676701 return True
677702
678- def _flush_buffer (self ):
703+ def _flush_buffers (self ):
679704 """clear the current buffer and return the current buffer data."""
680- buf = self ._rotate_buffer ()
681- data = buf .getvalue ()
682- buf .close ()
683- return data
705+ buffers = self ._rotate_buffers ()
706+ for frozen_parent , buffer in buffers .items ():
707+ data = buffer .getvalue ()
708+ buffer .close ()
709+ yield dict (frozen_parent ), data
684710
685- def _rotate_buffer (self ):
711+ def _rotate_buffers (self ):
686712 """Returns the current buffer and replaces it with an empty buffer."""
687713 with self ._buffer_lock :
688- old_buffer = self ._buffer
689- self ._buffer = StringIO ( )
690- return old_buffer
714+ old_buffers = self ._buffers
715+ self ._buffers = defaultdict ( StringIO )
716+ return old_buffers
691717
692718 @property
693719 def _hooks (self ):
0 commit comments