Skip to content

Commit

Permalink
Merge pull request #26 from brnaba-aws/main
Browse files Browse the repository at this point in the history
fix guardrail attributes when invoking bedrock
  • Loading branch information
3coins authored May 3, 2024
2 parents 6d85e77 + ce20b96 commit 2331bf6
Showing 1 changed file with 23 additions and 31 deletions.
54 changes: 23 additions & 31 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,9 @@ class BedrockBase(BaseLanguageModel, ABC):
}

guardrails: Optional[Mapping[str, Any]] = {
"id": None,
"version": None,
"trace": False,
"trace": None,
"guardrailIdentifier": None,
"guardrailVersion": None,
}
"""
An optional dictionary to configure guardrails for Bedrock.
Expand Down Expand Up @@ -446,7 +446,9 @@ def _identifying_params(self) -> Dict[str, Any]:
"model_id": self.model_id,
"provider": self._get_provider(),
"stream": self.streaming,
"guardrails": self.guardrails,
"trace": self.guardrails.get("trace"), # type: ignore[union-attr]
"guardrailIdentifier": self.guardrails.get("guardrailIdentifier", None), # type: ignore[union-attr]
"guardrailVersion": self.guardrails.get("guardrailVersion", None), # type: ignore[union-attr]
**_model_kwargs,
}

Expand Down Expand Up @@ -480,32 +482,16 @@ def _guardrails_enabled(self) -> bool:
try:
return (
isinstance(self.guardrails, dict)
and bool(self.guardrails["id"])
and bool(self.guardrails["version"])
and bool(self.guardrails["guardrailIdentifier"])
and bool(self.guardrails["guardrailVersion"])
)

except KeyError as e:
raise TypeError(
"Guardrails must be a dictionary with 'id' and 'version' keys."
"Guardrails must be a dictionary with 'guardrailIdentifier' \
and 'guardrailVersion' keys."
) from e

def _get_guardrails_canonical(self) -> Dict[str, Any]:
"""
The canonical way to pass in guardrails to the bedrock service
adheres to the following format:
"amazon-bedrock-guardrailDetails": {
"guardrailId": "string",
"guardrailVersion": "string"
}
"""
return {
"amazon-bedrock-guardrailDetails": {
"guardrailId": self.guardrails.get("id"), # type: ignore[union-attr]
"guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr]
}
}

def _prepare_input_and_invoke(
self,
prompt: Optional[str] = None,
Expand All @@ -519,8 +505,7 @@ def _prepare_input_and_invoke(

provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())

input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
Expand All @@ -540,7 +525,12 @@ def _prepare_input_and_invoke(
}

if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailVersion", ""
)
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"

Expand Down Expand Up @@ -628,9 +618,6 @@ def _prepare_input_and_invoke_stream(

params = {**_model_kwargs, **kwargs}

if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())

input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
prompt=prompt,
Expand All @@ -648,7 +635,12 @@ def _prepare_input_and_invoke_stream(
}

if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailVersion", ""
)
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"

Expand Down

0 comments on commit 2331bf6

Please sign in to comment.