Skip to content

Commit c965fcb

Browse files
authored
Merge pull request #3394 from PolicyEngine/codex/economy-policyengine-bundle
Attach PolicyEngine bundle metadata to economy results
2 parents 5e49396 + 39d6550 commit c965fcb

14 files changed

Lines changed: 1048 additions & 207 deletions

File tree

changelog.d/fixed/3394.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Record resolved PolicyEngine bundle metadata from the runtime that actually executed society-wide simulations, and key reproduce/cache behavior off the resolved dataset bundle rather than caller-side defaults.

policyengine_api/country.py

Lines changed: 103 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import importlib
2-
from flask import Response
2+
import inspect
33
import json
44
from policyengine_core.taxbenefitsystems import TaxBenefitSystem
55
from typing import Union, Optional
@@ -22,14 +22,6 @@
2222
build_congressional_district_metadata,
2323
)
2424

25-
# Note: The following policyengine_[xx] imports are probably redundant.
26-
# These modules are imported dynamically in the __init__ function below.
27-
import policyengine_uk
28-
import policyengine_us
29-
import policyengine_canada
30-
import policyengine_ng
31-
import policyengine_il
32-
3325
from policyengine_api.data import local_database
3426
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
3527

@@ -45,24 +37,40 @@ def __init__(self, country_package_name: str, country_id: str):
4537
self.build_metadata()
4638

