From 2fefbfad8a6f16965443e38e4b5da890cb63795b Mon Sep 17 00:00:00 2001 From: Valentyn Tymofieiev Date: Fri, 8 May 2026 12:32:31 -0700 Subject: [PATCH] Throttle reads on data channel when writes on output channel are backpressured. --- .../runners/worker/bundle_processor.py | 21 ++++++++++ .../apache_beam/runners/worker/data_plane.py | 39 ++++++++++++++++--- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index faa756d7c5c5..b5eed50e5757 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -32,6 +32,7 @@ import logging import random import threading +import time from dataclasses import dataclass from dataclasses import field from itertools import chain @@ -1162,6 +1163,10 @@ def __init__( for op in reversed(self.ops.values()): op.setup(self.data_sampler) self.splitting_lock = threading.Lock() + self.output_data_channels = list({ + op.data_channel for op in self.ops.values() + if isinstance(op, DataOutputOperation) + }) def create_execution_tree( self, descriptor: beam_fn_api_pb2.ProcessBundleDescriptor @@ -1288,6 +1293,22 @@ def process_bundle( for data_channel, expected_inputs in data_channels.items(): for element in data_channel.input_elements(instruction_id, expected_inputs): + # Cooperative backpressure throttling + backpressure_sleep = 0.001 + next_log_time = 0.0 + while any( + channel.is_backpressured() + for channel in self.output_data_channels): + current_time = time.time() + if current_time >= next_log_time: + _LOGGER.warning( + "Outgoing data channel backpressured. " + "Throttling input reading for instruction: %s", + instruction_id) + next_log_time = current_time + 300.0 + time.sleep(backpressure_sleep) + backpressure_sleep = min(backpressure_sleep * 2, 0.1) + # Since we have received a set of elements and are consuming it. self.consuming_received_data = True if isinstance(element, beam_fn_api_pb2.Elements.Timers): diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index cbd28f8b0a3f..0cbc97ef3ce5 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -359,6 +359,11 @@ def close(self): """ raise NotImplementedError(type(self)) + def is_backpressured(self): + # type: () -> bool + """Returns True if the channel has exceeded its buffering limits.""" + return False + class InMemoryDataChannel(DataChannel): """An in-memory implementation of a DataChannel. @@ -453,9 +458,15 @@ class _GrpcDataChannel(DataChannel): _WRITES_FINISHED = beam_fn_api_pb2.Elements.Data() - def __init__(self, data_buffer_time_limit_ms=0): - # type: (int) -> None + def __init__(self, data_buffer_time_limit_ms=0, max_queued_bytes=256 << 20): + # type: (int, int) -> None self._data_buffer_time_limit_ms = data_buffer_time_limit_ms + self._max_queued_bytes = max_queued_bytes + self._queued_bytes = 0 + self._to_send_lock = threading.Lock() + assert self._WRITES_FINISHED.ByteSize() == 0, ( + "WRITES_FINISHED sentinel must have a byte size of 0 to avoid " + "measurement drift in outbound queue size tracking.") self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers] self._received = collections.defaultdict( lambda: queue.Queue(maxsize=5) @@ -472,6 +483,18 @@ def __init__(self, data_buffer_time_limit_ms=0): self._closed = False self._exception = None # type: Optional[Exception] + def is_backpressured(self): + # type: () -> bool + with self._to_send_lock: + return self._queued_bytes > self._max_queued_bytes + + def _put_to_send_queue(self, element): + # type: (Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> None + element_size = element.ByteSize() + with self._to_send_lock: + self._queued_bytes += element_size + self._to_send.put(element) + def close(self): # type: () -> None self._to_send.put(self._WRITES_FINISHED) @@ -585,7 +608,7 @@ def output_stream(self, instruction_id, transform_id): def add_to_send_queue(data): # type: (bytes) -> None if data: - self._to_send.put( + self._put_to_send_queue( beam_fn_api_pb2.Elements.Data( instruction_id=instruction_id, transform_id=transform_id, @@ -595,7 +618,7 @@ def close_callback(data): # type: (bytes) -> None add_to_send_queue(data) # End of stream marker. - self._to_send.put( + self._put_to_send_queue( beam_fn_api_pb2.Elements.Data( instruction_id=instruction_id, transform_id=transform_id, @@ -614,7 +637,7 @@ def output_timer_stream( def add_to_send_queue(timer): # type: (bytes) -> None if timer: - self._to_send.put( + self._put_to_send_queue( beam_fn_api_pb2.Elements.Timers( instruction_id=instruction_id, transform_id=transform_id, @@ -625,7 +648,7 @@ def add_to_send_queue(timer): def close_callback(timer): # type: (bytes) -> None add_to_send_queue(timer) - self._to_send.put( + self._put_to_send_queue( beam_fn_api_pb2.Elements.Timers( instruction_id=instruction_id, transform_id=transform_id, @@ -640,12 +663,16 @@ def _write_outputs(self): stream_done = False while not stream_done: streams = [self._to_send.get()] + with self._to_send_lock: + self._queued_bytes -= streams[0].ByteSize() try: # Coalesce up to 100 other items. total_size_bytes = streams[0].ByteSize() while (total_size_bytes < _DEFAULT_SIZE_FLUSH_THRESHOLD and len(streams) <= 100): data_or_timer = self._to_send.get_nowait() + with self._to_send_lock: + self._queued_bytes -= data_or_timer.ByteSize() total_size_bytes += data_or_timer.ByteSize() streams.append(data_or_timer) except queue.Empty: