55from typing import TYPE_CHECKING
66
77from opentelemetry import trace
8- from pydantic import BaseModel
9- from typing_extensions import override
8+ from pydantic import BaseModel , Field
9+ from typing_extensions import Self , override
1010
1111from askui .callbacks .conversation_callback import ConversationCallback
1212from askui .reporting import NULL_REPORTER
1818 from askui .speaker .speaker import SpeakerResult
1919 from askui .utils .model_pricing import ModelPricing
2020
21+ _USD_CURRENCY = "USD"
22+
2123
2224class UsageSummary (BaseModel ):
2325 """Accumulated token usage and optional cost breakdown for a conversation.
@@ -27,9 +29,13 @@ class UsageSummary(BaseModel):
2729 output_tokens (int | None): Total output tokens generated.
2830 cache_creation_input_tokens (int | None): Tokens used for cache creation.
2931 cache_read_input_tokens (int | None): Tokens read from cache.
30- input_cost (float | None): Computed input cost in `currency`.
31- output_cost (float | None): Computed output cost in `currency`.
32- total_cost (float | None): Sum of `input_cost` and `output_cost`.
32+ input_token_cost (float | None): Computed cost for input tokens in `currency`.
33+ output_token_cost (float | None): Computed cost for output tokens in `currency`.
34+ cache_write_token_cost (float | None): Computed cost for cache write tokens in
35+ `currency`.
36+ cache_read_token_cost (float | None): Computed cost for cache read tokens in
37+ `currency`.
38+ total_cost (float | None): Sum of all computed cost values.
3339 currency (str | None): ISO 4217 currency code (e.g. ``"USD"``).
3440 input_cost_per_million_tokens (float | None): Rate used to compute `input_cost`.
3541 output_cost_per_million_tokens (float|None): Rate used to compute `output_cost`.
@@ -39,12 +45,138 @@ class UsageSummary(BaseModel):
3945 output_tokens : int | None = None
4046 cache_creation_input_tokens : int | None = None
4147 cache_read_input_tokens : int | None = None
42- input_cost : float | None = None
43- output_cost : float | None = None
48+ input_token_cost : float | None = None
49+ output_token_cost : float | None = None
50+ cache_write_token_cost : float | None = None
51+ cache_read_token_cost : float | None = None
4452 total_cost : float | None = None
4553 currency : str | None = None
4654 input_cost_per_million_tokens : float | None = None
4755 output_cost_per_million_tokens : float | None = None
56+ cache_write_cost_per_million_tokens : float | None = None
57+ cache_read_cost_per_million_tokens : float | None = None
58+ per_conversation_summaries : list [ConversationUsageSummary ] | None = None
59+
60+ @classmethod
61+ def create (cls , pricing : ModelPricing | None = None ) -> "UsageSummary" :
62+ """Create a summary configured with optional model pricing."""
63+ if pricing is None :
64+ return cls ()
65+ return cls (
66+ input_cost_per_million_tokens = pricing .input_cost_per_million_tokens ,
67+ output_cost_per_million_tokens = pricing .output_cost_per_million_tokens ,
68+ cache_write_cost_per_million_tokens = (
69+ pricing .cache_write_cost_per_million_tokens
70+ ),
71+ cache_read_cost_per_million_tokens = (
72+ pricing .cache_read_cost_per_million_tokens
73+ ),
74+ )
75+
76+ @classmethod
77+ def create_from (cls , summary : "UsageSummary" ) -> "UsageSummary" :
78+ """Create a new summary that reuses pricing fields from `summary`."""
79+ return cls (
80+ input_cost_per_million_tokens = summary .input_cost_per_million_tokens ,
81+ output_cost_per_million_tokens = summary .output_cost_per_million_tokens ,
82+ cache_write_cost_per_million_tokens = (
83+ summary .cache_write_cost_per_million_tokens
84+ ),
85+ cache_read_cost_per_million_tokens = (
86+ summary .cache_read_cost_per_million_tokens
87+ ),
88+ )
89+
90+ def add_usage (self , usage : UsageParam ) -> None :
91+ """Add token counts from `usage`."""
92+ self .input_tokens = (self .input_tokens or 0 ) + (usage .input_tokens or 0 )
93+ self .output_tokens = (self .output_tokens or 0 ) + (usage .output_tokens or 0 )
94+ self .cache_creation_input_tokens = (self .cache_creation_input_tokens or 0 ) + (
95+ usage .cache_creation_input_tokens or 0
96+ )
97+ self .cache_read_input_tokens = (self .cache_read_input_tokens or 0 ) + (
98+ usage .cache_read_input_tokens or 0
99+ )
100+
101+ def generate (self ) -> Self :
102+ """Compute and populate cost fields from current token and pricing fields."""
103+ if not self ._has_pricing ():
104+ self ._clear_cost_fields ()
105+ return self
106+
107+ input_tokens = self .input_tokens or 0
108+ output_tokens = self .output_tokens or 0
109+ cache_write_tokens = self .cache_creation_input_tokens or 0
110+ cache_read_tokens = self .cache_read_input_tokens or 0
111+
112+ assert self .input_cost_per_million_tokens is not None
113+ assert self .output_cost_per_million_tokens is not None
114+ assert self .cache_write_cost_per_million_tokens is not None
115+ assert self .cache_read_cost_per_million_tokens is not None
116+
117+ self .input_token_cost = self ._calculate_cost (
118+ input_tokens , self .input_cost_per_million_tokens
119+ )
120+ self .output_token_cost = self ._calculate_cost (
121+ output_tokens , self .output_cost_per_million_tokens
122+ )
123+ self .cache_write_token_cost = self ._calculate_cost (
124+ cache_write_tokens , self .cache_write_cost_per_million_tokens
125+ )
126+ self .cache_read_token_cost = self ._calculate_cost (
127+ cache_read_tokens , self .cache_read_cost_per_million_tokens
128+ )
129+ self .total_cost = (
130+ (self .input_token_cost or 0.0 )
131+ + (self .output_token_cost or 0.0 )
132+ + (self .cache_write_token_cost or 0.0 )
133+ + (self .cache_read_token_cost or 0.0 )
134+ )
135+ self .currency = _USD_CURRENCY
136+ return self
137+
138+ def token_attributes (self ) -> dict [str , int ]:
139+ """Return token fields for telemetry attributes."""
140+ return {
141+ "input_tokens" : self .input_tokens or 0 ,
142+ "output_tokens" : self .output_tokens or 0 ,
143+ "cache_creation_input_tokens" : self .cache_creation_input_tokens or 0 ,
144+ "cache_read_input_tokens" : self .cache_read_input_tokens or 0 ,
145+ }
146+
147+ def _has_pricing (self ) -> bool :
148+ return (
149+ self .input_cost_per_million_tokens is not None
150+ and self .output_cost_per_million_tokens is not None
151+ and self .cache_write_cost_per_million_tokens is not None
152+ and self .cache_read_cost_per_million_tokens is not None
153+ )
154+
155+ def _clear_cost_fields (self ) -> None :
156+ self .input_token_cost = None
157+ self .output_token_cost = None
158+ self .cache_write_token_cost = None
159+ self .cache_read_token_cost = None
160+ self .total_cost = None
161+ self .currency = None
162+
163+ @staticmethod
164+ def _calculate_cost (tokens : int , rate_per_million_tokens : float ) -> float :
165+ return rate_per_million_tokens * tokens / 1e6
166+
167+
168+ class StepUsageSummary (UsageSummary ):
169+ """Usage summary for a single step."""
170+
171+ step_index : int
172+
173+
174+ class ConversationUsageSummary (UsageSummary ):
175+ """Usage summary for one conversation including per-step breakdown."""
176+
177+ conversation_index : int
178+ conversation_id : str
179+ step_summaries : list [StepUsageSummary ] = Field (default_factory = list )
48180
49181
50182class UsageTrackingCallback (ConversationCallback ):
@@ -62,12 +194,17 @@ def __init__(
62194 pricing : ModelPricing | None = None ,
63195 ) -> None :
64196 self ._reporter = reporter
65- self ._pricing = pricing
66- self ._summary = UsageSummary ()
197+ self ._summary : UsageSummary = UsageSummary .create (pricing )
198+ self ._per_conversation_usage : UsageSummary = UsageSummary .create (pricing )
199+ self ._per_conversation_summaries : list [ConversationUsageSummary ] = []
200+ self ._per_step_summaries : list [StepUsageSummary ] = []
201+ self ._conversation_index : int = 0
67202
68203 @override
69204 def on_conversation_start (self , conversation : Conversation ) -> None :
70- self ._summary = UsageSummary ()
205+ self ._per_conversation_usage = UsageSummary .create_from (self ._summary )
206+ self ._per_step_summaries = []
207+ self ._conversation_index += 1
71208
72209 @override
73210 def on_step_end (
@@ -76,71 +213,85 @@ def on_step_end(
76213 step_index : int ,
77214 result : SpeakerResult ,
78215 ) -> None :
79- if result .usage :
80- self ._accumulate (result .usage )
216+ step_usage : UsageParam | None = result .usage
217+ if step_usage is None :
218+ return
219+
220+ step_summary = self ._create_step_summary (
221+ step_index = step_index , usage = step_usage
222+ )
223+ self ._per_step_summaries .append (step_summary )
224+ self ._per_conversation_usage .add_usage (step_usage )
225+ self ._summary .add_usage (step_usage )
226+
227+ current_span = trace .get_current_span ()
228+ current_span .set_attributes (step_summary .token_attributes ())
81229
82230 @override
83231 def on_conversation_end (self , conversation : Conversation ) -> None :
84- self ._reporter .add_usage_summary (self ._summary )
232+ generated_steps : list [StepUsageSummary ] = [
233+ step_summary .generate () for step_summary in self ._per_step_summaries
234+ ]
235+ conversation_summary = self ._create_conversation_summary (
236+ conversation = conversation ,
237+ generated_step_summaries = generated_steps ,
238+ )
239+ self ._per_conversation_summaries .append (conversation_summary )
240+ self ._summary .per_conversation_summaries = list (
241+ self ._per_conversation_summaries
242+ )
243+ self ._reporter .add_usage_summary (self ._summary .generate ().model_copy (deep = True ))
85244
86245 @property
87246 def accumulated_usage (self ) -> UsageSummary :
88247 """Current accumulated usage statistics."""
89248 return self ._summary
90249
91- def _accumulate (self , step_usage : UsageParam ) -> None :
92- # Add step tokens to running totals (None counts as 0)
93- self ._summary .input_tokens = (self ._summary .input_tokens or 0 ) + (
94- step_usage .input_tokens or 0
95- )
96- self ._summary .output_tokens = (self ._summary .output_tokens or 0 ) + (
97- step_usage .output_tokens or 0
98- )
99- self ._summary .cache_creation_input_tokens = (
100- self ._summary .cache_creation_input_tokens or 0
101- ) + (step_usage .cache_creation_input_tokens or 0 )
102- self ._summary .cache_read_input_tokens = (
103- self ._summary .cache_read_input_tokens or 0
104- ) + (step_usage .cache_read_input_tokens or 0 )
105-
106- # Record per-step token counts on the current OTel span
107- current_span = trace .get_current_span ()
108- current_span .set_attributes (
109- {
110- "input_tokens" : step_usage .input_tokens or 0 ,
111- "output_tokens" : step_usage .output_tokens or 0 ,
112- "cache_creation_input_tokens" : (
113- step_usage .cache_creation_input_tokens or 0
114- ),
115- "cache_read_input_tokens" : (step_usage .cache_read_input_tokens or 0 ),
116- }
250+ def _create_step_summary (
251+ self , step_index : int , usage : UsageParam
252+ ) -> StepUsageSummary :
253+ return StepUsageSummary (
254+ step_index = step_index ,
255+ input_tokens = usage .input_tokens or 0 ,
256+ output_tokens = usage .output_tokens or 0 ,
257+ cache_creation_input_tokens = usage .cache_creation_input_tokens or 0 ,
258+ cache_read_input_tokens = usage .cache_read_input_tokens or 0 ,
259+ input_cost_per_million_tokens = self ._summary .input_cost_per_million_tokens ,
260+ output_cost_per_million_tokens = self ._summary .output_cost_per_million_tokens ,
261+ cache_write_cost_per_million_tokens = (
262+ self ._summary .cache_write_cost_per_million_tokens
263+ ),
264+ cache_read_cost_per_million_tokens = (
265+ self ._summary .cache_read_cost_per_million_tokens
266+ ),
117267 )
118268
119- # Update costs from updated totals if pricing values are set
120- if not (
121- self . _pricing
122- and self . _pricing . input_cost_per_million_tokens
123- and self . _pricing . output_cost_per_million_tokens
124- ):
125- return
126-
127- input_cost = (
128- self ._summary .input_tokens
129- * self ._pricing . input_cost_per_million_tokens
130- / 1e6
131- )
132- output_cost = (
133- self ._summary . output_tokens
134- * self . _pricing . output_cost_per_million_tokens
135- / 1e6
136- )
137- self . _summary . input_cost = input_cost
138- self ._summary . output_cost = output_cost
139- self . _summary . total_cost = input_cost + output_cost
140- self . _summary . currency = self . _pricing . currency
141- self ._summary . input_cost_per_million_tokens = (
142- self . _pricing . input_cost_per_million_tokens
143- )
144- self ._summary . output_cost_per_million_tokens = (
145- self . _pricing . output_cost_per_million_tokens
269+ def _create_conversation_summary (
270+ self ,
271+ conversation : Conversation ,
272+ generated_step_summaries : list [ StepUsageSummary ],
273+ ) -> ConversationUsageSummary :
274+ conversation_summary = ConversationUsageSummary (
275+ conversation_index = self . _conversation_index ,
276+ conversation_id = conversation . conversation_id ,
277+ step_summaries = generated_step_summaries ,
278+ input_tokens = self ._per_conversation_usage .input_tokens ,
279+ output_tokens = self ._per_conversation_usage . output_tokens ,
280+ cache_creation_input_tokens = (
281+ self . _per_conversation_usage . cache_creation_input_tokens
282+ ),
283+ cache_read_input_tokens = self ._per_conversation_usage . cache_read_input_tokens ,
284+ input_cost_per_million_tokens = (
285+ self . _per_conversation_usage . input_cost_per_million_tokens
286+ ),
287+ output_cost_per_million_tokens = (
288+ self ._per_conversation_usage . output_cost_per_million_tokens
289+ ),
290+ cache_write_cost_per_million_tokens = (
291+ self ._per_conversation_usage . cache_write_cost_per_million_tokens
292+ ),
293+ cache_read_cost_per_million_tokens = (
294+ self ._per_conversation_usage . cache_read_cost_per_million_tokens
295+ ),
146296 )
297+ return conversation_summary .generate ()
0 commit comments