Skip to content

Commit 25752c1

Browse files
committed
test: add unit tests for on-demand checkpointing
24 tests covering: - Trigger file helpers (path generation, write, exists, remove) - ParentSignalHandler (install, handle, idempotency, uninstall, real signal) - check_checkpoint_requested (trigger detection, cleanup, all_reduce consensus) - BatchLossManager interrupt handling (all 3 check points, early exit, float loss)
1 parent d7b965b commit 25752c1

1 file changed

Lines changed: 349 additions & 0 deletions

File tree

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for on-demand checkpointing."""
3+
4+
# Standard
5+
from unittest.mock import MagicMock, call, patch
6+
import os
7+
import signal
8+
9+
# Third Party
10+
import pytest
11+
import torch
12+
13+
# First Party
14+
from instructlab.training.on_demand_checkpoint import (
15+
_CATCHABLE_SIGNALS,
16+
ParentSignalHandler,
17+
_get_trigger_path,
18+
check_checkpoint_requested,
19+
remove_trigger_file,
20+
trigger_file_exists,
21+
write_trigger_file,
22+
)
23+
24+
# ---------------------------------------------------------------------------
25+
# Trigger file helpers
26+
# ---------------------------------------------------------------------------
27+
28+
29+
class TestGetTriggerPath:
30+
def test_without_job_id(self):
31+
path = _get_trigger_path()
32+
assert path.name == "instructlab_checkpoint_requested"
33+
assert str(path.parent) == "/dev/shm"
34+
35+
def test_with_job_id(self):
36+
path = _get_trigger_path("my-job-123")
37+
assert path.name == "instructlab_checkpoint_requested_my-job-123"
38+
39+
def test_different_job_ids_produce_different_paths(self):
40+
p1 = _get_trigger_path("job-a")
41+
p2 = _get_trigger_path("job-b")
42+
assert p1 != p2
43+
44+
45+
class TestWriteTriggerFile:
46+
def test_creates_file(self, tmp_path):
47+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
48+
path = write_trigger_file("test-write")
49+
assert path.exists()
50+
assert path.read_text() == "1"
51+
52+
def test_returns_correct_path(self, tmp_path):
53+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
54+
path = write_trigger_file("test-path")
55+
assert path == tmp_path / "instructlab_checkpoint_requested_test-path"
56+
57+
58+
class TestTriggerFileExists:
59+
def test_returns_false_when_absent(self, tmp_path):
60+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
61+
assert trigger_file_exists("nonexistent") is False
62+
63+
def test_returns_true_when_present(self, tmp_path):
64+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
65+
write_trigger_file("exists")
66+
assert trigger_file_exists("exists") is True
67+
68+
69+
class TestRemoveTriggerFile:
70+
def test_removes_existing_file(self, tmp_path):
71+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
72+
write_trigger_file("to-remove")
73+
assert trigger_file_exists("to-remove") is True
74+
remove_trigger_file("to-remove")
75+
assert trigger_file_exists("to-remove") is False
76+
77+
def test_noop_on_missing_file(self, tmp_path):
78+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
79+
# Should not raise
80+
remove_trigger_file("never-existed")
81+
82+
83+
# ---------------------------------------------------------------------------
84+
# ParentSignalHandler
85+
# ---------------------------------------------------------------------------
86+
87+
88+
class TestParentSignalHandler:
89+
def test_install_registers_handlers(self):
90+
handler = ParentSignalHandler(job_id="test-install")
91+
original_handlers = {sig: signal.getsignal(sig) for sig in _CATCHABLE_SIGNALS}
92+
try:
93+
handler.install()
94+
for sig in _CATCHABLE_SIGNALS:
95+
current = signal.getsignal(sig)
96+
assert current == handler._handle, (
97+
f"Expected handler._handle for {sig.name}, got {current}"
98+
)
99+
finally:
100+
# Restore originals regardless
101+
for sig, orig in original_handlers.items():
102+
signal.signal(sig, orig)
103+
104+
def test_handle_writes_trigger_and_records_signal(self, tmp_path):
105+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
106+
handler = ParentSignalHandler(job_id="test-handle")
107+
assert handler.signal_received is None
108+
assert handler._trigger_written is False
109+
110+
handler._handle(signal.SIGUSR1, None)
111+
112+
assert handler.signal_received == signal.SIGUSR1
113+
assert handler._trigger_written is True
114+
assert trigger_file_exists("test-handle") is True
115+
116+
def test_handle_is_idempotent(self, tmp_path):
117+
"""Multiple signals should only write the trigger file once."""
118+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
119+
handler = ParentSignalHandler(job_id="test-idempotent")
120+
121+
with patch(
122+
"instructlab.training.on_demand_checkpoint.write_trigger_file"
123+
) as mock_write:
124+
mock_write.return_value = tmp_path / "dummy"
125+
handler._handle(signal.SIGUSR1, None)
126+
handler._handle(signal.SIGTERM, None)
127+
handler._handle(signal.SIGINT, None)
128+
129+
# write_trigger_file called only once
130+
mock_write.assert_called_once_with("test-idempotent")
131+
132+
# signal_received should be the LAST signal
133+
assert handler.signal_received == signal.SIGINT
134+
135+
def test_uninstall_restores_original_handlers(self):
136+
handler = ParentSignalHandler(job_id="test-uninstall")
137+
originals = {sig: signal.getsignal(sig) for sig in _CATCHABLE_SIGNALS}
138+
139+
handler.install()
140+
# Verify handlers changed
141+
for sig in _CATCHABLE_SIGNALS:
142+
assert signal.getsignal(sig) == handler._handle
143+
144+
handler.uninstall()
145+
# Verify handlers restored
146+
for sig in _CATCHABLE_SIGNALS:
147+
assert signal.getsignal(sig) == originals[sig], f"{sig.name} not restored"
148+
149+
def test_install_via_real_signal(self, tmp_path):
150+
"""End-to-end: install handler, send SIGUSR1, verify trigger written."""
151+
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
152+
handler = ParentSignalHandler(job_id="test-real-signal")
153+
handler.install()
154+
try:
155+
os.kill(os.getpid(), signal.SIGUSR1)
156+
assert handler.signal_received == signal.SIGUSR1
157+
assert trigger_file_exists("test-real-signal") is True
158+
finally:
159+
handler.uninstall()
160+
remove_trigger_file("test-real-signal")
161+
162+
163+
# ---------------------------------------------------------------------------
164+
# check_checkpoint_requested (worker-side, mocked dist)
165+
# ---------------------------------------------------------------------------
166+
167+
168+
class TestCheckCheckpointRequested:
169+
def _mock_all_reduce_propagate(self, tensor, op=None):
170+
"""Mock all_reduce that just keeps the local value."""
171+
pass # tensor already has the local value
172+
173+
def test_returns_false_when_no_trigger(self, tmp_path):
174+
with (
175+
patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path),
176+
patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist,
177+
patch("torch.cuda.current_device", return_value=0),
178+
):
179+
mock_dist.all_reduce = self._mock_all_reduce_propagate
180+
mock_dist.is_initialized.return_value = True
181+
mock_dist.get_rank.return_value = 0
182+
183+
result = check_checkpoint_requested("test-no-trigger")
184+
assert result is False
185+
186+
def test_returns_true_when_trigger_exists(self, tmp_path):
187+
with (
188+
patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path),
189+
patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist,
190+
patch("torch.cuda.current_device", return_value=0),
191+
):
192+
mock_dist.all_reduce = self._mock_all_reduce_propagate
193+
mock_dist.is_initialized.return_value = True
194+
mock_dist.get_rank.return_value = 0
195+
196+
write_trigger_file("test-trigger")
197+
result = check_checkpoint_requested("test-trigger")
198+
assert result is True
199+
200+
def test_cleans_up_trigger_after_detection(self, tmp_path):
201+
with (
202+
patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path),
203+
patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist,
204+
patch("torch.cuda.current_device", return_value=0),
205+
):
206+
mock_dist.all_reduce = self._mock_all_reduce_propagate
207+
mock_dist.is_initialized.return_value = True
208+
mock_dist.get_rank.return_value = 0
209+
210+
write_trigger_file("test-cleanup")
211+
check_checkpoint_requested("test-cleanup")
212+
assert trigger_file_exists("test-cleanup") is False
213+
214+
def test_all_reduce_is_called(self, tmp_path):
215+
with (
216+
patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path),
217+
patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist,
218+
patch("torch.cuda.current_device", return_value=0),
219+
):
220+
mock_dist.all_reduce = MagicMock()
221+
mock_dist.is_initialized.return_value = True
222+
mock_dist.get_rank.return_value = 0
223+
mock_dist.ReduceOp.MAX = torch.distributed.ReduceOp.MAX
224+
225+
check_checkpoint_requested("test-allreduce")
226+
mock_dist.all_reduce.assert_called_once()
227+
# Verify MAX reduction op
228+
_, kwargs = mock_dist.all_reduce.call_args
229+
assert kwargs.get("op") == torch.distributed.ReduceOp.MAX
230+
231+
232+
# ---------------------------------------------------------------------------
233+
# BatchLossManager.process_batch interrupt handling
234+
# ---------------------------------------------------------------------------
235+
236+
237+
class TestBatchLossManagerInterrupt:
238+
"""Test that interrupt_check callbacks stop processing correctly."""
239+
240+
@pytest.fixture
241+
def manager(self):
242+
model = MagicMock()
243+
model.compute_loss.return_value = (
244+
torch.tensor(1.0, requires_grad=True),
245+
MagicMock(main_loss=torch.tensor(0.5), aux_loss=None),
246+
)
247+
accelerator = MagicMock()
248+
accelerator.device = torch.device("cpu")
249+
# reduce is called with a 2-element tensor (metrics) and a scalar (loss).
250+
# Return the input unchanged to simulate single-rank "reduction".
251+
accelerator.reduce.side_effect = lambda t, **kw: t
252+
accelerator.backward = MagicMock()
253+
254+
# First Party
255+
from instructlab.training.batch_loss_manager import BatchLossManager
256+
257+
mgr = BatchLossManager(
258+
model=model,
259+
accelerator=accelerator,
260+
world_size=1,
261+
local_rank=0,
262+
)
263+
return mgr
264+
265+
def _make_batch(self, n_minibatches=3):
266+
"""Create a fake batch with n minibatches."""
267+
return [
268+
{
269+
"input_ids": torch.randint(0, 100, (2, 32)),
270+
"labels": torch.randint(0, 100, (2, 32)),
271+
"num_samples": 2,
272+
"total_length": 32,
273+
"batch_num_loss_counted_tokens": 64,
274+
}
275+
for _ in range(n_minibatches)
276+
]
277+
278+
def test_no_interrupt_processes_all_minibatches(self, manager):
279+
batch = self._make_batch(3)
280+
metrics, _ = manager.process_batch(batch, interrupt_check=None)
281+
assert metrics.interrupted is False
282+
assert metrics.grad_accum_steps == 3
283+
284+
def test_interrupt_before_first_forward(self, manager):
285+
"""Interrupt fires immediately — no forward/backward should run."""
286+
batch = self._make_batch(3)
287+
metrics, _ = manager.process_batch(batch, interrupt_check=lambda: True)
288+
assert metrics.interrupted is True
289+
assert metrics.grad_accum_steps == 0
290+
manager.model.compute_loss.assert_not_called()
291+
manager.accelerator.backward.assert_not_called()
292+
293+
def test_interrupt_before_backward(self, manager):
294+
"""Interrupt fires after forward but before backward."""
295+
call_count = 0
296+
297+
def interrupt_on_second_call():
298+
nonlocal call_count
299+
call_count += 1
300+
# First call: before forward — let it pass
301+
# Second call: before backward — interrupt
302+
return call_count == 2
303+
304+
batch = self._make_batch(3)
305+
metrics, _ = manager.process_batch(
306+
batch, interrupt_check=interrupt_on_second_call
307+
)
308+
assert metrics.interrupted is True
309+
# Forward ran once, backward never ran
310+
assert manager.model.compute_loss.call_count == 1
311+
manager.accelerator.backward.assert_not_called()
312+
assert metrics.grad_accum_steps == 0
313+
314+
def test_interrupt_after_backward(self, manager):
315+
"""Interrupt fires after first backward — one grad accum step done."""
316+
call_count = 0
317+
318+
def interrupt_on_third_call():
319+
nonlocal call_count
320+
call_count += 1
321+
# Calls: 1=before_fwd, 2=before_bwd, 3=after_bwd (interrupt)
322+
return call_count == 3
323+
324+
batch = self._make_batch(3)
325+
metrics, _ = manager.process_batch(
326+
batch, interrupt_check=interrupt_on_third_call
327+
)
328+
assert metrics.interrupted is True
329+
assert metrics.grad_accum_steps == 1
330+
manager.model.compute_loss.assert_called_once()
331+
manager.accelerator.backward.assert_called_once()
332+
333+
def test_interrupt_never_fires(self, manager):
334+
"""interrupt_check always returns False — full batch processed."""
335+
batch = self._make_batch(3)
336+
metrics, _ = manager.process_batch(batch, interrupt_check=lambda: False)
337+
assert metrics.interrupted is False
338+
assert metrics.grad_accum_steps == 3
339+
340+
def test_compute_average_loss_handles_float_when_interrupted(self, manager):
341+
"""When interrupted before any forward, accumulated_loss is 0.0 (float)."""
342+
# _compute_average_loss must handle float, not just Tensor
343+
result = manager._compute_average_loss(
344+
accumulated_loss=0.0,
345+
accumulated_aux_loss=None,
346+
batch_num_loss_counted_tokens=64,
347+
)
348+
# Should not raise and should return a float
349+
assert isinstance(result, float)

0 commit comments

Comments
 (0)