Skip to content

Commit 31d7c07

Browse files
fix(pydantic): retype wrapper spans as task to stop cost double-counting (#312)
Agent wrapper spans (agent_run, agent_run_sync, agent_run_stream*, agent_to_cli_sync, model_request*) were tagged type=llm and logged the same usage metrics as their nested leaf `chat <model>` span. Experiment aggregations that sum metrics across type=llm spans therefore counted a single provider call twice (wrapper + leaf), inflating reported tokens and cost ~2x for a single-turn agent and more for multi-turn runs. Retag every wrapper span as SpanTypeAttribute.TASK; only the leaf `chat <model>` emitted by _wrap_concrete_model_class stays LLM. _DirectStreamWrapper is used both as wrapper (direct.model_request_stream) and as leaf (Model.request_stream), so it gains a span_type parameter defaulting to LLM; wrapper callers pass TASK explicitly. Test coverage: flip existing wrapper-type assertions to TASK (leaf chat_span assertions stay LLM) and extend the cassette-backed test_agent_run_async with a regression check that exactly one type=llm span exists and that prompt/completion tokens summed across llm spans equal the leaf's values. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 680f030 commit 31d7c07

3 files changed

Lines changed: 61 additions & 34 deletions

File tree

py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def test_agent_run_async(memory_logger):
125125
assert chat_span is not None, "chat span not found"
126126

127127
# Check agent span
128-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
128+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
129129
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
130130
assert agent_span["metadata"]["provider"] == "openai"
131131
assert TEST_PROMPT in str(agent_span["input"])
@@ -146,6 +146,18 @@ async def test_agent_run_async(memory_logger):
146146
assert agent_span["metrics"]["prompt_tokens"] > 0
147147
assert agent_span["metrics"]["completion_tokens"] > 0
148148

149+
# Regression: no double-counting of cost/tokens. Experiment-level aggregations
150+
# sum metrics across type='llm' spans, so a single agent turn must contribute
151+
# its tokens exactly once. The wrapper agent_run span logs the same usage as
152+
# the leaf chat span; only the leaf should be type=LLM.
153+
llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM]
154+
assert len(llm_spans) == 1, f"expected exactly one LLM-typed span, got {len(llm_spans)}"
155+
assert llm_spans[0]["span_id"] == chat_span["span_id"]
156+
llm_prompt_tokens_sum = sum(s["metrics"].get("prompt_tokens", 0) for s in llm_spans)
157+
llm_completion_tokens_sum = sum(s["metrics"].get("completion_tokens", 0) for s in llm_spans)
158+
assert llm_prompt_tokens_sum == chat_span["metrics"]["prompt_tokens"]
159+
assert llm_completion_tokens_sum == chat_span["metrics"]["completion_tokens"]
160+
149161

150162
@pytest.mark.vcr
151163
@pytest.mark.asyncio
@@ -205,7 +217,7 @@ def test_agent_run_sync(memory_logger):
205217
assert chat_span is not None, "chat span not found"
206218

