Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/hooks/__tests__/useModel.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,51 @@ describe('useModel', () => {
expect.objectContaining({ reason: 'user_reset' }),
);
});

it('cancels the in-flight backend generation when starting a new session', async () => {
// Stall ask_model so the generation stays active while we reset.
let resolveInvoke!: () => void;
invoke.mockImplementationOnce(
async (_cmd: string, args?: Record<string, unknown>) => {
if (args && 'onEvent' in args) {
return new Promise<void>((res) => {
resolveInvoke = res;
});
}
},
);

const { result } = renderHook(() => useModel(''));

act(() => {
void result.current.ask('hello');
});
expect(result.current.isGenerating).toBe(true);

await act(async () => {
result.current.reset();
await Promise.resolve();
});

// A new session must stop the backend stream, not just the frontend
// view - otherwise the old generation holds the engine's single slot
// and the next turn queues behind it.
expect(invoke).toHaveBeenCalledWith('cancel_generation');

act(() => {
resolveInvoke?.();
});
});

it('does not call cancel_generation when reset runs with no active generation', () => {
const { result } = renderHook(() => useModel(''));

act(() => {
result.current.reset();
});

expect(invoke).not.toHaveBeenCalledWith('cancel_generation');
});
});

// ─── onTurnComplete callback ─────────────────────────────────────────────────
Expand Down Expand Up @@ -1125,6 +1170,40 @@ describe('useModel', () => {
expect.objectContaining({ reason: 'history_load' }),
);
});

it('cancels the in-flight backend generation when loading another conversation', async () => {
// Stall ask_model so the generation stays active while we load.
let resolveInvoke!: () => void;
invoke.mockImplementationOnce(
async (_cmd: string, args?: Record<string, unknown>) => {
if (args && 'onEvent' in args) {
return new Promise<void>((res) => {
resolveInvoke = res;
});
}
},
);

const { result } = renderHook(() => useModel(''));

act(() => {
void result.current.ask('original');
});
expect(result.current.isGenerating).toBe(true);

await act(async () => {
result.current.loadMessages([
{ id: 'l1', role: 'user', content: 'loaded' },
]);
await Promise.resolve();
});

expect(invoke).toHaveBeenCalledWith('cancel_generation');

act(() => {
resolveInvoke?.();
});
});
});

// ─── ThinkingToken handling ──────────────────────────────────────────────────
Expand Down
64 changes: 44 additions & 20 deletions src/hooks/useModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,30 @@ export function useModel(
return true;
}, []);

/**
* Signals the backend to stop the active generation and tracks the
* in-flight cancel in `pendingCancelRef` so the next `ask()` /
* `askSearch()` awaits the round-trip before starting a new turn. That
* wait is what stops a fresh request from racing the outgoing one onto
* the engine's single decode slot. Idempotent while a cancel is already
* pending, so overlapping callers (double cancel, cancel-then-reset) only
* fire `cancel_generation` once. Returns the pending-cancel promise.
*/
const requestBackendCancel = useCallback((): Promise<void> => {
if (!pendingCancelRef.current) {
pendingCancelRef.current = (async () => {
try {
await invoke('cancel_generation');
} catch {
// Local hard-abort already reset the UI; backend best-effort only.
} finally {
pendingCancelRef.current = null;
}
})();
}
return pendingCancelRef.current;
}, []);

/**
* Submits a message to the Ollama backend and starts the streaming response.
*
Expand Down Expand Up @@ -717,22 +741,8 @@ export function useModel(
}

abortActiveGeneration();

if (!pendingCancelRef.current) {
const cancelPromise = (async () => {
try {
await invoke('cancel_generation');
} catch {
// Local hard-abort already reset the UI; backend best-effort only.
} finally {
pendingCancelRef.current = null;
}
})();
pendingCancelRef.current = cancelPromise;
}

await pendingCancelRef.current;
}, [abortActiveGeneration, isGenerating]);
await requestBackendCancel();
}, [abortActiveGeneration, isGenerating, requestBackendCancel]);

/** Resets all conversation state for a fresh session.
*
Expand All @@ -746,7 +756,15 @@ export function useModel(
* user-visible reset.
*/
const reset = useCallback(() => {
abortActiveGeneration();
const hadActiveGeneration = abortActiveGeneration();
// Starting a fresh session must also stop any in-flight backend stream,
// not just the frontend view abort above. Otherwise the outgoing
// generation keeps the engine's single decode slot and the next turn
// queues behind it. Routed through the same `pendingCancelRef` plumbing
// `cancel()` uses so the next `ask()` awaits the cancel round-trip.
if (hadActiveGeneration) {
void requestBackendCancel();
}
setMessages([]);
const outgoingId = traceConversationIdRef.current;
if (outgoingId !== null && !isFirstTurnRef.current) {
Expand All @@ -758,7 +776,7 @@ export function useModel(
traceConversationIdRef.current = null;
isFirstTurnRef.current = true;
void invoke('reset_conversation');
}, [abortActiveGeneration]);
}, [abortActiveGeneration, requestBackendCancel]);

/** Replaces the current message list with a previously loaded set of messages.
*
Expand All @@ -770,7 +788,13 @@ export function useModel(
*/
const loadMessages = useCallback(
(msgs: Message[]) => {
abortActiveGeneration();
const hadActiveGeneration = abortActiveGeneration();
// Loading another conversation is a session boundary too: stop the
// in-flight backend stream so it does not hold the engine slot and
// stall the loaded conversation's next turn. Same plumbing as reset().
if (hadActiveGeneration) {
void requestBackendCancel();
}
const outgoingId = traceConversationIdRef.current;
if (outgoingId !== null && !isFirstTurnRef.current) {
void invoke('record_conversation_end', {
Expand All @@ -782,7 +806,7 @@ export function useModel(
isFirstTurnRef.current = true;
setMessages(msgs);
},
[abortActiveGeneration],
[abortActiveGeneration, requestBackendCancel],
);

/**
Expand Down