4739
def build_metadata(self):
48-
self.metadata = dict(
49-
variables=self.build_variables(),
50-
parameters=self.build_parameters(),
51-
entities=self.build_entities(),
52-
variableModules=self.tax_benefit_system.variable_module_metadata,
53-
economy_options=self.build_microsimulation_options(),
54-
current_law_id={
55-
"uk": 1,
56-
"us": 2,
57-
"ca": 3,
58-
"ng": 4,
59-
"il": 5,
60-
}[self.country_id],
61-
basicInputs=self.tax_benefit_system.basic_inputs,
62-
modelled_policies=self.tax_benefit_system.modelled_policies,
63-
version=get_package_version(self.country_package_name.replace("_", "-")),
40+
self.metadata = self._json_safe(
41+
dict(
42+
variables=self.build_variables(),
43+
parameters=self.build_parameters(),
44+
entities=self.build_entities(),
45+
variableModules=self.tax_benefit_system.variable_module_metadata,
46+
economy_options=self.build_microsimulation_options(),
47+
current_law_id={
48+
"uk": 1,
49+
"us": 2,
50+
"ca": 3,
51+
"ng": 4,
52+
"il": 5,
53+
}[self.country_id],
54+
basicInputs=self.tax_benefit_system.basic_inputs,
55+
modelled_policies=self.tax_benefit_system.modelled_policies,
56+
version=get_package_version(
57+
self.country_package_name.replace("_", "-")
58+
),
59+
)
6460
)
6561

62+
def _json_safe(self, value):
63+
if isinstance(value, Path):
64+
return str(value)
65+
if isinstance(value, dict):
66+
return {
67+
key: self._json_safe(nested_value)
68+
for key, nested_value in value.items()
69+
}
70+
if isinstance(value, list):
71+
return [self._json_safe(nested_value) for nested_value in value]
72+
return value
73+
6674
def build_microsimulation_options(self) -> dict:
6775
# { region: [{ name: "uk", label: "the UK" }], time_period: [{ name: 2022, label: "2022", ... }] }
6876
options = dict()
@@ -363,31 +371,7 @@ def calculate(
363371
household_id: Optional[int] = None,
364372
policy_id: Optional[int] = None,
365373
):
366-
if reform is not None and len(reform.keys()) > 0:
367-
system = self.tax_benefit_system.clone()
368-
for parameter_name in reform:
369-
for time_period, value in reform[parameter_name].items():
370-
start_instant, end_instant = time_period.split(".")
371-
parameter = get_parameter(system.parameters, parameter_name)
372-
node_type = type(parameter.values_list[-1].value)
373-
if node_type == int:
374-
node_type = float
375-
try:
376-
value = float(value)
377-
except:
378-
pass
379-
parameter.update(
380-
start=instant(start_instant),
381-
stop=instant(end_instant),
382-
value=node_type(value),
383-
)
384-
else:
385-
system = self.tax_benefit_system
386-
387-
simulation = self.country_package.Simulation(
388-
tax_benefit_system=system,
389-
situation=household,
390-
)
374+
simulation, system = self._create_simulation(household, reform)
391375

392376
household = json.loads(json.dumps(household))
393377

@@ -429,14 +413,14 @@ def calculate(
429413
entity_index = population.get_index(entity_id)
430414
if variable.value_type == Enum:
431415
entity_result = result.decode()[entity_index].name
432-
elif variable.value_type == float:
416+
elif variable.value_type is float:
433417
entity_result = float(str(result[entity_index]))
434418
# Convert infinities to JSON infinities
435419
if entity_result == float("inf"):
436420
entity_result = "Infinity"
437421
elif entity_result == float("-inf"):
438422
entity_result = "-Infinity"
439-
elif variable.value_type == str:
423+
elif variable.value_type is str:
440424
entity_result = str(result[entity_index])
441425
else:
442426
entity_result = result.tolist()[entity_index]
@@ -473,6 +457,72 @@ def calculate(
473457

474458
return household
475459

460+
def _create_simulation(
461+
self,
462+
household: dict,
463+
reform: Union[dict, None],
464+
):
465+
normalized_reform = None
466+
if reform:
467+
system = self.tax_benefit_system.clone()
468+
normalized_reform = self._normalize_reform_values(reform, system)
469+
else:
470+
system = self.tax_benefit_system
471+
472+
if self._simulation_accepts_tax_benefit_system():
473+
if normalized_reform:
474+
self._apply_reform_to_system(system, normalized_reform)
475+
simulation = self.country_package.Simulation(
476+
tax_benefit_system=system,
477+
situation=household,
478+
)
479+
return simulation, system
480+
481+
simulation_kwargs = {"situation": household}
482+
if normalized_reform:
483+
simulation_kwargs["reform"] = normalized_reform
484+
simulation = self.country_package.Simulation(**simulation_kwargs)
485+
return simulation, simulation.tax_benefit_system
486+
487+
def _simulation_accepts_tax_benefit_system(self) -> bool:
488+
simulation_signature = inspect.signature(self.country_package.Simulation)
489+
return "tax_benefit_system" in simulation_signature.parameters
490+
491+
def _normalize_reform_values(
492+
self,
493+
reform: dict,
494+
system: TaxBenefitSystem,
495+
) -> dict:
496+
normalized_reform = {}
497+
for parameter_name, parameter_updates in reform.items():
498+
parameter = get_parameter(system.parameters, parameter_name)
499+
normalized_reform[parameter_name] = {}
500+
for time_period, value in parameter_updates.items():
501+
node_type = type(parameter.values_list[-1].value)
502+
if node_type is int:
503+
node_type = float
504+
try:
505+
value = float(value)
506+
except Exception:
507+
pass
508+
normalized_reform[parameter_name][time_period] = node_type(value)
509+
return normalized_reform
510+
511+
def _apply_reform_to_system(
512+
self,
513+
system: TaxBenefitSystem,
514+
reform: dict,
515+
) -> None:
516+
for parameter_name, parameter_updates in reform.items():
517+
parameter = get_parameter(system.parameters, parameter_name)
518+
for time_period, value in parameter_updates.items():
519+
start_instant, end_instant = time_period.split(".")
520+
parameter.update(
521+
start=instant(start_instant),
522+
stop=instant(end_instant),
523+
value=value,
524+
)
525+
476526

477527
def create_policy_reform(policy_data: dict) -> dict:
478528
"""
@@ -498,7 +548,7 @@ def modify_parameters(parameters: ParameterNode) -> ParameterNode:
498548
for period, value in values.items():
499549
start, end = period.split(".")
500550
node_type = type(node.values_list[-1].value)
501-
if node_type == int:
551+
if node_type is int:
502552
node_type = float # '0' is of type int by default, but usually we want to cast to float.
503553
if node.values_list[-1].value is None:
504554
node_type = float
Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2023_24.h5"
2-
FRS = "hf://policyengine/policyengine-uk-data/frs_2023_24.h5"
1+
ENHANCED_FRS = (
2+
"hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3"
3+
)
4+
FRS = "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3"
35

4-
ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"
5-
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
6-
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
6+
ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0"
7+
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0"
8+
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0"
79

810
datasets = {
911
"uk": {
@@ -16,28 +18,3 @@
1618
"pooled_cps": POOLED_CPS,
1719
},
1820
}
19-
20-
21-
def get_dataset_version(country_id: str) -> str | None:
22-
"""
23-
Get the dataset version for the specified country. If PolicyEngine does not
24-
publish data for the country, raise a ValueError.
25-
26-
By returning None for all valid countries, we allow policyengine.py to use
27-
whatever default dataset version it has available, without imposing version
28-
validation constraints from the API layer.
29-
"""
30-
match country_id:
31-
case "uk":
32-
return None
33-
case "us":
34-
return None
35-
case _:
36-
raise ValueError(f"Unknown country ID: {country_id}")
37-
38-
39-
for dataset in datasets["uk"]:
40-
datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}"
41-
42-
for dataset in datasets["us"]:
43-
datasets["us"][dataset] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}"

policyengine_api/libs/simulation_api_modal.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class ModalSimulationExecution:
2424
status: str
2525
result: Optional[dict] = None
2626
error: Optional[str] = None
27+
policyengine_bundle: Optional[dict] = None
28+
resolved_app_name: Optional[str] = None
2729

2830
@property
2931
def name(self) -> str:
@@ -94,6 +96,8 @@ def run(self, payload: dict) -> ModalSimulationExecution:
9496
return ModalSimulationExecution(
9597
job_id=data["job_id"],
9698
status=data["status"],
99+
policyengine_bundle=data.get("policyengine_bundle"),
100+
resolved_app_name=data.get("resolved_app_name"),
97101
)
98102

99103
except httpx.HTTPStatusError as e:
@@ -115,6 +119,22 @@ def run(self, payload: dict) -> ModalSimulationExecution:
115119
)
116120
raise
117121

122+
def resolve_app_name(
123+
self, country: str, version: Optional[str] = None
124+
) -> tuple[str, str]:
125+
"""Resolve the current gateway app name for a country/model version."""
126+
response = self.client.get(f"{self.base_url}/versions/{country}")
127+
response.raise_for_status()
128+
version_map = response.json()
129+
130+
resolved_version = version or version_map["latest"]
131+
try:
132+
return version_map[resolved_version], resolved_version
133+
except KeyError as exc:
134+
raise ValueError(
135+
f"Unknown version {resolved_version} for country {country}"
136+
) from exc
137+
118138
def get_execution_id(self, execution: ModalSimulationExecution) -> str:
119139
"""
120140
Get the job ID from an execution.
@@ -156,6 +176,8 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution:
156176
status=data["status"],
157177
result=data.get("result"),
158178
error=data.get("error"),
179+
policyengine_bundle=data.get("policyengine_bundle"),
180+
resolved_app_name=data.get("resolved_app_name"),
159181
)
160182

161183
except httpx.HTTPStatusError as e:

0 commit comments

Comments
 (0)