Skip to content

Commit a4d354f

Browse files
committed
test to ensure mapstream actually streams
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent 7025de9 commit a4d354f

1 file changed

Lines changed: 129 additions & 0 deletions

File tree

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
Test that MapStreamAsyncServer streams messages incrementally even when the
3+
user handler yields via a regular for-loop (no await between yields).
4+
5+
Regression test for https://github.com/numaproj/numaflow-python/issues/342
6+
7+
Root cause: asyncio.Queue.put() on an unbounded queue never suspends, so the
8+
MapFn consumer task was starved and couldn't stream responses to gRPC until
9+
the handler completed. Fix: asyncio.sleep(0) after each put in the servicer.
10+
"""
11+
12+
import logging
13+
import threading
14+
import time
15+
from collections import deque
16+
from collections.abc import AsyncIterable
17+
18+
import grpc
19+
import pytest
20+
21+
from pynumaflow import setup_logging
22+
from pynumaflow.mapstreamer import Datum, MapStreamAsyncServer, Message
23+
from pynumaflow.proto.mapper import map_pb2_grpc
24+
from tests.conftest import create_async_loop, start_async_server, teardown_async_server
25+
from tests.mapstream.utils import request_generator
26+
27+
LOGGER = setup_logging(__name__)
28+
29+
pytestmark = pytest.mark.integration
30+
31+
SOCK_PATH = "unix:///tmp/async_map_stream_streaming.sock"
32+
33+
NUM_MESSAGES = 5
34+
PRODUCE_INTERVAL_SECS = 0.2
35+
36+
37+
async def slow_streaming_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]:
38+
"""
39+
Handler that produces messages from a background thread with a delay
40+
between each, and yields them via a tight for-loop with NO await.
41+
This is the pattern from issue #342.
42+
"""
43+
messages: deque[Message] = deque()
44+
45+
def _produce():
46+
for i in range(NUM_MESSAGES):
47+
messages.append(Message(f"msg-{i}".encode(), keys=keys))
48+
time.sleep(PRODUCE_INTERVAL_SECS)
49+
50+
thread = threading.Thread(target=_produce)
51+
thread.start()
52+
53+
while thread.is_alive():
54+
# Tight loop: regular for, no await — the pattern that triggers #342
55+
while messages:
56+
yield messages.popleft()
57+
58+
thread.join()
59+
while messages:
60+
yield messages.popleft()
61+
62+
63+
async def _start_server(udfs):
64+
server = grpc.aio.server()
65+
map_pb2_grpc.add_MapServicer_to_server(udfs, server)
66+
server.add_insecure_port(SOCK_PATH)
67+
logging.info("Starting server on %s", SOCK_PATH)
68+
await server.start()
69+
return server, SOCK_PATH
70+
71+
72+
@pytest.fixture(scope="module")
73+
def streaming_server():
74+
loop = create_async_loop()
75+
server_obj = MapStreamAsyncServer(map_stream_instance=slow_streaming_handler)
76+
udfs = server_obj.servicer
77+
server = start_async_server(loop, _start_server(udfs))
78+
yield loop
79+
teardown_async_server(loop, server)
80+
81+
82+
@pytest.fixture()
83+
def streaming_stub(streaming_server):
84+
return map_pb2_grpc.MapStub(grpc.insecure_channel(SOCK_PATH))
85+
86+
87+
def test_messages_stream_incrementally(streaming_stub):
88+
"""
89+
Verify that messages are streamed to the client as they are produced,
90+
not batched until the handler completes.
91+
92+
The handler produces NUM_MESSAGES messages with PRODUCE_INTERVAL_SECS between
93+
each. If streaming works, the first message should arrive well before the
94+
last one is produced (total production time = NUM_MESSAGES * PRODUCE_INTERVAL_SECS).
95+
"""
96+
generator_response = streaming_stub.MapFn(
97+
request_iterator=request_generator(count=1, session=1)
98+
)
99+
100+
# Consume handshake
101+
handshake = next(generator_response)
102+
assert handshake.handshake.sot
103+
104+
# Collect messages with their arrival timestamps
105+
arrival_times = []
106+
result_count = 0
107+
for msg in generator_response:
108+
if hasattr(msg, "status") and msg.status.eot:
109+
continue
110+
arrival_times.append(time.monotonic())
111+
result_count += 1
112+
113+
assert result_count == NUM_MESSAGES, f"Expected {NUM_MESSAGES} messages, got {result_count}"
114+
115+
# If messages streamed incrementally, the time span between the first and
116+
# last arrival should be a significant portion of the total production time.
117+
# If they were batched, they'd all arrive within a few milliseconds of each other.
118+
total_production_time = NUM_MESSAGES * PRODUCE_INTERVAL_SECS
119+
first_to_last = arrival_times[-1] - arrival_times[0]
120+
121+
# The spread should be at least 40% of production time if streaming works.
122+
# If batched, the spread would be near zero (~1-5ms).
123+
min_expected_spread = total_production_time * 0.4
124+
assert first_to_last >= min_expected_spread, (
125+
f"Messages arrived too close together ({first_to_last:.3f}s spread), "
126+
f"expected at least {min_expected_spread:.3f}s. "
127+
f"This indicates messages were batched instead of streamed. "
128+
f"Arrival gaps: {[f'{b - a:.3f}s' for a, b in zip(arrival_times, arrival_times[1:])]}"
129+
)

0 commit comments

Comments
 (0)