Skip to content

Commit d807b7b

Browse files
authored
fix(anthropic): capture nested usage metrics and metadata (#189)
* fix(anthropic): capture nested usage metrics and metadata Normalize Anthropic usage with explicit extraction for nested cache creation and server tool fields, and attach allowlisted usage metadata to Anthropic and Claude Agent SDK spans. This preserves new usage signals while keeping compatibility across older Anthropic SDK versions whose usage payloads differ slightly in shape. Extend the existing VCR coverage and shared Anthropic usage tests so the fix is validated across Anthropic and Claude Agent SDK integrations. Fixes #164 * PR feedback
1 parent 19ecb8a commit d807b7b

5 files changed

Lines changed: 169 additions & 86 deletions

File tree

py/src/braintrust/integrations/anthropic/_utils.py

Lines changed: 88 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import Any
44

5+
from braintrust.util import is_numeric
6+
57

68
class Wrapper:
79
"""Base wrapper class with __getattr__ delegation to preserve original types."""
@@ -13,73 +15,104 @@ def __getattr__(self, name: str) -> Any:
1315
return getattr(self.__wrapped, name)
1416

1517

16-
def extract_anthropic_usage(usage: Any) -> dict[str, float]:
17-
"""Extract and normalize usage metrics from Anthropic usage object or dict.
18+
_ANTHROPIC_USAGE_METRIC_FIELDS = (
19+
("input_tokens", "prompt_tokens"),
20+
("output_tokens", "completion_tokens"),
21+
("cache_read_input_tokens", "prompt_cached_tokens"),
22+
("cache_creation_input_tokens", "prompt_cache_creation_tokens"),
23+
)
1824

19-
Converts Anthropic's usage format to Braintrust's standard token metric names.
20-
Handles both object attributes and dictionary keys.
25+
_ANTHROPIC_CACHE_CREATION_METRIC_FIELDS = (
26+
("ephemeral_5m_input_tokens", "prompt_cache_creation_ephemeral_5m_tokens"),
27+
("ephemeral_1h_input_tokens", "prompt_cache_creation_ephemeral_1h_tokens"),
28+
)
2129

22-
Args:
23-
usage: Anthropic usage object (from Message.usage) or dict
30+
_ANTHROPIC_SERVER_TOOL_USE_METRIC_FIELDS = (
31+
("web_search_requests", "server_tool_use_web_search_requests"),
32+
("web_fetch_requests", "server_tool_use_web_fetch_requests"),
33+
)
2434

25-
Returns:
26-
Dictionary with normalized metric names:
27-
- prompt_tokens (from input_tokens)
28-
- completion_tokens (from output_tokens)
29-
- prompt_cached_tokens (from cache_read_input_tokens)
30-
- prompt_cache_creation_tokens (from cache_creation_input_tokens)
31-
"""
32-
metrics: dict[str, float] = {}
35+
_ANTHROPIC_USAGE_METADATA_FIELDS = frozenset(
36+
{
37+
"service_tier",
38+
"inference_geo",
39+
}
40+
)
3341

34-
if not usage:
35-
return metrics
3642

37-
def get_value(key: str) -> Any:
38-
if isinstance(usage, dict):
39-
return usage.get(key)
40-
return getattr(usage, key, None)
43+
def _try_to_dict(obj: Any) -> dict[str, Any] | None:
44+
if isinstance(obj, dict):
45+
return obj
4146

42-
input_tokens = get_value("input_tokens")
43-
if input_tokens is not None:
47+
if hasattr(obj, "model_dump"):
4448
try:
45-
metrics["prompt_tokens"] = float(input_tokens)
46-
except (ValueError, TypeError):
47-
pass
49+
candidate = obj.model_dump(mode="python")
50+
except TypeError:
51+
candidate = obj.model_dump()
52+
return candidate if isinstance(candidate, dict) else None
4853

49-
output_tokens = get_value("output_tokens")
50-
if output_tokens is not None:
51-
try:
52-
metrics["completion_tokens"] = float(output_tokens)
53-
except (ValueError, TypeError):
54-
pass
54+
if hasattr(obj, "to_dict"):
55+
candidate = obj.to_dict()
56+
return candidate if isinstance(candidate, dict) else None
5557

56-
cache_read_tokens = get_value("cache_read_input_tokens")
57-
if cache_read_tokens is not None:
58-
try:
59-
metrics["prompt_cached_tokens"] = float(cache_read_tokens)
60-
except (ValueError, TypeError):
61-
pass
58+
if hasattr(obj, "dict"):
59+
candidate = obj.dict()
60+
return candidate if isinstance(candidate, dict) else None
61+
62+
if hasattr(obj, "__dict__"):
63+
return vars(obj)
64+
65+
return None
6266

63-
cache_creation_tokens = get_value("cache_creation_input_tokens")
64-
if cache_creation_tokens is not None:
65-
try:
66-
metrics["prompt_cache_creation_tokens"] = float(cache_creation_tokens)
67-
except (ValueError, TypeError):
68-
pass
6967

70-
return metrics
68+
def _set_numeric_metric(metrics: dict[str, float], name: str, value: Any) -> None:
69+
if is_numeric(value):
70+
metrics[name] = float(value)
7171

7272

73-
def finalize_anthropic_tokens(metrics: dict[str, float]) -> dict[str, float]:
74-
"""Finalize Anthropic token calculations."""
75-
total_prompt_tokens = (
76-
metrics.get("prompt_tokens", 0)
77-
+ metrics.get("prompt_cached_tokens", 0)
78-
+ metrics.get("prompt_cache_creation_tokens", 0)
79-
)
73+
def extract_anthropic_usage(usage: Any) -> tuple[dict[str, float], dict[str, Any]]:
74+
"""Extract normalized metrics and allowlisted metadata from Anthropic usage.
8075
81-
return {
82-
**metrics,
83-
"prompt_tokens": total_prompt_tokens,
84-
"tokens": total_prompt_tokens + metrics.get("completion_tokens", 0),
76+
Numeric usage fields are converted into Braintrust metrics. Allowlisted
77+
non-numeric fields are attached as span metadata with a ``usage_`` prefix.
78+
"""
79+
usage = _try_to_dict(usage)
80+
if usage is None:
81+
return {}, {}
82+
83+
metrics: dict[str, float] = {}
84+
for source_name, metric_name in _ANTHROPIC_USAGE_METRIC_FIELDS:
85+
_set_numeric_metric(metrics, metric_name, usage.get(source_name))
86+
87+
cache_creation = _try_to_dict(usage.get("cache_creation"))
88+
cache_creation_breakdown: list[float] = []
89+
if cache_creation is not None:
90+
for source_name, metric_name in _ANTHROPIC_CACHE_CREATION_METRIC_FIELDS:
91+
value = cache_creation.get(source_name)
92+
_set_numeric_metric(metrics, metric_name, value)
93+
if is_numeric(value):
94+
cache_creation_breakdown.append(float(value))
95+
96+
server_tool_use = _try_to_dict(usage.get("server_tool_use"))
97+
if server_tool_use is not None:
98+
for source_name, metric_name in _ANTHROPIC_SERVER_TOOL_USE_METRIC_FIELDS:
99+
_set_numeric_metric(metrics, metric_name, server_tool_use.get(source_name))
100+
101+
if "prompt_cache_creation_tokens" not in metrics and cache_creation_breakdown:
102+
metrics["prompt_cache_creation_tokens"] = sum(cache_creation_breakdown)
103+
104+
if metrics:
105+
total_prompt_tokens = (
106+
metrics.get("prompt_tokens", 0)
107+
+ metrics.get("prompt_cached_tokens", 0)
108+
+ metrics.get("prompt_cache_creation_tokens", 0)
109+
)
110+
metrics["prompt_tokens"] = total_prompt_tokens
111+
metrics["tokens"] = total_prompt_tokens + metrics.get("completion_tokens", 0)
112+
113+
metadata = {
114+
f"usage_{name}": value
115+
for name, value in usage.items()
116+
if name in _ANTHROPIC_USAGE_METADATA_FIELDS and value is not None
85117
}
118+
return metrics, metadata

py/src/braintrust/integrations/anthropic/test_anthropic.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
from braintrust import logger
1313
from braintrust.integrations.anthropic import AnthropicIntegration, wrap_anthropic
14+
from braintrust.integrations.anthropic._utils import extract_anthropic_usage
1415
from braintrust.integrations.anthropic.tracing import _log_message_to_span
1516
from braintrust.test_helpers import init_test_logger
1617

@@ -73,6 +74,7 @@ def test_log_message_to_span_includes_stop_reason_and_stop_sequence():
7374
"tokens": 18.0,
7475
"time_to_first_token": 0.123,
7576
},
77+
metadata={},
7678
)
7779

