diff --git a/libs/cohere/langchain_cohere/rerank.py b/libs/cohere/langchain_cohere/rerank.py index 9c832cc..989f476 100644 --- a/libs/cohere/langchain_cohere/rerank.py +++ b/libs/cohere/langchain_cohere/rerank.py @@ -24,6 +24,8 @@ class CohereRerank(BaseDocumentCompressor): COHERE_API_KEY.""" user_agent: str = "langchain:partner" """Identifier for the application making the request.""" + base_url: Optional[str] = None + """Override the default Cohere API URL.""" class Config: """Configuration for this pydantic object.""" @@ -34,13 +36,18 @@ class Config: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - if not values.get("client"): - cohere_api_key = get_from_dict_or_env( - values, "cohere_api_key", "COHERE_API_KEY" - ) - client_name = values["user_agent"] - values["client"] = cohere.Client(cohere_api_key, client_name=client_name) - return values + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) + request_timeout = values.get("request_timeout") + + client_name = values["user_agent"] + values["client"] = cohere.Client( + cohere_api_key, + timeout=request_timeout, + client_name=client_name, + base_url=values["base_url"], + ) @root_validator() def validate_model_specified(cls, values: Dict) -> Dict: