Skip to content

Commit

Permalink
Merge pull request #5 from 3coins/add-stream-for-sm
Browse files Browse the repository at this point in the history
Added stream function, added tests
  • Loading branch information
3coins authored Apr 2, 2024
2 parents 577f769 + c197816 commit 6370626
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 45 deletions.
93 changes: 55 additions & 38 deletions libs/aws/langchain_aws/llms/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Sagemaker InvokeEndpoint API."""
import io
import json
import re
from abc import abstractmethod
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra, root_validator

INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
Expand Down Expand Up @@ -304,6 +304,41 @@ def _llm_type(self) -> str:
"""Return type of llm."""
return "sagemaker_endpoint"

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
_model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {}

try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=self.content_handler.transform_input(prompt, _model_kwargs),
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
iterator = LineIterator(resp["Body"])

for line in iterator:
text = self.content_handler.transform_output(line)

if stop is not None:
text = enforce_stop_tokens(text, stop)

if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)

except Exception as e:
raise ValueError(f"Error raised by streaming inference endpoint: {e}")

def _call(
self,
prompt: str,
Expand Down Expand Up @@ -334,42 +369,24 @@ def _call(
accepts = self.content_handler.accepts

if self.streaming and run_manager:
try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=body,
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
iterator = LineIterator(resp["Body"])
current_completion: str = ""
for line in iterator:
resp = json.loads(line)
resp_output = resp.get("outputs")[0]
if stop is not None:
# Uses same approach as below
resp_output = enforce_stop_tokens(resp_output, stop)
current_completion += resp_output
run_manager.on_llm_new_token(resp_output)
return current_completion
except Exception as e:
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
else:
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
completion: str = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion

try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")

text = self.content_handler.transform_output(response["Body"])
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to the sagemaker endpoint.
text = enforce_stop_tokens(text, stop)
text = self.content_handler.transform_output(response["Body"])
if stop is not None:
text = enforce_stop_tokens(text, stop)

return text
return text
92 changes: 85 additions & 7 deletions libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,99 @@
import json
from typing import Dict
from unittest.mock import Mock

from langchain_aws.llms import SagemakerEndpoint
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler


class ContentHandler(LLMContentHandler):
class DefaultHandler(LLMContentHandler):
accepts = "application/json"
content_type = "application/json"

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
return b""
return prompt.encode()

def transform_output(self, output: bytes) -> str:
return ""
body = json.loads(output.decode())
return body[0]["generated_text"]


def test_sagemaker_endpoint_invoke() -> None:
client = Mock()
response = {
"ContentType": "application/json",
"Body": b'[{"generated_text": "SageMaker Endpoint"}]',
}
client.invoke_endpoint.return_value = response

def test_sagemaker_endpoint_name_param() -> None:
llm = SagemakerEndpoint(
endpoint_name="foo",
content_handler=ContentHandler(),
endpoint_name="my-endpoint",
region_name="us-west-2",
content_handler=DefaultHandler(),
model_kwargs={
"parameters": {
"max_new_tokens": 50,
}
},
client=client,
)

service_response = llm.invoke("What is Sagemaker endpoints?")

assert service_response == "SageMaker Endpoint"
client.invoke_endpoint.assert_called_once_with(
EndpointName="my-endpoint",
Body=b"What is Sagemaker endpoints?",
ContentType="application/json",
Accept="application/json",
)


def test_sagemaker_endpoint_stream() -> None:
class ContentHandler(LLMContentHandler):
accepts = "application/json"
content_type = "application/json"

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
body = json.dumps({"inputs": prompt, **model_kwargs})
return body.encode()

def transform_output(self, output: bytes) -> str:
body = json.loads(output)
return body.get("outputs")[0]

body = (
{"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}},
)

response = {"ContentType": "application/json", "Body": body}

client = Mock()
client.invoke_endpoint_with_response_stream.return_value = response

llm = SagemakerEndpoint(
endpoint_name="my-endpoint",
region_name="us-west-2",
content_handler=ContentHandler(),
client=client,
model_kwargs={"parameters": {"max_new_tokens": 50}},
)

expected_body = json.dumps(
{"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}}
).encode()

chunks = ["S", "age", "Maker"]
service_chunks = []

for chunk in llm.stream("What is Sagemaker endpoints?"):
service_chunks.append(chunk)

assert service_chunks == chunks
client.invoke_endpoint_with_response_stream.assert_called_once_with(
EndpointName="my-endpoint",
Body=expected_body,
ContentType="application/json",
)
assert llm.endpoint_name == "foo"

0 comments on commit 6370626

Please sign in to comment.