Skip to content

Commit d34435b

Browse files
fern-supportclaude
andcommitted
feat: Emit message-start and content-start events in OCI streaming
Complete the V2 streaming protocol lifecycle: message-start → content-start → content-delta* → content-end → message-end Previously only content-delta, content-end, and message-end were emitted, causing consumers expecting message-start to fail. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 02c65b4 commit d34435b

2 files changed

Lines changed: 67 additions & 3 deletions

File tree

src/cohere/oci_client.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,13 +770,20 @@ def transform_oci_stream_wrapper(
770770
"""
771771
Wrap OCI stream and transform events to Cohere V2 format.
772772
773+
Emits the full V2 streaming lifecycle:
774+
message-start -> content-start -> content-delta* -> content-end -> message-end
775+
773776
Args:
774777
stream: Original OCI stream iterator
775778
endpoint: Cohere endpoint name
776779
777780
Yields:
778781
Bytes of transformed streaming events
779782
"""
783+
import logging
784+
785+
generation_id = str(uuid.uuid4())
786+
emitted_start = False
780787
buffer = b""
781788
for chunk in stream:
782789
buffer += chunk
@@ -795,14 +802,39 @@ def transform_oci_stream_wrapper(
795802
try:
796803
oci_event = json.loads(data_str)
797804
except json.JSONDecodeError:
798-
import logging
799805
logging.warning(
800806
"OCI stream: failed to parse SSE event as JSON (endpoint=%s, data=%r)",
801807
endpoint, data_str[:200],
802808
)
803809
continue
804810

805811
try:
812+
# Emit message-start and content-start before first content delta
813+
if not emitted_start:
814+
# Detect content type from first event
815+
content_type = "text"
816+
if "message" in oci_event and "content" in oci_event["message"]:
817+
content_list = oci_event["message"]["content"]
818+
if content_list and isinstance(content_list, list) and len(content_list) > 0:
819+
oci_type = content_list[0].get("type", "TEXT").upper()
820+
if oci_type == "THINKING":
821+
content_type = "thinking"
822+
823+
message_start = {
824+
"type": "message-start",
825+
"id": generation_id,
826+
"delta": {"message": {"role": "assistant"}},
827+
}
828+
yield b"data: " + json.dumps(message_start).encode("utf-8") + b"\n\n"
829+
830+
content_start = {
831+
"type": "content-start",
832+
"index": 0,
833+
"delta": {"message": {"content": {"type": content_type}}},
834+
}
835+
yield b"data: " + json.dumps(content_start).encode("utf-8") + b"\n\n"
836+
emitted_start = True
837+
806838
cohere_event = transform_stream_event(endpoint, oci_event)
807839
yield b"data: " + json.dumps(cohere_event).encode("utf-8") + b"\n\n"
808840
except Exception as e:

tests/test_oci_client.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,38 @@ def test_load_oci_config_missing_private_key_raises(self):
555555
)
556556
self.assertIn("oci_private_key_path", str(ctx.exception))
557557

558+
def test_stream_wrapper_emits_full_event_lifecycle(self):
559+
"""Test that stream emits message-start, content-start, content-delta, content-end, message-end."""
560+
import json
561+
from cohere.oci_client import transform_oci_stream_wrapper
562+
563+
chunks = [
564+
b'data: {"message": {"content": [{"type": "TEXT", "text": "Hello"}]}}\n',
565+
b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE"}\n',
566+
b'data: [DONE]\n',
567+
]
568+
569+
events = []
570+
for raw in transform_oci_stream_wrapper(iter(chunks), "chat"):
571+
line = raw.decode("utf-8").strip()
572+
if line.startswith("data: "):
573+
events.append(json.loads(line[6:]))
574+
575+
event_types = [e["type"] for e in events]
576+
self.assertEqual(event_types[0], "message-start")
577+
self.assertEqual(event_types[1], "content-start")
578+
self.assertEqual(event_types[2], "content-delta")
579+
self.assertEqual(event_types[3], "content-end")
580+
self.assertEqual(event_types[4], "message-end")
581+
582+
# Verify message-start has id and role
583+
self.assertIn("id", events[0])
584+
self.assertEqual(events[0]["delta"]["message"]["role"], "assistant")
585+
586+
# Verify content-start has index and type
587+
self.assertEqual(events[1]["index"], 0)
588+
self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "text")
589+
558590
def test_stream_wrapper_skips_malformed_json_with_warning(self):
559591
"""Test that malformed JSON in SSE stream is skipped (not silently swallowed)."""
560592
import json
@@ -566,8 +598,8 @@ def test_stream_wrapper_skips_malformed_json_with_warning(self):
566598
b'data: [DONE]\n',
567599
]
568600
events = list(transform_oci_stream_wrapper(iter(chunks), "chat"))
569-
# Should get content-delta + message-end (malformed line skipped)
570-
self.assertEqual(len(events), 2)
601+
# Should get message-start + content-start + content-delta + message-end (malformed line skipped)
602+
self.assertEqual(len(events), 4)
571603

572604
def test_stream_wrapper_raises_on_transform_error(self):
573605
"""Test that transform errors in stream produce OCI-specific error, not opaque httpx error."""

0 commit comments

Comments
 (0)