Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import logging
import random
import threading
import time
from dataclasses import dataclass
from dataclasses import field
from itertools import chain
Expand Down Expand Up @@ -1162,6 +1163,10 @@ def __init__(
for op in reversed(self.ops.values()):
op.setup(self.data_sampler)
self.splitting_lock = threading.Lock()
self.output_data_channels = list({
op.data_channel for op in self.ops.values()
if isinstance(op, DataOutputOperation)
})

def create_execution_tree(
self, descriptor: beam_fn_api_pb2.ProcessBundleDescriptor
Expand Down Expand Up @@ -1288,6 +1293,22 @@ def process_bundle(
for data_channel, expected_inputs in data_channels.items():
for element in data_channel.input_elements(instruction_id,
expected_inputs):
# Cooperative backpressure throttling
backpressure_sleep = 0.001
next_log_time = 0.0
while any(
channel.is_backpressured()
for channel in self.output_data_channels):
current_time = time.time()
if current_time >= next_log_time:
_LOGGER.warning(
"Outgoing data channel backpressured. "
"Throttling input reading for instruction: %s",
instruction_id)
next_log_time = current_time + 300.0
time.sleep(backpressure_sleep)
backpressure_sleep = min(backpressure_sleep * 2, 0.1)

# Since we have received a set of elements and are consuming it.
self.consuming_received_data = True
if isinstance(element, beam_fn_api_pb2.Elements.Timers):
Expand Down
39 changes: 33 additions & 6 deletions sdks/python/apache_beam/runners/worker/data_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ def close(self):
"""
raise NotImplementedError(type(self))

def is_backpressured(self):
# type: () -> bool
"""Returns True if the channel has exceeded its buffering limits."""
return False


class InMemoryDataChannel(DataChannel):
"""An in-memory implementation of a DataChannel.
Expand Down Expand Up @@ -453,9 +458,15 @@ class _GrpcDataChannel(DataChannel):

_WRITES_FINISHED = beam_fn_api_pb2.Elements.Data()

def __init__(self, data_buffer_time_limit_ms=0):
# type: (int) -> None
def __init__(self, data_buffer_time_limit_ms=0, max_queued_bytes=256 << 20):
# type: (int, int) -> None
self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
self._max_queued_bytes = max_queued_bytes
self._queued_bytes = 0
self._to_send_lock = threading.Lock()
assert self._WRITES_FINISHED.ByteSize() == 0, (
"WRITES_FINISHED sentinel must have a byte size of 0 to avoid "
"measurement drift in outbound queue size tracking.")
self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers]
self._received = collections.defaultdict(
lambda: queue.Queue(maxsize=5)
Expand All @@ -472,6 +483,18 @@ def __init__(self, data_buffer_time_limit_ms=0):
self._closed = False
self._exception = None # type: Optional[Exception]

def is_backpressured(self):
# type: () -> bool
with self._to_send_lock:
return self._queued_bytes > self._max_queued_bytes

def _put_to_send_queue(self, element):
# type: (Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> None
element_size = element.ByteSize()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might be expensive, but we could _element_weight from the other PR

with self._to_send_lock:
self._queued_bytes += element_size
self._to_send.put(element)

def close(self):
# type: () -> None
self._to_send.put(self._WRITES_FINISHED)
Expand Down Expand Up @@ -585,7 +608,7 @@ def output_stream(self, instruction_id, transform_id):
def add_to_send_queue(data):
# type: (bytes) -> None
if data:
self._to_send.put(
self._put_to_send_queue(
beam_fn_api_pb2.Elements.Data(
instruction_id=instruction_id,
transform_id=transform_id,
Expand All @@ -595,7 +618,7 @@ def close_callback(data):
# type: (bytes) -> None
add_to_send_queue(data)
# End of stream marker.
self._to_send.put(
self._put_to_send_queue(
beam_fn_api_pb2.Elements.Data(
instruction_id=instruction_id,
transform_id=transform_id,
Expand All @@ -614,7 +637,7 @@ def output_timer_stream(
def add_to_send_queue(timer):
# type: (bytes) -> None
if timer:
self._to_send.put(
self._put_to_send_queue(
beam_fn_api_pb2.Elements.Timers(
instruction_id=instruction_id,
transform_id=transform_id,
Expand All @@ -625,7 +648,7 @@ def add_to_send_queue(timer):
def close_callback(timer):
# type: (bytes) -> None
add_to_send_queue(timer)
self._to_send.put(
self._put_to_send_queue(
beam_fn_api_pb2.Elements.Timers(
instruction_id=instruction_id,
transform_id=transform_id,
Expand All @@ -640,12 +663,16 @@ def _write_outputs(self):
stream_done = False
while not stream_done:
streams = [self._to_send.get()]
with self._to_send_lock:
self._queued_bytes -= streams[0].ByteSize()
try:
# Coalesce up to 100 other items.
total_size_bytes = streams[0].ByteSize()
while (total_size_bytes < _DEFAULT_SIZE_FLUSH_THRESHOLD and
len(streams) <= 100):
data_or_timer = self._to_send.get_nowait()
with self._to_send_lock:
self._queued_bytes -= data_or_timer.ByteSize()
total_size_bytes += data_or_timer.ByteSize()
streams.append(data_or_timer)
except queue.Empty:
Expand Down
Loading