Skip to content

Commit 34261b0

Browse files
committed
Fix lint for protected AI routes
1 parent faf47a2 commit 34261b0

File tree

4 files changed

+53
-20
lines changed

4 files changed

+53
-20
lines changed

policyengine_api/routes/simulation_analysis_routes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import json
2+
13
from flask import Blueprint, request, Response, stream_with_context
24
from werkzeug.exceptions import BadRequest
3-
from policyengine_api.utils.payload_validators import validate_country
5+
6+
from policyengine_api.security import require_simulation_analysis_api_key
47
from policyengine_api.services.simulation_analysis_service import (
58
SimulationAnalysisService,
69
)
@@ -10,8 +13,6 @@
1013
from policyengine_api.utils.payload_validators.ai import (
1114
validate_sim_analysis_payload,
1215
)
13-
from policyengine_api.security import require_simulation_analysis_api_key
14-
import json
1516

1617
simulation_analysis_bp = Blueprint("simulation_analysis", __name__)
1718
simulation_analysis_service = SimulationAnalysisService()

policyengine_api/routes/tracer_analysis_routes.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1+
import json
2+
13
from flask import Blueprint, request, Response, stream_with_context
24
from werkzeug.exceptions import BadRequest
5+
6+
from policyengine_api.security import require_simulation_analysis_api_key
37
from policyengine_api.utils.payload_validators import (
48
validate_country,
59
validate_tracer_analysis_payload,
610
)
7-
from policyengine_api.security import require_simulation_analysis_api_key
811
from policyengine_api.services.tracer_analysis_service import (
912
TracerAnalysisService,
1013
)
11-
import json
12-
from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS
13-
import re
1414

1515
tracer_analysis_bp = Blueprint("tracer_analysis", __name__)
1616
tracer_analysis_service = TracerAnalysisService()
@@ -30,8 +30,6 @@ def execute_tracer_analysis(country_id):
3030
household_id = payload.get("household_id")
3131
policy_id = payload.get("policy_id")
3232
variable = payload.get("variable")
33-
api_version = COUNTRY_PACKAGE_VERSIONS[country_id]
34-
3533
if not isinstance(variable, str):
3634
raise BadRequest("variable must be a string")
3735

policyengine_api/security.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,28 @@
66
from flask import request
77
from werkzeug.exceptions import Unauthorized
88

9-
_LOCAL_CLIENT_HOSTS = {"127.0.0.1", "::1", "localhost"}
9+
_ALLOW_UNAUTHENTICATED_AI_ANALYSIS_ENV = (
10+
"POLICYENGINE_API_ALLOW_UNAUTHENTICATED_AI_ANALYSIS"
11+
)
1012

1113

1214
def require_simulation_analysis_api_key(view):
13-
"""Require a shared API key for non-local simulation analysis requests."""
15+
"""Require a shared API key for simulation analysis requests."""
1416

1517
@wraps(view)
1618
def wrapped(*args, **kwargs):
17-
client_host = request.remote_addr
18-
if client_host in _LOCAL_CLIENT_HOSTS:
19+
if os.getenv(_ALLOW_UNAUTHENTICATED_AI_ANALYSIS_ENV, "").strip().lower() in {
20+
"1",
21+
"true",
22+
"yes",
23+
}:
1924
return view(*args, **kwargs)
2025

21-
expected_key = os.getenv(
22-
"POLICYENGINE_API_AI_ANALYSIS_API_KEY", ""
23-
).strip()
24-
if expected_key and request.headers.get("X-PolicyEngine-Api-Key") == expected_key:
26+
expected_key = os.getenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "").strip()
27+
if not expected_key:
28+
raise Unauthorized("Simulation analysis API key is not configured")
29+
30+
if request.headers.get("X-PolicyEngine-Api-Key") == expected_key:
2531
return view(*args, **kwargs)
2632

2733
raise Unauthorized("API key required for simulation analysis")

tests/unit/routes/test_ai_route_auth.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ def test_ai_prompt_rejects_requests_without_api_key(client, monkeypatch):
2929
assert "API key required" in response.json["message"]
3030

3131

32+
def test_ai_prompt_rejects_loopback_requests_without_api_key(client, monkeypatch):
33+
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")
34+
35+
response = client.post(
36+
"/us/ai-prompts/simulation_analysis",
37+
json=valid_input_us,
38+
environ_base={"REMOTE_ADDR": "127.0.0.1"},
39+
)
40+
41+
assert response.status_code == 401
42+
assert "API key required" in response.json["message"]
43+
44+
3245
def test_ai_prompt_allows_requests_with_api_key(client, monkeypatch):
3346
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")
3447

@@ -65,6 +78,23 @@ def test_tracer_analysis_rejects_requests_without_api_key(client, monkeypatch):
6578
assert "API key required" in response.json["message"]
6679

6780

81+
def test_requests_fail_closed_when_api_key_is_not_configured(client, monkeypatch):
82+
monkeypatch.delenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", raising=False)
83+
84+
response = client.post(
85+
"/us/tracer-analysis",
86+
json={
87+
"household_id": 1500,
88+
"policy_id": 2,
89+
"variable": "disposable_income",
90+
},
91+
environ_base={"REMOTE_ADDR": "203.0.113.10"},
92+
)
93+
94+
assert response.status_code == 401
95+
assert "not configured" in response.json["message"]
96+
97+
6898
def test_tracer_analysis_allows_requests_with_api_key(client, monkeypatch):
6999
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")
70100

@@ -85,6 +115,4 @@ def test_tracer_analysis_allows_requests_with_api_key(client, monkeypatch):
85115

86116
assert response.status_code == 200
87117
assert response.json["result"] == "Existing analysis"
88-
mock_execute_analysis.assert_called_once_with(
89-
"us", 1500, 2, "disposable_income"
90-
)
118+
mock_execute_analysis.assert_called_once_with("us", 1500, 2, "disposable_income")

0 commit comments

Comments
 (0)