Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sdks/python/apache_beam/runners/portability/job_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,13 @@ def start(self):
def stop(self):
with self._lock:
if self._started:
self._job_server.stop()
self._started = False
try:
self._job_server.stop()
finally:
self._started = False
# Unregister the atexit handler to prevent duplicate
# registrations when the server is restarted/reused.
atexit.unregister(self.stop)


class SubprocessJobServer(JobServer):
Expand Down
16 changes: 9 additions & 7 deletions sdks/python/apache_beam/utils/subprocess_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def register(self):
return owner

def purge(self, owner):
if owner not in self._live_owners:
raise ValueError(f"{owner} not in {self._live_owners}")
self._live_owners.remove(owner)
to_delete = []
with self._lock:
if owner not in self._live_owners:
raise ValueError(f"{owner} not in {self._live_owners}")
self._live_owners.remove(owner)
for key, entry in list(self._cache.items()):
if owner in entry.owners:
entry.owners.remove(owner)
Expand Down Expand Up @@ -255,15 +255,17 @@ def stop(self):

def stop_process(self):
if self._owner_id is not None:
self._cache.purge(self._owner_id)
self._owner_id = None
try:
self._cache.purge(self._owner_id)
finally:
# Make sure _owner_id is set to None even if purge fails.
self._owner_id = None
if self._grpc_channel:
try:
self._grpc_channel.close()
except: # pylint: disable=bare-except
_LOGGER.error(
"Could not close the gRPC channel started for the "
"expansion service")
"Could not close the gRPC channel started with cmd %s", self._cmd)
finally:
self._grpc_channel = None

Expand Down
124 changes: 124 additions & 0 deletions sdks/python/apache_beam/utils/subprocess_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pytype: skip-file

import atexit
import glob
import os
import random
Expand All @@ -29,7 +30,9 @@
import tempfile
import threading
import unittest
from unittest.mock import patch

from apache_beam.runners.portability import job_server
from apache_beam.utils import subprocess_server


Expand Down Expand Up @@ -302,6 +305,127 @@ def test_interleaved_owners(self):
self.assertNotEqual(cache.get('b'), b)
cache.purge(owner3)

def test_destructor_exception_partial_state(self):
# In SubprocessServer.stop_process(), we need to make sure self._owner_id is always
# set to None if it is not already set, even if a destructor exception happens
# during purge(owner_id).

destructor_calls = []

def faulty_destructor(obj):
destructor_calls.append(obj)
raise RuntimeError("Destructor failed")

custom_cache = subprocess_server._SharedCache(
lambda *args: "process_obj", faulty_destructor)

class CustomServer(subprocess_server.SubprocessServer):
_cache = custom_cache

def __init__(self):
super().__init__(lambda channel: None, ["dummy_cmd"], port=12345)

server = CustomServer()
server.start_process()
owner_id = server._owner_id
self.assertIsNotNone(owner_id)
self.assertIn(owner_id, custom_cache._live_owners)

# First stop attempt fails in the destructor
with self.assertRaises(RuntimeError):
server.stop_process()

# Verify fixed state: owner is purged from cache set, AND self._owner_id is successfully cleared to None
self.assertNotIn(owner_id, custom_cache._live_owners)
self.assertIsNone(server._owner_id)

# Second stop attempt safely does nothing (no ValueError raised)
try:
server.stop_process()
except ValueError:
self.fail("ValueError should not be raised here.")

def test_duplicate_atexit_registration_on_restart(self):
# Make sure we don't have duplicate atexit registration when reusing a
# StopOnExistJobServer instance.

class DummyJobServer(job_server.JobServer):
def start(self):
return "localhost:8080"

def stop(self):
pass

wrapper = job_server.StopOnExitJobServer(DummyJobServer())

registered_callbacks = []

def mock_register(cb):
registered_callbacks.append(cb)

def mock_unregister(cb):
if cb in registered_callbacks:
registered_callbacks.remove(cb)

with patch('atexit.register', side_effect=mock_register), \
patch('atexit.unregister', side_effect=mock_unregister, create=True):
# First start registers stop callback
wrapper.start()
self.assertTrue(wrapper._started)
self.assertEqual(len(registered_callbacks), 1)

# Explicit stop clears _started AND unregisters the callback
wrapper.stop()
self.assertFalse(wrapper._started)
self.assertEqual(len(registered_callbacks), 0)

# Re-starting registers the callback again, leaving exactly 1 active callback
wrapper.start()
self.assertTrue(wrapper._started)
self.assertEqual(len(registered_callbacks), 1)

def test_concurrent_purge_race_condition(self):
# Concurrent threads attempting to check memebership and call purge for the same owner.
# Here we explicitly define a synchronized set to mimic the behavior of _live_owners.
# This set will block two threads on __contains__, allowing us to test the race condition.
cache = subprocess_server._SharedCache(lambda x: "obj", lambda x: None)
owner = cache.register()

barrier = threading.Barrier(2)
exceptions = []

class SynchronizedSet(set):
def __contains__(self, item):
res = super().__contains__(item)
try:
# Force both threads to align right after checking membership but before removal
barrier.wait(timeout=0.2)
except threading.BrokenBarrierError:
pass
return res

cache._live_owners = SynchronizedSet(cache._live_owners)

def purge_worker():
try:
cache.purge(owner)
except Exception as e:
exceptions.append(e)

t1 = threading.Thread(target=purge_worker)
t2 = threading.Thread(target=purge_worker)

t1.start()
t2.start()

t1.join()
t2.join()

# Exactly one thread should raise the expected ValueError because they are cleanly serialized
self.assertEqual(len(exceptions), 1)
self.assertIsInstance(exceptions[0], ValueError)
self.assertNotIsInstance(exceptions[0], KeyError)


if __name__ == '__main__':
unittest.main()
Loading