diff --git a/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py b/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py index 81b8c1a5a..753e98931 100644 --- a/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py +++ b/nemo_retriever/src/nemo_retriever/adapters/cli/sdk_workflow.py @@ -262,4 +262,6 @@ def query_documents( retriever_kwargs["rerank_kwargs"] = rerank_kwargs retriever = Retriever(**retriever_kwargs) - return retriever.query(query) + hits = retriever.query(query) + hits = [{"text": hit.get("text", ""), "source": hit.get("source", ""), "page_number": hit.get("page_number")} for hit in hits] + return hits diff --git a/nemo_retriever/tests/test_root_cli_workflow.py b/nemo_retriever/tests/test_root_cli_workflow.py index b55784397..d935a7dde 100644 --- a/nemo_retriever/tests/test_root_cli_workflow.py +++ b/nemo_retriever/tests/test_root_cli_workflow.py @@ -330,9 +330,13 @@ def fail_create_ingestor(**_kwargs: Any) -> Any: def test_root_query_passes_query_options_and_prints_json(monkeypatch) -> None: retriever_calls: list[dict[str, Any]] = [] query_calls: list[str] = [] - hits = [ - {"text": "passage", "page_number": 1, "_distance": 0.2}, - {"text": "other", "page_number": 2, "_distance": 0.4}, + raw_hits = [ + {"text": "passage", "source": "a.pdf", "page_number": 1, "_distance": 0.2}, + {"text": "other", "source": "b.pdf", "page_number": 2, "_distance": 0.4}, + ] + # query_documents exposes only text / source / page_number (no scores or extra keys). + public_hits = [ + {"text": h["text"], "source": h["source"], "page_number": h["page_number"]} for h in raw_hits ] class FakeRetriever: @@ -341,7 +345,7 @@ def __init__(self, **kwargs: Any) -> None: def query(self, query: str) -> list[dict[str, Any]]: query_calls.append(query) - return hits + return raw_hits monkeypatch.setattr(sdk_workflow, "Retriever", FakeRetriever)