Skip to content

Commit 0bef7a1

Browse files
Merge pull request #257 from askui/fix/truncation_strategy
fix(truncation strategy): add proper message history management through summarization
2 parents 3c74720 + 50e4caa commit 0bef7a1

File tree

10 files changed

+1786
-326
lines changed

10 files changed

+1786
-326
lines changed

src/askui/agent_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
LocateSettings,
2424
)
2525
from askui.models.shared.tools import Tool, ToolCollection
26+
from askui.models.shared.truncation_strategies import TruncationStrategy
2627
from askui.prompts.act_prompts import CACHE_USE_PROMPT, create_default_prompt
2728
from askui.telemetry.otel import OtelSettings, setup_opentelemetry_tracing
2829
from askui.tools.agent_os import AgentOs
@@ -59,6 +60,7 @@ def __init__(
5960
agent_os: AgentOs | AndroidAgentOs | None = None,
6061
settings: AgentSettings | None = None,
6162
callbacks: list[ConversationCallback] | None = None,
63+
truncation_strategy: TruncationStrategy | None = None,
6264
) -> None:
6365
load_dotenv()
6466
self._reporter: Reporter = reporter or CompositeReporter(reporters=None)
@@ -87,6 +89,7 @@ def __init__(
8789
image_qa_provider=self._image_qa_provider,
8890
detection_provider=self._detection_provider,
8991
reporter=self._reporter,
92+
truncation_strategy=truncation_strategy,
9093
callbacks=_callbacks,
9194
)
9295

