Skip to content

Commit

Permalink
propagate try after in rate limit errors on sdk (#2461)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Sep 13, 2024
1 parent 02c5a06 commit b50b7f4
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
7 changes: 6 additions & 1 deletion nucliadb_sdk/src/nucliadb_sdk/v2/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional


class ClientError(Exception):
pass

Expand All @@ -33,7 +36,9 @@ class AccountLimitError(ClientError):


class RateLimitError(ClientError):
pass
def __init__(self, message, try_after: Optional[float] = None):
super().__init__(message)
self.try_after = try_after


class ConflictError(ClientError):
Expand Down
9 changes: 8 additions & 1 deletion nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import inspect
import io
import warnings
from json import JSONDecodeError
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -345,7 +346,13 @@ def _check_response(self, response: httpx.Response):
f"Account limits exceeded error {response.status_code}: {response.text}"
)
elif response.status_code == 429:
raise exceptions.RateLimitError(response.text)
try_after: Optional[float] = None
try:
body = response.json()
try_after = body.get("detail", {}).get("try_after")
except JSONDecodeError:
pass
raise exceptions.RateLimitError(response.text, try_after=try_after)
elif response.status_code in (
409,
419,
Expand Down
42 changes: 42 additions & 0 deletions nucliadb_sdk/tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import httpx
import pytest

import nucliadb_sdk
Expand Down Expand Up @@ -103,3 +104,44 @@ def test_learning_config_endpoints(sdk: nucliadb_sdk.NucliaDB, kb):
sdk.get_models(kbid=kb.uuid)
sdk.get_model(kbid=kb.uuid, model_id="foo")
sdk.get_configuration_schema(kbid=kb.uuid)


def test_check_response():
sdk = nucliadb_sdk.NucliaDB(region="europe-1")

response = httpx.Response(200)
assert sdk._check_response(response) is response

response = httpx.Response(299)
assert sdk._check_response(response) is response

with pytest.raises(nucliadb_sdk.exceptions.UnknownError) as err:
sdk._check_response(httpx.Response(300, text="foo"))
assert str(err.value) == "Unknown error connecting to API: 300: foo"

for status_code in (401, 403):
with pytest.raises(nucliadb_sdk.exceptions.AuthError) as err:
sdk._check_response(httpx.Response(status_code, text="foo"))
assert str(err.value) == f"Auth error {status_code}: foo"

with pytest.raises(nucliadb_sdk.exceptions.AccountLimitError) as err:
sdk._check_response(httpx.Response(402, text="foo"))
assert str(err.value) == f"Account limits exceeded error {status_code}: foo"

with pytest.raises(nucliadb_sdk.exceptions.RateLimitError) as err:
sdk._check_response(httpx.Response(429, json={"detail": {"try_after": 1}}, text="Rate limit!"))
assert str(err.value) == f"Rate limit!"
assert err.value.try_after == 1

for status_code in (409, 419):
with pytest.raises(nucliadb_sdk.exceptions.ConflictError) as err:
sdk._check_response(httpx.Response(status_code, text="foo"))
assert str(err.value) == "foo"

with pytest.raises(nucliadb_sdk.exceptions.NotFoundError) as err:
sdk._check_response(
httpx.Response(
404, text="foo", request=httpx.Request(method="GET", url=httpx.URL("http://url"))
),
)
assert str(err.value) == "Resource not found at http://url: foo"

0 comments on commit b50b7f4

Please sign in to comment.