diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index 9fdaabd1a177..53688e0be950 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -94,8 +94,13 @@ def start(self): def stop(self): with self._lock: if self._started: - self._job_server.stop() - self._started = False + try: + self._job_server.stop() + finally: + self._started = False + # Unregister the atexit handler to prevent duplicate + # registrations when the server is restarted/reused. + atexit.unregister(self.stop) class SubprocessJobServer(JobServer): diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index 988bd680b923..5752a49dde2e 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -88,11 +88,11 @@ def register(self): return owner def purge(self, owner): - if owner not in self._live_owners: - raise ValueError(f"{owner} not in {self._live_owners}") - self._live_owners.remove(owner) to_delete = [] with self._lock: + if owner not in self._live_owners: + raise ValueError(f"{owner} not in {self._live_owners}") + self._live_owners.remove(owner) for key, entry in list(self._cache.items()): if owner in entry.owners: entry.owners.remove(owner) @@ -255,15 +255,17 @@ def stop(self): def stop_process(self): if self._owner_id is not None: - self._cache.purge(self._owner_id) - self._owner_id = None + try: + self._cache.purge(self._owner_id) + finally: + # Make sure _owner_id is set to None even if purge fails. + self._owner_id = None if self._grpc_channel: try: self._grpc_channel.close() except: # pylint: disable=bare-except _LOGGER.error( - "Could not close the gRPC channel started for the " - "expansion service") + "Could not close the gRPC channel started with cmd %s", self._cmd) finally: self._grpc_channel = None diff --git a/sdks/python/apache_beam/utils/subprocess_server_test.py b/sdks/python/apache_beam/utils/subprocess_server_test.py index c848595db355..0f25d9904f07 100644 --- a/sdks/python/apache_beam/utils/subprocess_server_test.py +++ b/sdks/python/apache_beam/utils/subprocess_server_test.py @@ -19,6 +19,7 @@ # pytype: skip-file +import atexit import glob import os import random @@ -29,7 +30,9 @@ import tempfile import threading import unittest +from unittest.mock import patch +from apache_beam.runners.portability import job_server from apache_beam.utils import subprocess_server @@ -302,6 +305,127 @@ def test_interleaved_owners(self): self.assertNotEqual(cache.get('b'), b) cache.purge(owner3) + def test_destructor_exception_partial_state(self): + # In SubprocessServer.stop_process(), we need to make sure self._owner_id is always + # set to None if it is not already set, even if a destructor exception happens + # during purge(owner_id). + + destructor_calls = [] + + def faulty_destructor(obj): + destructor_calls.append(obj) + raise RuntimeError("Destructor failed") + + custom_cache = subprocess_server._SharedCache( + lambda *args: "process_obj", faulty_destructor) + + class CustomServer(subprocess_server.SubprocessServer): + _cache = custom_cache + + def __init__(self): + super().__init__(lambda channel: None, ["dummy_cmd"], port=12345) + + server = CustomServer() + server.start_process() + owner_id = server._owner_id + self.assertIsNotNone(owner_id) + self.assertIn(owner_id, custom_cache._live_owners) + + # First stop attempt fails in the destructor + with self.assertRaises(RuntimeError): + server.stop_process() + + # Verify fixed state: owner is purged from cache set, AND self._owner_id is successfully cleared to None + self.assertNotIn(owner_id, custom_cache._live_owners) + self.assertIsNone(server._owner_id) + + # Second stop attempt safely does nothing (no ValueError raised) + try: + server.stop_process() + except ValueError: + self.fail("ValueError should not be raised here.") + + def test_duplicate_atexit_registration_on_restart(self): + # Make sure we don't have duplicate atexit registration when reusing a + # StopOnExistJobServer instance. + + class DummyJobServer(job_server.JobServer): + def start(self): + return "localhost:8080" + + def stop(self): + pass + + wrapper = job_server.StopOnExitJobServer(DummyJobServer()) + + registered_callbacks = [] + + def mock_register(cb): + registered_callbacks.append(cb) + + def mock_unregister(cb): + if cb in registered_callbacks: + registered_callbacks.remove(cb) + + with patch('atexit.register', side_effect=mock_register), \ + patch('atexit.unregister', side_effect=mock_unregister, create=True): + # First start registers stop callback + wrapper.start() + self.assertTrue(wrapper._started) + self.assertEqual(len(registered_callbacks), 1) + + # Explicit stop clears _started AND unregisters the callback + wrapper.stop() + self.assertFalse(wrapper._started) + self.assertEqual(len(registered_callbacks), 0) + + # Re-starting registers the callback again, leaving exactly 1 active callback + wrapper.start() + self.assertTrue(wrapper._started) + self.assertEqual(len(registered_callbacks), 1) + + def test_concurrent_purge_race_condition(self): + # Concurrent threads attempting to check memebership and call purge for the same owner. + # Here we explicitly define a synchronized set to mimic the behavior of _live_owners. + # This set will block two threads on __contains__, allowing us to test the race condition. + cache = subprocess_server._SharedCache(lambda x: "obj", lambda x: None) + owner = cache.register() + + barrier = threading.Barrier(2) + exceptions = [] + + class SynchronizedSet(set): + def __contains__(self, item): + res = super().__contains__(item) + try: + # Force both threads to align right after checking membership but before removal + barrier.wait(timeout=0.2) + except threading.BrokenBarrierError: + pass + return res + + cache._live_owners = SynchronizedSet(cache._live_owners) + + def purge_worker(): + try: + cache.purge(owner) + except Exception as e: + exceptions.append(e) + + t1 = threading.Thread(target=purge_worker) + t2 = threading.Thread(target=purge_worker) + + t1.start() + t2.start() + + t1.join() + t2.join() + + # Exactly one thread should raise the expected ValueError because they are cleanly serialized + self.assertEqual(len(exceptions), 1) + self.assertIsInstance(exceptions[0], ValueError) + self.assertNotIsInstance(exceptions[0], KeyError) + if __name__ == '__main__': unittest.main()