diff --git a/build-logic/conventions/src/main/kotlin/com/datadoghq/profiler/ProfilerTestPlugin.kt b/build-logic/conventions/src/main/kotlin/com/datadoghq/profiler/ProfilerTestPlugin.kt index 2d2b46287..36752d0f9 100644 --- a/build-logic/conventions/src/main/kotlin/com/datadoghq/profiler/ProfilerTestPlugin.kt +++ b/build-logic/conventions/src/main/kotlin/com/datadoghq/profiler/ProfilerTestPlugin.kt @@ -560,21 +560,22 @@ class ProfilerTestPlugin : Plugin { } } - // Wire up assemble/gtest dependencies + // Wire up gtest -> test dependencies (C++ tests run before Java tests) project.gradle.projectsEvaluated { configNames.forEach { cfgName -> - val testTask = project.tasks.findByName("test$cfgName") + val capitalizedCfgName = cfgName.replaceFirstChar { it.uppercaseChar() } + val testTaskName = "test$capitalizedCfgName" + val testTask = project.tasks.findByName(testTaskName) val profilerLibProject = project.rootProject.findProject(profilerLibProjectPath) - if (profilerLibProject != null) { - val assembleTask = profilerLibProject.tasks.findByName("assemble${cfgName.replaceFirstChar { it.uppercaseChar() }}") - if (testTask != null && assembleTask != null) { - assembleTask.dependsOn(testTask) - } - - val gtestTask = profilerLibProject.tasks.findByName("gtest${cfgName.replaceFirstChar { it.uppercaseChar() }}") - if (testTask != null && gtestTask != null) { + if (profilerLibProject != null && testTask != null) { + // gtest runs before test (C++ unit tests run before Java integration tests) + val gtestTaskName = "gtest${capitalizedCfgName}" + try { + val gtestTask = profilerLibProject.tasks.named(gtestTaskName) testTask.dependsOn(gtestTask) + } catch (e: org.gradle.api.UnknownTaskException) { + project.logger.info("Task $gtestTaskName not found in $profilerLibProjectPath - gtest may not be available") } } } diff --git a/ddprof-lib/src/main/cpp/os.h b/ddprof-lib/src/main/cpp/os.h index 1093feff0..1c16af80c 100644 --- a/ddprof-lib/src/main/cpp/os.h +++ b/ddprof-lib/src/main/cpp/os.h @@ -157,6 +157,7 @@ class OS { static SigAction getSegvChainTarget(); static SigAction getBusChainTarget(); static void* getSigactionHook(); + static void resetSignalHandlersForTesting(); static int getMaxThreadId(int floor) { int maxThreadId = getMaxThreadId(); diff --git a/ddprof-lib/src/main/cpp/os_linux.cpp b/ddprof-lib/src/main/cpp/os_linux.cpp index 1b0fe9626..fd2699395 100644 --- a/ddprof-lib/src/main/cpp/os_linux.cpp +++ b/ddprof-lib/src/main/cpp/os_linux.cpp @@ -739,6 +739,11 @@ static SigAction _protected_bus_handler = nullptr; static volatile SigAction _segv_chain_target = nullptr; static volatile SigAction _bus_chain_target = nullptr; +// Original handlers (JVM's) saved before we install ours - used for oldact in sigaction hook +static struct sigaction _orig_segv_sigaction; +static struct sigaction _orig_bus_sigaction; +static bool _orig_handlers_saved = false; + // Real sigaction function pointer (resolved via dlsym) typedef int (*real_sigaction_t)(int, const struct sigaction*, struct sigaction*); static real_sigaction_t _real_sigaction = nullptr; @@ -748,6 +753,14 @@ void OS::protectSignalHandlers(SigAction segvHandler, SigAction busHandler) { if (_real_sigaction == nullptr) { _real_sigaction = (real_sigaction_t)dlsym(RTLD_DEFAULT, "sigaction"); } + // Save the current (JVM's) signal handlers BEFORE we install ours. + // These will be returned as oldact when we intercept other libraries' sigaction calls, + // so they chain to JVM instead of back to us (which would cause infinite loops). + if (!__atomic_load_n(&_orig_handlers_saved, __ATOMIC_ACQUIRE) && _real_sigaction != nullptr) { + _real_sigaction(SIGSEGV, nullptr, &_orig_segv_sigaction); + _real_sigaction(SIGBUS, nullptr, &_orig_bus_sigaction); + __atomic_store_n(&_orig_handlers_saved, true, __ATOMIC_RELEASE); + } _protected_segv_handler = segvHandler; _protected_bus_handler = busHandler; } @@ -760,7 +773,67 @@ SigAction OS::getBusChainTarget() { return __atomic_load_n(&_bus_chain_target, __ATOMIC_ACQUIRE); } -// sigaction hook - called via GOT patching to intercept sigaction calls +// sigaction_hook - intercepts sigaction(2) calls from any library via GOT patching. +// +// PROBLEM SOLVED +// ============== +// Without interception, a library (e.g. wasmtime) can overwrite our SIGSEGV handler: +// +// Before: kernel --> our_handler --> JVM_handler +// After lib calls sigaction(SIGSEGV, lib_handler, &oldact): +// kernel --> lib_handler +// lib_handler stores oldact = our_handler as its chain target +// => when lib chains on unhandled fault: lib_handler --> our_handler --> lib_handler --> ... +// INFINITE LOOP +// +// HANDLER CHAIN AFTER SETUP +// ========================== +// +// protectSignalHandlers() replaceSigsegvHandler() LibraryPatcher::patch_sigaction() +// | | | +// v v v +// save JVM handler install our_handler GOT-patch sigaction +// into _orig_segv_sigaction as real OS handler => all future sigaction() +// calls go through us +// +// Signal delivery chain: +// +// kernel +// | +// v +// our_handler (installed via replaceSigsegvHandler, never displaced) +// | +// +-- handled by us? --> done +// | +// v (not handled) +// _segv_chain_target (lib_handler, set when we intercepted lib's sigaction call) +// | +// +-- handled by lib? --> done +// | +// v (lib chains to its saved oldact) +// _orig_segv_sigaction (JVM's original handler, what we returned as oldact to lib) +// | +// v +// JVM handles or terminates +// +// INTERCEPTION LOGIC (this function) +// =================================== +// Case 1 - Install call [act != nullptr, SA_SIGINFO]: +// - Save lib's handler as _segv_chain_target (we'll call it if we can't handle) +// - Return _orig_segv_sigaction as oldact (NOT our handler, to break the loop) +// - Do NOT actually install lib's handler (keep ours on top) +// +// Case 2 - Query-only call [act == nullptr, oldact != nullptr]: +// - Return _orig_segv_sigaction as oldact (same reason: lib must not see our handler) +// - A lib that queries, stores the result, then uses it as a chain target would +// loop if we returned our handler here. +// +// Case 3 - 1-arg handler [act != nullptr, no SA_SIGINFO]: +// - Pass through: we cannot safely chain 1-arg handlers (different calling convention) +// +// Case 4 - Any other signal, or protection not yet active: +// - Pass through to real sigaction unchanged. +// static int sigaction_hook(int signum, const struct sigaction* act, struct sigaction* oldact) { // _real_sigaction must be resolved before any GOT patching happens if (_real_sigaction == nullptr) { @@ -769,10 +842,14 @@ static int sigaction_hook(int signum, const struct sigaction* act, struct sigact } // If this is SIGSEGV or SIGBUS and we have protected handlers installed, - // intercept the call to keep our handler on top - if (act != nullptr) { - if (signum == SIGSEGV && _protected_segv_handler != nullptr) { - // Only intercept SA_SIGINFO handlers (3-arg form) for safe chaining + // intercept the call to keep our handler on top. + // We intercept both install calls (act != nullptr) and query-only calls (act == nullptr) + // to ensure callers always see the JVM's original handler, never ours. + // A caller that gets our handler as oldact and later chains to it would cause an + // infinite loop: us -> them -> us -> ... + if (signum == SIGSEGV && _protected_segv_handler != nullptr) { + if (act != nullptr) { + // Install call: only intercept SA_SIGINFO handlers (3-arg form) for safe chaining if (act->sa_flags & SA_SIGINFO) { SigAction new_handler = act->sa_sigaction; // Don't intercept if it's our own handler being installed @@ -780,7 +857,9 @@ static int sigaction_hook(int signum, const struct sigaction* act, struct sigact // Save their handler as our chain target __atomic_exchange_n(&_segv_chain_target, new_handler, __ATOMIC_ACQ_REL); if (oldact != nullptr) { - _real_sigaction(signum, nullptr, oldact); + // Return the original (JVM's) handler, not ours, to prevent + // the caller from chaining back to us. + *oldact = _orig_segv_sigaction; } Counters::increment(SIGACTION_INTERCEPTED); // Don't actually install their handler - keep ours on top @@ -788,18 +867,28 @@ static int sigaction_hook(int signum, const struct sigaction* act, struct sigact } } // Let 1-arg handlers (without SA_SIGINFO) pass through - we can't safely chain them - } else if (signum == SIGBUS && _protected_bus_handler != nullptr) { + } else if (oldact != nullptr) { + // Query-only call: return the JVM's original handler, not ours. + // Same reason: a caller that stores our handler and later chains to it causes loops. + *oldact = _orig_segv_sigaction; + return 0; + } + } else if (signum == SIGBUS && _protected_bus_handler != nullptr) { + if (act != nullptr) { if (act->sa_flags & SA_SIGINFO) { SigAction new_handler = act->sa_sigaction; if (new_handler != _protected_bus_handler) { __atomic_exchange_n(&_bus_chain_target, new_handler, __ATOMIC_ACQ_REL); if (oldact != nullptr) { - _real_sigaction(signum, nullptr, oldact); + *oldact = _orig_bus_sigaction; } Counters::increment(SIGACTION_INTERCEPTED); return 0; } } + } else if (oldact != nullptr) { + *oldact = _orig_bus_sigaction; + return 0; } } @@ -811,4 +900,15 @@ void* OS::getSigactionHook() { return (void*)sigaction_hook; } +void OS::resetSignalHandlersForTesting() { + __atomic_store_n(&_orig_handlers_saved, false, __ATOMIC_RELEASE); + memset(&_orig_segv_sigaction, 0, sizeof(_orig_segv_sigaction)); + memset(&_orig_bus_sigaction, 0, sizeof(_orig_bus_sigaction)); + _protected_segv_handler = nullptr; + _protected_bus_handler = nullptr; + __atomic_store_n(&_segv_chain_target, (SigAction)nullptr, __ATOMIC_RELEASE); + __atomic_store_n(&_bus_chain_target, (SigAction)nullptr, __ATOMIC_RELEASE); + // _real_sigaction is intentionally not reset: safe to reuse across tests +} + #endif // __linux__ diff --git a/ddprof-lib/src/main/cpp/os_macos.cpp b/ddprof-lib/src/main/cpp/os_macos.cpp index 8337a7420..72f978bac 100644 --- a/ddprof-lib/src/main/cpp/os_macos.cpp +++ b/ddprof-lib/src/main/cpp/os_macos.cpp @@ -496,4 +496,12 @@ SigAction OS::getBusChainTarget() { return nullptr; } +void* OS::getSigactionHook() { + return nullptr; // No sigaction interception on macOS +} + +void OS::resetSignalHandlersForTesting() { + // No-op: no sigaction interception state on macOS +} + #endif // __APPLE__ diff --git a/ddprof-lib/src/main/cpp/profiler.cpp b/ddprof-lib/src/main/cpp/profiler.cpp index cbcfb7284..d1837333e 100644 --- a/ddprof-lib/src/main/cpp/profiler.cpp +++ b/ddprof-lib/src/main/cpp/profiler.cpp @@ -1044,81 +1044,102 @@ void Profiler::disableEngines() { } void Profiler::segvHandler(int signo, siginfo_t *siginfo, void *ucontext) { - if (!crashHandler(signo, siginfo, ucontext)) { - // Check dynamic chain target first (set by intercepted sigaction calls) - SigAction chain = OS::getSegvChainTarget(); - if (chain != nullptr) { - chain(signo, siginfo, ucontext); - } else if (orig_segvHandler != nullptr) { - orig_segvHandler(signo, siginfo, ucontext); - } + if (crashHandlerInternal(signo, siginfo, ucontext)) { + return; // Handled + } + // Not handled, chain to next handler + SigAction chain = OS::getSegvChainTarget(); + if (chain != nullptr) { + chain(signo, siginfo, ucontext); + } else if (orig_segvHandler != nullptr) { + orig_segvHandler(signo, siginfo, ucontext); } } void Profiler::busHandler(int signo, siginfo_t *siginfo, void *ucontext) { - if (!crashHandler(signo, siginfo, ucontext)) { - // Check dynamic chain target first (set by intercepted sigaction calls) - SigAction chain = OS::getBusChainTarget(); - if (chain != nullptr) { - chain(signo, siginfo, ucontext); - } else if (orig_busHandler != nullptr) { - orig_busHandler(signo, siginfo, ucontext); - } + if (crashHandlerInternal(signo, siginfo, ucontext)) { + return; // Handled + } + // Not handled, chain to next handler + SigAction chain = OS::getBusChainTarget(); + if (chain != nullptr) { + chain(signo, siginfo, ucontext); + } else if (orig_busHandler != nullptr) { + orig_busHandler(signo, siginfo, ucontext); } } -bool Profiler::crashHandler(int signo, siginfo_t *siginfo, void *ucontext) { +// Returns: 0 = not handled (chain to next handler), non-zero = handled +int Profiler::crashHandlerInternal(int signo, siginfo_t *siginfo, void *ucontext) { ProfiledThread* thrd = ProfiledThread::currentSignalSafe(); - if (thrd != nullptr && !thrd->enterCrashHandler()) { - // we are already in a crash handler; don't recurse! - return false; - } + // First, try to handle safefetch - this doesn't need TLS or any protection + // because it directly checks the PC and modifies ucontext to skip the fault. + // This must be checked first before any reentrancy checks. if (SafeAccess::handle_safefetch(signo, ucontext)) { - if (thrd != nullptr) { - thrd->exitCrashHandler(); + return 1; // handled + } + + // Reentrancy protection: use TLS-based tracking if available. + // If TLS is not available, we can only safely handle faults that we can + // prove are from our protected code paths (checked via sameStack heuristic + // in StackWalker::checkFault). For anything else, we must chain immediately + // to avoid claiming faults that aren't ours. + bool have_tls_protection = false; + if (thrd != nullptr) { + if (!thrd->enterCrashHandler()) { + // we are already in a crash handler; don't recurse! + return 0; // not handled, safe to chain } - return true; + have_tls_protection = true; } + // If thrd == nullptr, we proceed but with limited handling capability. + // Only StackWalker::checkFault (which has its own sameStack fallback) + // and the JDK-8313796 workaround can safely handle faults without TLS. - uintptr_t fault_address = (uintptr_t)siginfo->si_addr; StackFrame frame(ucontext); uintptr_t pc = frame.pc(); + + uintptr_t fault_address = (uintptr_t)siginfo->si_addr; if (pc == fault_address) { // it is 'pc' that is causing the fault; can not access it safely - if (thrd != nullptr) { + if (have_tls_protection) { thrd->exitCrashHandler(); } - return false; + return 0; // not handled, safe to chain } if (WX_MEMORY && Trap::isFaultInstruction(pc)) { - if (thrd != nullptr) { + if (have_tls_protection) { thrd->exitCrashHandler(); } - return true; + return 1; // handled } if (VM::isHotspot()) { // the following checks require vmstructs and therefore HotSpot + // StackWalker::checkFault has its own fallback for when TLS is unavailable: + // it uses sameStack() heuristic to check if we're in a protected stack walk. + // If the fault is from our protected walk, it will longjmp and never return. + // If it returns, the fault wasn't from our code. StackWalker::checkFault(thrd); // Workaround for JDK-8313796 if needed. Setting cstack=dwarf also helps if (_need_JDK_8313796_workaround && VMStructs::isInterpretedFrameValidFunc((const void *)pc) && frame.skipFaultInstruction()) { - if (thrd != nullptr) { + if (have_tls_protection) { thrd->exitCrashHandler(); } - return true; + return 1; // handled } } - if (thrd != nullptr) { + if (have_tls_protection) { thrd->exitCrashHandler(); } - return false; + return 0; // not handled, safe to chain } void Profiler::setupSignalHandlers() { @@ -1126,11 +1147,13 @@ void Profiler::setupSignalHandlers() { if (__sync_bool_compare_and_swap(&_signals_initialized, false, true)) { if (VM::isHotspot() || VM::isOpenJ9()) { // HotSpot and J9 tolerate interposed SIGSEGV/SIGBUS handler; other JVMs probably not + // IMPORTANT: protectSignalHandlers must be called BEFORE replaceSigsegvHandler so that + // the original (JVM's) handlers are saved before we install ours. This way, when we + // intercept other libraries' sigaction calls and return oldact, we return the JVM's + // handler (not ours), preventing infinite chaining loops. + OS::protectSignalHandlers(segvHandler, busHandler); orig_segvHandler = OS::replaceSigsegvHandler(segvHandler); orig_busHandler = OS::replaceSigbusHandler(busHandler); - // Protect our handlers from being overwritten by other libraries (e.g., wasmtime). - // Their handlers will be stored as chain targets and called from our handlers. - OS::protectSignalHandlers(segvHandler, busHandler); // Patch sigaction GOT in libraries with broken signal handlers (already loaded) LibraryPatcher::patch_sigaction(); } diff --git a/ddprof-lib/src/main/cpp/profiler.h b/ddprof-lib/src/main/cpp/profiler.h index 285c64805..5d4634c25 100644 --- a/ddprof-lib/src/main/cpp/profiler.h +++ b/ddprof-lib/src/main/cpp/profiler.h @@ -193,7 +193,7 @@ class alignas(alignof(SpinLock)) Profiler { void lockAll(); void unlockAll(); - static bool crashHandler(int signo, siginfo_t *siginfo, void *ucontext); + static int crashHandlerInternal(int signo, siginfo_t *siginfo, void *ucontext); static void check_JDK_8313796_workaround(); static Profiler *const _instance; diff --git a/ddprof-lib/src/main/cpp/vmStructs.h b/ddprof-lib/src/main/cpp/vmStructs.h index 9a51fdc08..f93397897 100644 --- a/ddprof-lib/src/main/cpp/vmStructs.h +++ b/ddprof-lib/src/main/cpp/vmStructs.h @@ -31,10 +31,11 @@ class VMNMethod; // sending SIGABRT which is uncatchable by crash protection. // When crash protection is active the assert is redundant — any bad read will // be caught by the SIGSEGV handler and recovered via longjmp — so we skip it. -inline bool crashProtectionActive() { - ProfiledThread* pt = ProfiledThread::currentSignalSafe(); - return pt != nullptr && pt->isCrashProtectionActive(); -} +// +// Defined at the bottom of this file after VMThread is declared so that the +// VMThread fallback path (isExceptionActive) is accessible without forward- +// declaring the full class. +inline bool crashProtectionActive(); template inline T* cast_to(const void* ptr) { @@ -788,6 +789,17 @@ DECLARE(VMThread) return *(void**) at(_thread_exception_offset); } + // Returns true if setjmp crash protection is currently active for this thread. + // Reads the exception field via direct pointer arithmetic, deliberately bypassing + // at() and its crashProtectionActive() assertion to avoid infinite recursion. + // Safe because 'this' is the current live thread (we are in its signal handler). + static bool isExceptionActive() { + if (_thread_exception_offset < 0) return false; + VMThread* vt = current(); + if (vt == nullptr) return false; + return *(const void* const*)((const char*)vt + _thread_exception_offset) != nullptr; + } + NOADDRSANITIZE VMJavaFrameAnchor* anchor() { if (!cachedIsJavaThread()) return NULL; assert(_thread_anchor_offset >= 0); @@ -1123,4 +1135,18 @@ class InterpreterFrame : VMStructs { } }; +// Defined here (after VMThread) so the VMThread::isExceptionActive() fallback +// is accessible. The forward declaration at the top of this file allows cast_to() +// to reference it before VMThread is declared. +inline bool crashProtectionActive() { + ProfiledThread* pt = ProfiledThread::currentSignalSafe(); + if (pt != nullptr && pt->isCrashProtectionActive()) return true; + // Fallback for threads without ProfiledThread TLS (e.g. JVM internal threads): + // if walkVM has set up setjmp protection via vm_thread->exception(), the assert + // is equally redundant — any bad read will be caught by the SIGSEGV handler. + // Uses VMThread::isExceptionActive() which reads the field directly without + // going through at() to avoid recursive assertion. + return VMThread::key() >= 0 && VMThread::isExceptionActive(); +} + #endif // _VMSTRUCTS_H diff --git a/ddprof-lib/src/test/cpp/sigaction_interception_ut.cpp b/ddprof-lib/src/test/cpp/sigaction_interception_ut.cpp new file mode 100644 index 000000000..6d45bee2b --- /dev/null +++ b/ddprof-lib/src/test/cpp/sigaction_interception_ut.cpp @@ -0,0 +1,266 @@ +/* + * Copyright 2025 Datadog, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "os.h" + +/** + * Test for signal handler chaining to prevent infinite loops. + * + * This test verifies that when we intercept sigaction calls from other libraries + * (like wasmtime), the oldact we return points to the original handler (e.g., JVM's), + * not to our handler. If we return our handler as oldact, the chain becomes: + * Us -> other_lib -> Us -> other_lib -> ... (infinite loop) + * + * The correct chain should be: + * Us -> other_lib -> original_handler (JVM) + */ + +// Type for our sigaction hook function +typedef int (*sigaction_hook_t)(int, const struct sigaction*, struct sigaction*); + +// Counter to detect infinite loops +static std::atomic handler_call_count{0}; +static const int MAX_HANDLER_CALLS = 10; + +// Jump buffer for escaping infinite loops +static sigjmp_buf escape_jmp; +static std::atomic should_escape{false}; + +// The "other library's" handler (simulating wasmtime) +static struct sigaction other_lib_saved_oldact; + +static void other_lib_handler(int signo, siginfo_t* siginfo, void* context) { + handler_call_count++; + + // Detect infinite loop + if (handler_call_count > MAX_HANDLER_CALLS) { + // We're in an infinite loop - escape + if (should_escape) { + siglongjmp(escape_jmp, 1); + } + return; + } + + // Simulate what wasmtime does: if we don't handle it, chain to "previous" handler + // If oldact points back to our handler, this will cause infinite recursion + if (other_lib_saved_oldact.sa_flags & SA_SIGINFO) { + if (other_lib_saved_oldact.sa_sigaction != nullptr) { + other_lib_saved_oldact.sa_sigaction(signo, siginfo, context); + } + } else if (other_lib_saved_oldact.sa_handler != SIG_DFL && + other_lib_saved_oldact.sa_handler != SIG_IGN && + other_lib_saved_oldact.sa_handler != nullptr) { + other_lib_saved_oldact.sa_handler(signo); + } + // If oldact is SIG_DFL, we just return (signal will re-trigger or terminate) +} + +// Our handler (profiler's handler) +static void our_handler(int signo, siginfo_t* siginfo, void* context) { + handler_call_count++; + + // Detect infinite loop + if (handler_call_count > MAX_HANDLER_CALLS) { + if (should_escape) { + siglongjmp(escape_jmp, 1); + } + return; + } + + // We don't handle this signal, chain to the "other library" via chain target + SigAction chain = OS::getSegvChainTarget(); + if (chain != nullptr) { + chain(signo, siginfo, context); + } + // After chain returns (or if no chain), we return +} + +// Original handler (simulating JVM's handler) +static std::atomic original_handler_called{false}; +static void original_handler(int signo, siginfo_t* siginfo, void* context) { + original_handler_called = true; + // The original handler would normally handle the signal or terminate. + // For this test, we just mark that we were called. + if (should_escape) { + siglongjmp(escape_jmp, 1); + } +} + +class SigactionInterceptionTest : public ::testing::Test { +protected: + struct sigaction saved_segv_action; + + void SetUp() override { + // Save current SIGSEGV handler + sigaction(SIGSEGV, nullptr, &saved_segv_action); + + // Reset state + handler_call_count = 0; + original_handler_called = false; + should_escape = false; + memset(&other_lib_saved_oldact, 0, sizeof(other_lib_saved_oldact)); + } + + void TearDown() override { + // Restore original SIGSEGV handler + sigaction(SIGSEGV, &saved_segv_action, nullptr); + // Reset static interception state so tests don't bleed into each other + OS::resetSignalHandlersForTesting(); + } +}; + +/** + * Test that sigaction interception returns the correct oldact. + * + * Setup: + * 1. Install "original" handler (simulating JVM) + * 2. Call protectSignalHandlers with "our" handler + * 3. Install "our" handler + * 4. "Other library" calls sigaction (via the hook) to install its handler + * 5. Verify that oldact returned to "other library" is the original handler, not ours + * + * This ensures that when "other library" chains to oldact, it goes to the original + * handler, not back to us (which would cause infinite loop). + * + * NOTE: We call the sigaction hook directly because in a standalone test binary, + * GOT patching isn't active. This tests the core logic of the hook function. + */ +TEST_F(SigactionInterceptionTest, OldactPointsToOriginalHandler) { + // Get the sigaction hook function + sigaction_hook_t hook = (sigaction_hook_t)OS::getSigactionHook(); +#ifdef __APPLE__ + // Sigaction interception is only implemented on Linux + if (hook == nullptr) { + GTEST_SKIP() << "Sigaction interception not implemented on macOS"; + } +#endif + ASSERT_NE(hook, nullptr) << "getSigactionHook returned nullptr"; + + // Step 1: Install "original" handler (simulating JVM) + struct sigaction original_sa; + memset(&original_sa, 0, sizeof(original_sa)); + original_sa.sa_sigaction = original_handler; + original_sa.sa_flags = SA_SIGINFO; + sigemptyset(&original_sa.sa_mask); + sigaction(SIGSEGV, &original_sa, nullptr); + + // Step 2 & 3: Protect and install our handler + // Note: protectSignalHandlers should save the current (original) handler + OS::protectSignalHandlers(our_handler, nullptr); + OS::replaceSigsegvHandler(our_handler); + + // Step 4: "Other library" calls sigaction via the hook to install its handler + struct sigaction other_lib_sa; + memset(&other_lib_sa, 0, sizeof(other_lib_sa)); + other_lib_sa.sa_sigaction = other_lib_handler; + other_lib_sa.sa_flags = SA_SIGINFO; + sigemptyset(&other_lib_sa.sa_mask); + + // Call the hook directly (simulates what GOT patching does in production) + int result = hook(SIGSEGV, &other_lib_sa, &other_lib_saved_oldact); + ASSERT_EQ(result, 0); + + // Step 5: Verify oldact + // The oldact should point to original_handler, NOT to our_handler + // If it points to our_handler, chaining will cause infinite loop + + // Check that oldact is not our handler + if (other_lib_saved_oldact.sa_flags & SA_SIGINFO) { + EXPECT_NE(other_lib_saved_oldact.sa_sigaction, our_handler) + << "oldact points to our handler - this would cause infinite loop!"; + + // It should be the original handler + EXPECT_EQ(other_lib_saved_oldact.sa_sigaction, original_handler) + << "oldact should be the original (JVM's) handler"; + } else { + // If not SA_SIGINFO, check sa_handler + EXPECT_NE(other_lib_saved_oldact.sa_handler, (void (*)(int))our_handler) + << "oldact points to our handler - this would cause infinite loop!"; + } +} + +/** + * Test that signal chaining doesn't cause infinite loop. + * + * This test actually triggers a SIGSEGV and verifies the chain doesn't loop forever. + * The chain should be: our_handler -> other_lib_handler -> original_handler + * + * NOTE: We use the hook directly to set up the interception since GOT patching + * isn't active in the standalone test binary. + */ +TEST_F(SigactionInterceptionTest, NoInfiniteLoopOnChaining) { + // Get the sigaction hook function + sigaction_hook_t hook = (sigaction_hook_t)OS::getSigactionHook(); +#ifdef __APPLE__ + // Sigaction interception is only implemented on Linux + if (hook == nullptr) { + GTEST_SKIP() << "Sigaction interception not implemented on macOS"; + } +#endif + ASSERT_NE(hook, nullptr) << "getSigactionHook returned nullptr"; + + // Setup: Install original handler (simulating JVM) + struct sigaction original_sa; + memset(&original_sa, 0, sizeof(original_sa)); + original_sa.sa_sigaction = original_handler; + original_sa.sa_flags = SA_SIGINFO; + sigemptyset(&original_sa.sa_mask); + sigaction(SIGSEGV, &original_sa, nullptr); + + // Protect and install our handler + OS::protectSignalHandlers(our_handler, nullptr); + OS::replaceSigsegvHandler(our_handler); + + // "Other library" installs its handler via the hook + struct sigaction other_lib_sa; + memset(&other_lib_sa, 0, sizeof(other_lib_sa)); + other_lib_sa.sa_sigaction = other_lib_handler; + other_lib_sa.sa_flags = SA_SIGINFO; + sigemptyset(&other_lib_sa.sa_mask); + hook(SIGSEGV, &other_lib_sa, &other_lib_saved_oldact); + + // Now trigger a SIGSEGV and see what happens + should_escape = true; + + if (sigsetjmp(escape_jmp, 1) == 0) { + // Trigger SIGSEGV by accessing null pointer + volatile int* p = nullptr; + *p = 42; // This will trigger SIGSEGV + + // Should not reach here + FAIL() << "SIGSEGV was not triggered"; + } else { + // We escaped via siglongjmp + // Check that we didn't loop too many times + EXPECT_LE(handler_call_count.load(), MAX_HANDLER_CALLS) + << "Handler was called too many times - possible infinite loop!"; + + // Ideally, the chain should be: our_handler(1) -> other_lib(2) -> original(3) + // So we expect around 3 calls, definitely less than MAX_HANDLER_CALLS + EXPECT_LE(handler_call_count.load(), 5) + << "Handler chain is longer than expected"; + + // Verify the original handler was actually called (chain completed) + EXPECT_TRUE(original_handler_called) + << "Original handler was not called - chain may not be set up correctly"; + } +} diff --git a/ddprof-lib/src/test/cpp/test_tlsPriming.cpp b/ddprof-lib/src/test/cpp/test_tlsPriming.cpp deleted file mode 100644 index e8bb52dd0..000000000 --- a/ddprof-lib/src/test/cpp/test_tlsPriming.cpp +++ /dev/null @@ -1,297 +0,0 @@ -/* - * Copyright 2025, Datadog, Inc - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "gtest/gtest.h" -#include "os.h" -#include "common.h" -#include "thread.h" -#include -#include -#include -#include -#include -#include - -namespace { - -static std::atomic g_signal_received{0}; -static std::atomic g_threads_primed{0}; - -// Simple TLS test POD -thread_local uint64_t g_test_tls = 0; - -void testTlsSignalHandler(int signo) { - g_signal_received++; - g_test_tls = 0x1234ABCD; // Touch TLS to prime it - g_threads_primed++; -} - -class TlsPrimingTest : public ::testing::Test { -protected: - void SetUp() override { - g_signal_received.store(0); - g_threads_primed.store(0); - g_test_tls = 0; - } -}; - -TEST_F(TlsPrimingTest, InstallSignalHandler) { - int signal_num = OS::installTlsPrimeSignalHandler(testTlsSignalHandler, 5); - -#ifdef __linux__ - if (signal_num > 0) { - TEST_LOG("Successfully installed RT signal handler for signal %d", signal_num); - EXPECT_GT(signal_num, SIGRTMIN); - EXPECT_LE(signal_num, SIGRTMAX); - } else { - TEST_LOG("RT signal installation failed (may indicate signal exhaustion)"); - EXPECT_EQ(signal_num, -1); - } -#elif defined(__APPLE__) - TEST_LOG("TLS prime signal handler not supported on macOS"); - EXPECT_EQ(signal_num, -1); -#else - TEST_LOG("TLS prime signal handler not supported on this platform"); - EXPECT_EQ(signal_num, -1); -#endif -} - -TEST_F(TlsPrimingTest, EnumerateThreadIds) { - std::atomic thread_count{0}; - - OS::enumerateThreadIds([&](int tid) { - TEST_LOG("Found thread ID: %d", tid); -#ifdef __linux__ - EXPECT_GT(tid, 0); // Linux uses actual thread IDs > 0 -#endif - thread_count++; - }); - - TEST_LOG("Found %d threads total", thread_count.load()); - - // Should find at least the current thread on all platforms that implement enumeration - EXPECT_GE(thread_count.load(), 1); -} - -TEST_F(TlsPrimingTest, GetThreadCount) { - int count = OS::getThreadCount(); - TEST_LOG("Thread count: %d", count); - - // Should be at least 1 on platforms that implement thread counting - EXPECT_GE(count, 1); -} - -TEST_F(TlsPrimingTest, SignalCurrentThread) { - int signal_num = OS::installTlsPrimeSignalHandler(testTlsSignalHandler, 6); - -#ifdef __linux__ - if (signal_num > 0) { - TEST_LOG("Signaling current thread with signal %d", signal_num); - - // Get the first thread ID from enumeration - std::atomic first_tid{-1}; - OS::enumerateThreadIds([&](int tid) { - if (first_tid.load() == -1) { - first_tid.store(tid); - } - }); - - int tid = first_tid.load(); - if (tid >= 0) { - OS::signalThread(tid, signal_num); - - // Wait a bit for signal to be delivered - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - EXPECT_GT(g_signal_received.load(), 0); - EXPECT_GT(g_threads_primed.load(), 0); - EXPECT_EQ(g_test_tls, 0x1234ABCD); - - TEST_LOG("Signal delivered successfully, TLS primed"); - } else { - TEST_LOG("No threads found for signaling"); - FAIL() << "Thread enumeration should find at least one thread"; - } - } else { - TEST_LOG("TLS prime signaling failed to install handler"); - FAIL() << "Signal handler installation should succeed on Linux"; - } -#else - TEST_LOG("TLS prime signaling not supported on this platform"); - EXPECT_EQ(signal_num, -1); -#endif -} - -// Test TLS cleanup for JVMTI-allocated threads (non-buffer) -TEST_F(TlsPrimingTest, JvmtiThreadCleanup) { - TEST_LOG("Testing JVMTI-allocated thread cleanup"); - - std::atomic thread_done{false}; - std::atomic tid_observed{0}; - - // Create a thread that simulates JVMTI initialization - std::thread test_thread([&]() { - // Simulate JVMTI callback: initCurrentThread() - ProfiledThread::initCurrentThread(); - - // Verify TLS is initialized - ProfiledThread* tls = ProfiledThread::current(); - ASSERT_NE(tls, nullptr); - - tid_observed.store(tls->tid()); - TEST_LOG("JVMTI thread initialized with TID: %d", tls->tid()); - - // Verify this is NOT a buffer-allocated thread - // (buffer_pos should be -1 for JVMTI threads) - // We can't directly access _buffer_pos, but we can verify behavior - - thread_done.store(true); - // When thread exits, pthread should call freeKey() which should delete the instance - }); - - test_thread.join(); - - EXPECT_TRUE(thread_done.load()); - EXPECT_GT(tid_observed.load(), 0); - TEST_LOG("JVMTI thread cleanup completed (instance should be deleted)"); -} - -// Test TLS cleanup for buffer-allocated threads (signal priming) -TEST_F(TlsPrimingTest, BufferThreadCleanup) { -#ifdef __linux__ - TEST_LOG("Testing buffer-allocated thread cleanup"); - - // Initialize the buffer system - ProfiledThread::initExistingThreads(); - - std::atomic thread_initialized{false}; - std::atomic thread_done{false}; - std::atomic tid_observed{0}; - - // Create a thread that simulates signal-based priming - std::thread test_thread([&]() { - // Directly call buffer initialization (simulating signal handler) - // This is what simpleTlsSignalHandler() does for native threads - ProfiledThread::initCurrentThreadWithBuffer(); - - // Verify TLS is initialized from buffer - ProfiledThread* tls = ProfiledThread::currentSignalSafe(); - if (tls != nullptr) { - tid_observed.store(tls->tid()); - thread_initialized.store(true); - TEST_LOG("Buffer thread initialized with TID: %d", tls->tid()); - } - - thread_done.store(true); - // When thread exits, pthread should call freeKey() which should recycle the buffer slot - }); - - test_thread.join(); - - EXPECT_TRUE(thread_done.load()); - EXPECT_TRUE(thread_initialized.load()); - EXPECT_GT(tid_observed.load(), 0); - TEST_LOG("Buffer thread cleanup completed (slot should be recycled)"); - - // Cleanup - ProfiledThread::cleanupTlsPriming(); -#else - TEST_LOG("Buffer-allocated thread cleanup test only supported on Linux"); -#endif -} - -// Test that buffer slots are properly recycled -TEST_F(TlsPrimingTest, BufferSlotRecycling) { -#ifdef __linux__ - TEST_LOG("Testing buffer slot recycling"); - - ProfiledThread::initExistingThreads(); - - std::vector tids_observed; - const int num_threads = 10; - - for (int i = 0; i < num_threads; i++) { - std::atomic tid{0}; - - std::thread test_thread([&]() { - ProfiledThread::initCurrentThreadWithBuffer(); - ProfiledThread* tls = ProfiledThread::currentSignalSafe(); - if (tls != nullptr) { - tid.store(tls->tid()); - } - }); - - test_thread.join(); - - if (tid.load() > 0) { - tids_observed.push_back(tid.load()); - } - } - - EXPECT_EQ(tids_observed.size(), num_threads); - TEST_LOG("Created and recycled %d buffer slots", num_threads); - - // Verify all threads got valid TIDs - for (size_t i = 0; i < tids_observed.size(); i++) { - EXPECT_GT(tids_observed[i], 0) << "Thread " << i << " got invalid TID"; - } - - ProfiledThread::cleanupTlsPriming(); -#else - TEST_LOG("Buffer slot recycling test only supported on Linux"); -#endif -} - -// Test mixed JVMTI and buffer-allocated thread cleanup -TEST_F(TlsPrimingTest, MixedThreadCleanup) { -#ifdef __linux__ - TEST_LOG("Testing mixed JVMTI and buffer-allocated thread cleanup"); - - ProfiledThread::initExistingThreads(); - - std::atomic jvmti_count{0}; - std::atomic buffer_count{0}; - - std::vector threads; - - // Create mix of JVMTI-style and buffer-style threads - for (int i = 0; i < 10; i++) { - if (i % 2 == 0) { - // JVMTI-style thread - threads.emplace_back([&]() { - ProfiledThread::initCurrentThread(); - ProfiledThread* tls = ProfiledThread::current(); - if (tls != nullptr) { - jvmti_count++; - } - }); - } else { - // Buffer-style thread - threads.emplace_back([&]() { - ProfiledThread::initCurrentThreadWithBuffer(); - ProfiledThread* tls = ProfiledThread::currentSignalSafe(); - if (tls != nullptr) { - buffer_count++; - } - }); - } - } - - for (auto& t : threads) { - t.join(); - } - - EXPECT_EQ(jvmti_count.load(), 5); - EXPECT_EQ(buffer_count.load(), 5); - TEST_LOG("Mixed cleanup: %d JVMTI threads (deleted), %d buffer threads (recycled)", - jvmti_count.load(), buffer_count.load()); - - ProfiledThread::cleanupTlsPriming(); -#else - TEST_LOG("Mixed thread cleanup test only supported on Linux"); -#endif -} - -} // namespace \ No newline at end of file