From 08c6f6095d1cac197490eb92cf5921e55aa99002 Mon Sep 17 00:00:00 2001 From: Jens Leinenbach <1786119+jleinenbach@users.noreply.github.com> Date: Tue, 21 May 2024 23:19:46 +0200 Subject: [PATCH] Update helpers.py Patch for https://github.com/jekalmin/extended_openai_conversation/issues/217 --- .../extended_openai_conversation/helpers.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index 7f9bc23..b6b972e 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -12,6 +12,7 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI import voluptuous as vol import yaml +import asyncio from homeassistant.components import ( automation, @@ -132,9 +133,25 @@ async def validate_authentication( organization: str = None, skip_authentication=False, ) -> None: + """ + Validate the authentication with OpenAI or Azure. + + Parameters: + hass (HomeAssistant): The Home Assistant instance. + api_key (str): The API key for OpenAI or Azure. + base_url (str): The base URL for the API. + api_version (str): The API version to use. + organization (str): The organization ID for the API (optional). + skip_authentication (bool): If True, skip the authentication check. + + Returns: + None + """ + # If skip_authentication is True, return immediately if skip_authentication: return + # Determine if the base URL is for Azure or OpenAI and create the appropriate client if is_azure(base_url): client = AsyncAzureOpenAI( api_key=api_key, @@ -147,7 +164,13 @@ async def validate_authentication( api_key=api_key, base_url=base_url, organization=organization ) - await client.models.list(timeout=10) + # Define an asynchronous function that lists models with a timeout using asyncio.to_thread + async def list_models_with_timeout(): + # Use asyncio.to_thread to run the blocking call in a separate thread + return await asyncio.to_thread(client.models.list, timeout=10) + + # Await the execution of the list_models_with_timeout function + await list_models_with_timeout() class FunctionExecutor(ABC):