diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py index c6fc78e..d2b98b5 100644 --- a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py @@ -524,13 +524,15 @@ def __init__( stream_done_callback=None, ): super().__init__(response) - self._stream_done_callback = stream_done_callback self._accumulating_body = {"generation": ""} + self.last_chunk = None def __iter__(self): for event in self.__wrapped__: + # Process the event self._process_event(event) + # Yield the original event immediately yield event def _process_event(self, event): @@ -545,7 +547,11 @@ def _process_event(self, event): self._stream_done_callback(decoded_chunk) return if "generation" in decoded_chunk: - self._accumulating_body["generation"] += decoded_chunk.get("generation") + generation = decoded_chunk.get("generation") + if self.last_chunk == generation: + return + self.last_chunk = generation + self._accumulating_body["generation"] += generation if type == "message_start": self._accumulating_body = decoded_chunk.get("message") @@ -554,9 +560,11 @@ def _process_event(self, event): decoded_chunk.get("content_block") ) elif type == "content_block_delta": - self._accumulating_body["content"][-1]["text"] += decoded_chunk.get( - "delta" - ).get("text") + text = decoded_chunk.get("delta").get("text") + if self.last_chunk == text: + return + self.last_chunk = text + self._accumulating_body["content"][-1]["text"] += text elif self.has_finished(type, decoded_chunk): self._accumulating_body["invocation_metrics"] = decoded_chunk.get( diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index c10bada..15bd66f 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.8.19" +__version__ = "3.8.20" diff --git a/src/tests/aws_bedrock/conftest.py b/src/tests/aws_bedrock/conftest.py index c049192..5bfc806 100644 --- a/src/tests/aws_bedrock/conftest.py +++ b/src/tests/aws_bedrock/conftest.py @@ -14,7 +14,9 @@ @pytest.fixture(autouse=True) def environment(): if not os.getenv("AWS_ACCESS_KEY_ID"): - os.environ["AWS_ACCESS_KEY_ID"] = "test_api_key" + os.environ["AWS_ACCESS_KEY_ID"] = "test_aws_access_key_id" + if not os.getenv("AWS_SECRET_ACCESS_KEY"): + os.environ["AWS_SECRET_ACCESS_KEY"] = "test_aws_secret_access_key" @pytest.fixture