Skip to content

Commit

Permalink
Refactor fastapi example (#1666)
Browse files Browse the repository at this point in the history
* refactor: use schemas and routers to reduce duplicated code in fastapi example

* Warm tip for module not found
  • Loading branch information
waketzheng authored Jul 16, 2024
1 parent 717a92a commit 4c14800
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 164 deletions.
1 change: 1 addition & 0 deletions examples/fastapi/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
db.sqlite3
145 changes: 96 additions & 49 deletions examples/fastapi/_tests.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,118 @@
# mypy: no-disallow-untyped-decorators
# pylint: disable=E0611,E0401
import datetime
from typing import AsyncGenerator
import os
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator, Tuple

import pytest
import pytz
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
from main import app, app_east
from models import Users

from tortoise.contrib.test import MEMORY_SQLITE
from tortoise.fields.data import JSON_LOADS

os.environ["DB_URL"] = MEMORY_SQLITE
try:
from main import app
from main_custom_timezone import app as app_east
from models import Users
from schemas import User_Pydantic
except ImportError:
if (cwd := Path.cwd()) == (parent := Path(__file__).parent):
dirpath = "."
else:
dirpath = str(parent.relative_to(cwd))
print(f"You may need to explicitly declare python path:\n\nexport PYTHONPATH={dirpath}\n")
raise

ClientManagerType = AsyncGenerator[AsyncClient, None]


@pytest.fixture(scope="module")
def anyio_backend() -> str:
return "asyncio"


@pytest.fixture(scope="module")
async def client() -> AsyncGenerator[AsyncClient, None]:
@asynccontextmanager
async def client_manager(app, base_url="http://test", **kw) -> ClientManagerType:
async with LifespanManager(app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
async with AsyncClient(transport=transport, base_url=base_url, **kw) as c:
yield c


@pytest.mark.anyio
async def test_create_user(client: AsyncClient) -> None: # nosec
response = await client.post("/users", json={"username": "admin"})
assert response.status_code == 200, response.text
data = response.json()
assert data["username"] == "admin"
assert "id" in data
user_id = data["id"]

user_obj = await Users.get(id=user_id)
assert user_obj.id == user_id


@pytest.fixture(scope="module")
async def client_east() -> AsyncGenerator[AsyncClient, None]:
async with LifespanManager(app_east):
transport = ASGITransport(app=app_east)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c


@pytest.mark.anyio
async def test_create_user_east(client_east: AsyncClient) -> None: # nosec
response = await client_east.post("/users_east", json={"username": "admin"})
assert response.status_code == 200, response.text
data = response.json()
assert data["username"] == "admin"
assert "id" in data
user_id = data["id"]
async def client() -> ClientManagerType:
async with client_manager(app) as c:
yield c

user_obj = await Users.get(id=user_id)
assert user_obj.id == user_id

# Verify that the time zone is East 8.
created_at = user_obj.created_at

# Asia/Shanghai timezone
asia_tz = pytz.timezone("Asia/Shanghai")
asia_now = datetime.datetime.now(pytz.utc).astimezone(asia_tz)
assert created_at.hour - asia_now.hour == 0

# UTC timezone
utc_tz = pytz.timezone("UTC")
utc_now = datetime.datetime.now(pytz.utc).astimezone(utc_tz)
assert created_at.hour - utc_now.hour == 8
@pytest.fixture(scope="module")
async def client_east() -> ClientManagerType:
async with client_manager(app_east) as c:
yield c


class UserTester:
async def create_user(self, async_client: AsyncClient) -> Users:
response = await async_client.post("/users", json={"username": "admin"})
assert response.status_code == 200, response.text
data = response.json()
assert data["username"] == "admin"
assert "id" in data
user_id = data["id"]

user_obj = await Users.get(id=user_id)
assert user_obj.id == user_id
return user_obj

async def user_list(self, async_client: AsyncClient) -> Tuple[datetime, Users, User_Pydantic]:
utc_now = datetime.now(pytz.utc)
user_obj = await Users.create(username="test")
response = await async_client.get("/users")
assert response.status_code == 200, response.text
data = response.json()
assert isinstance(data, list)
item = await User_Pydantic.from_tortoise_orm(user_obj)
assert JSON_LOADS(item.model_dump_json()) in data
return utc_now, user_obj, item


class TestUser(UserTester):
@pytest.mark.anyio
async def test_create_user(self, client: AsyncClient) -> None: # nosec
await self.create_user(client)

@pytest.mark.anyio
async def test_user_list(self, client: AsyncClient) -> None: # nosec
await self.user_list(client)


class TestUserEast(UserTester):
timezone = "Asia/Shanghai"
delta_hours = 8

@pytest.mark.anyio
async def test_create_user_east(self, client_east: AsyncClient) -> None: # nosec
user_obj = await self.create_user(client_east)
created_at = user_obj.created_at

# Verify time zone
asia_tz = pytz.timezone(self.timezone)
asia_now = datetime.now(pytz.utc).astimezone(asia_tz)
assert created_at.hour - asia_now.hour == 0

# UTC timezone
utc_tz = pytz.timezone("UTC")
utc_now = datetime.now(pytz.utc).astimezone(utc_tz)
assert (created_at.hour - utc_now.hour) in [self.delta_hours, self.delta_hours - 24]

@pytest.mark.anyio
async def test_user_list(self, client_east: AsyncClient) -> None: # nosec
time, user_obj, item = await self.user_list(client_east)
created_at = user_obj.created_at
assert (created_at.hour - time.hour) in [self.delta_hours, self.delta_hours - 24]
assert item.model_dump()["created_at"].hour == created_at.hour
12 changes: 12 additions & 0 deletions examples/fastapi/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
from functools import partial

from tortoise.contrib.fastapi import RegisterTortoise

register_orm = partial(
RegisterTortoise,
db_url=os.getenv("DB_URL", "sqlite://db.sqlite3"),
modules={"models": ["models"]},
generate_schemas=True,
add_exception_handlers=True,
)
115 changes: 6 additions & 109 deletions examples/fastapi/main.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,21 @@
# pylint: disable=E0611,E0401
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator, List
from typing import AsyncGenerator

from fastapi import FastAPI, HTTPException
from models import Users
from pydantic import BaseModel

from tortoise.contrib.fastapi import RegisterTortoise
from tortoise.contrib.pydantic import PydanticModel

if TYPE_CHECKING: # pragma: nocoverage

class UserIn_Pydantic(Users, PydanticModel): # type:ignore[misc]
pass

class User_Pydantic(Users, PydanticModel): # type:ignore[misc]
pass

else:
from models import User_Pydantic, UserIn_Pydantic
from config import register_orm
from fastapi import FastAPI
from routers import router as users_router


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# app startup
async with RegisterTortoise(
app,
db_url="sqlite://:memory:",
modules={"models": ["models"]},
generate_schemas=True,
add_exception_handlers=True,
):
# db connected
yield
# app teardown
# db connections closed


@asynccontextmanager
async def lifespan_east(app: FastAPI) -> AsyncGenerator[None, None]:
# app startup
async with RegisterTortoise(
app,
db_url="sqlite://:memory:",
modules={"models": ["models"]},
generate_schemas=True,
add_exception_handlers=True,
use_tz=False,
timezone="Asia/Shanghai",
):
async with register_orm(app):
# db connected
yield
# app teardown
# db connections closed


app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan)
app_east = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan_east)


