Skip to content

Commit ce11854

Browse files
authored
Harden CPS basic ORG loading and caching (#727)
* Harden CPS basic ORG loading * Validate cached ORG donor data * Format ORG cache validation helper * Fix CPS basic ORG column alignment
1 parent d6ac877 commit ce11854

3 files changed

Lines changed: 265 additions & 48 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Harden CPS basic ORG donor loading against transient fetch failures and concurrent cache builds.

policyengine_us_data/datasets/org/org.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
imputation onto CPS records.
77
"""
88

9+
from contextlib import contextmanager
910
from functools import lru_cache
11+
from io import BytesIO
12+
from pathlib import Path
13+
import fcntl
1014

1115
from microimpute.models.qrf import QRF
1216
import numpy as np
1317
import pandas as pd
18+
import requests
1419

1520
from policyengine_us_data.storage import STORAGE_FOLDER
1621

@@ -181,11 +186,13 @@ def _cps_basic_org_month_url(year: int, month: str) -> str:
181186
)
182187

183188

184-
def _select_cps_basic_org_columns(month_df: pd.DataFrame) -> pd.DataFrame:
185-
"""Normalize CPS basic-month columns onto the ORG schema."""
189+
def _resolve_cps_basic_org_column_names(
190+
columns: pd.Index | list[str],
191+
) -> list[str]:
192+
"""Resolve CPS basic-month columns onto the expected ORG schema order."""
186193
column_lookup = {
187-
str(column).lower(): column
188-
for column in month_df.columns
194+
str(column).lower(): str(column)
195+
for column in columns
189196
if isinstance(column, str)
190197
}
191198
missing = [
@@ -196,9 +203,12 @@ def _select_cps_basic_org_columns(month_df: pd.DataFrame) -> pd.DataFrame:
196203
if missing:
197204
raise ValueError(f"CPS basic ORG month is missing required columns: {missing}")
198205

199-
selected = month_df[
200-
[column_lookup[column.lower()] for column in CPS_BASIC_MONTHLY_ORG_COLUMNS]
201-
].copy()
206+
return [column_lookup[column.lower()] for column in CPS_BASIC_MONTHLY_ORG_COLUMNS]
207+
208+
209+
def _select_cps_basic_org_columns(month_df: pd.DataFrame) -> pd.DataFrame:
210+
"""Normalize CPS basic-month columns onto the ORG schema."""
211+
selected = month_df[_resolve_cps_basic_org_column_names(month_df.columns)].copy()
202212
selected.columns = CPS_BASIC_MONTHLY_ORG_COLUMNS
203213
return selected
204214

@@ -211,16 +221,18 @@ def _load_cps_basic_org_month(
211221
) -> pd.DataFrame:
212222
"""Load one CPS basic-month file with light retry around transient fetch/parser issues."""
213223
url = _cps_basic_org_month_url(year, month)
214-
required_columns = {column.lower() for column in CPS_BASIC_MONTHLY_ORG_COLUMNS}
215224
last_error: Exception | None = None
216225

217226
for _ in range(max_attempts):
218227
try:
228+
response = requests.get(url, timeout=60)
229+
response.raise_for_status()
230+
content = response.content
231+
header = pd.read_csv(BytesIO(content), nrows=0)
232+
selected_columns = _resolve_cps_basic_org_column_names(header.columns)
219233
month_df = pd.read_csv(
220-
url,
221-
usecols=lambda column: (
222-
isinstance(column, str) and column.lower() in required_columns
223-
),
234+
BytesIO(content),
235+
usecols=selected_columns,
224236
low_memory=False,
225237
)
226238
return _select_cps_basic_org_columns(month_df)
@@ -233,6 +245,36 @@ def _load_cps_basic_org_month(
233245
) from last_error
234246

235247

248+
@contextmanager
249+
def _org_cache_build_lock(lock_path: Path):
250+
lock_path.parent.mkdir(parents=True, exist_ok=True)
251+
with open(lock_path, "w") as lock_file:
252+
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
253+
try:
254+
yield
255+
finally:
256+
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
257+
258+
259+
def _load_valid_cached_org_training_data(cache_path: Path) -> pd.DataFrame | None:
260+
"""Return a cached ORG training frame when it is present and structurally valid."""
261+
required_columns = set(
262+
ORG_PREDICTORS + ORG_QRF_IMPUTED_VARIABLES + ["sample_weight"]
263+
)
264+
try:
265+
cached = pd.read_csv(cache_path)
266+
except (FileNotFoundError, OSError, pd.errors.EmptyDataError):
267+
return None
268+
269+
if cached.empty:
270+
return None
271+
272+
if not required_columns.issubset(cached.columns):
273+
return None
274+
275+
return cached
276+
277+
236278
def _transform_cps_basic_org_month(month_df: pd.DataFrame) -> pd.DataFrame:
237279
"""Convert one monthly CPS basic file into ORG donor rows.
238280
@@ -451,17 +493,31 @@ def _predict_union_coverage_from_bls_tables(
451493
def load_org_training_data() -> pd.DataFrame:
452494
"""Load ORG donor rows built from official CPS basic monthly files."""
453495
cache_path = STORAGE_FOLDER / ORG_FILENAME
454-
if cache_path.exists():
455-
return pd.read_csv(cache_path)
456-
457-
months = []
458-
for month in ORG_MONTHS:
459-
month_df = _load_cps_basic_org_month(ORG_YEAR, month)
460-
months.append(_transform_cps_basic_org_month(month_df))
461-
462-
org = pd.concat(months, ignore_index=True)
463-
org.to_csv(cache_path, index=False, compression="gzip")
464-
return org
496+
lock_path = cache_path.parent / f"{cache_path.name}.lock"
497+
cached = _load_valid_cached_org_training_data(cache_path)
498+
if cached is not None:
499+
return cached
500+
501+
with _org_cache_build_lock(lock_path):
502+
cached = _load_valid_cached_org_training_data(cache_path)
503+
if cached is not None:
504+
return cached
505+
if cache_path.exists():
506+
cache_path.unlink()
507+
508+
months = []
509+
for month in ORG_MONTHS:
510+
month_df = _load_cps_basic_org_month(ORG_YEAR, month)
511+
months.append(_transform_cps_basic_org_month(month_df))
512+
513+
org = pd.concat(months, ignore_index=True)
514+
temp_path = cache_path.parent / f"{cache_path.name}.tmp.gz"
515+
org.to_csv(temp_path, index=False, compression="gzip")
516+
temp_path.replace(cache_path)
517+
cached = _load_valid_cached_org_training_data(cache_path)
518+
if cached is None:
519+
raise ValueError("Failed to build a valid cached ORG donor file")
520+
return cached
465521

466522

467523
@lru_cache(maxsize=1)

tests/unit/datasets/test_org.py

Lines changed: 185 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22
import pandas as pd
3+
from concurrent.futures import ThreadPoolExecutor
4+
import time
35

46
from policyengine_us_data.datasets.cps import cps as cps_module
57
from policyengine_us_data.datasets.org import (
@@ -10,6 +12,7 @@
1012
CPS_BASIC_MONTHLY_ORG_COLUMNS,
1113
_build_union_priority_weights,
1214
_load_cps_basic_org_month,
15+
load_org_training_data,
1316
_predict_union_coverage_from_bls_tables,
1417
_select_cps_basic_org_columns,
1518
_transform_cps_basic_org_month,
@@ -154,42 +157,199 @@ def test_load_cps_basic_org_month_retries_after_transient_parser_failure(
154157
monkeypatch,
155158
):
156159
calls = []
157-
month_df = pd.DataFrame(
158-
{
159-
"hrmis": [4],
160-
"GESTFIPS": [6],
161-
"PRTAGE": [30],
162-
"PESEX": [2],
163-
"PTDTRACE": [1],
164-
"PEHSPNON": [2],
165-
"PWORWGT": [100.0],
166-
"PTERNWA": [100000.0],
167-
"PTERNHLY": [2500.0],
168-
"PEERNHRY": [1],
169-
"PEHRUSLT": [40.0],
170-
"PRERELG": [1],
171-
"PEMLR": [1],
172-
"PEIO1COW": [1],
173-
}
160+
csv_text = (
161+
"hrmis,GESTFIPS,PRTAGE,PESEX,PTDTRACE,PEHSPNON,PWORWGT,"
162+
"PTERNWA,PTERNHLY,PEERNHRY,PEHRUSLT,PRERELG,PEMLR,PEIO1COW\n"
163+
"4,6,30,2,1,2,100.0,100000.0,2500.0,1,40.0,1,1,1\n"
174164
)
175165

176-
def fake_read_csv(*args, **kwargs):
166+
class FakeResponse:
167+
def __init__(self, text: str, status_code: int = 200):
168+
self.content = text.encode("utf-8")
169+
self.status_code = status_code
170+
171+
def raise_for_status(self):
172+
if self.status_code >= 400:
173+
raise ValueError("bad status")
174+
175+
responses = [
176+
FakeResponse("<html>temporary error</html>"),
177+
FakeResponse(csv_text),
178+
]
179+
180+
def fake_get(*args, **kwargs):
177181
calls.append(kwargs)
178-
if len(calls) == 1:
179-
raise ValueError("Usecols do not match columns")
180-
return month_df
182+
return responses.pop(0)
181183

182-
monkeypatch.setattr(
183-
"policyengine_us_data.datasets.org.org.pd.read_csv", fake_read_csv
184-
)
184+
monkeypatch.setattr("policyengine_us_data.datasets.org.org.requests.get", fake_get)
185185

186186
loaded = _load_cps_basic_org_month(2024, "may", max_attempts=2)
187187

188188
assert len(calls) == 2
189-
assert callable(calls[0]["usecols"])
190189
assert loaded.columns.tolist() == CPS_BASIC_MONTHLY_ORG_COLUMNS
191190

192191

192+
def test_load_cps_basic_org_month_reorders_file_order_columns(monkeypatch):
193+
csv_text = (
194+
"PTERNWA,PEHRUSLT,hrmis,PEMLR,PEERNHRY,PEHSPNON,PRTAGE,"
195+
"PTDTRACE,pworwgt,peio1cow,GESTFIPS,PESEX,PTERNHLY,PRERELG\n"
196+
"100000.0,40.0,4,1,1,2,30,1,100.0,1,6,2,2500.0,1\n"
197+
)
198+
199+
class FakeResponse:
200+
def __init__(self, text: str):
201+
self.content = text.encode("utf-8")
202+
203+
def raise_for_status(self):
204+
return None
205+
206+
monkeypatch.setattr(
207+
"policyengine_us_data.datasets.org.org.requests.get",
208+
lambda *args, **kwargs: FakeResponse(csv_text),
209+
)
210+
211+
loaded = _load_cps_basic_org_month(2024, "may", max_attempts=1)
212+
213+
assert loaded.columns.tolist() == CPS_BASIC_MONTHLY_ORG_COLUMNS
214+
assert loaded.iloc[0].to_dict() == {
215+
"HRMIS": 4,
216+
"gestfips": 6,
217+
"prtage": 30,
218+
"pesex": 2,
219+
"ptdtrace": 1,
220+
"pehspnon": 2,
221+
"pworwgt": 100.0,
222+
"pternwa": 100000.0,
223+
"pternhly": 2500.0,
224+
"peernhry": 1,
225+
"pehruslt": 40.0,
226+
"prerelg": 1,
227+
"pemlr": 1,
228+
"peio1cow": 1,
229+
}
230+
231+
232+
def test_load_org_training_data_serializes_first_cache_build(monkeypatch, tmp_path):
233+
raw_month = pd.DataFrame(
234+
{
235+
"HRMIS": [4],
236+
"gestfips": [6],
237+
"prtage": [30],
238+
"pesex": [2],
239+
"ptdtrace": [1],
240+
"pehspnon": [2],
241+
"pworwgt": [100.0],
242+
"pternwa": [100000.0],
243+
"pternhly": [2500.0],
244+
"peernhry": [1],
245+
"pehruslt": [40.0],
246+
"prerelg": [1],
247+
"pemlr": [1],
248+
"peio1cow": [1],
249+
}
250+
)
251+
call_count = {"value": 0}
252+
253+
monkeypatch.setattr(
254+
"policyengine_us_data.datasets.org.org.STORAGE_FOLDER", tmp_path
255+
)
256+
monkeypatch.setattr(
257+
"policyengine_us_data.datasets.org.org.ORG_MONTHS",
258+
("may",),
259+
)
260+
261+
def fake_load_month(year, month):
262+
call_count["value"] += 1
263+
time.sleep(0.2)
264+
return raw_month.copy()
265+
266+
monkeypatch.setattr(
267+
"policyengine_us_data.datasets.org.org._load_cps_basic_org_month",
268+
fake_load_month,
269+
)
270+
271+
load_org_training_data.cache_clear()
272+
try:
273+
with ThreadPoolExecutor(max_workers=2) as executor:
274+
left = executor.submit(load_org_training_data)
275+
right = executor.submit(load_org_training_data)
276+
left_result = left.result()
277+
right_result = right.result()
278+
finally:
279+
load_org_training_data.cache_clear()
280+
281+
assert call_count["value"] == 1
282+
pd.testing.assert_frame_equal(left_result, right_result)
283+
284+
285+
def test_load_org_training_data_rebuilds_invalid_cached_file(monkeypatch, tmp_path):
286+
raw_month = pd.DataFrame(
287+
{
288+
"HRMIS": [4],
289+
"gestfips": [6],
290+
"prtage": [30],
291+
"pesex": [2],
292+
"ptdtrace": [1],
293+
"pehspnon": [2],
294+
"pworwgt": [100.0],
295+
"pternwa": [100000.0],
296+
"pternhly": [2500.0],
297+
"peernhry": [1],
298+
"pehruslt": [40.0],
299+
"prerelg": [1],
300+
"pemlr": [1],
301+
"peio1cow": [1],
302+
}
303+
)
304+
cache_path = tmp_path / "census_cps_org_2024_wages.csv.gz"
305+
pd.DataFrame(columns=["employment_income", "weekly_hours_worked"]).to_csv(
306+
cache_path,
307+
index=False,
308+
compression="gzip",
309+
)
310+
call_count = {"value": 0}
311+
312+
monkeypatch.setattr(
313+
"policyengine_us_data.datasets.org.org.STORAGE_FOLDER", tmp_path
314+
)
315+
monkeypatch.setattr(
316+
"policyengine_us_data.datasets.org.org.ORG_MONTHS",
317+
("may",),
318+
)
319+
320+
def fake_load_month(year, month):
321+
call_count["value"] += 1
322+
return raw_month.copy()
323+
324+
monkeypatch.setattr(
325+
"policyengine_us_data.datasets.org.org._load_cps_basic_org_month",
326+
fake_load_month,
327+
)
328+
329+
load_org_training_data.cache_clear()
330+
try:
331+
rebuilt = load_org_training_data()
332+
finally:
333+
load_org_training_data.cache_clear()
334+
335+
assert call_count["value"] == 1
336+
assert not rebuilt.empty
337+
assert set(
338+
[
339+
"employment_income",
340+
"weekly_hours_worked",
341+
"age",
342+
"is_female",
343+
"is_hispanic",
344+
"race_wbho",
345+
"state_fips",
346+
"hourly_wage",
347+
"is_paid_hourly",
348+
"sample_weight",
349+
]
350+
).issubset(rebuilt.columns)
351+
352+
193353
def test_build_union_priority_weights_reflect_bls_demographics():
194354
receiver = pd.DataFrame(
195355
{

0 commit comments

Comments
 (0)