7880

@@ -539,18 +541,82 @@ def test_setup_creates_spans(memory_logger):
539541
AnthropicIntegration.setup()
540542

541543
client = anthropic.Anthropic()
542-
client.messages.create(
544+
message = client.messages.create(
543545
model=MODEL,
544546
max_tokens=100,
545547
messages=[{"role": "user", "content": "hi"}],
546548
)
547549

550+
usage = message.usage
551+
548552
spans = memory_logger.pop()
549553
assert len(spans) == 1
550554
span = spans[0]
551555
assert span["metadata"]["model"] == MODEL
552556
assert span["metadata"]["provider"] == "anthropic"
553557

558+
cache_creation = getattr(usage, "cache_creation", None)
559+
if cache_creation is None:
560+
pytest.skip("Anthropic SDK version does not expose nested cache_creation usage fields")
561+
562+
if isinstance(cache_creation, dict):
563+
ephemeral_5m = cache_creation["ephemeral_5m_input_tokens"]
564+
ephemeral_1h = cache_creation["ephemeral_1h_input_tokens"]
565+
else:
566+
ephemeral_5m = cache_creation.ephemeral_5m_input_tokens
567+
ephemeral_1h = cache_creation.ephemeral_1h_input_tokens
568+
569+
assert span["metadata"]["usage_service_tier"] == usage.service_tier
570+
assert span["metadata"]["usage_inference_geo"] == usage.inference_geo
571+
metrics = span["metrics"]
572+
assert metrics["prompt_tokens"] == (
573+
usage.input_tokens + usage.cache_read_input_tokens + usage.cache_creation_input_tokens
574+
)
575+
assert metrics["completion_tokens"] == usage.output_tokens
576+
assert metrics["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens
577+
assert metrics["prompt_cache_creation_ephemeral_5m_tokens"] == ephemeral_5m
578+
assert metrics["prompt_cache_creation_ephemeral_1h_tokens"] == ephemeral_1h
579+
assert "service_tier" not in metrics
580+
581+
582+
def test_extract_anthropic_usage_preserves_nested_numeric_fields():
583+
usage = {
584+
"input_tokens": 8,
585+
"output_tokens": 12,
586+
"cache_creation": {
587+
"ephemeral_5m_input_tokens": 3,
588+
"ephemeral_1h_input_tokens": 4,
589+
},
590+
"server_tool_use": {
591+
"web_search_requests": 2,
592+
"web_fetch_requests": 1,
593+
},
594+
"service_tier": "standard",
595+
"inference_geo": "not_available",
596+
}
597+
metrics, metadata = extract_anthropic_usage(usage)
598+
599+
assert metrics["prompt_tokens"] == 15
600+
assert metrics["completion_tokens"] == 12
601+
assert metrics["tokens"] == 27
602+
assert metrics["prompt_cache_creation_tokens"] == 7
603+
assert metrics["prompt_cache_creation_ephemeral_5m_tokens"] == 3
604+
assert metrics["prompt_cache_creation_ephemeral_1h_tokens"] == 4
605+
assert metrics["server_tool_use_web_search_requests"] == 2
606+
assert metrics["server_tool_use_web_fetch_requests"] == 1
607+
assert "service_tier" not in metrics
608+
assert metadata == {
609+
"usage_service_tier": "standard",
610+
"usage_inference_geo": "not_available",
611+
}
612+
613+
614+
def test_extract_anthropic_usage_skips_empty_usage():
615+
metrics, metadata = extract_anthropic_usage(SimpleNamespace())
616+
617+
assert metrics == {}
618+
assert metadata == {}
619+
554620

555621
def _make_batch_requests():
556622
return [

py/src/braintrust/integrations/anthropic/tracing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from contextlib import contextmanager
55

6-
from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
6+
from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage
77
from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span
88

99

@@ -445,7 +445,7 @@ def _start_span(name, kwargs):
445445
def _log_message_to_span(message, span, time_to_first_token: float | None = None):
446446
with _catch_exceptions():
447447
usage = getattr(message, "usage", {})
448-
metrics = finalize_anthropic_tokens(extract_anthropic_usage(usage))
448+
metrics, metadata = extract_anthropic_usage(usage)
449449

450450
if time_to_first_token is not None:
451451
metrics["time_to_first_token"] = time_to_first_token
@@ -462,7 +462,7 @@ def _log_message_to_span(message, span, time_to_first_token: float | None = None
462462
if v is not None
463463
} or None
464464

465-
span.log(output=output, metrics=metrics)
465+
span.log(output=output, metrics=metrics, metadata=metadata)
466466

467467

468468
@contextmanager

py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
print("Claude Agent SDK not installed, skipping integration tests")
2525

2626
from braintrust import logger
27+
from braintrust.integrations.anthropic._utils import extract_anthropic_usage
2728
from braintrust.integrations.claude_agent_sdk import setup_claude_agent_sdk
2829
from braintrust.integrations.claude_agent_sdk._test_transport import make_cassette_transport
2930
from braintrust.integrations.claude_agent_sdk.tracing import (
3031
ToolSpanTracker,
3132
_build_llm_input,
3233
_create_client_wrapper_class,
3334
_create_tool_wrapper_class,
34-
_extract_usage_from_result_message,
3535
_parse_tool_name,
3636
_serialize_content_blocks,
3737
_serialize_system_message,
@@ -184,6 +184,8 @@ async def calculator_handler(args):
184184
for metric_name in ("prompt_tokens", "completion_tokens", "tokens"):
185185
if metric_name in llm_span.get("metrics", {}):
186186
assert llm_span["metrics"][metric_name] > 0
187+
assert any(llm_span.get("metadata", {}).get("usage_service_tier") == "standard" for llm_span in llm_spans)
188+
assert any("usage_inference_geo" in llm_span.get("metadata", {}) for llm_span in llm_spans)
187189
tool_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TOOL]
188190
for tool_span in tool_spans:
189191
assert tool_span["span_attributes"]["name"] == "calculator"
@@ -1828,9 +1830,9 @@ def test_serialize_system_message_extracts_known_fields(message, expected):
18281830
assert _serialize_system_message(message) == expected
18291831

18301832

1831-
def test_extract_usage_from_result_message_normalizes_anthropic_tokens():
1832-
metrics = _extract_usage_from_result_message(
1833-
ResultMessage(input_tokens=5, output_tokens=3, cache_creation_input_tokens=2)
1833+
def test_extract_anthropic_usage_normalizes_claude_result_message_usage():
1834+
metrics, metadata = extract_anthropic_usage(
1835+
ResultMessage(input_tokens=5, output_tokens=3, cache_creation_input_tokens=2).usage
18341836
)
18351837

18361838
assert metrics == {
@@ -1839,6 +1841,7 @@ def test_extract_usage_from_result_message_normalizes_anthropic_tokens():
18391841
"prompt_cache_creation_tokens": 2.0,
18401842
"tokens": 10.0,
18411843
}
1844+
assert metadata == {}
18421845

18431846

18441847
@pytest.mark.parametrize(

py/src/braintrust/integrations/claude_agent_sdk/tracing.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import AsyncGenerator, AsyncIterable
88
from typing import Any
99

10-
from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
10+
from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage
1111
from braintrust.integrations.claude_agent_sdk._constants import (
1212
ANTHROPIC_MESSAGES_CREATE_SPAN_NAME,
1313
CLAUDE_AGENT_TASK_SPAN_NAME,
@@ -711,10 +711,10 @@ def _handle_user(self, message: Any) -> None:
711711
def _handle_result(self, message: Any) -> None:
712712
self._active_key = None
713713
if hasattr(message, "usage"):
714-
usage_metrics = _extract_usage_from_result_message(message)
714+
usage_metrics, usage_metadata = extract_anthropic_usage(message.usage)
715715
ctx = self._get_context(None)
716-
if ctx.llm_span and usage_metrics:
717-
ctx.llm_span.log(metrics=usage_metrics)
716+
if ctx.llm_span and (usage_metrics or usage_metadata):
717+
ctx.llm_span.log(metrics=usage_metrics or None, metadata=usage_metadata or None)
718718
result_metadata = {
719719
k: v
720720
for k, v in {
@@ -1203,25 +1203,6 @@ def _serialize_content_blocks(content: Any) -> Any:
12031203
return content
12041204

12051205

1206-
def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]:
1207-
"""Extracts and normalizes usage metrics from a ResultMessage.
1208-
1209-
Uses shared Anthropic utilities for consistent metric extraction.
1210-
"""
1211-
if not hasattr(result_message, "usage"):
1212-
return {}
1213-
1214-
usage = result_message.usage
1215-
if not usage:
1216-
return {}
1217-
1218-
metrics = extract_anthropic_usage(usage)
1219-
if metrics:
1220-
metrics = finalize_anthropic_tokens(metrics)
1221-
1222-
return metrics
1223-
1224-
12251206
def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
12261207
"""Builds the input array for an LLM span from the initial prompt and conversation history.
12271208

0 commit comments

Comments
 (0)