Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self._iterator = self.__stream__()
self.__final_message_snapshot: ParsedMessage[ResponseFormatT] | None = None
self.__output_format = output_format
self._emit_events = True

@property
def response(self) -> httpx.Response:
Expand Down Expand Up @@ -117,6 +118,7 @@ def get_final_text(self) -> str:

def until_done(self) -> None:
"""Blocks until the stream has been consumed"""
self._emit_events = False
consume_sync_iterator(self)

# properties
Expand All @@ -131,11 +133,13 @@ def __stream__(self) -> Iterator[ParsedMessageStreamEvent[ResponseFormatT]]:
event=sse_event,
current_snapshot=self.__final_message_snapshot,
output_format=self.__output_format,
emit_events=self._emit_events,
)

events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
for event in events_to_fire:
yield event
if self._emit_events:
events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
for event in events_to_fire:
yield event

def __stream_text__(self) -> Iterator[str]:
for chunk in self:
Expand Down Expand Up @@ -199,6 +203,7 @@ def __init__(
self._iterator = self.__stream__()
self.__final_message_snapshot: ParsedMessage[ResponseFormatT] | None = None
self.__output_format = output_format
self._emit_events = True

@property
def response(self) -> httpx.Response:
Expand Down Expand Up @@ -265,6 +270,7 @@ async def get_final_text(self) -> str:

async def until_done(self) -> None:
"""Waits until the stream has been consumed"""
self._emit_events = False
await consume_async_iterator(self)

# properties
Expand All @@ -279,11 +285,13 @@ async def __stream__(self) -> AsyncIterator[ParsedMessageStreamEvent[ResponseFor
event=sse_event,
current_snapshot=self.__final_message_snapshot,
output_format=self.__output_format,
emit_events=self._emit_events,
)

events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
for event in events_to_fire:
yield event
if self._emit_events:
events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
for event in events_to_fire:
yield event

async def __stream_text__(self) -> AsyncIterator[str]:
async for chunk in self:
Expand Down Expand Up @@ -423,6 +431,7 @@ def build_events(


JSON_BUF_PROPERTY = "__json_buf"
TEXT_BUF_PROPERTY = "__text_buf"

TRACKS_TOOL_INPUT = (
ToolUseBlock,
Expand All @@ -435,6 +444,7 @@ def accumulate_event(
event: RawMessageStreamEvent,
current_snapshot: ParsedMessage[ResponseFormatT] | None,
output_format: ResponseFormatT | NotGiven = NOT_GIVEN,
emit_events: bool = True,
) -> ParsedMessage[ResponseFormatT]:
if not isinstance(cast(Any, event), BaseModel):
event = cast( # pyright: ignore[reportUnnecessaryCast]
Expand Down Expand Up @@ -465,7 +475,16 @@ def accumulate_event(
content = current_snapshot.content[event.index]
if event.delta.type == "text_delta":
if content.type == "text":
content.text += event.delta.text
if emit_events:
content.text += event.delta.text
else:
# Drain path: accumulate into a list to avoid O(n²) attribute-target concatenation.
# The buffer is joined once at content_block_stop.
text_buf = cast("list[str] | None", getattr(content, TEXT_BUF_PROPERTY, None))
if text_buf is None:
text_buf = [content.text] if content.text else []
setattr(content, TEXT_BUF_PROPERTY, text_buf)
text_buf.append(event.delta.text)
elif event.delta.type == "input_json_delta":
if isinstance(content, TRACKS_TOOL_INPUT):
from jiter import from_json
Expand All @@ -476,7 +495,9 @@ def accumulate_event(
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
if emit_events and json_buf:
# Iteration path: keep content.input updated on each delta so that
# InputJsonEvent.snapshot reflects the partial state.
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
Expand All @@ -498,6 +519,17 @@ def accumulate_event(
assert_never(event.delta)
elif event.type == "content_block_stop":
content_block = current_snapshot.content[event.index]
if not emit_events:
# Drain path: finalize deferred text and JSON buffers accumulated above.
text_buf = cast("list[str] | None", getattr(content_block, TEXT_BUF_PROPERTY, None))
if text_buf is not None:
content_block.text = "".join(text_buf) # type: ignore[union-attr]
if isinstance(content_block, TRACKS_TOOL_INPUT):
from jiter import from_json

json_buf = cast(bytes, getattr(content_block, JSON_BUF_PROPERTY, b""))
if json_buf:
content_block.input = from_json(json_buf, partial_mode=False)
if content_block.type == "text" and is_given(output_format):
content_block.parsed_output = parse_text(content_block.text, output_format)
elif event.type == "message_delta":
Expand Down
74 changes: 74 additions & 0 deletions tests/lib/streaming/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,80 @@ async def test_refusal_stop_details_propagated(self, respx_mock: MockRouter) ->
assert_refusal_response(await stream.get_final_message())


class TestDrainPath:
"""get_final_message() / until_done() must produce the same result as full event iteration."""

@pytest.mark.respx(base_url=base_url)
def test_get_final_message_without_iteration_basic(self, respx_mock: MockRouter) -> None:
respx_mock.post("/v1/messages").mock(
return_value=httpx.Response(200, content=get_response("basic_response.txt"))
)

with sync_client.messages.stream(
max_tokens=1024,
messages=[{"role": "user", "content": "Say hello there!"}],
model="claude-3-opus-latest",
) as stream:
message = stream.get_final_message()

assert message.content[0].type == "text"
assert message.content[0].text == "Hello there!"

@pytest.mark.respx(base_url=base_url)
def test_get_final_message_without_iteration_tool_use(self, respx_mock: MockRouter) -> None:
respx_mock.post("/v1/messages").mock(
return_value=httpx.Response(200, content=get_response("tool_use_response.txt"))
)

with sync_client.messages.stream(
max_tokens=1024,
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
model="claude-sonnet-4-5",
) as stream:
message = stream.get_final_message()

assert message.content[0].type == "text"
assert message.content[0].text == "I'll check the current weather in Paris for you."
assert message.content[1].type == "tool_use"
assert message.content[1].input == {"location": "Paris"}

@pytest.mark.asyncio
@pytest.mark.respx(base_url=base_url)
async def test_async_get_final_message_without_iteration_basic(self, respx_mock: MockRouter) -> None:
respx_mock.post("/v1/messages").mock(
return_value=httpx.Response(200, content=to_async_iter(get_response("basic_response.txt")))
)

async with async_client.messages.stream(
max_tokens=1024,
messages=[{"role": "user", "content": "Say hello there!"}],
model="claude-3-opus-latest",
) as stream:
message = await stream.get_final_message()

assert message.content[0].type == "text"
assert message.content[0].text == "Hello there!"

@pytest.mark.asyncio
@pytest.mark.respx(base_url=base_url)
async def test_async_get_final_message_without_iteration_tool_use(self, respx_mock: MockRouter) -> None:
respx_mock.post("/v1/messages").mock(
return_value=httpx.Response(200, content=to_async_iter(get_response("tool_use_response.txt")))
)

async with async_client.messages.stream(
max_tokens=1024,
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
model="claude-sonnet-4-5",
) as stream:
message = await stream.get_final_message()

assert message.content[0].type == "text"
assert message.content[0].text == "I'll check the current weather in Paris for you."
assert message.content[1].type == "tool_use"
assert message.content[1].input == {"location": "Paris"}


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_stream_method_definition_in_sync(sync: bool) -> None:
client: Anthropic | AsyncAnthropic = sync_client if sync else async_client
Expand Down