diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index 5c0da999..97ce58ee 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -51,6 +51,7 @@ def __init__( self._iterator = self.__stream__() self.__final_message_snapshot: ParsedMessage[ResponseFormatT] | None = None self.__output_format = output_format + self._emit_events = True @property def response(self) -> httpx.Response: @@ -117,6 +118,7 @@ def get_final_text(self) -> str: def until_done(self) -> None: """Blocks until the stream has been consumed""" + self._emit_events = False consume_sync_iterator(self) # properties @@ -131,11 +133,13 @@ def __stream__(self) -> Iterator[ParsedMessageStreamEvent[ResponseFormatT]]: event=sse_event, current_snapshot=self.__final_message_snapshot, output_format=self.__output_format, + emit_events=self._emit_events, ) - events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) - for event in events_to_fire: - yield event + if self._emit_events: + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event def __stream_text__(self) -> Iterator[str]: for chunk in self: @@ -199,6 +203,7 @@ def __init__( self._iterator = self.__stream__() self.__final_message_snapshot: ParsedMessage[ResponseFormatT] | None = None self.__output_format = output_format + self._emit_events = True @property def response(self) -> httpx.Response: @@ -265,6 +270,7 @@ async def get_final_text(self) -> str: async def until_done(self) -> None: """Waits until the stream has been consumed""" + self._emit_events = False await consume_async_iterator(self) # properties @@ -279,11 +285,13 @@ async def __stream__(self) -> AsyncIterator[ParsedMessageStreamEvent[ResponseFor event=sse_event, current_snapshot=self.__final_message_snapshot, output_format=self.__output_format, + emit_events=self._emit_events, ) - events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) - for event in events_to_fire: - yield event + if self._emit_events: + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event async def __stream_text__(self) -> AsyncIterator[str]: async for chunk in self: @@ -423,6 +431,7 @@ def build_events( JSON_BUF_PROPERTY = "__json_buf" +TEXT_BUF_PROPERTY = "__text_buf" TRACKS_TOOL_INPUT = ( ToolUseBlock, @@ -435,6 +444,7 @@ def accumulate_event( event: RawMessageStreamEvent, current_snapshot: ParsedMessage[ResponseFormatT] | None, output_format: ResponseFormatT | NotGiven = NOT_GIVEN, + emit_events: bool = True, ) -> ParsedMessage[ResponseFormatT]: if not isinstance(cast(Any, event), BaseModel): event = cast( # pyright: ignore[reportUnnecessaryCast] @@ -465,7 +475,16 @@ def accumulate_event( content = current_snapshot.content[event.index] if event.delta.type == "text_delta": if content.type == "text": - content.text += event.delta.text + if emit_events: + content.text += event.delta.text + else: + # Drain path: accumulate into a list to avoid O(n²) attribute-target concatenation. + # The buffer is joined once at content_block_stop. + text_buf = cast("list[str] | None", getattr(content, TEXT_BUF_PROPERTY, None)) + if text_buf is None: + text_buf = [content.text] if content.text else [] + setattr(content, TEXT_BUF_PROPERTY, text_buf) + text_buf.append(event.delta.text) elif event.delta.type == "input_json_delta": if isinstance(content, TRACKS_TOOL_INPUT): from jiter import from_json @@ -476,7 +495,9 @@ def accumulate_event( json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b"")) json_buf += bytes(event.delta.partial_json, "utf-8") - if json_buf: + if emit_events and json_buf: + # Iteration path: keep content.input updated on each delta so that + # InputJsonEvent.snapshot reflects the partial state. content.input = from_json(json_buf, partial_mode=True) setattr(content, JSON_BUF_PROPERTY, json_buf) @@ -498,6 +519,17 @@ def accumulate_event( assert_never(event.delta) elif event.type == "content_block_stop": content_block = current_snapshot.content[event.index] + if not emit_events: + # Drain path: finalize deferred text and JSON buffers accumulated above. + text_buf = cast("list[str] | None", getattr(content_block, TEXT_BUF_PROPERTY, None)) + if text_buf is not None: + content_block.text = "".join(text_buf) # type: ignore[union-attr] + if isinstance(content_block, TRACKS_TOOL_INPUT): + from jiter import from_json + + json_buf = cast(bytes, getattr(content_block, JSON_BUF_PROPERTY, b"")) + if json_buf: + content_block.input = from_json(json_buf, partial_mode=False) if content_block.type == "text" and is_given(output_format): content_block.parsed_output = parse_text(content_block.text, output_format) elif event.type == "message_delta": diff --git a/tests/lib/streaming/test_messages.py b/tests/lib/streaming/test_messages.py index b86a3906..ec3ba438 100644 --- a/tests/lib/streaming/test_messages.py +++ b/tests/lib/streaming/test_messages.py @@ -306,6 +306,80 @@ async def test_refusal_stop_details_propagated(self, respx_mock: MockRouter) -> assert_refusal_response(await stream.get_final_message()) +class TestDrainPath: + """get_final_message() / until_done() must produce the same result as full event iteration.""" + + @pytest.mark.respx(base_url=base_url) + def test_get_final_message_without_iteration_basic(self, respx_mock: MockRouter) -> None: + respx_mock.post("/v1/messages").mock( + return_value=httpx.Response(200, content=get_response("basic_response.txt")) + ) + + with sync_client.messages.stream( + max_tokens=1024, + messages=[{"role": "user", "content": "Say hello there!"}], + model="claude-3-opus-latest", + ) as stream: + message = stream.get_final_message() + + assert message.content[0].type == "text" + assert message.content[0].text == "Hello there!" + + @pytest.mark.respx(base_url=base_url) + def test_get_final_message_without_iteration_tool_use(self, respx_mock: MockRouter) -> None: + respx_mock.post("/v1/messages").mock( + return_value=httpx.Response(200, content=get_response("tool_use_response.txt")) + ) + + with sync_client.messages.stream( + max_tokens=1024, + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + model="claude-sonnet-4-5", + ) as stream: + message = stream.get_final_message() + + assert message.content[0].type == "text" + assert message.content[0].text == "I'll check the current weather in Paris for you." + assert message.content[1].type == "tool_use" + assert message.content[1].input == {"location": "Paris"} + + @pytest.mark.asyncio + @pytest.mark.respx(base_url=base_url) + async def test_async_get_final_message_without_iteration_basic(self, respx_mock: MockRouter) -> None: + respx_mock.post("/v1/messages").mock( + return_value=httpx.Response(200, content=to_async_iter(get_response("basic_response.txt"))) + ) + + async with async_client.messages.stream( + max_tokens=1024, + messages=[{"role": "user", "content": "Say hello there!"}], + model="claude-3-opus-latest", + ) as stream: + message = await stream.get_final_message() + + assert message.content[0].type == "text" + assert message.content[0].text == "Hello there!" + + @pytest.mark.asyncio + @pytest.mark.respx(base_url=base_url) + async def test_async_get_final_message_without_iteration_tool_use(self, respx_mock: MockRouter) -> None: + respx_mock.post("/v1/messages").mock( + return_value=httpx.Response(200, content=to_async_iter(get_response("tool_use_response.txt"))) + ) + + async with async_client.messages.stream( + max_tokens=1024, + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + model="claude-sonnet-4-5", + ) as stream: + message = await stream.get_final_message() + + assert message.content[0].type == "text" + assert message.content[0].text == "I'll check the current weather in Paris for you." + assert message.content[1].type == "tool_use" + assert message.content[1].input == {"location": "Paris"} + + @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) def test_stream_method_definition_in_sync(sync: bool) -> None: client: Anthropic | AsyncAnthropic = sync_client if sync else async_client