class Status(BaseModel):
message: str


@app.get("/users", response_model=List[User_Pydantic])
async def get_users():
return await User_Pydantic.from_queryset(Users.all())


@app.post("/users", response_model=User_Pydantic)
async def create_user(user: UserIn_Pydantic):
user_obj = await Users.create(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_tortoise_orm(user_obj)


@app.get("/user/{user_id}", response_model=User_Pydantic)
async def get_user(user_id: int):
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app.put("/user/{user_id}", response_model=User_Pydantic)
async def update_user(user_id: int, user: UserIn_Pydantic):
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app.delete("/user/{user_id}", response_model=Status)
async def delete_user(user_id: int):
deleted_count = await Users.filter(id=user_id).delete()
if not deleted_count:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return Status(message=f"Deleted user {user_id}")


############################ East 8 ############################
@app_east.get("/users_east", response_model=List[User_Pydantic])
async def get_users_east():
return await User_Pydantic.from_queryset(Users.all())


@app_east.post("/users_east", response_model=User_Pydantic)
async def create_user_east(user: UserIn_Pydantic):
user_obj = await Users.create(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_tortoise_orm(user_obj)


@app_east.get("/user_east/{user_id}", response_model=User_Pydantic)
async def get_user_east(user_id: int):
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app_east.put("/user_east/{user_id}", response_model=User_Pydantic)
async def update_user_east(user_id: int, user: UserIn_Pydantic):
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app_east.delete("/user_east/{user_id}", response_model=Status)
async def delete_user_east(user_id: int):
deleted_count = await Users.filter(id=user_id).delete()
if not deleted_count:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return Status(message=f"Deleted user {user_id}")
app.include_router(users_router, prefix="")
25 changes: 25 additions & 0 deletions examples/fastapi/main_custom_timezone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# pylint: disable=E0611,E0401
from contextlib import asynccontextmanager
from typing import AsyncGenerator

from config import register_orm
from fastapi import FastAPI
from routers import router as users_router


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# app startup
async with register_orm(
app,
use_tz=False,
timezone="Asia/Shanghai",
):
# db connected
yield
# app teardown
# db connections closed


app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan)
app.include_router(users_router, prefix="")
5 changes: 0 additions & 5 deletions examples/fastapi/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from tortoise import fields, models
from tortoise.contrib.pydantic import pydantic_model_creator


class Users(models.Model):
Expand Down Expand Up @@ -28,7 +27,3 @@ def full_name(self) -> str:
class PydanticMeta:
computed = ["full_name"]
exclude = ["password_hash"]


User_Pydantic = pydantic_model_creator(Users, name="User")
UserIn_Pydantic = pydantic_model_creator(Users, name="UserIn", exclude_readonly=True)
37 changes: 37 additions & 0 deletions examples/fastapi/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List

from fastapi import APIRouter, HTTPException
from models import Users
from schemas import Status, User_Pydantic, UserIn_Pydantic

router = APIRouter()


@router.get("/users", response_model=List[User_Pydantic])
async def get_users():
return await User_Pydantic.from_queryset(Users.all())


@router.post("/users", response_model=User_Pydantic)
async def create_user(user: UserIn_Pydantic):
user_obj = await Users.create(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_tortoise_orm(user_obj)


@router.get("/user/{user_id}", response_model=User_Pydantic)
async def get_user(user_id: int):
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@router.put("/user/{user_id}", response_model=User_Pydantic)
async def update_user(user_id: int, user: UserIn_Pydantic):
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@router.delete("/user/{user_id}", response_model=Status)
async def delete_user(user_id: int):
deleted_count = await Users.filter(id=user_id).delete()
if not deleted_count:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return Status(message=f"Deleted user {user_id}")
Loading

0 comments on commit 4c14800

Please sign in to comment.