From f63f06e34a09a2009f6c46ab120ff93db343c129 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 6 May 2026 08:25:41 -0700 Subject: [PATCH 1/2] Refactor coordinator test fixtures; prefer MockBroker to patched_coord --- test/consumer/test_coordinator.py | 300 +++++++++++++++++------------- 1 file changed, 173 insertions(+), 127 deletions(-) diff --git a/test/consumer/test_coordinator.py b/test/consumer/test_coordinator.py index 9f9cc50ca..341932fb1 100644 --- a/test/consumer/test_coordinator.py +++ b/test/consumer/test_coordinator.py @@ -10,11 +10,10 @@ from kafka.coordinator.assignors.range import RangePartitionAssignor from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor -from kafka.coordinator.base import Generation, MemberState, HeartbeatThread +from kafka.coordinator.base import Generation, MemberState from kafka.coordinator.consumer import ConsumerCoordinator import kafka.errors as Errors from kafka.future import Future -from kafka.protocol.broker_version_data import BrokerVersionData from kafka.protocol.consumer import ( OffsetCommitRequest, OffsetCommitResponse, OffsetFetchRequest, OffsetFetchResponse, @@ -26,12 +25,15 @@ @pytest.fixture -def coordinator(client, metrics, mocker): - coord = ConsumerCoordinator(client, SubscriptionState(), metrics=metrics) +def coordinator(broker, client, metrics): + coord = ConsumerCoordinator(client, SubscriptionState(), + metrics=metrics, + api_version=broker.broker_version, + max_poll_interval_ms=300000 if broker.broker_version >= (0, 10, 1) else 10000, + session_timeout_ms=10000) try: yield coord finally: - mocker.patch.object(coord, 'coordinator_unknown', return_value=True) # avoid attempting to leave group during close() coord.close(timeout_ms=0) @@ -41,21 +43,12 @@ def test_init(client, coordinator): assert WeakMethod(coordinator._handle_metadata_update) in client.cluster._listeners -@pytest.mark.parametrize("api_version", [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)]) -def test_autocommit_enable_api_version(client, metrics, api_version): - coordinator = ConsumerCoordinator(client, - SubscriptionState(), - metrics=metrics, - enable_auto_commit=True, - session_timeout_ms=30000, # session_timeout_ms and max_poll_interval_ms - max_poll_interval_ms=30000, # should be the same to avoid KafkaConfigurationError - group_id='foobar', - api_version=api_version) - if api_version < (0, 8, 1): +@pytest.mark.parametrize("broker", [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)], indirect=True) +def test_autocommit_enable_api_version(broker, coordinator): + if broker.broker_version < (0, 8, 1): assert coordinator.config['enable_auto_commit'] is False else: assert coordinator.config['enable_auto_commit'] is True - coordinator.close() def test_protocol_type(coordinator): @@ -88,14 +81,8 @@ def test_group_protocols(coordinator): ] -@pytest.mark.parametrize('api_version', [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)]) -def test_pattern_subscription(client, metrics, api_version): - coordinator = ConsumerCoordinator(client, - SubscriptionState(), - metrics=metrics, - api_version=api_version, - session_timeout_ms=10000, - max_poll_interval_ms=10000) +@pytest.mark.parametrize("broker", [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)], indirect=True) +def test_pattern_subscription(broker, coordinator): coordinator._subscription.subscribe(pattern='foo') assert coordinator._subscription.subscription == set([]) assert coordinator._metadata_snapshot == coordinator._build_metadata_snapshot(coordinator._subscription, {}) @@ -113,7 +100,7 @@ def test_pattern_subscription(client, metrics, api_version): assert coordinator._subscription.subscription == {'foo1', 'foo2'} # 0.9 consumers should trigger dynamic partition assignment - if api_version >= (0, 9): + if broker.broker_version >= (0, 9): assert coordinator._subscription.assignment == {} # earlier consumers get all partitions assigned locally @@ -405,65 +392,79 @@ def test_maybe_auto_commit_offsets_sync(mocker, client, api_version, group_id, e @pytest.fixture -def patched_coord(mocker, coordinator): +def seeded_coord(broker, coordinator): + """A coordinator wired to a bootstrapped MockBroker with state seeded + so _send_offset_*_request can dispatch a real wire request to node 0. + """ + coordinator._client._manager.bootstrap(timeout_ms=5000) coordinator._subscription.subscribe(topics=['foobar']) - mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) coordinator.coordinator_id = 0 - mocker.patch.object(coordinator, 'coordinator', return_value=0) coordinator._generation = Generation(0, 'foobar', b'') coordinator.state = MemberState.STABLE coordinator.rejoin_needed = False - mocker.patch.object(coordinator, 'need_rejoin', return_value=False) - mocker.patch.object(coordinator._client, 'least_loaded_node', - return_value=1) - mocker.patch.object(coordinator._client, 'ready', return_value=True) + return coordinator + + +@pytest.fixture +def patched_coord(mocker, seeded_coord): + """Minimal mock-send fixture for transport-failure tests (Category D). + The two _failure tests exercise the _failed_request errback path, which + requires control over the underlying send Future. Other request-shape + tests have moved to the broker round-trip via ``seeded_coord``. + """ send_future = Future() - mocker.patch.object(coordinator._client, 'send', return_value=send_future) - mocker.patch.object(coordinator._client._manager, 'send', return_value=send_future) - mocker.patch.object(coordinator, '_heartbeat_thread') - mocker.spy(coordinator, '_failed_request') - mocker.spy(coordinator, '_handle_offset_commit_response') - mocker.spy(coordinator, '_handle_offset_fetch_response') + mocker.patch.object(seeded_coord._client, 'send', return_value=send_future) + mocker.patch.object(seeded_coord._client._manager, 'send', return_value=send_future) + mocker.spy(seeded_coord, '_failed_request') try: - yield coordinator + yield seeded_coord finally: - send_future.failure(Errors.KafkaConnectionError()) - coordinator.close() + if not send_future.is_done: + send_future.failure(Errors.KafkaConnectionError()) -def test_send_offset_commit_request_fail(mocker, patched_coord, offsets): - patched_coord.coordinator_unknown.return_value = True - patched_coord.coordinator_id = None - patched_coord.coordinator.return_value = None +def test_send_offset_commit_request_fail(coordinator, offsets): + # Default coordinator state has coordinator_id=None, so coordinator() + # returns None and the early-return paths fire without any patching. # No offsets - ret = patched_coord._send_offset_commit_request({}) + ret = coordinator._send_offset_commit_request({}) assert isinstance(ret, Future) assert ret.succeeded() # No coordinator - ret = patched_coord._send_offset_commit_request(offsets) + ret = coordinator._send_offset_commit_request(offsets) assert ret.failed() assert isinstance(ret.exception, Errors.CoordinatorNotAvailableError) -@pytest.mark.parametrize('api_version,version', [ +@pytest.mark.parametrize('broker,version', [ ((0, 8, 1), 0), ((0, 8, 2), 1), ((0, 9), 2), ((0, 11), 3), ((2, 0), 4), ((2, 1), 6), -]) -def test_send_offset_commit_request_versions(patched_coord, offsets, - api_version, version): - expect_node = 0 - patched_coord._client._manager.broker_version_data = BrokerVersionData(api_version) - - patched_coord._send_offset_commit_request(offsets) - (node, request), _ = patched_coord._client.send.call_args - assert node == expect_node, 'Unexpected coordinator node' - assert request.API_VERSION == version +], indirect=['broker']) +def test_send_offset_commit_request_versions(broker, seeded_coord, offsets, version): + captured = {} + _Topic = OffsetCommitResponse.OffsetCommitResponseTopic + _Partition = _Topic.OffsetCommitResponsePartition + + def handler(api_key, api_version, correlation_id, request_bytes): + captured['api_version'] = api_version + return OffsetCommitResponse( + throttle_time_ms=0, + topics=[_Topic(name='foobar', partitions=[ + _Partition(partition_index=0, error_code=0), + _Partition(partition_index=1, error_code=0), + ])]) + + broker.respond_fn(OffsetCommitRequest, handler) + future = seeded_coord._send_offset_commit_request(offsets) + seeded_coord._client.poll(future=future, timeout_ms=5000) + assert future.succeeded() + assert captured['api_version'] == version def test_send_offset_commit_request_failure(patched_coord, offsets): @@ -478,15 +479,26 @@ def test_send_offset_commit_request_failure(patched_coord, offsets): assert future.exception is error -def test_send_offset_commit_request_success(mocker, patched_coord, offsets): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_commit_request(offsets) - (node, request), _ = patched_coord._client.send.call_args - response = OffsetCommitResponse[0]([('foobar', [(0, 0), (1, 0)])]) - _f.success(response) - patched_coord._handle_offset_commit_response.assert_called_with( - offsets, future, mocker.ANY, response) +def test_send_offset_commit_request_success(mocker, broker, seeded_coord, offsets): + _Topic = OffsetCommitResponse.OffsetCommitResponseTopic + _Partition = _Topic.OffsetCommitResponsePartition + broker.respond(OffsetCommitRequest, OffsetCommitResponse( + throttle_time_ms=0, + topics=[_Topic(name='foobar', partitions=[ + _Partition(partition_index=0, error_code=0), + _Partition(partition_index=1, error_code=0), + ])])) + spy = mocker.spy(seeded_coord, '_handle_offset_commit_response') + + future = seeded_coord._send_offset_commit_request(offsets) + seeded_coord._client.poll(future=future, timeout_ms=5000) + + assert future.succeeded() + assert spy.call_count == 1 + call_offsets, call_future, _send_time, response = spy.call_args[0] + assert call_offsets == offsets + assert call_future is future + assert isinstance(response, OffsetCommitResponse) @pytest.mark.parametrize('response,error,dead', [ @@ -529,13 +541,12 @@ def test_send_offset_commit_request_success(mocker, patched_coord, offsets): (OffsetCommitResponse[6](0, [('foobar', [(0, 0), (1, 0)])]), None, False), ]) -def test_handle_offset_commit_response(mocker, patched_coord, offsets, - response, error, dead): +def test_handle_offset_commit_response(coordinator, offsets, response, error, dead): + coordinator.coordinator_id = 0 future = Future() - patched_coord._handle_offset_commit_response(offsets, future, time.monotonic(), - response) + coordinator._handle_offset_commit_response(offsets, future, time.monotonic(), response) assert isinstance(future.exception, error) if error else True - assert patched_coord.coordinator_id is (None if dead else 0) + assert coordinator.coordinator_id is (None if dead else 0) @pytest.fixture @@ -543,24 +554,22 @@ def partitions(): return [TopicPartition('foobar', 0), TopicPartition('foobar', 1)] -def test_send_offset_fetch_request_fail(mocker, patched_coord, partitions): - patched_coord.coordinator_unknown.return_value = True - patched_coord.coordinator_id = None - patched_coord.coordinator.return_value = None +def test_send_offset_fetch_request_fail(coordinator, partitions): + # Default coordinator state has coordinator_id=None. # No partitions - ret = patched_coord._send_offset_fetch_request([]) + ret = coordinator._send_offset_fetch_request([]) assert isinstance(ret, Future) assert ret.succeeded() assert ret.value == {} # No coordinator - ret = patched_coord._send_offset_fetch_request(partitions) + ret = coordinator._send_offset_fetch_request(partitions) assert ret.failed() assert isinstance(ret.exception, Errors.CoordinatorNotAvailableError) -@pytest.mark.parametrize('api_version,version', [ +@pytest.mark.parametrize('broker,version', [ ((0, 8, 1), 0), ((0, 8, 2), 1), ((0, 9), 1), @@ -568,17 +577,29 @@ def test_send_offset_fetch_request_fail(mocker, patched_coord, partitions): ((0, 11), 3), ((2, 0), 4), ((2, 1), 5), -]) -def test_send_offset_fetch_request_versions(patched_coord, partitions, - api_version, version): - # assuming fixture sets coordinator=0, least_loaded_node=1 - expect_node = 0 - patched_coord._client._manager.broker_version_data = BrokerVersionData(api_version) - - patched_coord._send_offset_fetch_request(partitions) - (node, request), _ = patched_coord._client.send.call_args - assert node == expect_node, 'Unexpected coordinator node' - assert request.API_VERSION == version +], indirect=['broker']) +def test_send_offset_fetch_request_versions(broker, seeded_coord, partitions, version): + captured = {} + _Topic = OffsetFetchResponse.OffsetFetchResponseTopic + _Partition = _Topic.OffsetFetchResponsePartition + + def handler(api_key, api_version, correlation_id, request_bytes): + captured['api_version'] = api_version + return OffsetFetchResponse( + throttle_time_ms=0, + error_code=0, + topics=[_Topic(name='foobar', partitions=[ + _Partition(partition_index=0, committed_offset=123, + committed_leader_epoch=-1, metadata='', error_code=0), + _Partition(partition_index=1, committed_offset=234, + committed_leader_epoch=-1, metadata='', error_code=0), + ])]) + + broker.respond_fn(OffsetFetchRequest, handler) + future = seeded_coord._send_offset_fetch_request(partitions) + seeded_coord._client.poll(future=future, timeout_ms=5000) + assert future.succeeded() + assert captured['api_version'] == version def test_send_offset_fetch_request_failure(patched_coord, partitions): @@ -593,15 +614,29 @@ def test_send_offset_fetch_request_failure(patched_coord, partitions): assert future.exception is error -def test_send_offset_fetch_request_success(patched_coord, partitions): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_fetch_request(partitions) - (node, request), _ = patched_coord._client.send.call_args - response = OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 0), (1, 234, b'', 0)])]) - _f.success(response) - patched_coord._handle_offset_fetch_response.assert_called_with( - future, response) +def test_send_offset_fetch_request_success(mocker, broker, seeded_coord, partitions, offsets): + _Topic = OffsetFetchResponse.OffsetFetchResponseTopic + _Partition = _Topic.OffsetFetchResponsePartition + broker.respond(OffsetFetchRequest, OffsetFetchResponse( + throttle_time_ms=0, + error_code=0, + topics=[_Topic(name='foobar', partitions=[ + _Partition(partition_index=0, committed_offset=123, + committed_leader_epoch=-1, metadata='', error_code=0), + _Partition(partition_index=1, committed_offset=234, + committed_leader_epoch=-1, metadata='', error_code=0), + ])])) + spy = mocker.spy(seeded_coord, '_handle_offset_fetch_response') + + future = seeded_coord._send_offset_fetch_request(partitions) + seeded_coord._client.poll(future=future, timeout_ms=5000) + + assert future.succeeded() + assert future.value == offsets + assert spy.call_count == 1 + call_future, response = spy.call_args[0] + assert call_future is future + assert isinstance(response, OffsetFetchResponse) @pytest.mark.parametrize('response,error,dead', [ @@ -628,55 +663,66 @@ def test_send_offset_fetch_request_success(patched_coord, partitions): (OffsetFetchResponse[5](0, [('foobar', [(0, 123, -1, '', 0), (1, 234, -1, '', 0)])], 0), None, False), ]) -def test_handle_offset_fetch_response(patched_coord, offsets, - response, error, dead): +def test_handle_offset_fetch_response(coordinator, offsets, response, error, dead): + coordinator.coordinator_id = 0 future = Future() - patched_coord._handle_offset_fetch_response(future, response) + coordinator._handle_offset_fetch_response(future, response) if error is not None: assert isinstance(future.exception, error) else: assert future.succeeded() assert future.value == offsets - assert patched_coord.coordinator_id is (None if dead else 0) + assert coordinator.coordinator_id is (None if dead else 0) -def test_heartbeat(mocker, patched_coord): - heartbeat = patched_coord.heartbeat - net = patched_coord._manager._net +def test_heartbeat(mocker, coordinator): + coordinator.coordinator_id = 0 + coordinator.state = MemberState.STABLE + net = coordinator._manager._net - assert not patched_coord._heartbeat_enabled and not patched_coord._heartbeat_closed + assert not coordinator._heartbeat_enabled and not coordinator._heartbeat_closed - assert patched_coord._heartbeat_loop_future is None - patched_coord._maybe_start_heartbeat_loop() - assert patched_coord._heartbeat_loop_future is not None + assert coordinator._heartbeat_loop_future is None + coordinator._maybe_start_heartbeat_loop() + assert coordinator._heartbeat_loop_future is not None - patched_coord._enable_heartbeat() - assert patched_coord._heartbeat_enabled + coordinator._enable_heartbeat() + assert coordinator._heartbeat_enabled - patched_coord._disable_heartbeat() - assert not patched_coord._heartbeat_enabled + coordinator._disable_heartbeat() + assert not coordinator._heartbeat_enabled # heartbeat disables when un-joined - patched_coord._enable_heartbeat() - patched_coord.state = MemberState.UNJOINED + coordinator._enable_heartbeat() + coordinator.state = MemberState.UNJOINED net.poll(timeout_ms=50) - assert not patched_coord._heartbeat_enabled + assert not coordinator._heartbeat_enabled - patched_coord._enable_heartbeat() - patched_coord.state = MemberState.STABLE - mocker.spy(patched_coord, '_send_heartbeat_request') - mocker.patch.object(patched_coord.heartbeat, 'should_heartbeat', return_value=True) + coordinator._enable_heartbeat() + coordinator.state = MemberState.STABLE + # Replace _send_heartbeat_request with a stub returning a Future we control, + # so the heartbeat coroutine reaches the dispatch and blocks there. The + # Mock's call_count then verifies the loop fired exactly once. + blocked_send = Future() + mocker.patch.object(coordinator, '_send_heartbeat_request', return_value=blocked_send) + mocker.patch.object(coordinator.heartbeat, 'should_heartbeat', return_value=True) # Wakeup callback resolves the future on one poll cycle; the heartbeat # coroutine resumes and reaches _send_heartbeat_request on the next. deadline = time.monotonic() + 0.5 while time.monotonic() < deadline: net.poll(timeout_ms=10) - if patched_coord._send_heartbeat_request.call_count > 0: + if coordinator._send_heartbeat_request.call_count > 0: break - assert patched_coord._send_heartbeat_request.call_count == 1 - - patched_coord._close_heartbeat() - assert patched_coord._heartbeat_closed + assert coordinator._send_heartbeat_request.call_count == 1 + + coordinator._close_heartbeat() + assert coordinator._heartbeat_closed + # Unblock the suspended heartbeat coroutine and let it observe + # _heartbeat_closed=True so it exits cleanly. Otherwise the task lingers + # in NetworkSelector._pending_tasks and is later GC-closed, raising + # GeneratorExit through the loop's BaseException handler. + blocked_send.failure(Errors.KafkaConnectionError()) + net.poll(timeout_ms=50) def test_lookup_coordinator_failure(mocker, coordinator): From ee2dad9f5d02b341d567dab95f60d3a3a892825a Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 6 May 2026 08:43:55 -0700 Subject: [PATCH 2/2] MockBroker fail_next; drop patched_coord --- test/consumer/test_coordinator.py | 62 ++++++++++++++----------------- test/mock_broker.py | 46 +++++++++++++++++++++-- test/test_mock_broker.py | 25 +++++++++++++ 3 files changed, 95 insertions(+), 38 deletions(-) diff --git a/test/consumer/test_coordinator.py b/test/consumer/test_coordinator.py index 341932fb1..22051a743 100644 --- a/test/consumer/test_coordinator.py +++ b/test/consumer/test_coordinator.py @@ -405,24 +405,6 @@ def seeded_coord(broker, coordinator): return coordinator -@pytest.fixture -def patched_coord(mocker, seeded_coord): - """Minimal mock-send fixture for transport-failure tests (Category D). - The two _failure tests exercise the _failed_request errback path, which - requires control over the underlying send Future. Other request-shape - tests have moved to the broker round-trip via ``seeded_coord``. - """ - send_future = Future() - mocker.patch.object(seeded_coord._client, 'send', return_value=send_future) - mocker.patch.object(seeded_coord._client._manager, 'send', return_value=send_future) - mocker.spy(seeded_coord, '_failed_request') - try: - yield seeded_coord - finally: - if not send_future.is_done: - send_future.failure(Errors.KafkaConnectionError()) - - def test_send_offset_commit_request_fail(coordinator, offsets): # Default coordinator state has coordinator_id=None, so coordinator() # returns None and the early-return paths fire without any patching. @@ -467,16 +449,22 @@ def handler(api_key, api_version, correlation_id, request_bytes): assert captured['api_version'] == version -def test_send_offset_commit_request_failure(patched_coord, offsets): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_commit_request(offsets) - (node, request), _ = patched_coord._client.send.call_args - error = Exception() - _f.failure(error) - patched_coord._failed_request.assert_called_with(0, request, future, error) +def test_send_offset_commit_request_failure(mocker, broker, seeded_coord, offsets): + spy = mocker.spy(seeded_coord, '_failed_request') + error = Errors.KafkaConnectionError('simulated transport failure') + broker.fail_next(OffsetCommitRequest, error=error) + + future = seeded_coord._send_offset_commit_request(offsets) + seeded_coord._client.poll(future=future, timeout_ms=5000) + assert future.failed() assert future.exception is error + assert spy.call_count == 1 + node_id, request, call_future, call_error = spy.call_args[0] + assert node_id == 0 + assert isinstance(request, OffsetCommitRequest) + assert call_future is future + assert call_error is error def test_send_offset_commit_request_success(mocker, broker, seeded_coord, offsets): @@ -602,16 +590,22 @@ def handler(api_key, api_version, correlation_id, request_bytes): assert captured['api_version'] == version -def test_send_offset_fetch_request_failure(patched_coord, partitions): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_fetch_request(partitions) - (node, request), _ = patched_coord._client.send.call_args - error = Exception() - _f.failure(error) - patched_coord._failed_request.assert_called_with(0, request, future, error) +def test_send_offset_fetch_request_failure(mocker, broker, seeded_coord, partitions): + spy = mocker.spy(seeded_coord, '_failed_request') + error = Errors.KafkaConnectionError('simulated transport failure') + broker.fail_next(OffsetFetchRequest, error=error) + + future = seeded_coord._send_offset_fetch_request(partitions) + seeded_coord._client.poll(future=future, timeout_ms=5000) + assert future.failed() assert future.exception is error + assert spy.call_count == 1 + node_id, request, call_future, call_error = spy.call_args[0] + assert node_id == 0 + assert isinstance(request, OffsetFetchRequest) + assert call_future is future + assert call_error is error def test_send_offset_fetch_request_success(mocker, broker, seeded_coord, partitions, offsets): diff --git a/test/mock_broker.py b/test/mock_broker.py index 5a26a66c3..901821825 100644 --- a/test/mock_broker.py +++ b/test/mock_broker.py @@ -25,9 +25,21 @@ import struct import time +import kafka.errors as Errors from kafka.protocol.broker_version_data import BrokerVersionData from kafka.protocol.metadata import ApiVersionsRequest, ApiVersionsResponse, MetadataRequest, MetadataResponse + +class _MockBrokerFailure: + """Sentinel returned by MockBroker.handle_request to signal that the + MockTransport should abort the connection with the given error, simulating + a transport-level failure (TCP disconnect, broker crash mid-request, etc.). + """ + __slots__ = ('error',) + + def __init__(self, error): + self.error = error + log = logging.getLogger(__name__) @@ -155,12 +167,17 @@ async def _process_requests(self): log.debug('%s: Request api_key=%d version=%d correlation_id=%d', self, api_key, api_version, correlation_id) - response_bytes = await self._broker.handle_request( + result = await self._broker.handle_request( api_key, api_version, correlation_id, request_bytes) - if response_bytes is not None and self._protocol and not self._closed: + if isinstance(result, _MockBrokerFailure): + log.debug('%s: simulating transport failure: %s', self, result.error) + self.abort(result.error) + return + + if result is not None and self._protocol and not self._closed: self.last_read = time.monotonic() - self._protocol.data_received(response_bytes) + self._protocol.data_received(result) class MockBroker: @@ -262,6 +279,24 @@ def respond_fn(self, request_class, fn): """ self._response_queue.append((request_class.API_KEY, fn)) + def fail_next(self, request_class, error=None): + """Enqueue a transport failure for the next request of the given type. + + When the matching request arrives, the MockTransport aborts the + connection with ``error`` instead of returning a response, simulating + a transport-level send failure (TCP disconnect, broker crash + mid-request, etc.). The pending request's Future then fails via the + connection's ``connection_lost`` -> ``fail_in_flight_requests`` path. + + Arguments: + request_class: The request class for API key matching. + error: Exception delivered to ``transport.abort()``. Defaults to + ``Errors.KafkaConnectionError``. + """ + if error is None: + error = Errors.KafkaConnectionError('MockBroker.fail_next') + self._response_queue.append((request_class.API_KEY, _MockBrokerFailure(error))) + async def handle_request(self, api_key, api_version, correlation_id, request_bytes): """Process a request and return framed response bytes. @@ -272,7 +307,8 @@ async def handle_request(self, api_key, api_version, correlation_id, request_byt Returns: bytes: Framed response ready for ``protocol.data_received()``, - or None if the request expects no response. + or None if the request expects no response, or + ``_MockBrokerFailure`` to signal that the transport should abort. """ self.requests_received += 1 @@ -280,6 +316,8 @@ async def handle_request(self, api_key, api_version, correlation_id, request_byt for i, (queued_key, queued_response) in enumerate(self._response_queue): if queued_key == api_key: del self._response_queue[i] + if isinstance(queued_response, _MockBrokerFailure): + return queued_response if callable(queued_response): response = queued_response(api_key, api_version, correlation_id, request_bytes) # Support both sync and async respond_fn callables diff --git a/test/test_mock_broker.py b/test/test_mock_broker.py index ca32011e9..8c38de302 100644 --- a/test/test_mock_broker.py +++ b/test/test_mock_broker.py @@ -237,6 +237,31 @@ def test_send_and_receive(self): finally: client.close() + def test_fail_next_aborts_request(self): + """fail_next aborts the connection and fails the in-flight request.""" + import kafka.errors as Errors + + broker = MockBroker() + client = self._make_client(broker) + try: + client.check_version(timeout_ms=5000) + node_id = client.least_loaded_node(bootstrap_fallback=True) + client.await_ready(node_id, timeout_ms=5000) + + error = Errors.KafkaConnectionError('simulated') + broker.fail_next(MetadataRequest, error=error) + + version = client.api_version(MetadataRequest, max_version=8) + future = client.send(node_id, MetadataRequest[version]()) + _poll_for_future(client, future) + + assert future.failed() + assert future.exception is error + # Connection was aborted, so it should no longer be ready. + assert not client.is_ready(node_id) + finally: + client.close() + def test_api_version_negotiation(self): """Client negotiates ApiVersions when api_version is not pre-set.""" broker = MockBroker()