Skip to content

Commit 5d83604

Browse files
VegetarianOrcQuinn-With-Two-Nsclaude
authored
Populate Request Deadline in Nexus Operation Contexts (#1277)
* Nexus error ser. * Update to nexus-rpc 1.4.0 * Populate Nexus request deadline into operation contexts when present on the task received from Core. * Add test to confirm request deadline is present in cancel operation contexts * Update request deadline tests for workflow_run_operation to reflect how users will invoke rather than using a private api * refactor request deadline out of if branches in nexus worker * Fix bad rebase in pyproject.toml * gen protos * fix some rebase mistakes. Update test to not use the removed http client * Add request deadline tests for StartOperationContext and CancelOperationContext Adds tests verifying request_deadline is accessible and timezone-aware (UTC) in both StartOperationContext and CancelOperationContext. Also strengthens the existing WorkflowRunOperationContext test to validate the deadline is a proper UTC datetime rather than just truthy. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Move request_deadline assertion from handler to test body in workflow_run_operation test Captures the deadline into a list on the handler instance and asserts in the test body, consistent with the start/cancel operation deadline tests. This gives clear pytest failure messages instead of opaque handler errors. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fold start and cancel tests into a single test. Replace sleep with an asyncio.Event to avoid potential flakes --------- Co-authored-by: Quinn Klassen <klassenq@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent af89be6 commit 5d83604

3 files changed

Lines changed: 210 additions & 11 deletions

File tree

temporalio/worker/_nexus.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import threading
99
from collections.abc import Callable, Mapping, Sequence
1010
from dataclasses import dataclass
11+
from datetime import datetime, timezone
1112
from functools import reduce
1213
from typing import (
1314
Any,
@@ -119,14 +120,22 @@ async def raise_from_exception_queue() -> NoReturn:
119120

120121
if nexus_task.HasField("task"):
121122
task = nexus_task.task
123+
request_deadline = (
124+
nexus_task.request_deadline.ToDatetime().replace(
125+
tzinfo=timezone.utc
126+
)
127+
if nexus_task.HasField("request_deadline")
128+
else None
129+
)
122130
if task.request.HasField("start_operation"):
123131
task_cancellation = _NexusTaskCancellation()
124132
start_op_task = asyncio.create_task(
125133
self._handle_start_operation_task(
126-
task.task_token,
127-
task.request.start_operation,
128-
dict(task.request.header),
129-
task_cancellation,
134+
task_token=task.task_token,
135+
start_request=task.request.start_operation,
136+
headers=dict(task.request.header),
137+
task_cancellation=task_cancellation,
138+
request_deadline=request_deadline,
130139
)
131140
)
132141
self._running_tasks[task.task_token] = _RunningNexusTask(
@@ -136,10 +145,11 @@ async def raise_from_exception_queue() -> NoReturn:
136145
task_cancellation = _NexusTaskCancellation()
137146
cancel_op_task = asyncio.create_task(
138147
self._handle_cancel_operation_task(
139-
task.task_token,
140-
task.request.cancel_operation,
141-
dict(task.request.header),
142-
task_cancellation,
148+
task_token=task.task_token,
149+
request=task.request.cancel_operation,
150+
headers=dict(task.request.header),
151+
task_cancellation=task_cancellation,
152+
request_deadline=request_deadline,
143153
)
144154
)
145155
self._running_tasks[task.task_token] = _RunningNexusTask(
@@ -209,6 +219,7 @@ async def _handle_cancel_operation_task(
209219
request: temporalio.api.nexus.v1.CancelOperationRequest,
210220
headers: Mapping[str, str],
211221
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
222+
request_deadline: datetime | None,
212223
) -> None:
213224
"""Handle a cancel operation task.
214225
@@ -226,6 +237,7 @@ async def _handle_cancel_operation_task(
226237
operation=request.operation,
227238
headers=headers,
228239
task_cancellation=task_cancellation,
240+
request_deadline=request_deadline,
229241
)
230242
temporalio.nexus._operation_context._TemporalCancelOperationContext(
231243
info=lambda: Info(task_queue=self._task_queue),
@@ -276,6 +288,7 @@ async def _handle_start_operation_task(
276288
start_request: temporalio.api.nexus.v1.StartOperationRequest,
277289
headers: Mapping[str, str],
278290
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
291+
request_deadline: datetime | None,
279292
) -> None:
280293
"""Handle a start operation task.
281294
@@ -285,7 +298,7 @@ async def _handle_start_operation_task(
285298
try:
286299
try:
287300
start_response = await self._start_operation(
288-
start_request, headers, task_cancellation
301+
start_request, headers, task_cancellation, request_deadline
289302
)
290303
except asyncio.CancelledError:
291304
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
@@ -328,6 +341,7 @@ async def _start_operation(
328341
start_request: temporalio.api.nexus.v1.StartOperationRequest,
329342
headers: Mapping[str, str],
330343
cancellation: nexusrpc.handler.OperationTaskCancellation,
344+
request_deadline: datetime | None,
331345
) -> temporalio.api.nexus.v1.StartOperationResponse:
332346
"""Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.
333347
@@ -352,6 +366,7 @@ async def _start_operation(
352366
],
353367
callback_headers=dict(start_request.callback_header),
354368
task_cancellation=cancellation,
369+
request_deadline=request_deadline,
355370
)
356371
temporalio.nexus._operation_context._TemporalStartOperationContext(
357372
nexus_context=ctx,

tests/nexus/test_workflow_caller.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import uuid
88
from collections.abc import Awaitable, Callable
99
from dataclasses import dataclass
10+
from datetime import datetime, timezone
1011
from enum import IntEnum
1112
from typing import Any
1213
from urllib.request import urlopen
@@ -173,6 +174,11 @@ class HeaderTestService:
173174
cancellable_operation: nexusrpc.Operation[None, str]
174175

175176

177+
@nexusrpc.service
178+
class RequestDeadlineService:
179+
cancellable_op: nexusrpc.Operation[None, str]
180+
181+
176182
# -----------------------------------------------------------------------------
177183
# Service implementation
178184
#
@@ -335,6 +341,46 @@ def cancellable_operation(self) -> OperationHandler[None, str]:
335341
return CancellableOperationHandler(self.cancel_headers_received)
336342

337343

344+
class CancellableDeadlineOperationHandler(OperationHandler[None, str]):
345+
"""Operation handler that captures request_deadline from start and cancel contexts."""
346+
347+
def __init__(
348+
self,
349+
start_deadlines_received: list[datetime | None],
350+
cancel_deadlines_received: list[datetime | None],
351+
cancel_received: asyncio.Event,
352+
) -> None:
353+
self._start_deadlines_received = start_deadlines_received
354+
self._cancel_deadlines_received = cancel_deadlines_received
355+
self._cancel_received = cancel_received
356+
357+
async def start(
358+
self, ctx: StartOperationContext, input: None
359+
) -> StartOperationResultAsync:
360+
self._start_deadlines_received.append(ctx.request_deadline)
361+
return StartOperationResultAsync("test-token")
362+
363+
async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
364+
self._cancel_deadlines_received.append(ctx.request_deadline)
365+
self._cancel_received.set()
366+
367+
368+
@service_handler(service=RequestDeadlineService)
369+
class RequestDeadlineServiceImpl:
370+
def __init__(self) -> None:
371+
self.start_deadlines_received: list[datetime | None] = []
372+
self.cancel_deadlines_received: list[datetime | None] = []
373+
self.cancel_received = asyncio.Event()
374+
375+
@operation_handler
376+
def cancellable_op(self) -> OperationHandler[None, str]:
377+
return CancellableDeadlineOperationHandler(
378+
self.start_deadlines_received,
379+
self.cancel_deadlines_received,
380+
self.cancel_received,
381+
)
382+
383+
338384
# -----------------------------------------------------------------------------
339385
# Caller workflow
340386
#
@@ -570,6 +616,30 @@ async def run(self, input: CancelHeaderTestCallerWfInput) -> None:
570616
await asyncio.sleep(0.1)
571617

572618

619+
@workflow.defn
620+
class CancelDeadlineCallerWorkflow:
621+
"""Workflow that starts a cancellable operation and then cancels it, for deadline testing."""
622+
623+
@workflow.run
624+
async def run(self, task_queue: str) -> None:
625+
nexus_client = workflow.create_nexus_client(
626+
service=RequestDeadlineService,
627+
endpoint=make_nexus_endpoint_name(task_queue),
628+
)
629+
op_handle = await nexus_client.start_operation(
630+
RequestDeadlineService.cancellable_op,
631+
None,
632+
cancellation_type=workflow.NexusOperationCancellationType.WAIT_REQUESTED,
633+
)
634+
# Request cancellation - this sends a cancel operation to the handler
635+
op_handle.cancel()
636+
637+
try:
638+
await op_handle
639+
except NexusOperationError:
640+
pass
641+
642+
573643
@workflow.defn
574644
class WorkflowRunHeaderTestCallerWorkflow:
575645
"""Workflow that calls a workflow_run_operation and verifies headers."""
@@ -2172,3 +2242,47 @@ async def test_task_executor_operation_cancel_method(
21722242
# Verify the workflow completed successfully
21732243
result = await caller_wf_handle.result()
21742244
assert result == "cancelled_successfully"
2245+
2246+
2247+
async def test_request_deadline_is_accessible_in_operation(
2248+
client: Client,
2249+
env: WorkflowEnvironment,
2250+
):
2251+
"""Test that request_deadline is accessible in StartOperationContext."""
2252+
if env.supports_time_skipping:
2253+
pytest.skip("Nexus tests don't work with time-skipping server")
2254+
2255+
task_queue = str(uuid.uuid4())
2256+
service_handler = RequestDeadlineServiceImpl()
2257+
2258+
async with Worker(
2259+
client,
2260+
nexus_service_handlers=[service_handler],
2261+
workflows=[CancelDeadlineCallerWorkflow],
2262+
task_queue=task_queue,
2263+
):
2264+
endpoint_name = make_nexus_endpoint_name(task_queue)
2265+
await env.create_nexus_endpoint(endpoint_name, task_queue)
2266+
2267+
await client.execute_workflow(
2268+
CancelDeadlineCallerWorkflow.run,
2269+
task_queue,
2270+
id=str(uuid.uuid4()),
2271+
task_queue=task_queue,
2272+
)
2273+
2274+
assert len(service_handler.start_deadlines_received) == 1
2275+
deadline = service_handler.start_deadlines_received[0]
2276+
assert (
2277+
deadline is not None
2278+
), "request_deadline should be set in StartOperationContext"
2279+
assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc"
2280+
2281+
await asyncio.wait_for(service_handler.cancel_received.wait(), 1)
2282+
2283+
assert len(service_handler.cancel_deadlines_received) == 1
2284+
deadline = service_handler.cancel_deadlines_received[0]
2285+
assert (
2286+
deadline is not None
2287+
), "request_deadline should be set in CancelOperationContext"
2288+
assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc"

tests/nexus/test_workflow_run_operation.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22
from dataclasses import dataclass
3+
from datetime import datetime, timezone
34
from typing import Any
45

56
import nexusrpc
@@ -13,9 +14,9 @@
1314
)
1415
from nexusrpc.handler._decorators import operation_handler
1516

16-
from temporalio import workflow
17+
from temporalio import nexus, workflow
1718
from temporalio.client import Client
18-
from temporalio.nexus import WorkflowRunOperationContext
19+
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
1920
from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler
2021
from temporalio.testing import WorkflowEnvironment
2122
from temporalio.worker import Worker
@@ -59,6 +60,42 @@ def op(self) -> OperationHandler[Input, str]:
5960
return MyOperation()
6061

6162

63+
@service
64+
class RequestDeadlineService:
65+
op: Operation[Input, str]
66+
67+
68+
@service_handler(service=RequestDeadlineService)
69+
class RequestDeadlineHandler:
70+
def __init__(self) -> None:
71+
self.start_deadlines_received: list[datetime | None] = []
72+
73+
@workflow_run_operation
74+
async def op(
75+
self, ctx: WorkflowRunOperationContext, input: Input
76+
) -> nexus.WorkflowHandle[str]:
77+
self.start_deadlines_received.append(ctx.request_deadline)
78+
return await ctx.start_workflow(
79+
EchoWorkflow.run,
80+
input.value,
81+
id=str(uuid.uuid4()),
82+
)
83+
84+
85+
@workflow.defn
86+
class RequestDeadlineWorkflow:
87+
@workflow.run
88+
async def run(self, input: Input, task_queue: str) -> str:
89+
client = workflow.create_nexus_client(
90+
service=RequestDeadlineService,
91+
endpoint=make_nexus_endpoint_name(task_queue),
92+
)
93+
return await client.execute_operation(
94+
RequestDeadlineService.op,
95+
input,
96+
)
97+
98+
6299
@service
63100
class Service:
64101
op: Operation[Input, str]
@@ -116,3 +153,36 @@ async def test_workflow_run_operation(
116153
task_queue=task_queue,
117154
)
118155
assert result == "test"
156+
157+
158+
async def test_request_deadline_is_accessible_in_workflow_run_operation(
159+
client: Client,
160+
env: WorkflowEnvironment,
161+
):
162+
"""Test that request_deadline is accessible in WorkflowRunOperationContext."""
163+
if env.supports_time_skipping:
164+
pytest.skip("Nexus tests don't work with time-skipping server")
165+
166+
task_queue = str(uuid.uuid4())
167+
endpoint_name = make_nexus_endpoint_name(task_queue)
168+
await env.create_nexus_endpoint(endpoint_name, task_queue)
169+
service_handler = RequestDeadlineHandler()
170+
async with Worker(
171+
env.client,
172+
task_queue=task_queue,
173+
nexus_service_handlers=[service_handler],
174+
workflows=[RequestDeadlineWorkflow, EchoWorkflow],
175+
):
176+
await client.execute_workflow(
177+
RequestDeadlineWorkflow.run,
178+
args=[Input(value="test"), task_queue],
179+
task_queue=task_queue,
180+
id=str(uuid.uuid4()),
181+
)
182+
183+
assert len(service_handler.start_deadlines_received) == 1
184+
deadline = service_handler.start_deadlines_received[0]
185+
assert (
186+
deadline is not None
187+
), "request_deadline should be set in WorkflowRunOperationContext"
188+
assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc"

0 commit comments

Comments
 (0)