Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/runners/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ cdef class PerWindowInvoker(DoFnInvoker):
cdef dict kwargs_for_process_batch
cdef list placeholders_for_process_batch
cdef bint has_windowed_inputs
cdef bint cache_globally_windowed_args
cdef bint recalculate_window_args
cdef bint has_cached_window_args
cdef bint has_cached_window_batch_args
cdef object process_method
cdef object process_batch_method
cdef bint is_splittable
Expand Down
75 changes: 51 additions & 24 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,17 @@ def __init__(self,
self.current_window_index = None
self.stop_window_index = None

# TODO(https://github.com/apache/beam/issues/28776): Remove caching after
# fully rolling out.
# If true, always recalculate window args. If false, has_cached_window_args
# and has_cached_window_batch_args will be set to true if the corresponding
# self.args_for_process,have been updated and should be reused directly.
self.recalculate_window_args = (
self.has_windowed_inputs or 'disable_global_windowed_args_caching' in
RuntimeValueProvider.experiments)
self.has_cached_window_args = False
self.has_cached_window_batch_args = False

# Try to prepare all the arguments that can just be filled in
# without any additional work. in the process function.
# Also cache all the placeholders needed in the process function.
Expand Down Expand Up @@ -921,16 +932,23 @@ def _invoke_process_per_window(self,
additional_kwargs,
):
# type: (...) -> Optional[SplitResultResidual]
if self.has_windowed_inputs:
assert len(windowed_value.windows) <= 1
window, = windowed_value.windows
if self.has_cached_window_args:
args_for_process, kwargs_for_process = (
self.args_for_process, self.kwargs_for_process)
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process, kwargs_for_process = util.insert_values_in_args(
self.args_for_process, self.kwargs_for_process,
side_inputs)
if self.has_windowed_inputs:
assert len(windowed_value.windows) <= 1
window, = windowed_value.windows
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process, kwargs_for_process = util.insert_values_in_args(
self.args_for_process, self.kwargs_for_process, side_inputs)
if not self.recalculate_window_args:
self.args_for_process, self.kwargs_for_process = (
args_for_process, kwargs_for_process)
self.has_cached_window_args = True

# Extract key in the case of a stateful DoFn. Note that in the case of a
# stateful DoFn, we set during __init__ self.has_windowed_inputs to be
Expand Down Expand Up @@ -1012,20 +1030,29 @@ def _invoke_process_batch_per_window(
):
# type: (...) -> Optional[SplitResultResidual]

if self.has_windowed_inputs:
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
assert len(windowed_batch.windows) <= 1
window, = windowed_batch.windows
if self.has_cached_window_batch_args:
args_for_process_batch, kwargs_for_process_batch = (
self.args_for_process_batch, self.kwargs_for_process_batch)
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
(args_for_process_batch, kwargs_for_process_batch) = (
util.insert_values_in_args(
self.args_for_process_batch,
self.kwargs_for_process_batch,
side_inputs,
))
if self.has_windowed_inputs:
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
assert len(windowed_batch.windows) <= 1
window, = windowed_batch.windows
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process_batch, kwargs_for_process_batch = (
util.insert_values_in_args(
self.args_for_process_batch,
self.kwargs_for_process_batch,
side_inputs,
)
)
if not self.recalculate_window_args:
self.args_for_process_batch, self.kwargs_for_process_batch = (
args_for_process_batch, kwargs_for_process_batch)
self.has_cached_window_batch_args = True

for i, p in self.placeholders_for_process_batch:
if core.DoFn.ElementParam == p:
Expand Down Expand Up @@ -1541,8 +1568,8 @@ def __init__(self,
tagged_receivers, # type: Mapping[Optional[str], Receiver]
per_element_output_counter,
output_batch_converter, # type: Optional[BatchConverter]
process_yields_batches, # type: bool,
process_batch_yields_elements, # type: bool,
process_yields_batches, # type: bool
process_batch_yields_elements, # type: bool
):
"""Initializes ``_OutputHandler``.

Expand Down