207219
# Check agent span
208-
assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
220+
assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
209221
assert agent_sync_span["metadata"]["model"] == "gpt-4o-mini"
210222
assert agent_sync_span["metadata"]["provider"] == "openai"
211223
assert TEST_PROMPT in str(agent_sync_span["input"])
@@ -287,7 +299,7 @@ async def fake_run_chat(
287299
assert len(spans) == 1, f"Expected 1 CLI span, got {len(spans)}"
288300

289301
cli_span = spans[0]
290-
assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
302+
assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
291303
assert cli_span["span_attributes"]["name"] == "agent_to_cli_sync [cli-agent]"
292304
assert cli_span["metadata"]["model"] == "gpt-4o-mini"
293305
assert cli_span["metadata"]["provider"] == "openai"
@@ -497,7 +509,7 @@ async def test_agent_run_stream(memory_logger):
497509
assert chat_span is not None, "chat span not found"
498510

499511
# Check agent span
500-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
512+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
501513
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
502514
assert "Count from 1 to 5" in str(agent_span["input"])
503515
_assert_metrics_are_valid(agent_span["metrics"], start, end)
@@ -607,7 +619,7 @@ async def test_direct_model_request(memory_logger, direct):
607619
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
608620
assert direct_span is not None
609621

610-
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
622+
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
611623
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
612624
assert direct_span["metadata"]["provider"] == "openai"
613625
assert TEST_PROMPT in str(direct_span["input"])
@@ -637,7 +649,7 @@ def test_direct_model_request_sync(memory_logger, direct):
637649
# Find the model_request_sync span
638650
span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_sync"), None)
639651
assert span is not None, "model_request_sync span not found"
640-
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
652+
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
641653
assert span["metadata"]["model"] == "gpt-4o-mini"
642654
assert TEST_PROMPT in str(span["input"])
643655
_assert_metrics_are_valid(span["metrics"], start, end)
@@ -668,7 +680,7 @@ async def test_direct_model_request_with_settings(memory_logger, direct):
668680
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
669681
assert direct_span is not None
670682

671-
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
683+
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
672684

673685
# Verify model_settings is in input (NOT metadata)
674686
assert "model_settings" in direct_span["input"], "model_settings should be in input"
@@ -713,7 +725,7 @@ async def test_direct_model_request_stream(memory_logger, direct):
713725
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_stream"), None)
714726
assert direct_span is not None
715727

716-
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
728+
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
717729
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
718730
_assert_metrics_are_valid(direct_span["metrics"], start, end)
719731

@@ -804,7 +816,7 @@ class MathAnswer(BaseModel):
804816
assert chat_span is not None, "chat span not found"
805817

806818
# Check agent span
807-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
819+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
808820
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
809821
assert agent_span["metadata"]["provider"] == "openai"
810822
assert "10 + 15" in str(agent_span["input"])
@@ -1092,7 +1104,7 @@ def test_agent_run_stream_sync(memory_logger):
10921104
assert chat_span is not None, "chat span not found"
10931105

10941106
# Check agent span
1095-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1107+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
10961108
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
10971109
assert "Count from 1 to 3" in str(agent_span["input"])
10981110
_assert_metrics_are_valid(agent_span["metrics"], start, end)
@@ -1165,7 +1177,7 @@ async def test_agent_run_stream_events(memory_logger):
11651177
assert agent_span is not None, "agent_run_stream_events span not found"
11661178

11671179
# Check agent span has basic structure
1168-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1180+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
11691181
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
11701182
assert "5+5" in str(agent_span["input"]) or "What" in str(agent_span["input"])
11711183
assert agent_span["metrics"]["event_count"] == event_count
@@ -1194,7 +1206,7 @@ def test_direct_model_request_stream_sync(memory_logger, direct):
11941206
assert len(spans) == 1
11951207

11961208
span = spans[0]
1197-
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1209+
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
11981210
assert span["span_attributes"]["name"] == "model_request_stream_sync"
11991211
assert span["metadata"]["model"] == "gpt-4o-mini"
12001212
_assert_metrics_are_valid(span["metrics"], start, end)
@@ -1258,7 +1270,7 @@ async def stream_wrapper():
12581270
assert len(spans) >= 1, "Should have at least one span even with early break"
12591271

12601272
span = spans[0]
1261-
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1273+
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
12621274
assert span["span_attributes"]["name"] == "model_request_stream"
12631275

12641276

@@ -1297,7 +1309,7 @@ async def test_agent_stream_early_break(memory_logger):
12971309

12981310
# Verify at least agent_run_stream span exists and has basic structure
12991311
if agent_span:
1300-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1312+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
13011313
# Metrics may be incomplete due to early break
13021314
assert "start" in agent_span["metrics"]
13031315

@@ -1368,7 +1380,7 @@ async def _buffer_stream() -> LLMStreamResponse:
13681380
assert len(spans) >= 1, "Should have at least one span even with early return"
13691381

13701382
span = spans[0]
1371-
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1383+
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
13721384
assert span["span_attributes"]["name"] == "model_request_stream"
13731385
assert "start" in span["metrics"]
13741386
assert span["metrics"]["start"] >= start
@@ -1446,7 +1458,7 @@ async def _consume_until_final() -> StreamEvent:
14461458
# Find agent_run_stream span
14471459
agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
14481460
assert agent_span is not None, "agent_run_stream span should exist"
1449-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1461+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
14501462
assert "start" in agent_span["metrics"]
14511463

14521464

@@ -1500,7 +1512,7 @@ async def test_agent_with_binary_content(memory_logger):
15001512
assert chat_span is not None, "chat span not found"
15011513

15021514
# Verify basic span structure
1503-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1515+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
15041516
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
15051517
_assert_metrics_are_valid(agent_span["metrics"], start, end)
15061518

@@ -2113,7 +2125,7 @@ class Product(BaseModel):
21132125
assert chat_span is not None, "chat span not found"
21142126

21152127
# Check agent span
2116-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
2128+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
21172129
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
21182130
_assert_metrics_are_valid(agent_span["metrics"], start, end)
21192131

@@ -2663,7 +2675,7 @@ async def test_no_model_agent_run(memory_logger):
26632675
assert agent_span is not None, "agent_run span not found"
26642676
assert chat_span is not None, "chat span not found"
26652677

2666-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
2678+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
26672679
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
26682680
assert agent_span["metadata"]["provider"] == "openai"
26692681
assert TEST_PROMPT in str(agent_span["input"])

py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_logfire.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ async def test_no_model_agent_run_with_logfire(memory_logger):
6666
assert agent_span is not None, "agent_run span not found"
6767
assert chat_span is not None, "chat span not found"
6868

69-
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
69+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
7070
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
7171
assert agent_span["metadata"]["provider"] == "openai"
7272
assert TEST_PROMPT in str(agent_span["input"])

py/src/braintrust/integrations/pydantic_ai/tracing.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ async def _agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any
7070

7171
with start_span(
7272
name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
73-
type=SpanTypeAttribute.LLM,
73+
type=SpanTypeAttribute.TASK,
7474
input=input_data if input_data else None,
7575
metadata=metadata,
7676
) as agent_span:
@@ -96,7 +96,7 @@ def _agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any)
9696

