diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/exceptions.py b/nucliadb_sdk/src/nucliadb_sdk/v2/exceptions.py index 88a3f7bfc8..0cc13d4f1c 100644 --- a/nucliadb_sdk/src/nucliadb_sdk/v2/exceptions.py +++ b/nucliadb_sdk/src/nucliadb_sdk/v2/exceptions.py @@ -16,6 +16,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional + + class ClientError(Exception): pass @@ -33,7 +36,9 @@ class AccountLimitError(ClientError): class RateLimitError(ClientError): - pass + def __init__(self, message, try_after: Optional[float] = None): + super().__init__(message) + self.try_after = try_after class ConflictError(ClientError): diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py index dd2f3a4ce0..253a5da140 100644 --- a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py +++ b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py @@ -21,6 +21,7 @@ import inspect import io import warnings +from json import JSONDecodeError from typing import ( Any, AsyncGenerator, @@ -345,7 +346,13 @@ def _check_response(self, response: httpx.Response): f"Account limits exceeded error {response.status_code}: {response.text}" ) elif response.status_code == 429: - raise exceptions.RateLimitError(response.text) + try_after: Optional[float] = None + try: + body = response.json() + try_after = body.get("detail", {}).get("try_after") + except JSONDecodeError: + pass + raise exceptions.RateLimitError(response.text, try_after=try_after) elif response.status_code in ( 409, 419, diff --git a/nucliadb_sdk/tests/test_sdk.py b/nucliadb_sdk/tests/test_sdk.py index 0f29c3c3c1..f338788aae 100644 --- a/nucliadb_sdk/tests/test_sdk.py +++ b/nucliadb_sdk/tests/test_sdk.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # +import httpx import pytest import nucliadb_sdk @@ -103,3 +104,44 @@ def test_learning_config_endpoints(sdk: nucliadb_sdk.NucliaDB, kb): sdk.get_models(kbid=kb.uuid) sdk.get_model(kbid=kb.uuid, model_id="foo") sdk.get_configuration_schema(kbid=kb.uuid) + + +def test_check_response(): + sdk = nucliadb_sdk.NucliaDB(region="europe-1") + + response = httpx.Response(200) + assert sdk._check_response(response) is response + + response = httpx.Response(299) + assert sdk._check_response(response) is response + + with pytest.raises(nucliadb_sdk.exceptions.UnknownError) as err: + sdk._check_response(httpx.Response(300, text="foo")) + assert str(err.value) == "Unknown error connecting to API: 300: foo" + + for status_code in (401, 403): + with pytest.raises(nucliadb_sdk.exceptions.AuthError) as err: + sdk._check_response(httpx.Response(status_code, text="foo")) + assert str(err.value) == f"Auth error {status_code}: foo" + + with pytest.raises(nucliadb_sdk.exceptions.AccountLimitError) as err: + sdk._check_response(httpx.Response(402, text="foo")) + assert str(err.value) == f"Account limits exceeded error {status_code}: foo" + + with pytest.raises(nucliadb_sdk.exceptions.RateLimitError) as err: + sdk._check_response(httpx.Response(429, json={"detail": {"try_after": 1}}, text="Rate limit!")) + assert str(err.value) == f"Rate limit!" + assert err.value.try_after == 1 + + for status_code in (409, 419): + with pytest.raises(nucliadb_sdk.exceptions.ConflictError) as err: + sdk._check_response(httpx.Response(status_code, text="foo")) + assert str(err.value) == "foo" + + with pytest.raises(nucliadb_sdk.exceptions.NotFoundError) as err: + sdk._check_response( + httpx.Response( + 404, text="foo", request=httpx.Request(method="GET", url=httpx.URL("http://url")) + ), + ) + assert str(err.value) == "Resource not found at http://url: foo"