Skip to content

Commit

Permalink
Merge with main id 77ec531
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-heon committed Oct 11, 2024
1 parent 121c7df commit dd91d3f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
class Mode(Enum):
CHAIN = "chain"


def get_guardrails() -> dict:
if "BEDROCK_GUARDRAILS_ID" in os.environ:
logger.debug("Guardrails ID found in environment variables.")
Expand Down Expand Up @@ -593,12 +594,8 @@ def format(self, **kwargs: Any) -> str:

# Register the adapters
registry.register(r"^bedrock.ai21.jamba*", BedrockChatAdapter)
registry.register(
r"^bedrock.ai21.j2*", BedrockChatNoStreamingNoSystemPromptAdapter
)
registry.register(
r"^bedrock\.cohere\.command-(text|light-text).*", BedrockChatNoSystemPromptAdapter
)
registry.register(r"^bedrock.ai21.j2*", BedrockChatNoStreamingNoSystemPromptAdapter)
registry.register(r"^bedrock\.cohere\.command-(text|light-text).*", BedrockChatNoSystemPromptAdapter)
registry.register(r"^bedrock\.cohere\.command-r.*", BedrockChatAdapter)
registry.register(r"^bedrock.anthropic.claude*", BedrockChatAdapter)
registry.register(r"^bedrock.meta.llama*", BedrockChatAdapter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,24 @@ def get_condense_question_prompt(self):
def get_llm(self, model_kwargs={}, extra={}):
bedrock = genai_core.clients.get_bedrock_client()
params = {}

# Collect temperature, topP, and maxTokens if available
temperature = model_kwargs.get("temperature")
top_p = model_kwargs.get("topP")
max_tokens = model_kwargs.get("maxTokens")

if temperature:
params["temperature"] = temperature
if top_p:
params["top_p"] = top_p
if max_tokens:
params["max_tokens"] = max_tokens

# Fetch guardrails if any
guardrails = get_guardrails()
if len(guardrails.keys()) > 0:
params["guardrails"] = guardrails

# Log all parameters in a single log entry, including full guardrails
logger.info(
f"Creating LLM chain for model {self.model_id}",
Expand All @@ -121,7 +121,7 @@ def get_llm(self, model_kwargs={}, extra={}):
max_tokens=max_tokens,
guardrails=guardrails,
)

# Return ChatBedrockConverse instance with the collected params
return ChatBedrockConverse(
client=bedrock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from adapters.shared.prompts.system_prompts import prompts # Ajout de l'importation



def test_registry():
with pytest.raises(ValueError, match="not found"):
registry.get_adapter("invalid")
Expand Down Expand Up @@ -37,23 +36,23 @@ def test_chat_adapter(mocker):
result = model.get_qa_prompt().format(
input="input", context="context", chat_history=[HumanMessage(content="history")]
)
# Mise à jour de l'assertion pour correspondre au prompt anglais dans system_prompts.py

assert "Use the following pieces of context" in result
assert "Human: history" in result
assert "Human: input" in result

result = model.get_prompt().format(
input="input", chat_history=[HumanMessage(content="history")]
)
# Mise à jour de l'assertion pour correspondre au prompt anglais dans system_prompts.py

assert "The following is a friendly conversation" in result
assert "Human: history" in result
assert "Human: input" in result

result = model.get_condense_question_prompt().format(
input="input", chat_history=[HumanMessage(content="history")]
)
# Mise à jour de l'assertion pour correspondre au prompt anglais dans system_prompts.py

assert "Given the conversation inside the tags" in result
assert "Human: history" in result
assert "Human: input" in result
Expand Down Expand Up @@ -119,4 +118,3 @@ def test_chat_without_system_adapter(mocker):
model="model",
callbacks=ANY,
)

0 comments on commit dd91d3f

Please sign in to comment.