Skip to content

Commit b1c2911

Browse files
authored
fix: Avoid event loop starvation if user has tight loops in streaming UDFs (#352)
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent 9cd8d2b commit b1c2911

4 files changed

Lines changed: 391 additions & 3 deletions

File tree

packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,8 @@ async def _process_inputs(
108108
self.background_tasks.add(msg_task)
109109
msg_task.add_done_callback(self.background_tasks.discard)
110110

111-
# wait for all tasks to complete
112-
for task in self.background_tasks:
113-
await task
111+
# Wait for all tasks to complete concurrently
112+
await asyncio.gather(*self.background_tasks)
114113

115114
# send an EOF to result queue to indicate that all tasks have completed
116115
await result_queue.put(STREAM_EOF)

packages/pynumaflow/pynumaflow/shared/asynciter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,13 @@ async def read_iterator(self) -> AsyncIterator[T]:
2525

2626
async def put(self, item: T) -> None:
2727
await self._queue.put(item)
28+
# Yield to the event loop after each put. The underlying
29+
# asyncio.Queue is unbounded (maxsize=0), so Queue.put() never
30+
# actually suspends — it calls sync put_nowait() under the hood.
31+
# If the UDF async generator yields messages via a sync for-loop
32+
# (no await between yields), the event loop is starved and
33+
# consumer tasks (including gRPC streaming) cannot make progress
34+
# until the generator completes. The sleep(0) ensures the event
35+
# loop gets a turn after every put regardless of the caller's code.
36+
# See: https://github.com/numaproj/numaflow-python/issues/350
37+
await asyncio.sleep(0)
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""
2+
Test that AccumulatorAsyncServer streams messages incrementally even when the
3+
user handler writes to the output queue in a tight loop (no meaningful await
4+
between puts).
5+
6+
Regression test for https://github.com/numaproj/numaflow-python/issues/350
7+
8+
Root cause: The SDK's write_to_global_queue reads from the per-task
9+
NonBlockingIterator and writes to the global result queue. Both are backed by
10+
unbounded asyncio.Queues, so neither await truly suspends. This starves the
11+
consumer task that reads from the global queue and streams responses to gRPC,
12+
causing all messages to arrive at once after the handler completes.
13+
14+
Fix: asyncio.sleep(0) after each put in write_to_global_queue.
15+
"""
16+
17+
import logging
18+
import threading
19+
import time
20+
from collections.abc import AsyncIterable
21+
22+
import grpc
23+
import pytest
24+
25+
from pynumaflow import setup_logging
26+
from pynumaflow.accumulator import (
27+
Message,
28+
Datum,
29+
AccumulatorAsyncServer,
30+
Accumulator,
31+
)
32+
from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc
33+
from pynumaflow.shared.asynciter import NonBlockingIterator
34+
from tests.conftest import create_async_loop, start_async_server, teardown_async_server
35+
from tests.testing_utils import (
36+
mock_message,
37+
mock_interval_window_start,
38+
mock_interval_window_end,
39+
get_time_args,
40+
)
41+
42+
LOGGER = setup_logging(__name__)
43+
44+
pytestmark = pytest.mark.integration
45+
46+
SOCK_PATH = "unix:///tmp/accumulator_streaming.sock"
47+
48+
NUM_MESSAGES = 5
49+
PRODUCE_INTERVAL_SECS = 0.2
50+
51+
52+
class SlowStreamingAccumulator(Accumulator):
53+
"""
54+
Handler that produces messages from a background thread with a delay
55+
between each, and writes them to the output queue in a tight loop.
56+
This mirrors the pattern from issue #342/#350: the user's code has no
57+
meaningful await between output.put() calls, and the messages are
58+
produced slowly by a background thread.
59+
"""
60+
61+
def __init__(self):
62+
pass
63+
64+
async def handler(self, datums: AsyncIterable[Datum], output: NonBlockingIterator):
65+
# Consume all datums first (required by the protocol)
66+
keys = []
67+
async for datum in datums:
68+
keys = datum.keys
69+
70+
# Now produce messages from a background thread with delays,
71+
# and write them to output in a tight loop (no await between puts)
72+
from collections import deque
73+
74+
messages: deque[Message] = deque()
75+
76+
def _produce():
77+
for i in range(NUM_MESSAGES):
78+
messages.append(Message(f"msg-{i}".encode(), keys=keys))
79+
time.sleep(PRODUCE_INTERVAL_SECS)
80+
81+
thread = threading.Thread(target=_produce)
82+
thread.start()
83+
84+
while thread.is_alive():
85+
# Tight loop: no await between puts — triggers starvation
86+
while messages:
87+
await output.put(messages.popleft())
88+
89+
thread.join()
90+
while messages:
91+
await output.put(messages.popleft())
92+
93+
94+
def request_generator(count, request):
95+
for i in range(count):
96+
if i == 0:
97+
request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.OPEN
98+
else:
99+
request.operation.event = (
100+
accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND
101+
)
102+
yield request
103+
104+
105+
def start_request():
106+
event_time_timestamp, watermark_timestamp = get_time_args()
107+
window = accumulator_pb2.KeyedWindow(
108+
start=mock_interval_window_start(),
109+
end=mock_interval_window_end(),
110+
slot="slot-0",
111+
keys=["test_key"],
112+
)
113+
payload = accumulator_pb2.Payload(
114+
keys=["test_key"],
115+
value=mock_message(),
116+
event_time=event_time_timestamp,
117+
watermark=watermark_timestamp,
118+
id="test_id",
119+
)
120+
operation = accumulator_pb2.AccumulatorRequest.WindowOperation(
121+
event=accumulator_pb2.AccumulatorRequest.WindowOperation.Event.OPEN,
122+
keyedWindow=window,
123+
)
124+
return accumulator_pb2.AccumulatorRequest(payload=payload, operation=operation)
125+
126+
127+
async def _start_server(udfs):
128+
server = grpc.aio.server()
129+
accumulator_pb2_grpc.add_AccumulatorServicer_to_server(udfs, server)
130+
server.add_insecure_port(SOCK_PATH)
131+
logging.info("Starting server on %s", SOCK_PATH)
132+
await server.start()
133+
return server, SOCK_PATH
134+
135+
136+
@pytest.fixture(scope="module")
137+
def streaming_server():
138+
loop = create_async_loop()
139+
server_obj = AccumulatorAsyncServer(SlowStreamingAccumulator)
140+
udfs = server_obj.servicer
141+
server = start_async_server(loop, _start_server(udfs))
142+
yield loop
143+
teardown_async_server(loop, server)
144+
145+
146+
@pytest.fixture()
147+
def streaming_stub(streaming_server):
148+
return accumulator_pb2_grpc.AccumulatorStub(grpc.insecure_channel(SOCK_PATH))
149+
150+
151+
def test_accumulator_messages_stream_incrementally(streaming_stub):
152+
"""
153+
Verify that messages are streamed to the client as they are produced,
154+
not batched until the handler completes.
155+
156+
The handler produces NUM_MESSAGES messages with PRODUCE_INTERVAL_SECS between
157+
each. If streaming works, the first message should arrive well before the
158+
last one is produced (total production time = NUM_MESSAGES * PRODUCE_INTERVAL_SECS).
159+
"""
160+
request = start_request()
161+
generator_response = streaming_stub.AccumulateFn(
162+
request_iterator=request_generator(count=1, request=request)
163+
)
164+
165+
# Collect messages with their arrival timestamps
166+
arrival_times = []
167+
result_count = 0
168+
for msg in generator_response:
169+
if hasattr(msg, "payload") and msg.payload.value:
170+
arrival_times.append(time.monotonic())
171+
result_count += 1
172+
173+
assert result_count == NUM_MESSAGES, f"Expected {NUM_MESSAGES} messages, got {result_count}"
174+
175+
# If messages streamed incrementally, the time span between the first and
176+
# last arrival should be a significant portion of the total production time.
177+
# If they were batched, they'd all arrive within a few milliseconds of each other.
178+
total_production_time = NUM_MESSAGES * PRODUCE_INTERVAL_SECS
179+
first_to_last = arrival_times[-1] - arrival_times[0]
180+
181+
# The spread should be at least 40% of production time if streaming works.
182+
# If batched, the spread would be near zero (~1-5ms).
183+
min_expected_spread = total_production_time * 0.4
184+
assert first_to_last >= min_expected_spread, (
185+
f"Messages arrived too close together ({first_to_last:.3f}s spread), "
186+
f"expected at least {min_expected_spread:.3f}s. "
187+
f"This indicates messages were batched instead of streamed. "
188+
f"Arrival gaps: {[f'{b - a:.3f}s' for a, b in zip(arrival_times, arrival_times[1:])]}"
189+
)

0 commit comments

Comments
 (0)