src/askui/android_agent.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from askui.models.models import Point
1313
from askui.models.shared.settings import ActSettings, MessageSettings
1414
from askui.models.shared.tools import Tool
15+
from askui.models.shared.truncation_strategies import TruncationStrategy
1516
from askui.prompts.act_prompts import create_android_agent_prompt
1617
from askui.tools.android.agent_os import ANDROID_KEY
1718
from askui.tools.android.agent_os_facade import AndroidAgentOsFacade
@@ -64,7 +65,15 @@ class AndroidAgent(Agent):
6465
```
6566
"""
6667

67-
@telemetry.record_call(exclude={"reporters", "settings", "act_tools", "callbacks"})
68+
@telemetry.record_call(
69+
exclude={
70+
"reporters",
71+
"settings",
72+
"act_tools",
73+
"callbacks",
74+
"truncation_strategy",
75+
}
76+
)
6877
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
6978
def __init__(
7079
self,
@@ -74,6 +83,7 @@ def __init__(
7483
retry: Retry | None = None,
7584
act_tools: list[Tool] | None = None,
7685
callbacks: list[ConversationCallback] | None = None,
86+
truncation_strategy: TruncationStrategy | None = None,
7787
) -> None:
7888
reporter = CompositeReporter(reporters=reporters)
7989
self.os = PpadbAgentOs(device_identifier=device, reporter=reporter)
@@ -85,6 +95,7 @@ def __init__(
8595
agent_os=self.os,
8696
settings=settings,
8797
callbacks=callbacks,
98+
truncation_strategy=truncation_strategy,
8899
)
89100
self.act_tool_collection.add_agent_os(self.act_agent_os_facade)
90101
# Override default act settings with Android-specific settings

src/askui/callbacks/conversation_callback.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44

5+
from askui.models.shared.agent_message_param import UsageParam
6+
57
if TYPE_CHECKING:
68
from askui.models.shared.conversation import Conversation
79
from askui.speaker.speaker import SpeakerResult
@@ -123,3 +125,10 @@ def on_tool_execution_end(
123125
conversation: The conversation instance.
124126
tool_names: Names of tools that were executed.
125127
"""
128+
129+
def on_truncation_summarize(self, usage: UsageParam) -> None:
130+
"""Called when a truncation strategy summarizes message history.
131+
132+
Args:
133+
usage: Token usage from the summarization LLM call.
134+
"""

src/askui/callbacks/usage_tracking_callback.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def on_step_end(
227227
current_span = trace.get_current_span()
228228
current_span.set_attributes(step_summary.token_attributes())
229229

230+
@override
231+
def on_truncation_summarize(self, usage: UsageParam) -> None:
232+
self._per_conversation_usage.add_usage(usage)
233+
self._summary.add_usage(usage)
234+
230235
@override
231236
def on_conversation_end(self, conversation: Conversation) -> None:
232237
generated_steps: list[StepUsageSummary] = [

src/askui/computer_agent.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from askui.models.models import Point
1313
from askui.models.shared.settings import ActSettings, LocateSettings, MessageSettings
1414
from askui.models.shared.tools import Tool
15+
from askui.models.shared.truncation_strategies import TruncationStrategy
1516
from askui.prompts.act_prompts import (
1617
create_computer_agent_prompt,
1718
)
@@ -69,7 +70,14 @@ class ComputerAgent(Agent):
6970
"""
7071

7172
@telemetry.record_call(
72-
exclude={"reporters", "tools", "settings", "act_tools", "callbacks"}
73+
exclude={
74+
"reporters",
75+
"tools",
76+
"settings",
77+
"act_tools",
78+
"callbacks",
79+
"truncation_strategy",
80+
}
7381
)
7482
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
7583
def __init__(
@@ -81,6 +89,7 @@ def __init__(
8189
retry: Retry | None = None,
8290
act_tools: list[Tool] | None = None,
8391
callbacks: list[ConversationCallback] | None = None,
92+
truncation_strategy: TruncationStrategy | None = None,
8493
) -> None:
8594
reporter = CompositeReporter(reporters=reporters)
8695
self.tools = tools or AgentToolbox(
@@ -96,6 +105,7 @@ def __init__(
96105
agent_os=self.tools.os,
97106
settings=settings,
98107
callbacks=callbacks,
108+
truncation_strategy=truncation_strategy,
99109
)
100110
self.act_agent_os_facade: ComputerAgentOsFacade = ComputerAgentOsFacade(
101111
self.tools.os

src/askui/models/shared/conversation.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from askui.models.shared.settings import ActSettings
1414
from askui.models.shared.tools import ToolCollection
1515
from askui.models.shared.truncation_strategies import (
16-
SimpleTruncationStrategyFactory,
16+
SummarizingTruncationStrategy,
1717
TruncationStrategy,
18-
TruncationStrategyFactory,
1918
)
2019
from askui.reporting import NULL_REPORTER, Reporter
2120
from askui.speaker.speaker import SpeakerResult, Speakers
@@ -55,7 +54,7 @@ class Conversation:
5554
detection_provider: Detection provider (optional)
5655
reporter: Reporter for logging messages and actions
5756
cache_manager: Cache manager for recording/playback (optional)
58-
truncation_strategy_factory: Factory for creating truncation strategies
57+
truncation_strategy: truncation strategies (optional)
5958
callbacks: List of callbacks for conversation lifecycle hooks (optional)
6059
"""
6160

@@ -67,7 +66,7 @@ def __init__(
6766
detection_provider: DetectionProvider | None = None,
6867
reporter: Reporter = NULL_REPORTER,
6968
cache_manager: "CacheManager | None" = None,
70-
truncation_strategy_factory: TruncationStrategyFactory | None = None,
69+
truncation_strategy: TruncationStrategy | None = None,
7170
callbacks: "list[ConversationCallback] | None" = None,
7271
) -> None:
7372
"""Initialize conversation with speakers and model providers."""
@@ -90,10 +89,6 @@ def __init__(
9089
# Infrastructure
9190
self._reporter = reporter
9291
self.cache_manager = cache_manager
93-
self._truncation_strategy_factory = (
94-
truncation_strategy_factory or SimpleTruncationStrategyFactory()
95-
)
96-
self._truncation_strategy: TruncationStrategy | None = None
9792
self._callbacks: "list[ConversationCallback]" = callbacks or []
9893

9994
# State for current execution (set in start())
@@ -102,6 +97,22 @@ def __init__(
10297
self._reporters: list[Reporter] = []
10398
self._step_index: int = 0
10499

100+
# Truncation strategy. Conversation-owned dependencies are
101+
# auto-injected so users can pass a custom strategy with only
102+
# strategy-specific config (e.g. n_messages_to_keep) without
103+
# needing access to vlm_provider/reporter/callbacks/conversation
104+
# at construction time. ``vlm_provider`` is only injected when
105+
# not pre-set, allowing callers to override the summarization
106+
# VLM (e.g. with a cheaper model).
107+
self._truncation_strategy: TruncationStrategy = (
108+
truncation_strategy or SummarizingTruncationStrategy()
109+
)
110+
if self._truncation_strategy.vlm_provider is None:
111+
self._truncation_strategy.vlm_provider = vlm_provider
112+
self._truncation_strategy.reporter = reporter
113+
self._truncation_strategy.callbacks = self._callbacks
114+
self._truncation_strategy.conversation = self
115+
105116
# Track if cache execution was used (to prevent recording during playback)
106117
self._executed_from_cache: bool = False
107118

@@ -180,6 +191,7 @@ def _setup_control_loop(
180191
reporters: list[Reporter] | None = None,
181192
) -> None:
182193
# Reset state
194+
self._truncation_strategy.reset(messages)
183195
self._executed_from_cache = False
184196
self.speakers.reset_state()
185197

@@ -191,16 +203,6 @@ def _setup_control_loop(
191203
# Auto-populate speaker descriptions and switch_speaker tool
192204
self._setup_speaker_handoff()
193205

194-
# Initialize truncation strategy
195-
self._truncation_strategy = (
196-
self._truncation_strategy_factory.create_truncation_strategy(
197-
tools=self.tools.to_params(),
198-
system=self.settings.messages.system,
199-
messages=messages,
200-
model=self.vlm_provider.model_id,
201-
)
202-
)
203-
204206
@tracer.start_as_current_span("_execute_control_loop")
205207
def _execute_control_loop(self) -> None:
206208
self._on_control_loop_start()
@@ -448,7 +450,9 @@ def get_messages(self) -> list[MessageParam]:
448450
Returns:
449451
List of messages in current conversation
450452
"""
451-
return self._truncation_strategy.messages if self._truncation_strategy else []
453+
return (
454+
self._truncation_strategy.full_messages if self._truncation_strategy else []
455+
)
452456

453457
def get_truncation_strategy(self) -> TruncationStrategy | None:
454458
"""Get current truncation strategy.

0 commit comments

Comments
 (0)