Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions python/interpret-core/interpret/glassbox/_ebm_core/_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,18 @@ def boost(
):
break

if callback is not None:
is_done = callback(
bag_idx, step_idx, make_progress, cur_metric
)
if is_done:
if stop_flag is not None:
stop_flag[0] = True
break

if stop_flag is not None and stop_flag[0]:
break

if callback is not None:
is_done = callback(
bag_idx, step_idx, make_progress, cur_metric
)
if is_done:
if stop_flag is not None:
stop_flag[0] = True
break

state_idx = state_idx + 1
if len(term_features) <= state_idx:
if smoothing_rounds > 0:
Expand Down
208 changes: 208 additions & 0 deletions python/interpret-core/tests/glassbox/ebm/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) 2023 The InterpretML Contributors
# Distributed under the MIT software license

"""Regression tests for issue #635: callback fires repeatedly with same n_steps."""

import numpy as np

from interpret.glassbox import (
ExplainableBoostingClassifier,
ExplainableBoostingRegressor,
)
from interpret.utils import make_synthetic


class RecordingCallback:
"""Picklable callback that records all invocations.

Uses n_jobs=1 in tests so that state is shared in-process.
"""

def __init__(self):
self.records = []

def __call__(self, bag_idx, n_steps, has_progressed, best_score):
self.records.append((bag_idx, n_steps, has_progressed, best_score))
return False


class StopAfterCallback:
"""Picklable callback that stops training after N calls."""

def __init__(self, stop_after):
self.stop_after = stop_after
self.call_count = 0

def __call__(self, bag_idx, n_steps, has_progressed, best_score):
self.call_count += 1
return self.call_count >= self.stop_after


def _split_into_phases(steps):
"""Split a list of n_steps values into phases.

EBM training has multiple boosting phases (main terms, then
interactions) where step_idx resets to 1. This splits the
sequence at phase boundaries (where n_steps drops).
"""
if not steps:
return []
phases = [[steps[0]]]
for i in range(1, len(steps)):
if steps[i] < steps[i - 1]:
# n_steps dropped — new phase started
phases.append([steps[i]])
else:
phases[-1].append(steps[i])
return phases


def test_callback_no_repeated_steps_classifier():
"""Verify the callback receives strictly increasing n_steps values.

Before the fix, the callback was invoked on every internal loop
iteration — including non-progressing cycles — which caused
the same n_steps value to be reported multiple times.
"""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"

steps_by_bag = {}
for bag_idx, n_steps, _, _ in cb.records:
steps_by_bag.setdefault(bag_idx, []).append(n_steps)

for bag_idx, steps in steps_by_bag.items():
for phase in _split_into_phases(steps):
for i in range(1, len(phase)):
assert phase[i] > phase[i - 1], (
f"Bag {bag_idx}: n_steps went from {phase[i - 1]} to "
f"{phase[i]} (expected strictly increasing within phase)"
)


def test_callback_no_repeated_steps_regressor():
"""Same test as above but for ExplainableBoostingRegressor."""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=None, output_type="float", n_samples=500
)

ebm = ExplainableBoostingRegressor(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"

steps_by_bag = {}
for bag_idx, n_steps, _, _ in cb.records:
steps_by_bag.setdefault(bag_idx, []).append(n_steps)

for bag_idx, steps in steps_by_bag.items():
for phase in _split_into_phases(steps):
for i in range(1, len(phase)):
assert phase[i] > phase[i - 1], (
f"Bag {bag_idx}: n_steps went from {phase[i - 1]} to "
f"{phase[i]} (expected strictly increasing within phase)"
)


def test_callback_has_progressed_always_true():
"""Verify has_progressed is always True when the callback fires.

Since the callback now only fires inside the `if make_progress`
block, has_progressed should never be False.
"""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"
assert all(rec[2] for rec in cb.records), (
"has_progressed should always be True when callback fires"
)


def test_callback_early_termination():
"""Verify the callback can still terminate training early."""
cb = StopAfterCallback(stop_after=5)

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=5000,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert cb.call_count == cb.stop_after, (
f"Expected callback to be called exactly {cb.stop_after} times "
f"before stopping, but was called {cb.call_count} times"
)

# The model should still be valid after early stopping
predictions = ebm.predict(X)
assert len(predictions) == len(y)


def test_callback_receives_valid_metrics():
"""Verify the callback receives valid (finite) metric values."""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"

for i, (_, _, _, metric) in enumerate(cb.records):
assert np.isfinite(metric), f"Metric at step {i} is not finite: {metric}"
Loading