9797
with start_span(
9898
name=f"agent_run_sync [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run_sync",
99-
type=SpanTypeAttribute.LLM,
99+
type=SpanTypeAttribute.TASK,
100100
input=input_data if input_data else None,
101101
metadata=metadata,
102102
) as agent_span:
@@ -124,7 +124,7 @@ def _agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: A
124124
name=f"agent_to_cli_sync [{instance.name}]"
125125
if hasattr(instance, "name") and instance.name
126126
else "agent_to_cli_sync",
127-
type=SpanTypeAttribute.LLM,
127+
type=SpanTypeAttribute.TASK,
128128
input=input_data if input_data else None,
129129
metadata=metadata,
130130
) as agent_span:
@@ -156,7 +156,7 @@ def _agent_run_stream_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwarg
156156
# Create span context BEFORE calling wrapped function so internal spans nest under it
157157
span_cm = start_span(
158158
name=span_name,
159-
type=SpanTypeAttribute.LLM,
159+
type=SpanTypeAttribute.TASK,
160160
input=input_data if input_data else None,
161161
metadata=metadata,
162162
)
@@ -189,7 +189,7 @@ async def _agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: An
189189

190190
with start_span(
191191
name=span_name,
192-
type=SpanTypeAttribute.LLM,
192+
type=SpanTypeAttribute.TASK,
193193
input=input_data if input_data else None,
194194
metadata=metadata,
195195
) as agent_span:
@@ -236,7 +236,7 @@ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
236236

