diff --git a/unify/clients.py b/unify/clients.py index a204248..ea0de13 100644 --- a/unify/clients.py +++ b/unify/clients.py @@ -126,7 +126,9 @@ def generate( # noqa: WPS234, WPS211 user_prompt: Optional[str] = None, system_prompt: Optional[str] = None, messages: Optional[List[Dict[str, str]]] = None, - max_tokens: Optional[int] = None, + max_tokens: Optional[int] = 1024, + temperature: Optional[float] = 1.0, + stop: Optional[List[str]] = None, stream: bool = False, ) -> Union[Generator[str, None, None], str]: # noqa: DAR101, DAR201, DAR401 """Generate content using the Unify API. @@ -141,8 +143,15 @@ def generate( # noqa: WPS234, WPS211 messages (List[Dict[str, str]]): A list of dictionaries containing the conversation history. If provided, user_prompt must be None. - max_tokens (Optional[int]): The max number of output tokens, defaults - to the provider's default max_tokens when the value is None. + max_tokens (Optional[int]): The max number of output tokens. + Defaults to the provider's default max_tokens when the value is None. + + temperature (Optional[float]): What sampling temperature to use, between 0 and 2. + Higher values like 0.8 will make the output more random, + while lower values like 0.2 will make it more focused and deterministic. + Defaults to the provider's default max_tokens when the value is None. + + stop (Optional[List[str]]): Up to 4 sequences where the API will stop generating further tokens. stream (bool): If True, generates content as a stream. If False, generates content as a single response. @@ -159,7 +168,6 @@ def generate( # noqa: WPS234, WPS211 contents = [] if system_prompt: contents.append({"role": "system", "content": system_prompt}) - if user_prompt: contents.append({"role": "user", "content": user_prompt}) elif messages: @@ -168,8 +176,14 @@ def generate( # noqa: WPS234, WPS211 raise UnifyError("You must provider either the user_prompt or messages!") if stream: - return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens) - return self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens) + return self._generate_stream(contents, self._endpoint, + max_tokens=max_tokens, + temperature=temperature, + stop=stop) + return self._generate_non_stream(contents, self._endpoint, + max_tokens=max_tokens, + temperature=temperature, + stop=stop) def get_credit_balance(self) -> float: # noqa: DAR201, DAR401 @@ -201,13 +215,17 @@ def _generate_stream( self, messages: List[Dict[str, str]], endpoint: str, - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 1024, + temperature: Optional[float] = 1.0, + stop: Optional[List[str]] = None, ) -> Generator[str, None, None]: try: chat_completion = self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] max_tokens=max_tokens, + temperature=temperature, + stop=stop, stream=True, ) for chunk in chat_completion: @@ -222,13 +240,17 @@ def _generate_non_stream( self, messages: List[Dict[str, str]], endpoint: str, - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 1024, + temperature: Optional[float] = 1.0, + stop: Optional[List[str]] = None, ) -> str: try: chat_completion = self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] max_tokens=max_tokens, + temperature=temperature, + stop=stop, stream=False, ) self.set_provider( @@ -388,6 +410,8 @@ async def generate( # noqa: WPS234, WPS211 system_prompt: Optional[str] = None, messages: Optional[List[Dict[str, str]]] = None, max_tokens: Optional[int] = None, + temperature: Optional[float] = 1.0, + stop: Optional[List[str]] = None, stream: bool = False, ) -> Union[AsyncGenerator[str, None], str]: # noqa: DAR101, DAR201, DAR401 """Generate content asynchronously using the Unify API. @@ -405,6 +429,13 @@ async def generate( # noqa: WPS234, WPS211 max_tokens (Optional[int]): The max number of output tokens, defaults to the provider's default max_tokens when the value is None. + temperature (Optional[float]): What sampling temperature to use, between 0 and 2. + Higher values like 0.8 will make the output more random, + while lower values like 0.2 will make it more focused and deterministic. + Defaults to the provider's default max_tokens when the value is None. + + stop (Optional[List[str]]): Up to 4 sequences where the API will stop generating further tokens. + stream (bool): If True, generates content as a stream. If False, generates content as a single response. Defaults to False. @@ -429,20 +460,24 @@ async def generate( # noqa: WPS234, WPS211 raise UnifyError("You must provide either the user_prompt or messages!") if stream: - return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens) - return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens) + return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature) + return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature) async def _generate_stream( self, messages: List[Dict[str, str]], endpoint: str, max_tokens: Optional[int] = None, + temperature: Optional[float] = 1.0, + stop: Optional[List[str]] = None, ) -> AsyncGenerator[str, None]: try: async_stream = await self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] max_tokens=max_tokens, + temperature=temperature, + stop=stop, stream=True, ) async for chunk in async_stream: # type: ignore[union-attr] @@ -456,12 +491,16 @@ async def _generate_non_stream( messages: List[Dict[str, str]], endpoint: str, max_tokens: Optional[int] = None, + temperature: Optional[float] = 1.0, + stop: Optional[List[str]] = None, ) -> str: try: async_response = await self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] max_tokens=max_tokens, + temperature=temperature, + stop=stop, stream=False, ) self.set_provider(async_response.model.split("@")[-1]) # type: ignore