Skip to content

Commit

Permalink
Fixed lint.
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Apr 2, 2024
1 parent 08749b4 commit c197816
Showing 1 changed file with 25 additions and 49 deletions.
74 changes: 25 additions & 49 deletions libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import json
from typing import Dict
from unittest.mock import Mock

from langchain_aws.llms import SagemakerEndpoint
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler

from botocore.stub import Stubber, ANY

from unittest.mock import Mock

import json
import io


class DefaultHandler(LLMContentHandler):
accepts = "application/json"
Expand All @@ -19,23 +14,15 @@ def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
return prompt.encode()

def transform_output(self, output: bytes) -> str:
body = json.loads(output.read())
body = json.loads(output.decode())
return body[0]["generated_text"]

def create_mock_raw_stream(*data):
raw_stream = Mock()
def generator():
yield from data

raw_stream.stream = generator
return raw_stream

def test_sagemaker_endpoint_invoke() -> None:

client = Mock()
response = {
'ContentType': 'application/json',
'Body': io.StringIO('[{"generated_text": "SageMaker Endpoint"}]')
"ContentType": "application/json",
"Body": b'[{"generated_text": "SageMaker Endpoint"}]',
}
client.invoke_endpoint.return_value = response

Expand All @@ -48,16 +35,17 @@ def test_sagemaker_endpoint_invoke() -> None:
"max_new_tokens": 50,
}
},
client=client
client=client,
)

service_response = llm.invoke("What is Sagemaker endpoints?")

assert service_response == "SageMaker Endpoint"
client.invoke_endpoint.assert_called_once_with(
EndpointName='my-endpoint',
Body=b'What is Sagemaker endpoints?',
ContentType='application/json', Accept='application/json'
EndpointName="my-endpoint",
Body=b"What is Sagemaker endpoints?",
ContentType="application/json",
Accept="application/json",
)


Expand All @@ -67,27 +55,20 @@ class ContentHandler(LLMContentHandler):
content_type = "application/json"

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
body = {
'inputs': prompt,
**model_kwargs
}
return body
body = json.dumps({"inputs": prompt, **model_kwargs})
return body.encode()

def transform_output(self, output: bytes) -> str:
body = json.loads(output)
return body.get("outputs")[0]


body = (
{'PayloadPart': {'Bytes': b'{"outputs": ["S"]}\n'}},
{'PayloadPart': {'Bytes': b'{"outputs": ["age"]}\n'}},
{'PayloadPart': {'Bytes': b'{"outputs": ["Maker"]}\n'}}
{"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}},
)

response = {
'ContentType': 'application/json',
'Body': body
}
response = {"ContentType": "application/json", "Body": body}

client = Mock()
client.invoke_endpoint_with_response_stream.return_value = response
Expand All @@ -97,27 +78,22 @@ def transform_output(self, output: bytes) -> str:
region_name="us-west-2",
content_handler=ContentHandler(),
client=client,
model_kwargs={
"parameters": {
"max_new_tokens": 50
}
}
model_kwargs={"parameters": {"max_new_tokens": 50}},
)


chunks = ['S', 'age', 'Maker']
expected_body = json.dumps(
{"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}}
).encode()

chunks = ["S", "age", "Maker"]
service_chunks = []

for chunk in llm.stream("What is Sagemaker endpoints?"):
service_chunks.append(chunk)

assert service_chunks == chunks
client.invoke_endpoint_with_response_stream.assert_called_once_with(
EndpointName='my-endpoint',
Body={
'inputs': 'What is Sagemaker endpoints?',
'parameters': {'max_new_tokens': 50}
},
ContentType='application/json'
EndpointName="my-endpoint",
Body=expected_body,
ContentType="application/json",
)

0 comments on commit c197816

Please sign in to comment.