diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index b0994730..27879f19 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -1,12 +1,12 @@ """Sagemaker InvokeEndpoint API.""" import io -import json import re from abc import abstractmethod from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, root_validator INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) @@ -304,6 +304,41 @@ def _llm_type(self) -> str: """Return type of llm.""" return "sagemaker_endpoint" + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + _model_kwargs = self.model_kwargs or {} + _model_kwargs = {**_model_kwargs, **kwargs} + _endpoint_kwargs = self.endpoint_kwargs or {} + + try: + resp = self.client.invoke_endpoint_with_response_stream( + EndpointName=self.endpoint_name, + Body=self.content_handler.transform_input(prompt, _model_kwargs), + ContentType=self.content_handler.content_type, + **_endpoint_kwargs, + ) + iterator = LineIterator(resp["Body"]) + + for line in iterator: + text = self.content_handler.transform_output(line) + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + if text: + chunk = GenerationChunk(text=text) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text) + + except Exception as e: + raise ValueError(f"Error raised by streaming inference endpoint: {e}") + def _call( self, prompt: str, @@ -334,42 +369,24 @@ def _call( accepts = self.content_handler.accepts if self.streaming and run_manager: - try: - resp = self.client.invoke_endpoint_with_response_stream( - EndpointName=self.endpoint_name, - Body=body, - ContentType=self.content_handler.content_type, - **_endpoint_kwargs, - ) - iterator = LineIterator(resp["Body"]) - current_completion: str = "" - for line in iterator: - resp = json.loads(line) - resp_output = resp.get("outputs")[0] - if stop is not None: - # Uses same approach as below - resp_output = enforce_stop_tokens(resp_output, stop) - current_completion += resp_output - run_manager.on_llm_new_token(resp_output) - return current_completion - except Exception as e: - raise ValueError(f"Error raised by streaming inference endpoint: {e}") - else: - try: - response = self.client.invoke_endpoint( - EndpointName=self.endpoint_name, - Body=body, - ContentType=content_type, - Accept=accepts, - **_endpoint_kwargs, - ) - except Exception as e: - raise ValueError(f"Error raised by inference endpoint: {e}") + completion: str = "" + for chunk in self._stream(prompt, stop, run_manager, **kwargs): + completion += chunk.text + return completion + + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") - text = self.content_handler.transform_output(response["Body"]) - if stop is not None: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to the sagemaker endpoint. - text = enforce_stop_tokens(text, stop) + text = self.content_handler.transform_output(response["Body"]) + if stop is not None: + text = enforce_stop_tokens(text, stop) - return text + return text diff --git a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py index cb5e32ef..bac022a0 100644 --- a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py +++ b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py @@ -1,21 +1,99 @@ +import json from typing import Dict +from unittest.mock import Mock from langchain_aws.llms import SagemakerEndpoint from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler -class ContentHandler(LLMContentHandler): +class DefaultHandler(LLMContentHandler): + accepts = "application/json" + content_type = "application/json" + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: - return b"" + return prompt.encode() def transform_output(self, output: bytes) -> str: - return "" + body = json.loads(output.decode()) + return body[0]["generated_text"] + +def test_sagemaker_endpoint_invoke() -> None: + client = Mock() + response = { + "ContentType": "application/json", + "Body": b'[{"generated_text": "SageMaker Endpoint"}]', + } + client.invoke_endpoint.return_value = response -def test_sagemaker_endpoint_name_param() -> None: llm = SagemakerEndpoint( - endpoint_name="foo", - content_handler=ContentHandler(), + endpoint_name="my-endpoint", + region_name="us-west-2", + content_handler=DefaultHandler(), + model_kwargs={ + "parameters": { + "max_new_tokens": 50, + } + }, + client=client, + ) + + service_response = llm.invoke("What is Sagemaker endpoints?") + + assert service_response == "SageMaker Endpoint" + client.invoke_endpoint.assert_called_once_with( + EndpointName="my-endpoint", + Body=b"What is Sagemaker endpoints?", + ContentType="application/json", + Accept="application/json", + ) + + +def test_sagemaker_endpoint_stream() -> None: + class ContentHandler(LLMContentHandler): + accepts = "application/json" + content_type = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + body = json.dumps({"inputs": prompt, **model_kwargs}) + return body.encode() + + def transform_output(self, output: bytes) -> str: + body = json.loads(output) + return body.get("outputs")[0] + + body = ( + {"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}}, + ) + + response = {"ContentType": "application/json", "Body": body} + + client = Mock() + client.invoke_endpoint_with_response_stream.return_value = response + + llm = SagemakerEndpoint( + endpoint_name="my-endpoint", region_name="us-west-2", + content_handler=ContentHandler(), + client=client, + model_kwargs={"parameters": {"max_new_tokens": 50}}, + ) + + expected_body = json.dumps( + {"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}} + ).encode() + + chunks = ["S", "age", "Maker"] + service_chunks = [] + + for chunk in llm.stream("What is Sagemaker endpoints?"): + service_chunks.append(chunk) + + assert service_chunks == chunks + client.invoke_endpoint_with_response_stream.assert_called_once_with( + EndpointName="my-endpoint", + Body=expected_body, + ContentType="application/json", ) - assert llm.endpoint_name == "foo"