|
17 | 17 | ack_req_source_fn, |
18 | 18 | mock_partitions, |
19 | 19 | AsyncSource, |
| 20 | + AsyncSourceWithTotalPartitions, |
20 | 21 | mock_offset, |
21 | 22 | nack_req_source_fn, |
22 | 23 | ) |
@@ -194,6 +195,66 @@ def test_partitions(async_source_server) -> None: |
194 | 195 | assert response.result.partitions == mock_partitions() |
195 | 196 |
|
196 | 197 |
|
| 198 | +def test_partitions_default_total_partitions_is_none(async_source_server) -> None: |
| 199 | + """ |
| 200 | + Verify total_partitions is not set when the source doesn't override |
| 201 | + total_partitions_handler. |
| 202 | + """ |
| 203 | + with grpc.insecure_channel(server_port) as channel: |
| 204 | + stub = source_pb2_grpc.SourceStub(channel) |
| 205 | + request = _empty_pb2.Empty() |
| 206 | + response = stub.PartitionsFn(request=request) |
| 207 | + |
| 208 | + assert response.result.partitions == mock_partitions() |
| 209 | + assert not response.result.HasField("total_partitions") |
| 210 | + |
| 211 | + |
| 212 | +server_port_tp = "unix:///tmp/async_source_tp.sock" |
| 213 | + |
| 214 | + |
| 215 | +def NewAsyncSourcerWithTotalPartitions(): |
| 216 | + class_instance = AsyncSourceWithTotalPartitions() |
| 217 | + server = SourceAsyncServer(sourcer_instance=class_instance) |
| 218 | + udfs = server.servicer |
| 219 | + return udfs |
| 220 | + |
| 221 | + |
| 222 | +async def start_server_tp(udfs): |
| 223 | + server = grpc.aio.server() |
| 224 | + source_pb2_grpc.add_SourceServicer_to_server(udfs, server) |
| 225 | + listen_addr = server_port_tp |
| 226 | + server.add_insecure_port(listen_addr) |
| 227 | + logging.info("Starting server on %s", listen_addr) |
| 228 | + await server.start() |
| 229 | + return server, listen_addr |
| 230 | + |
| 231 | + |
| 232 | +@pytest.fixture(scope="module") |
| 233 | +def async_source_server_with_total_partitions(): |
| 234 | + """Module-scoped fixture: starts an async gRPC source server with total partitions.""" |
| 235 | + loop = create_async_loop() |
| 236 | + |
| 237 | + udfs = NewAsyncSourcerWithTotalPartitions() |
| 238 | + server = start_async_server(loop, start_server_tp(udfs)) |
| 239 | + |
| 240 | + yield loop |
| 241 | + |
| 242 | + teardown_async_server(loop, server) |
| 243 | + |
| 244 | + |
| 245 | +def test_partitions_with_total_partitions(async_source_server_with_total_partitions) -> None: |
| 246 | + """ |
| 247 | + Verify total_partitions flows through gRPC when the source implements total_partitions_handler. |
| 248 | + """ |
| 249 | + with grpc.insecure_channel(server_port_tp) as channel: |
| 250 | + stub = source_pb2_grpc.SourceStub(channel) |
| 251 | + request = _empty_pb2.Empty() |
| 252 | + response = stub.PartitionsFn(request=request) |
| 253 | + |
| 254 | + assert response.result.partitions == mock_partitions() |
| 255 | + assert response.result.total_partitions == 10 |
| 256 | + |
| 257 | + |
197 | 258 | @pytest.mark.parametrize( |
198 | 259 | "max_threads_arg,expected", |
199 | 260 | [ |
|
0 commit comments