From 51c454c460d3a81d818f5fdee76a08f45482351a Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 1 Apr 2024 14:09:25 -0700 Subject: [PATCH 1/3] Added streaming function to llm. --- .../langchain_aws/llms/sagemaker_endpoint.py | 93 +++++++++++-------- 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index b0994730..27879f19 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -1,12 +1,12 @@ """Sagemaker InvokeEndpoint API.""" import io -import json import re from abc import abstractmethod from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, root_validator INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) @@ -304,6 +304,41 @@ def _llm_type(self) -> str: """Return type of llm.""" return "sagemaker_endpoint" + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + _model_kwargs = self.model_kwargs or {} + _model_kwargs = {**_model_kwargs, **kwargs} + _endpoint_kwargs = self.endpoint_kwargs or {} + + try: + resp = self.client.invoke_endpoint_with_response_stream( + EndpointName=self.endpoint_name, + Body=self.content_handler.transform_input(prompt, _model_kwargs), + ContentType=self.content_handler.content_type, + **_endpoint_kwargs, + ) + iterator = LineIterator(resp["Body"]) + + for line in iterator: + text = self.content_handler.transform_output(line) + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + if text: + chunk = GenerationChunk(text=text) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text) + + except Exception as e: + raise ValueError(f"Error raised by streaming inference endpoint: {e}") + def _call( self, prompt: str, @@ -334,42 +369,24 @@ def _call( accepts = self.content_handler.accepts if self.streaming and run_manager: - try: - resp = self.client.invoke_endpoint_with_response_stream( - EndpointName=self.endpoint_name, - Body=body, - ContentType=self.content_handler.content_type, - **_endpoint_kwargs, - ) - iterator = LineIterator(resp["Body"]) - current_completion: str = "" - for line in iterator: - resp = json.loads(line) - resp_output = resp.get("outputs")[0] - if stop is not None: - # Uses same approach as below - resp_output = enforce_stop_tokens(resp_output, stop) - current_completion += resp_output - run_manager.on_llm_new_token(resp_output) - return current_completion - except Exception as e: - raise ValueError(f"Error raised by streaming inference endpoint: {e}") - else: - try: - response = self.client.invoke_endpoint( - EndpointName=self.endpoint_name, - Body=body, - ContentType=content_type, - Accept=accepts, - **_endpoint_kwargs, - ) - except Exception as e: - raise ValueError(f"Error raised by inference endpoint: {e}") + completion: str = "" + for chunk in self._stream(prompt, stop, run_manager, **kwargs): + completion += chunk.text + return completion + + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") - text = self.content_handler.transform_output(response["Body"]) - if stop is not None: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to the sagemaker endpoint. - text = enforce_stop_tokens(text, stop) + text = self.content_handler.transform_output(response["Body"]) + if stop is not None: + text = enforce_stop_tokens(text, stop) - return text + return text From 08749b46885e5591e1349b72f55170e2f985819a Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 2 Apr 2024 10:27:43 -0700 Subject: [PATCH 2/3] Added streaming, integration tests. --- .../llms/test_sagemaker_endpoint.py | 116 ++++++++++++++++-- 1 file changed, 109 insertions(+), 7 deletions(-) diff --git a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py index cb5e32ef..8142bb2b 100644 --- a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py +++ b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py @@ -3,19 +3,121 @@ from langchain_aws.llms import SagemakerEndpoint from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler +from botocore.stub import Stubber, ANY + +from unittest.mock import Mock + +import json +import io + + +class DefaultHandler(LLMContentHandler): + accepts = "application/json" + content_type = "application/json" -class ContentHandler(LLMContentHandler): def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: - return b"" + return prompt.encode() def transform_output(self, output: bytes) -> str: - return "" + body = json.loads(output.read()) + return body[0]["generated_text"] + +def create_mock_raw_stream(*data): + raw_stream = Mock() + def generator(): + yield from data + raw_stream.stream = generator + return raw_stream + +def test_sagemaker_endpoint_invoke() -> None: + + client = Mock() + response = { + 'ContentType': 'application/json', + 'Body': io.StringIO('[{"generated_text": "SageMaker Endpoint"}]') + } + client.invoke_endpoint.return_value = response -def test_sagemaker_endpoint_name_param() -> None: llm = SagemakerEndpoint( - endpoint_name="foo", - content_handler=ContentHandler(), + endpoint_name="my-endpoint", + region_name="us-west-2", + content_handler=DefaultHandler(), + model_kwargs={ + "parameters": { + "max_new_tokens": 50, + } + }, + client=client + ) + + service_response = llm.invoke("What is Sagemaker endpoints?") + + assert service_response == "SageMaker Endpoint" + client.invoke_endpoint.assert_called_once_with( + EndpointName='my-endpoint', + Body=b'What is Sagemaker endpoints?', + ContentType='application/json', Accept='application/json' + ) + + +def test_sagemaker_endpoint_stream() -> None: + class ContentHandler(LLMContentHandler): + accepts = "application/json" + content_type = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + body = { + 'inputs': prompt, + **model_kwargs + } + return body + + def transform_output(self, output: bytes) -> str: + body = json.loads(output) + return body.get("outputs")[0] + + + body = ( + {'PayloadPart': {'Bytes': b'{"outputs": ["S"]}\n'}}, + {'PayloadPart': {'Bytes': b'{"outputs": ["age"]}\n'}}, + {'PayloadPart': {'Bytes': b'{"outputs": ["Maker"]}\n'}} + ) + + response = { + 'ContentType': 'application/json', + 'Body': body + } + + client = Mock() + client.invoke_endpoint_with_response_stream.return_value = response + + llm = SagemakerEndpoint( + endpoint_name="my-endpoint", region_name="us-west-2", + content_handler=ContentHandler(), + client=client, + model_kwargs={ + "parameters": { + "max_new_tokens": 50 + } + } + ) + + + chunks = ['S', 'age', 'Maker'] + service_chunks = [] + + for chunk in llm.stream("What is Sagemaker endpoints?"): + service_chunks.append(chunk) + + assert service_chunks == chunks + client.invoke_endpoint_with_response_stream.assert_called_once_with( + EndpointName='my-endpoint', + Body={ + 'inputs': 'What is Sagemaker endpoints?', + 'parameters': {'max_new_tokens': 50} + }, + ContentType='application/json' ) - assert llm.endpoint_name == "foo" + From c197816c74aed40850b55808b95348ad568e245f Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 2 Apr 2024 10:42:01 -0700 Subject: [PATCH 3/3] Fixed lint. --- .../llms/test_sagemaker_endpoint.py | 74 +++++++------------ 1 file changed, 25 insertions(+), 49 deletions(-) diff --git a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py index 8142bb2b..bac022a0 100644 --- a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py +++ b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py @@ -1,15 +1,10 @@ +import json from typing import Dict +from unittest.mock import Mock from langchain_aws.llms import SagemakerEndpoint from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler -from botocore.stub import Stubber, ANY - -from unittest.mock import Mock - -import json -import io - class DefaultHandler(LLMContentHandler): accepts = "application/json" @@ -19,23 +14,15 @@ def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: return prompt.encode() def transform_output(self, output: bytes) -> str: - body = json.loads(output.read()) + body = json.loads(output.decode()) return body[0]["generated_text"] - -def create_mock_raw_stream(*data): - raw_stream = Mock() - def generator(): - yield from data - raw_stream.stream = generator - return raw_stream def test_sagemaker_endpoint_invoke() -> None: - client = Mock() response = { - 'ContentType': 'application/json', - 'Body': io.StringIO('[{"generated_text": "SageMaker Endpoint"}]') + "ContentType": "application/json", + "Body": b'[{"generated_text": "SageMaker Endpoint"}]', } client.invoke_endpoint.return_value = response @@ -48,16 +35,17 @@ def test_sagemaker_endpoint_invoke() -> None: "max_new_tokens": 50, } }, - client=client + client=client, ) service_response = llm.invoke("What is Sagemaker endpoints?") assert service_response == "SageMaker Endpoint" client.invoke_endpoint.assert_called_once_with( - EndpointName='my-endpoint', - Body=b'What is Sagemaker endpoints?', - ContentType='application/json', Accept='application/json' + EndpointName="my-endpoint", + Body=b"What is Sagemaker endpoints?", + ContentType="application/json", + Accept="application/json", ) @@ -67,27 +55,20 @@ class ContentHandler(LLMContentHandler): content_type = "application/json" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: - body = { - 'inputs': prompt, - **model_kwargs - } - return body + body = json.dumps({"inputs": prompt, **model_kwargs}) + return body.encode() def transform_output(self, output: bytes) -> str: body = json.loads(output) return body.get("outputs")[0] - body = ( - {'PayloadPart': {'Bytes': b'{"outputs": ["S"]}\n'}}, - {'PayloadPart': {'Bytes': b'{"outputs": ["age"]}\n'}}, - {'PayloadPart': {'Bytes': b'{"outputs": ["Maker"]}\n'}} + {"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}}, ) - response = { - 'ContentType': 'application/json', - 'Body': body - } + response = {"ContentType": "application/json", "Body": body} client = Mock() client.invoke_endpoint_with_response_stream.return_value = response @@ -97,15 +78,14 @@ def transform_output(self, output: bytes) -> str: region_name="us-west-2", content_handler=ContentHandler(), client=client, - model_kwargs={ - "parameters": { - "max_new_tokens": 50 - } - } + model_kwargs={"parameters": {"max_new_tokens": 50}}, ) - - chunks = ['S', 'age', 'Maker'] + expected_body = json.dumps( + {"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}} + ).encode() + + chunks = ["S", "age", "Maker"] service_chunks = [] for chunk in llm.stream("What is Sagemaker endpoints?"): @@ -113,11 +93,7 @@ def transform_output(self, output: bytes) -> str: assert service_chunks == chunks client.invoke_endpoint_with_response_stream.assert_called_once_with( - EndpointName='my-endpoint', - Body={ - 'inputs': 'What is Sagemaker endpoints?', - 'parameters': {'max_new_tokens': 50} - }, - ContentType='application/json' + EndpointName="my-endpoint", + Body=expected_body, + ContentType="application/json", ) -