Skip to content

Commit 962c3df

Browse files
authored
Add _BaseClient class to waterdata_client package (#304)
* enable direct injection of url request builder callable * generic client module * add dev notes to class docstring * add attribute checking for subclasses * add minimal testing for generic client" * add high level get method requirement * rename generic client * setup request builder on init * move URL building to separate method * make get responses more specific * add tests for max pages and other value checks * add tests for bad responses * use more accurate definition for QueryType * add tests for different query types * add pagination handling tests * remove _ private class and module designation
1 parent 032ebce commit 962c3df

8 files changed

Lines changed: 453 additions & 15 deletions

File tree

python/waterdata_client/src/hydrotools/waterdata_client/async_web_client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from typing import Any, Optional, Self, Sequence
4040
from concurrent.futures import ThreadPoolExecutor
4141
from enum import StrEnum
42+
from json import JSONDecodeError
4243

4344
import aiohttp
4445
from tenacity import (
@@ -172,7 +173,19 @@ async def _execute_request(
172173
return None
173174

174175
if content_type == ResponseContentType.JSON:
175-
return await response.json()
176+
try:
177+
return await response.json()
178+
except (
179+
aiohttp.ContentTypeError,
180+
ValueError,
181+
JSONDecodeError
182+
) as exc:
183+
LOGGER.error(
184+
"Failed to decode JSON from %s: %s",
185+
url,
186+
exc
187+
)
188+
return None
176189
return await response.read()
177190

178191
async def fetch_all(
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Base OGC API Client.
2+
3+
This module provides the base class for all collection-specific USGS OGC API
4+
clients. It handles configuration state and the low-level request pipeline.
5+
"""
6+
import ssl
7+
from typing import Any, Optional, Sequence
8+
from functools import partial
9+
10+
from yarl import URL
11+
12+
from .async_web_client import get_all, ResponseContentType
13+
from .client_config import SETTINGS
14+
from .constants import USGSCollection, OGCPATH, OGCAPI
15+
from .url_builder import (
16+
build_request,
17+
build_request_batch,
18+
build_request_batch_from_queries,
19+
build_request_batch_from_feature_ids,
20+
QueryType
21+
)
22+
23+
class BaseClient:
24+
"""Base class for USGS OGC API clients. Specific child classes may overwrite
25+
private attributes: _server, _api, _endpoint, _path, _content_type,
26+
_max_pages.
27+
Otherwise, these are set to package SETTINGS defaults.
28+
29+
Attributes:
30+
concurrency_limit: Max simultaneous requests allowed for this client.
31+
max_retries: Number of times to attempt a failed request.
32+
timeout_seconds: Total request timeout in seconds.
33+
ssl_context: Custom SSL context for requests.
34+
"""
35+
_server: URL = SETTINGS.usgs_base_url
36+
_api: OGCAPI = SETTINGS.default_api
37+
_endpoint: USGSCollection = SETTINGS.default_collection
38+
_path: OGCPATH = SETTINGS.default_path
39+
_content_type: ResponseContentType = ResponseContentType.JSON
40+
_max_pages: int = SETTINGS.max_pages
41+
42+
def __init__(
43+
self,
44+
concurrency_limit: int = SETTINGS.default_concurrency,
45+
max_retries: int = SETTINGS.default_retries,
46+
timeout_seconds: int = SETTINGS.timeout_seconds,
47+
ssl_context: Optional[ssl.SSLContext] = None
48+
) -> None:
49+
self.concurrency_limit = concurrency_limit
50+
self.max_retries = max_retries
51+
self.timeout_seconds = timeout_seconds
52+
self.ssl_context = ssl_context
53+
54+
# Setup request builder
55+
self._builder = partial(
56+
build_request,
57+
server=self._server,
58+
api=self._api,
59+
endpoint=self._endpoint,
60+
path=self._path
61+
)
62+
63+
def __init_subclass__(cls):
64+
super().__init_subclass__()
65+
66+
# Enfore required attributes
67+
required = ["_endpoint", "_path", "_api", "_server", "_content_type"]
68+
for attr in required:
69+
if not hasattr(cls, attr) or getattr(cls, attr) is None:
70+
raise TypeError(
71+
f"Class {cls.__name__} failed to define required attribute: {attr}"
72+
)
73+
74+
# Check for `get` method
75+
if not callable(getattr(cls, "get", None)):
76+
raise NotImplementedError(
77+
f"Class {cls.__name__} must implement a public 'get' method "
78+
"that wraps the internal '_get_responses' pipeline."
79+
)
80+
81+
def _build_urls(
82+
self,
83+
feature_ids: Optional[Sequence[str]] = None,
84+
queries: Optional[Sequence[QueryType]] = None
85+
) -> list[URL]:
86+
"""Constructs a list of yarl.URL objects given arguments.
87+
88+
Args:
89+
feature_ids: Sequence of specific feature identifiers.
90+
queries: A sequence of query parameter dictionaries.
91+
92+
Returns:
93+
A list of yarl.URL objects.
94+
95+
"""
96+
if (feature_ids is not None) and (queries is not None):
97+
return build_request_batch(
98+
feature_ids=feature_ids,
99+
queries=queries,
100+
request_builder=self._builder
101+
)
102+
elif queries is not None:
103+
return build_request_batch_from_queries(
104+
queries=queries,
105+
request_builder=self._builder
106+
)
107+
elif feature_ids is not None:
108+
return build_request_batch_from_feature_ids(
109+
feature_ids=feature_ids,
110+
request_builder=self._builder
111+
)
112+
return [self._builder()]
113+
114+
def _get_json_responses(
115+
self,
116+
feature_ids: Optional[Sequence[str]] = None,
117+
queries: Optional[Sequence[QueryType]] = None
118+
) -> list[dict[str, Any]]:
119+
"""Internal method to build URLs and execute concurrent requests. Note
120+
that this method will silently drop bytes and None responses returned
121+
by `get_all`. These responses are logged in `async_web_client`.
122+
123+
Args:
124+
feature_ids: Sequence of specific feature identifiers.
125+
queries: A sequence of query parameter dictionaries.
126+
127+
Returns:
128+
A list of responses. Paginated responses are appended to the end
129+
of the list.
130+
131+
"""
132+
# Prepare initial fetch
133+
urls = self._build_urls(feature_ids, queries)
134+
results: list[dict[str, Any]] = []
135+
136+
# Fetch data
137+
for _ in range(self._max_pages):
138+
# Get batch of URLs
139+
batch = get_all(
140+
urls=urls,
141+
concurrency_limit=self.concurrency_limit,
142+
max_retries=self.max_retries,
143+
ssl_context=self.ssl_context,
144+
timeout_seconds=self.timeout_seconds,
145+
content_type=self._content_type
146+
)
147+
148+
# Filter batch
149+
json_batch: list[dict[str, Any]] = [b for b in batch if isinstance(b, dict)]
150+
151+
# Extend results
152+
results.extend(json_batch)
153+
154+
# Inspect links for pagination
155+
urls = [self._get_next_url(r) for r in json_batch if self._has_next_url(r)]
156+
157+
# No more data to fetch
158+
if not urls:
159+
break
160+
return results
161+
162+
def _has_next_url(self, response: dict[str, Any] | bytes | None) -> bool:
163+
"""Checks response for the presence of pagination 'next' link."""
164+
if not isinstance(response, dict):
165+
return False
166+
return any(link.get("rel") == "next" for link in response.get("links", []))
167+
168+
def _get_next_url(self, response: dict[str, Any]) -> URL:
169+
"""Attempts to return the first 'next' link encountered in response."""
170+
if not isinstance(response, dict):
171+
raise TypeError("response is not a dict")
172+
173+
for link in response.get("links", []):
174+
if link.get("rel") == "next":
175+
return URL(link["href"])
176+
raise KeyError("response does not contain 'next' link")

python/waterdata_client/src/hydrotools/waterdata_client/client_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class EnvironmentKey(StrEnum):
7272
RETRIES = f"{_KEY_START}RETRIES"
7373
TIMEOUT = f"{_KEY_START}TIMEOUT"
7474
API_KEY = f"{_KEY_START}USGS_API_KEY"
75+
MAX_PAGES = f"{_KEY_START}MAX_PAGES"
7576

7677
@classmethod
7778
def describe_keys(cls) -> str:
@@ -112,6 +113,8 @@ class _Settings:
112113
default_query.
113114
default_cache: Returns a diskcache.Cache object parameterized by cache_dir and
114115
cache_expires.
116+
max_pages: Maximum number of times to follow paginated 'next' links
117+
in JSON responses.
115118
"""
116119
usgs_base_url: URL = URL("https://api.waterdata.usgs.gov/ogcapi/v0")
117120
schema_path: str = "openapi"
@@ -125,6 +128,7 @@ class _Settings:
125128
default_retries: int = 3
126129
timeout_seconds: int = 900
127130
usgs_api_key: Optional[str] = None
131+
max_pages: int = 20
128132

129133
@classmethod
130134
def from_env(cls) -> Self:
@@ -149,7 +153,8 @@ def from_env(cls) -> Self:
149153
default_concurrency=int(os.getenv(EnvironmentKey.CONCURRENCY, cls.default_concurrency)),
150154
default_retries=int(os.getenv(EnvironmentKey.RETRIES, cls.default_retries)),
151155
timeout_seconds=int(os.getenv(EnvironmentKey.TIMEOUT, cls.timeout_seconds)),
152-
usgs_api_key=os.getenv(EnvironmentKey.API_KEY, cls.usgs_api_key)
156+
usgs_api_key=os.getenv(EnvironmentKey.API_KEY, cls.usgs_api_key),
157+
max_pages=int(os.getenv(EnvironmentKey.MAX_PAGES, cls.max_pages))
153158
)
154159

155160
@cached_property

python/waterdata_client/src/hydrotools/waterdata_client/url_builder.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,50 @@
11
"""USGS OGC API URL builders."""
2-
from typing import Optional, Sequence, Any
2+
from typing import Optional, Sequence, Any, Protocol
33

44
from yarl import URL
55
from multidict import MultiDict
66

77
from .client_config import SETTINGS
88
from .constants import OGCAPI, OGCPATH, USGSCollection
99

10-
QueryType = dict[str, Any] | MultiDict[str] | Sequence[tuple[str, Any]]
10+
QueryType = dict[str, Any] | MultiDict[Any] | Sequence[tuple[str, Any]]
1111
"""Type alias for USGS OGC API compatible queryables."""
1212

13+
class RequestBuilder(Protocol):
14+
"""Defines a Protocol for a callable object that builds USGS OGC compliant
15+
URLS.
16+
17+
Args:
18+
feature_id: Optional specific feature identifier.
19+
query: Optional query parameters.
20+
21+
Returns:
22+
A yarl.URL object.
23+
"""
24+
def __call__(
25+
self,
26+
feature_id: Optional[str] = None,
27+
query: Optional[QueryType] = None
28+
) -> URL: ...
29+
1330
def build_request(
31+
feature_id: Optional[str] = None,
32+
query: Optional[QueryType] = None,
1433
server: URL = SETTINGS.usgs_base_url,
1534
api: OGCAPI = SETTINGS.default_api,
1635
endpoint: USGSCollection = SETTINGS.default_collection,
17-
path: OGCPATH = SETTINGS.default_path,
18-
feature_id: Optional[str] = None,
19-
query: Optional[QueryType] = None
36+
path: OGCPATH = SETTINGS.default_path
2037
) -> URL:
2138
"""Constructs a single yarl.URL.
2239
2340
Args:
41+
feature_id: Optional specific feature identifier.
42+
query: Optional query parameters.
2443
server: The root URL for USGS OGC API services.
2544
URL('https://api.waterdata.usgs.gov/ogcapi/v0')
2645
api: USGS OGC API (e.g. 'collections').
2746
endpoint: USGS OGC API collection (e.g. 'continuous').
2847
path: USGS OGC path (e.g. 'items')
29-
feature_id: Optional specific feature identifier.
30-
query: Optional query parameters.
3148
3249
Returns:
3350
A yarl.URL object.
@@ -47,13 +64,15 @@ def build_request(
4764

4865
def build_request_batch(
4966
feature_ids: Sequence[str],
50-
queries: Sequence[QueryType]
67+
queries: Sequence[QueryType],
68+
request_builder: RequestBuilder = build_request
5169
) -> list[URL]:
5270
"""Constructs a list of yarl.URL objects for paired IDs and queries.
5371
5472
Args:
5573
feature_ids: Sequence of specific feature identifiers.
5674
queries: Sequence of query parameters corresponding to each ID.
75+
request_builder: A callable that builds URLs.
5776
5877
Returns:
5978
A list of yarl.URL objects.
@@ -66,30 +85,34 @@ def build_request_batch(
6685
f"Mismatched input lengths: feature_ids has {len(feature_ids)} items, "
6786
f"but queries has {len(queries)} items. Sequences must be of equal length."
6887
)
69-
return [build_request(feature_id=f, query=q) for f, q in zip(feature_ids, queries)]
88+
return [request_builder(feature_id=f, query=q) for f, q in zip(feature_ids, queries)]
7089

7190
def build_request_batch_from_feature_ids(
72-
feature_ids: Sequence[str]
91+
feature_ids: Sequence[str],
92+
request_builder: RequestBuilder = build_request
7393
) -> list[URL]:
7494
"""Constructs a list of yarl.URL objects for a sequence of feature IDs.
7595
7696
Args:
7797
feature_ids: Sequence of specific feature identifiers.
98+
request_builder: A callable that builds URLs.
7899
79100
Returns:
80101
A list of yarl.URL objects.
81102
"""
82-
return [build_request(feature_id=f) for f in feature_ids]
103+
return [request_builder(feature_id=f) for f in feature_ids]
83104

84105
def build_request_batch_from_queries(
85-
queries: Sequence[QueryType]
106+
queries: Sequence[QueryType],
107+
request_builder: RequestBuilder = build_request
86108
) -> list[URL]:
87109
"""Constructs a list of yarl.URL objects for a sequence of query dictionaries.
88110
89111
Args:
90112
queries: Sequence of query parameters to add or override default_query.
113+
request_builder: A callable that builds URLs.
91114
92115
Returns:
93116
A list of yarl.URL objects.
94117
"""
95-
return [build_request(query=q) for q in queries]
118+
return [request_builder(query=q) for q in queries]

python/waterdata_client/tests/test_async_web_client.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,17 @@ async def error_handler(request: web.Request) -> web.Response:
4545
return web.Response(status=400)
4646
return web.Response(status=500)
4747

48+
async def malformed_json_handler(request: web.Request) -> web.Response:
49+
if request.path != "/binary":
50+
return web.Response(status=400)
51+
return web.Response(body="{'invalid': json", content_type="application/json")
52+
4853
app = web.Application()
4954
app.router.add_get("/json", json_handler)
5055
app.router.add_get("/json2", json2_handler)
5156
app.router.add_get("/binary", binary_handler)
5257
app.router.add_get("/error", error_handler)
58+
app.router.add_get("/malformed", malformed_json_handler)
5359
return app
5460

5561
async def test_fetch_json(aiohttp_client):
@@ -169,3 +175,21 @@ def test_get_all_bytes(persistent_server) -> None:
169175

170176
assert results[0] == b"binary_data"
171177
assert results[1] == b"binary_data"
178+
179+
async def test_fetch_invalid_json(aiohttp_client):
180+
"""Verifies that malformed JSON returns None rather than raising an error."""
181+
mock_client = await aiohttp_client(create_web_application())
182+
url = URL(f"http://{mock_client.host}:{mock_client.port}/malformed")
183+
184+
async with AsyncWebClient() as client:
185+
result = await client.fetch(url)
186+
assert result is None
187+
188+
async def test_fetch_status_404_returns_none(aiohttp_client):
189+
"""Verifies that 4xx errors (non-retriable) return None immediately."""
190+
mock_client = await aiohttp_client(create_web_application())
191+
url = URL(f"http://{mock_client.host}:{mock_client.port}/missing_path")
192+
193+
async with AsyncWebClient() as client:
194+
result = await client.fetch(url)
195+
assert result is None

0 commit comments

Comments
 (0)