Skip to content

Commit 175b96d

Browse files
MaxGhenisclaude
andauthored
Seed numpy before EnhancedCPS initial weight jitter (#774)
`EnhancedCPS.generate()` and `ReweightedCPS_2024.generate()` perturbed the original household weights with np.random.normal(1, 0.1, len(original_weights)) before calling `reweight()`. `reweight()` calls `set_seeds(seed)` internally, but only after the perturbation, so the L0 optimizer started from a different point each run and the final calibrated weights were non-reproducible despite the `seed=1456` argument. Call `set_seeds(1456)` immediately before the jitter at both call sites. Add unit tests that (a) confirm set_seeds makes a subsequent np.random.normal draw deterministic and (b) guard the source-code invariant that every np.random.normal in enhanced_cps.py is preceded by a set_seeds(...) call within 5 lines. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent fe8e239 commit 175b96d

3 files changed

Lines changed: 70 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Seed numpy before the EnhancedCPS/ReweightedCPS initial weight jitter so calibrated weights are reproducible across runs.

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ def generate(self):
222222
base_year = int(sim.default_calculation_period)
223223
data["household_weight"] = {}
224224
original_weights = sim.calculate("household_weight")
225+
# Seed before the initial weight jitter so the L0 optimizer's
226+
# starting point is reproducible across runs. `reweight()` re-seeds
227+
# inside, but that happens AFTER this perturbation, so without
228+
# this call the jitter (and hence the final calibrated weights)
229+
# differ run-to-run.
230+
set_seeds(1456)
225231
original_weights = original_weights.values + np.random.normal(
226232
1, 0.1, len(original_weights)
227233
)
@@ -358,6 +364,9 @@ def generate(self):
358364
sim = Microsimulation(dataset=self.input_dataset)
359365
data = sim.dataset.load_dataset()
360366
original_weights = sim.calculate("household_weight")
367+
# Seed before the jitter so the starting weights (and the final
368+
# reweighted result) are reproducible across runs.
369+
set_seeds(1456)
361370
original_weights = original_weights.values + np.random.normal(
362371
1, 0.1, len(original_weights)
363372
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Regression test ensuring the initial weight jitter in EnhancedCPS is seeded.
2+
3+
Previously ``np.random.normal(1, 0.1, ...)`` ran with whatever numpy global
4+
state the process happened to be in. ``reweight()`` re-seeds, but only
5+
afterwards, so the final L0 weights differed run to run even with
6+
``seed=1456`` inside ``reweight``.
7+
8+
Fix: call ``set_seeds(1456)`` right before the jitter in
9+
``EnhancedCPS.generate`` and ``ReweightedCPS_2024.generate``.
10+
"""
11+
12+
import numpy as np
13+
14+
from policyengine_us_data.utils.seed import set_seeds
15+
16+
17+
def _mock_jitter(n: int = 10) -> np.ndarray:
18+
"""Mirror the enhanced_cps perturbation shape."""
19+
return np.random.normal(1, 0.1, n)
20+
21+
22+
def test_set_seeds_makes_numpy_normal_reproducible():
23+
set_seeds(1456)
24+
a = _mock_jitter()
25+
set_seeds(1456)
26+
b = _mock_jitter()
27+
assert np.array_equal(a, b)
28+
29+
30+
def test_unseeded_numpy_normal_is_non_reproducible():
31+
"""Sanity check: without set_seeds in between, two consecutive draws differ."""
32+
np.random.seed(None) # reset to fresh entropy
33+
a = _mock_jitter()
34+
# Don't reseed — same process draws again, distinct state.
35+
b = _mock_jitter()
36+
assert not np.array_equal(a, b)
37+
38+
39+
def test_enhanced_cps_sources_call_set_seeds_before_jitter():
40+
"""The fix places ``set_seeds(1456)`` immediately before
41+
``np.random.normal`` in both generate() methods. Verify the file
42+
preserves that invariant so regressions are caught by lint.
43+
"""
44+
import policyengine_us_data.datasets.cps.enhanced_cps as ec
45+
46+
source = open(ec.__file__).read()
47+
# Split into the two generate() bodies. Both must contain the
48+
# set_seeds call before the np.random.normal call.
49+
# The simplest invariant: every occurrence of np.random.normal
50+
# must be preceded (within the previous 5 non-blank lines) by a
51+
# set_seeds(...) call.
52+
lines = source.splitlines()
53+
normal_indices = [i for i, line in enumerate(lines) if "np.random.normal" in line]
54+
assert normal_indices, "Expected at least one np.random.normal site"
55+
for idx in normal_indices:
56+
window = "\n".join(lines[max(0, idx - 5) : idx])
57+
assert "set_seeds(" in window, (
58+
f"np.random.normal on line {idx + 1} is not preceded by set_seeds("
59+
f") within the previous 5 lines; window was:\n{window}"
60+
)

0 commit comments

Comments
 (0)