237237
with start_span(
238238
name="model_request",
239-
type=SpanTypeAttribute.LLM,
239+
type=SpanTypeAttribute.TASK,
240240
input=input_data,
241241
metadata=metadata,
242242
) as span:
@@ -261,7 +261,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
261261

262262
with start_span(
263263
name="model_request_sync",
264-
type=SpanTypeAttribute.LLM,
264+
type=SpanTypeAttribute.TASK,
265265
input=input_data,
266266
metadata=metadata,
267267
) as span:
@@ -289,6 +289,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
289289
"model_request_stream",
290290
input_data,
291291
metadata,
292+
span_type=SpanTypeAttribute.TASK,
292293
)
293294

294295
return wrapper
@@ -316,7 +317,7 @@ async def wrapper(*args, **kwargs):
316317

317318
with start_span(
318319
name="model_request",
319-
type=SpanTypeAttribute.LLM,
320+
type=SpanTypeAttribute.TASK,
320321
input=input_data,
321322
metadata=metadata,
322323
) as span:
@@ -339,7 +340,7 @@ def wrapper(*args, **kwargs):
339340

340341
with start_span(
341342
name="model_request_sync",
342-
type=SpanTypeAttribute.LLM,
343+
type=SpanTypeAttribute.TASK,
343344
input=input_data,
344345
metadata=metadata,
345346
) as span:
@@ -365,6 +366,7 @@ def wrapper(*args, **kwargs):
365366
"model_request_stream",
366367
input_data,
367368
metadata,
369+
span_type=SpanTypeAttribute.TASK,
368370
)
369371

370372
return wrapper
@@ -466,7 +468,7 @@ async def __aenter__(self):
466468
# DON'T pass start_time here - we'll set it via metrics in __aexit__
467469
self.span_cm = start_span(
468470
name=self.span_name,
469-
type=SpanTypeAttribute.LLM,
471+
type=SpanTypeAttribute.TASK,
470472
input=self.input_data if self.input_data else None,
471473
metadata=self.metadata,
472474
)
@@ -535,13 +537,26 @@ async def wrapped_method(*args, **kwargs):
535537

536538

537539
class _DirectStreamWrapper(AbstractAsyncContextManager):
538-
"""Wrapper for model_request_stream() that adds tracing while passing through the stream."""
540+
"""Wrapper for model_request_stream() that adds tracing while passing through the stream.
539541
540-
def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
542+
Used both as the leaf `chat <model>` span (from `_wrap_concrete_model_class`, default
543+
`span_type=LLM`) and as a non-leaf wrapper around a nested model call (from
544+
`direct.model_request_stream`, which passes `span_type=TASK` to avoid double-counting).
545+
"""
546+
547+
def __init__(
548+
self,
549+
stream_cm: Any,
550+
span_name: str,
551+
input_data: Any,
552+
metadata: Any,
553+
span_type: str = SpanTypeAttribute.LLM,
554+
):
541555
self.stream_cm = stream_cm
542556
self.span_name = span_name
543557
self.input_data = input_data
544558
self.metadata = metadata
559+
self.span_type = span_type
545560
self.span_cm = None
546561
self.start_time = None
547562
self.stream = None
@@ -555,7 +570,7 @@ async def __aenter__(self):
555570
# DON'T pass start_time here - we'll set it via metrics in __aexit__
556571
self.span_cm = start_span(
557572
name=self.span_name,
558-
type=SpanTypeAttribute.LLM,
573+
type=self.span_type,
559574
input=self.input_data if self.input_data else None,
560575
metadata=self.metadata,
561576
)
@@ -723,7 +738,7 @@ def __enter__(self):
723738
# DON'T pass start_time here - we'll set it via metrics in __exit__
724739
self.span_cm = start_span(
725740
name=self.span_name,
726-
type=SpanTypeAttribute.LLM,
741+
type=SpanTypeAttribute.TASK,
727742
input=self.input_data if self.input_data else None,
728743
metadata=self.metadata,
729744
)

0 commit comments

Comments
 (0)