Skip to content

Commit a5318c8

Browse files
Merge pull request #253 from askui/feat/prompt_caching
Refactors Prompt Caching
2 parents 9266293 + aff286d commit a5318c8

10 files changed

Lines changed: 2721 additions & 2425 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,5 @@ reports/
172172
.askui_cache/*
173173

174174
bom.json
175+
176+
*playground*

pdm.lock

Lines changed: 1933 additions & 2280 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ authors = [
66
]
77
dependencies = [
88
"askui-agent-os>=26.1.1",
9-
"anthropic>=0.72.0",
9+
"anthropic>=0.86.0",
1010
"fastapi>=0.115.12",
1111
"fastmcp>=2.3.0",
1212
"gradio-client>=1.4.3",
13-
"grpcio>=1.73.1",
13+
"grpcio>=1.73.1,<1.80.0",
1414
"httpx>=0.28.1",
1515
"Jinja2>=3.1.4",
1616
"openai>=1.61.1",

src/askui/callbacks/usage_tracking_callback.py

Lines changed: 217 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from typing import TYPE_CHECKING
66

77
from opentelemetry import trace
8-
from pydantic import BaseModel
9-
from typing_extensions import override
8+
from pydantic import BaseModel, Field
9+
from typing_extensions import Self, override
1010

1111
from askui.callbacks.conversation_callback import ConversationCallback
1212
from askui.reporting import NULL_REPORTER
@@ -18,6 +18,8 @@
1818
from askui.speaker.speaker import SpeakerResult
1919
from askui.utils.model_pricing import ModelPricing
2020

21+
_USD_CURRENCY = "USD"
22+
2123

2224
class UsageSummary(BaseModel):
2325
"""Accumulated token usage and optional cost breakdown for a conversation.
@@ -27,9 +29,13 @@ class UsageSummary(BaseModel):
2729
output_tokens (int | None): Total output tokens generated.
2830
cache_creation_input_tokens (int | None): Tokens used for cache creation.
2931
cache_read_input_tokens (int | None): Tokens read from cache.
30-
input_cost (float | None): Computed input cost in `currency`.
31-
output_cost (float | None): Computed output cost in `currency`.
32-
total_cost (float | None): Sum of `input_cost` and `output_cost`.
32+
input_token_cost (float | None): Computed cost for input tokens in `currency`.
33+
output_token_cost (float | None): Computed cost for output tokens in `currency`.
34+
cache_write_token_cost (float | None): Computed cost for cache write tokens in
35+
`currency`.
36+
cache_read_token_cost (float | None): Computed cost for cache read tokens in
37+
`currency`.
38+
total_cost (float | None): Sum of all computed cost values.
3339
currency (str | None): ISO 4217 currency code (e.g. ``"USD"``).
3440
input_cost_per_million_tokens (float | None): Rate used to compute `input_cost`.
3541
output_cost_per_million_tokens (float|None): Rate used to compute `output_cost`.
@@ -39,12 +45,138 @@ class UsageSummary(BaseModel):
3945
output_tokens: int | None = None
4046
cache_creation_input_tokens: int | None = None
4147
cache_read_input_tokens: int | None = None
42-
input_cost: float | None = None
43-
output_cost: float | None = None
48+
input_token_cost: float | None = None
49+
output_token_cost: float | None = None
50+
cache_write_token_cost: float | None = None
51+
cache_read_token_cost: float | None = None
4452
total_cost: float | None = None
4553
currency: str | None = None
4654
input_cost_per_million_tokens: float | None = None
4755
output_cost_per_million_tokens: float | None = None
56+
cache_write_cost_per_million_tokens: float | None = None
57+
cache_read_cost_per_million_tokens: float | None = None
58+
per_conversation_summaries: list[ConversationUsageSummary] | None = None
59+
60+
@classmethod
61+
def create(cls, pricing: ModelPricing | None = None) -> "UsageSummary":
62+
"""Create a summary configured with optional model pricing."""
63+
if pricing is None:
64+
return cls()
65+
return cls(
66+
input_cost_per_million_tokens=pricing.input_cost_per_million_tokens,
67+
output_cost_per_million_tokens=pricing.output_cost_per_million_tokens,
68+
cache_write_cost_per_million_tokens=(
69+
pricing.cache_write_cost_per_million_tokens
70+
),
71+
cache_read_cost_per_million_tokens=(
72+
pricing.cache_read_cost_per_million_tokens
73+
),
74+
)
75+
76+
@classmethod
77+
def create_from(cls, summary: "UsageSummary") -> "UsageSummary":
78+
"""Create a new summary that reuses pricing fields from `summary`."""
79+
return cls(
80+
input_cost_per_million_tokens=summary.input_cost_per_million_tokens,
81+
output_cost_per_million_tokens=summary.output_cost_per_million_tokens,
82+
cache_write_cost_per_million_tokens=(
83+
summary.cache_write_cost_per_million_tokens
84+
),
85+
cache_read_cost_per_million_tokens=(
86+
summary.cache_read_cost_per_million_tokens
87+
),
88+
)
89+
90+
def add_usage(self, usage: UsageParam) -> None:
91+
"""Add token counts from `usage`."""
92+
self.input_tokens = (self.input_tokens or 0) + (usage.input_tokens or 0)
93+
self.output_tokens = (self.output_tokens or 0) + (usage.output_tokens or 0)
94+
self.cache_creation_input_tokens = (self.cache_creation_input_tokens or 0) + (
95+
usage.cache_creation_input_tokens or 0
96+
)
97+
self.cache_read_input_tokens = (self.cache_read_input_tokens or 0) + (
98+
usage.cache_read_input_tokens or 0
99+
)
100+
101+
def generate(self) -> Self:
102+
"""Compute and populate cost fields from current token and pricing fields."""
103+
if not self._has_pricing():
104+
self._clear_cost_fields()
105+
return self
106+
107+
input_tokens = self.input_tokens or 0
108+
output_tokens = self.output_tokens or 0
109+
cache_write_tokens = self.cache_creation_input_tokens or 0
110+
cache_read_tokens = self.cache_read_input_tokens or 0
111+
112+
assert self.input_cost_per_million_tokens is not None
113+
assert self.output_cost_per_million_tokens is not None
114+
assert self.cache_write_cost_per_million_tokens is not None
115+
assert self.cache_read_cost_per_million_tokens is not None
116+
117+
self.input_token_cost = self._calculate_cost(
118+
input_tokens, self.input_cost_per_million_tokens
119+
)
120+
self.output_token_cost = self._calculate_cost(
121+
output_tokens, self.output_cost_per_million_tokens
122+
)
123+
self.cache_write_token_cost = self._calculate_cost(
124+
cache_write_tokens, self.cache_write_cost_per_million_tokens
125+
)
126+
self.cache_read_token_cost = self._calculate_cost(
127+
cache_read_tokens, self.cache_read_cost_per_million_tokens
128+
)
129+
self.total_cost = (
130+
(self.input_token_cost or 0.0)
131+
+ (self.output_token_cost or 0.0)
132+
+ (self.cache_write_token_cost or 0.0)
133+
+ (self.cache_read_token_cost or 0.0)
134+
)
135+
self.currency = _USD_CURRENCY
136+
return self
137+
138+
def token_attributes(self) -> dict[str, int]:
139+
"""Return token fields for telemetry attributes."""
140+
return {
141+
"input_tokens": self.input_tokens or 0,
142+
"output_tokens": self.output_tokens or 0,
143+
"cache_creation_input_tokens": self.cache_creation_input_tokens or 0,
144+
"cache_read_input_tokens": self.cache_read_input_tokens or 0,
145+
}
146+
147+
def _has_pricing(self) -> bool:
148+
return (
149+
self.input_cost_per_million_tokens is not None
150+
and self.output_cost_per_million_tokens is not None
151+
and self.cache_write_cost_per_million_tokens is not None
152+
and self.cache_read_cost_per_million_tokens is not None
153+
)
154+
155+
def _clear_cost_fields(self) -> None:
156+
self.input_token_cost = None
157+
self.output_token_cost = None
158+
self.cache_write_token_cost = None
159+
self.cache_read_token_cost = None
160+
self.total_cost = None
161+
self.currency = None
162+
163+
@staticmethod
164+
def _calculate_cost(tokens: int, rate_per_million_tokens: float) -> float:
165+
return rate_per_million_tokens * tokens / 1e6
166+
167+
168+
class StepUsageSummary(UsageSummary):
169+
"""Usage summary for a single step."""
170+
171+
step_index: int
172+
173+
174+
class ConversationUsageSummary(UsageSummary):
175+
"""Usage summary for one conversation including per-step breakdown."""
176+
177+
conversation_index: int
178+
conversation_id: str
179+
step_summaries: list[StepUsageSummary] = Field(default_factory=list)
48180

49181

50182
class UsageTrackingCallback(ConversationCallback):
@@ -62,12 +194,17 @@ def __init__(
62194
pricing: ModelPricing | None = None,
63195
) -> None:
64196
self._reporter = reporter
65-
self._pricing = pricing
66-
self._summary = UsageSummary()
197+
self._summary: UsageSummary = UsageSummary.create(pricing)
198+
self._per_conversation_usage: UsageSummary = UsageSummary.create(pricing)
199+
self._per_conversation_summaries: list[ConversationUsageSummary] = []
200+
self._per_step_summaries: list[StepUsageSummary] = []
201+
self._conversation_index: int = 0
67202

68203
@override
69204
def on_conversation_start(self, conversation: Conversation) -> None:
70-
self._summary = UsageSummary()
205+
self._per_conversation_usage = UsageSummary.create_from(self._summary)
206+
self._per_step_summaries = []
207+
self._conversation_index += 1
71208

72209
@override
73210
def on_step_end(
@@ -76,71 +213,85 @@ def on_step_end(
76213
step_index: int,
77214
result: SpeakerResult,
78215
) -> None:
79-
if result.usage:
80-
self._accumulate(result.usage)
216+
step_usage: UsageParam | None = result.usage
217+
if step_usage is None:
218+
return
219+
220+
step_summary = self._create_step_summary(
221+
step_index=step_index, usage=step_usage
222+
)
223+
self._per_step_summaries.append(step_summary)
224+
self._per_conversation_usage.add_usage(step_usage)
225+
self._summary.add_usage(step_usage)
226+
227+
current_span = trace.get_current_span()
228+
current_span.set_attributes(step_summary.token_attributes())
81229

82230
@override
83231
def on_conversation_end(self, conversation: Conversation) -> None:
84-
self._reporter.add_usage_summary(self._summary)
232+
generated_steps: list[StepUsageSummary] = [
233+
step_summary.generate() for step_summary in self._per_step_summaries
234+
]
235+
conversation_summary = self._create_conversation_summary(
236+
conversation=conversation,
237+
generated_step_summaries=generated_steps,
238+
)
239+
self._per_conversation_summaries.append(conversation_summary)
240+
self._summary.per_conversation_summaries = list(
241+
self._per_conversation_summaries
242+
)
243+
self._reporter.add_usage_summary(self._summary.generate().model_copy(deep=True))
85244

86245
@property
87246
def accumulated_usage(self) -> UsageSummary:
88247
"""Current accumulated usage statistics."""
89248
return self._summary
90249

91-
def _accumulate(self, step_usage: UsageParam) -> None:
92-
# Add step tokens to running totals (None counts as 0)
93-
self._summary.input_tokens = (self._summary.input_tokens or 0) + (
94-
step_usage.input_tokens or 0
95-
)
96-
self._summary.output_tokens = (self._summary.output_tokens or 0) + (
97-
step_usage.output_tokens or 0
98-
)
99-
self._summary.cache_creation_input_tokens = (
100-
self._summary.cache_creation_input_tokens or 0
101-
) + (step_usage.cache_creation_input_tokens or 0)
102-
self._summary.cache_read_input_tokens = (
103-
self._summary.cache_read_input_tokens or 0
104-
) + (step_usage.cache_read_input_tokens or 0)
105-
106-
# Record per-step token counts on the current OTel span
107-
current_span = trace.get_current_span()
108-
current_span.set_attributes(
109-
{
110-
"input_tokens": step_usage.input_tokens or 0,
111-
"output_tokens": step_usage.output_tokens or 0,
112-
"cache_creation_input_tokens": (
113-
step_usage.cache_creation_input_tokens or 0
114-
),
115-
"cache_read_input_tokens": (step_usage.cache_read_input_tokens or 0),
116-
}
250+
def _create_step_summary(
251+
self, step_index: int, usage: UsageParam
252+
) -> StepUsageSummary:
253+
return StepUsageSummary(
254+
step_index=step_index,
255+
input_tokens=usage.input_tokens or 0,
256+
output_tokens=usage.output_tokens or 0,
257+
cache_creation_input_tokens=usage.cache_creation_input_tokens or 0,
258+
cache_read_input_tokens=usage.cache_read_input_tokens or 0,
259+
input_cost_per_million_tokens=self._summary.input_cost_per_million_tokens,
260+
output_cost_per_million_tokens=self._summary.output_cost_per_million_tokens,
261+
cache_write_cost_per_million_tokens=(
262+
self._summary.cache_write_cost_per_million_tokens
263+
),
264+
cache_read_cost_per_million_tokens=(
265+
self._summary.cache_read_cost_per_million_tokens
266+
),
117267
)
118268

119-
# Update costs from updated totals if pricing values are set
120-
if not (
121-
self._pricing
122-
and self._pricing.input_cost_per_million_tokens
123-
and self._pricing.output_cost_per_million_tokens
124-
):
125-
return
126-
127-
input_cost = (
128-
self._summary.input_tokens
129-
* self._pricing.input_cost_per_million_tokens
130-
/ 1e6
131-
)
132-
output_cost = (
133-
self._summary.output_tokens
134-
* self._pricing.output_cost_per_million_tokens
135-
/ 1e6
136-
)
137-
self._summary.input_cost = input_cost
138-
self._summary.output_cost = output_cost
139-
self._summary.total_cost = input_cost + output_cost
140-
self._summary.currency = self._pricing.currency
141-
self._summary.input_cost_per_million_tokens = (
142-
self._pricing.input_cost_per_million_tokens
143-
)
144-
self._summary.output_cost_per_million_tokens = (
145-
self._pricing.output_cost_per_million_tokens
269+
def _create_conversation_summary(
270+
self,
271+
conversation: Conversation,
272+
generated_step_summaries: list[StepUsageSummary],
273+
) -> ConversationUsageSummary:
274+
conversation_summary = ConversationUsageSummary(
275+
conversation_index=self._conversation_index,
276+
conversation_id=conversation.conversation_id,
277+
step_summaries=generated_step_summaries,
278+
input_tokens=self._per_conversation_usage.input_tokens,
279+
output_tokens=self._per_conversation_usage.output_tokens,
280+
cache_creation_input_tokens=(
281+
self._per_conversation_usage.cache_creation_input_tokens
282+
),
283+
cache_read_input_tokens=self._per_conversation_usage.cache_read_input_tokens,
284+
input_cost_per_million_tokens=(
285+
self._per_conversation_usage.input_cost_per_million_tokens
286+
),
287+
output_cost_per_million_tokens=(
288+
self._per_conversation_usage.output_cost_per_million_tokens
289+
),
290+
cache_write_cost_per_million_tokens=(
291+
self._per_conversation_usage.cache_write_cost_per_million_tokens
292+
),
293+
cache_read_cost_per_million_tokens=(
294+
self._per_conversation_usage.cache_read_cost_per_million_tokens
295+
),
146296
)
297+
return conversation_summary.generate()

src/askui/model_providers/anthropic_vlm_provider.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,14 @@ class AnthropicVlmProvider(VlmProvider):
4040
client (Anthropic | None, optional): Pre-configured Anthropic client.
4141
If provided, other connection parameters are ignored.
4242
input_cost_per_million_tokens (float | None, optional): Override
43-
cost in USD per 1M input tokens. Both cost params must be set
44-
to override the built-in defaults.
43+
cost in USD per 1M input tokens. All override pricing params must be set to
44+
override the built-in defaults.
4545
output_cost_per_million_tokens (float | None, optional): Override
4646
cost in USD per 1M output tokens.
47+
cache_write_cost_per_million_tokens (float | None, optional): Override
48+
cost in USD per 1M cache write input tokens.
49+
cache_read_cost_per_million_tokens (float | None, optional): Override
50+
cost in USD per 1M cache read input tokens.
4751
4852
Example:
4953
```python
@@ -68,6 +72,8 @@ def __init__(
6872
client: Anthropic | None = None,
6973
input_cost_per_million_tokens: float | None = None,
7074
output_cost_per_million_tokens: float | None = None,
75+
cache_write_cost_per_million_tokens: float | None = None,
76+
cache_read_cost_per_million_tokens: float | None = None,
7177
) -> None:
7278
self._model_id_value = (
7379
model_id or os.environ.get("VLM_PROVIDER_MODEL_ID") or _DEFAULT_MODEL_ID
@@ -84,6 +90,8 @@ def __init__(
8490
self._model_id_value,
8591
input_cost_per_million_tokens=input_cost_per_million_tokens,
8692
output_cost_per_million_tokens=output_cost_per_million_tokens,
93+
cache_write_cost_per_million_tokens=cache_write_cost_per_million_tokens,
94+
cache_read_cost_per_million_tokens=cache_read_cost_per_million_tokens,
8795
)
8896

8997
@property

0 commit comments

Comments
 (0)