-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy path_async_servicer.py
More file actions
150 lines (136 loc) · 6.23 KB
/
_async_servicer.py
File metadata and controls
150 lines (136 loc) · 6.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import asyncio
import contextlib
from collections.abc import AsyncIterable
from google.protobuf import empty_pb2 as _empty_pb2
from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.server import handle_async_error
from pynumaflow.types import NumaflowServicerContext
class AsyncMapServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Async Map Servicer instance.
It implements the SyncMapServicer interface from the proto map.proto file.
Provides the functionality for the required rpc methods.
"""
def __init__(self, handler: MapAsyncCallable, multiproc: bool = False):
self.background_tasks = set()
# This indicates whether the grpc server attached is multiproc or not
self.multiproc = multiproc
self.__map_handler: MapAsyncCallable = handler
async def MapFn(
self,
request_iterator: AsyncIterable[map_pb2.MapRequest],
context: NumaflowServicerContext,
) -> AsyncIterable[map_pb2.MapResponse]:
"""
Applies a function to each datum element.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
# proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer
# we need to explicitly convert it to list
producer = None
try:
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise MapError("MapFn: expected handshake as the first message")
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))
global_result_queue = NonBlockingIterator()
# reader task to process the input task and invoke the required tasks
producer = asyncio.create_task(
self._process_inputs(request_iterator, global_result_queue)
)
# keep reading on result queue and send messages back
consumer = global_result_queue.read_iterator()
async for msg in consumer:
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING, self.multiproc)
return
# Send window response back to the client
else:
yield msg
# wait for the producer task to complete
await producer
except GeneratorExit:
_LOGGER.info("Client disconnected, generator closed.")
raise
except BaseException as e:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, self.multiproc)
return
finally:
if producer and not producer.done():
producer.cancel()
with contextlib.suppress(asyncio.CancelledError):
await producer
async def _process_inputs(self, request_iterator, result_queue):
try:
async for req in request_iterator:
task = asyncio.create_task(self._invoke_map(req, result_queue))
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)
await asyncio.gather(*self.background_tasks)
except BaseException:
_LOGGER.critical("MapFn Error in _process_inputs", exc_info=True)
finally:
await result_queue.put(STREAM_EOF)
# async def _process_inputs(
# self,
# request_iterator: AsyncIterable[map_pb2.MapRequest],
# result_queue: NonBlockingIterator,
# ):
# """
# Utility function for processing incoming MapRequests
# """
# try:
# # for each incoming request, create a background task to execute the
# # UDF code
# async for req in request_iterator:
# msg_task = asyncio.create_task(self._invoke_map(req, result_queue))
# # save a reference to a set to store active tasks
# self.background_tasks.add(msg_task)
# msg_task.add_done_callback(self.background_tasks.discard)
#
# # wait for all tasks to complete
# for task in self.background_tasks:
# await task
#
# # send an EOF to result queue to indicate that all tasks have completed
# await result_queue.put(STREAM_EOF)
#
# except BaseException:
# _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
async def _invoke_map(self, req: map_pb2.MapRequest, result_queue: NonBlockingIterator):
"""
Invokes the user defined function.
"""
try:
datum = Datum(
keys=list(req.request.keys),
value=req.request.value,
event_time=req.request.event_time.ToDatetime(),
watermark=req.request.watermark.ToDatetime(),
headers=dict(req.request.headers),
)
msgs = await self.__map_handler(list(req.request.keys), datum)
datums = []
for msg in msgs:
datums.append(
map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
)
await result_queue.put(map_pb2.MapResponse(results=datums, id=req.id))
except BaseException as err:
_LOGGER.critical("MapFn handler error", exc_info=True)
await result_queue.put(err)
async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
return map_pb2.ReadyResponse(ready=True)