11import importlib
2- from flask import Response
2+ import inspect
33import json
44from policyengine_core .taxbenefitsystems import TaxBenefitSystem
55from typing import Union , Optional
2222 build_congressional_district_metadata ,
2323)
2424
25- # Note: The following policyengine_[xx] imports are probably redundant.
26- # These modules are imported dynamically in the __init__ function below.
27- import policyengine_uk
28- import policyengine_us
29- import policyengine_canada
30- import policyengine_ng
31- import policyengine_il
32-
3325from policyengine_api .data import local_database
3426from policyengine_api .constants import COUNTRY_PACKAGE_VERSIONS
3527
@@ -45,24 +37,40 @@ def __init__(self, country_package_name: str, country_id: str):
4537 self .build_metadata ()
4638
4739 def build_metadata (self ):
48- self .metadata = dict (
49- variables = self .build_variables (),
50- parameters = self .build_parameters (),
51- entities = self .build_entities (),
52- variableModules = self .tax_benefit_system .variable_module_metadata ,
53- economy_options = self .build_microsimulation_options (),
54- current_law_id = {
55- "uk" : 1 ,
56- "us" : 2 ,
57- "ca" : 3 ,
58- "ng" : 4 ,
59- "il" : 5 ,
60- }[self .country_id ],
61- basicInputs = self .tax_benefit_system .basic_inputs ,
62- modelled_policies = self .tax_benefit_system .modelled_policies ,
63- version = get_package_version (self .country_package_name .replace ("_" , "-" )),
40+ self .metadata = self ._json_safe (
41+ dict (
42+ variables = self .build_variables (),
43+ parameters = self .build_parameters (),
44+ entities = self .build_entities (),
45+ variableModules = self .tax_benefit_system .variable_module_metadata ,
46+ economy_options = self .build_microsimulation_options (),
47+ current_law_id = {
48+ "uk" : 1 ,
49+ "us" : 2 ,
50+ "ca" : 3 ,
51+ "ng" : 4 ,
52+ "il" : 5 ,
53+ }[self .country_id ],
54+ basicInputs = self .tax_benefit_system .basic_inputs ,
55+ modelled_policies = self .tax_benefit_system .modelled_policies ,
56+ version = get_package_version (
57+ self .country_package_name .replace ("_" , "-" )
58+ ),
59+ )
6460 )
6561
62+ def _json_safe (self , value ):
63+ if isinstance (value , Path ):
64+ return str (value )
65+ if isinstance (value , dict ):
66+ return {
67+ key : self ._json_safe (nested_value )
68+ for key , nested_value in value .items ()
69+ }
70+ if isinstance (value , list ):
71+ return [self ._json_safe (nested_value ) for nested_value in value ]
72+ return value
73+
6674 def build_microsimulation_options (self ) -> dict :
6775 # { region: [{ name: "uk", label: "the UK" }], time_period: [{ name: 2022, label: "2022", ... }] }
6876 options = dict ()
@@ -363,31 +371,7 @@ def calculate(
363371 household_id : Optional [int ] = None ,
364372 policy_id : Optional [int ] = None ,
365373 ):
366- if reform is not None and len (reform .keys ()) > 0 :
367- system = self .tax_benefit_system .clone ()
368- for parameter_name in reform :
369- for time_period , value in reform [parameter_name ].items ():
370- start_instant , end_instant = time_period .split ("." )
371- parameter = get_parameter (system .parameters , parameter_name )
372- node_type = type (parameter .values_list [- 1 ].value )
373- if node_type == int :
374- node_type = float
375- try :
376- value = float (value )
377- except :
378- pass
379- parameter .update (
380- start = instant (start_instant ),
381- stop = instant (end_instant ),
382- value = node_type (value ),
383- )
384- else :
385- system = self .tax_benefit_system
386-
387- simulation = self .country_package .Simulation (
388- tax_benefit_system = system ,
389- situation = household ,
390- )
374+ simulation , system = self ._create_simulation (household , reform )
391375
392376 household = json .loads (json .dumps (household ))
393377
@@ -429,14 +413,14 @@ def calculate(
429413 entity_index = population .get_index (entity_id )
430414 if variable .value_type == Enum :
431415 entity_result = result .decode ()[entity_index ].name
432- elif variable .value_type == float :
416+ elif variable .value_type is float :
433417 entity_result = float (str (result [entity_index ]))
434418 # Convert infinities to JSON infinities
435419 if entity_result == float ("inf" ):
436420 entity_result = "Infinity"
437421 elif entity_result == float ("-inf" ):
438422 entity_result = "-Infinity"
439- elif variable .value_type == str :
423+ elif variable .value_type is str :
440424 entity_result = str (result [entity_index ])
441425 else :
442426 entity_result = result .tolist ()[entity_index ]
@@ -473,6 +457,72 @@ def calculate(
473457
474458 return household
475459
460+ def _create_simulation (
461+ self ,
462+ household : dict ,
463+ reform : Union [dict , None ],
464+ ):
465+ normalized_reform = None
466+ if reform :
467+ system = self .tax_benefit_system .clone ()
468+ normalized_reform = self ._normalize_reform_values (reform , system )
469+ else :
470+ system = self .tax_benefit_system
471+
472+ if self ._simulation_accepts_tax_benefit_system ():
473+ if normalized_reform :
474+ self ._apply_reform_to_system (system , normalized_reform )
475+ simulation = self .country_package .Simulation (
476+ tax_benefit_system = system ,
477+ situation = household ,
478+ )
479+ return simulation , system
480+
481+ simulation_kwargs = {"situation" : household }
482+ if normalized_reform :
483+ simulation_kwargs ["reform" ] = normalized_reform
484+ simulation = self .country_package .Simulation (** simulation_kwargs )
485+ return simulation , simulation .tax_benefit_system
486+
487+ def _simulation_accepts_tax_benefit_system (self ) -> bool :
488+ simulation_signature = inspect .signature (self .country_package .Simulation )
489+ return "tax_benefit_system" in simulation_signature .parameters
490+
491+ def _normalize_reform_values (
492+ self ,
493+ reform : dict ,
494+ system : TaxBenefitSystem ,
495+ ) -> dict :
496+ normalized_reform = {}
497+ for parameter_name , parameter_updates in reform .items ():
498+ parameter = get_parameter (system .parameters , parameter_name )
499+ normalized_reform [parameter_name ] = {}
500+ for time_period , value in parameter_updates .items ():
501+ node_type = type (parameter .values_list [- 1 ].value )
502+ if node_type is int :
503+ node_type = float
504+ try :
505+ value = float (value )
506+ except Exception :
507+ pass
508+ normalized_reform [parameter_name ][time_period ] = node_type (value )
509+ return normalized_reform
510+
511+ def _apply_reform_to_system (
512+ self ,
513+ system : TaxBenefitSystem ,
514+ reform : dict ,
515+ ) -> None :
516+ for parameter_name , parameter_updates in reform .items ():
517+ parameter = get_parameter (system .parameters , parameter_name )
518+ for time_period , value in parameter_updates .items ():
519+ start_instant , end_instant = time_period .split ("." )
520+ parameter .update (
521+ start = instant (start_instant ),
522+ stop = instant (end_instant ),
523+ value = value ,
524+ )
525+
476526
477527def create_policy_reform (policy_data : dict ) -> dict :
478528 """
@@ -498,7 +548,7 @@ def modify_parameters(parameters: ParameterNode) -> ParameterNode:
498548 for period , value in values .items ():
499549 start , end = period .split ("." )
500550 node_type = type (node .values_list [- 1 ].value )
501- if node_type == int :
551+ if node_type is int :
502552 node_type = float # '0' is of type int by default, but usually we want to cast to float.
503553 if node .values_list [- 1 ].value is None :
504554 node_type = float
0 commit comments