Skip to content

Commit cb74d93

Browse files
bot suggestions: improve tests, empty list check
1 parent 55a0d60 commit cb74d93

3 files changed

Lines changed: 36 additions & 7 deletions

File tree

src/routers/openml/tasks.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,17 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
263263
clauses.append("AND d.`name` = :data_name")
264264
parameters["data_name"] = data_name
265265

266-
if task_id:
266+
if task_id is not None:
267+
if not task_id:
268+
msg = "No tasks match the search criteria."
269+
raise NoResultsError(msg)
267270
clauses.append("AND t.`task_id` IN :task_ids")
268271
parameters["task_ids"] = task_id
269272

270-
if data_id:
273+
if data_id is not None:
274+
if not data_id:
275+
msg = "No tasks match the search criteria."
276+
raise NoResultsError(msg)
271277
clauses.append("AND d.`did` IN :data_ids")
272278
parameters["data_ids"] = data_id
273279

@@ -318,9 +324,9 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
318324
""", # noqa: S608
319325
)
320326

321-
if task_id:
327+
if task_id is not None:
322328
main_query = main_query.bindparams(bindparam("task_ids", expanding=True))
323-
if data_id:
329+
if data_id is not None:
324330
main_query = main_query.bindparams(bindparam("data_ids", expanding=True))
325331

326332
result = await expdb.execute(main_query, parameters=parameters)

tests/routers/openml/migration/tasks_migration_test.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,26 @@ def _normalize_py_task(task: dict[str, Any]) -> dict[str, Any]:
112112
({"type": 1}, {"task_type_id": 1}), # by task type
113113
({"tag": "OpenML100"}, {"tag": "OpenML100"}), # by tag
114114
({"type": 1, "tag": "OpenML100"}, {"task_type_id": 1, "tag": "OpenML100"}), # combined
115+
({"data_name": "iris"}, {"data_name": "iris"}), # by dataset name
116+
({"data_id": 61}, {"data_id": [61]}), # by dataset id
117+
({"data_tag": "study_14"}, {"data_tag": "study_14"}), # by dataset tag
118+
({"number_instances": "150"}, {"number_instances": "150"}), # quality filter
119+
(
120+
{"data_id": 61, "number_instances": "150"},
121+
{"data_id": [61], "number_instances": "150"},
122+
),
115123
]
116124

117-
_FILTER_IDS = ["type", "tag", "type_and_tag"]
125+
_FILTER_IDS = [
126+
"type",
127+
"tag",
128+
"type_and_tag",
129+
"data_name",
130+
"data_id",
131+
"data_tag",
132+
"number_instances",
133+
"data_and_quality",
134+
]
118135

119136

120137
@pytest.mark.parametrize(
@@ -156,7 +173,10 @@ async def test_list_tasks_equal(
156173
assert php_response.status_code == HTTPStatus.OK
157174
assert py_response.status_code == HTTPStatus.OK
158175

159-
php_tasks: list[dict[str, Any]] = php_response.json()["tasks"]["task"]
176+
php_tasks_raw = php_response.json()["tasks"]["task"]
177+
php_tasks: list[dict[str, Any]] = (
178+
php_tasks_raw if isinstance(php_tasks_raw, list) else [php_tasks_raw]
179+
)
160180
py_tasks: list[dict[str, Any]] = [_normalize_py_task(t) for t in py_response.json()]
161181

162182
php_ids = {int(t["task_id"]) for t in php_tasks}

tests/routers/openml/task_list_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ async def test_list_tasks_pagination_order_stable(py_api: httpx.AsyncClient) ->
114114
r2 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 5}})
115115
ids1 = [t["task_id"] for t in r1.json()]
116116
ids2 = [t["task_id"] for t in r2.json()]
117-
assert max(ids1) < min(ids2)
117+
assert ids1 == sorted(ids1)
118+
assert ids2 == sorted(ids2)
119+
if ids1 and ids2:
120+
assert max(ids1) < min(ids2)
118121

119122

120123
async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None:

0 commit comments

Comments
 (0)