-
-
Notifications
You must be signed in to change notification settings - Fork 383
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* refactor: use schemas and routers to reduce duplicated code in fastapi example * Warm tip for module not found
- Loading branch information
1 parent
717a92a
commit 4c14800
Showing
9 changed files
with
202 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
db.sqlite3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,118 @@ | ||
# mypy: no-disallow-untyped-decorators | ||
# pylint: disable=E0611,E0401 | ||
import datetime | ||
from typing import AsyncGenerator | ||
import os | ||
from contextlib import asynccontextmanager | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import AsyncGenerator, Tuple | ||
|
||
import pytest | ||
import pytz | ||
from asgi_lifespan import LifespanManager | ||
from httpx import ASGITransport, AsyncClient | ||
from main import app, app_east | ||
from models import Users | ||
|
||
from tortoise.contrib.test import MEMORY_SQLITE | ||
from tortoise.fields.data import JSON_LOADS | ||
|
||
os.environ["DB_URL"] = MEMORY_SQLITE | ||
try: | ||
from main import app | ||
from main_custom_timezone import app as app_east | ||
from models import Users | ||
from schemas import User_Pydantic | ||
except ImportError: | ||
if (cwd := Path.cwd()) == (parent := Path(__file__).parent): | ||
dirpath = "." | ||
else: | ||
dirpath = str(parent.relative_to(cwd)) | ||
print(f"You may need to explicitly declare python path:\n\nexport PYTHONPATH={dirpath}\n") | ||
raise | ||
|
||
ClientManagerType = AsyncGenerator[AsyncClient, None] | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def anyio_backend() -> str: | ||
return "asyncio" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
async def client() -> AsyncGenerator[AsyncClient, None]: | ||
@asynccontextmanager | ||
async def client_manager(app, base_url="http://test", **kw) -> ClientManagerType: | ||
async with LifespanManager(app): | ||
transport = ASGITransport(app=app) | ||
async with AsyncClient(transport=transport, base_url="http://test") as c: | ||
async with AsyncClient(transport=transport, base_url=base_url, **kw) as c: | ||
yield c | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_create_user(client: AsyncClient) -> None: # nosec | ||
response = await client.post("/users", json={"username": "admin"}) | ||
assert response.status_code == 200, response.text | ||
data = response.json() | ||
assert data["username"] == "admin" | ||
assert "id" in data | ||
user_id = data["id"] | ||
|
||
user_obj = await Users.get(id=user_id) | ||
assert user_obj.id == user_id | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
async def client_east() -> AsyncGenerator[AsyncClient, None]: | ||
async with LifespanManager(app_east): | ||
transport = ASGITransport(app=app_east) | ||
async with AsyncClient(transport=transport, base_url="http://test") as c: | ||
yield c | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_create_user_east(client_east: AsyncClient) -> None: # nosec | ||
response = await client_east.post("/users_east", json={"username": "admin"}) | ||
assert response.status_code == 200, response.text | ||
data = response.json() | ||
assert data["username"] == "admin" | ||
assert "id" in data | ||
user_id = data["id"] | ||
async def client() -> ClientManagerType: | ||
async with client_manager(app) as c: | ||
yield c | ||
|
||
user_obj = await Users.get(id=user_id) | ||
assert user_obj.id == user_id | ||
|
||
# Verify that the time zone is East 8. | ||
created_at = user_obj.created_at | ||
|
||
# Asia/Shanghai timezone | ||
asia_tz = pytz.timezone("Asia/Shanghai") | ||
asia_now = datetime.datetime.now(pytz.utc).astimezone(asia_tz) | ||
assert created_at.hour - asia_now.hour == 0 | ||
|
||
# UTC timezone | ||
utc_tz = pytz.timezone("UTC") | ||
utc_now = datetime.datetime.now(pytz.utc).astimezone(utc_tz) | ||
assert created_at.hour - utc_now.hour == 8 | ||
@pytest.fixture(scope="module") | ||
async def client_east() -> ClientManagerType: | ||
async with client_manager(app_east) as c: | ||
yield c | ||
|
||
|
||
class UserTester: | ||
async def create_user(self, async_client: AsyncClient) -> Users: | ||
response = await async_client.post("/users", json={"username": "admin"}) | ||
assert response.status_code == 200, response.text | ||
data = response.json() | ||
assert data["username"] == "admin" | ||
assert "id" in data | ||
user_id = data["id"] | ||
|
||
user_obj = await Users.get(id=user_id) | ||
assert user_obj.id == user_id | ||
return user_obj | ||
|
||
async def user_list(self, async_client: AsyncClient) -> Tuple[datetime, Users, User_Pydantic]: | ||
utc_now = datetime.now(pytz.utc) | ||
user_obj = await Users.create(username="test") | ||
response = await async_client.get("/users") | ||
assert response.status_code == 200, response.text | ||
data = response.json() | ||
assert isinstance(data, list) | ||
item = await User_Pydantic.from_tortoise_orm(user_obj) | ||
assert JSON_LOADS(item.model_dump_json()) in data | ||
return utc_now, user_obj, item | ||
|
||
|
||
class TestUser(UserTester): | ||
@pytest.mark.anyio | ||
async def test_create_user(self, client: AsyncClient) -> None: # nosec | ||
await self.create_user(client) | ||
|
||
@pytest.mark.anyio | ||
async def test_user_list(self, client: AsyncClient) -> None: # nosec | ||
await self.user_list(client) | ||
|
||
|
||
class TestUserEast(UserTester): | ||
timezone = "Asia/Shanghai" | ||
delta_hours = 8 | ||
|
||
@pytest.mark.anyio | ||
async def test_create_user_east(self, client_east: AsyncClient) -> None: # nosec | ||
user_obj = await self.create_user(client_east) | ||
created_at = user_obj.created_at | ||
|
||
# Verify time zone | ||
asia_tz = pytz.timezone(self.timezone) | ||
asia_now = datetime.now(pytz.utc).astimezone(asia_tz) | ||
assert created_at.hour - asia_now.hour == 0 | ||
|
||
# UTC timezone | ||
utc_tz = pytz.timezone("UTC") | ||
utc_now = datetime.now(pytz.utc).astimezone(utc_tz) | ||
assert (created_at.hour - utc_now.hour) in [self.delta_hours, self.delta_hours - 24] | ||
|
||
@pytest.mark.anyio | ||
async def test_user_list(self, client_east: AsyncClient) -> None: # nosec | ||
time, user_obj, item = await self.user_list(client_east) | ||
created_at = user_obj.created_at | ||
assert (created_at.hour - time.hour) in [self.delta_hours, self.delta_hours - 24] | ||
assert item.model_dump()["created_at"].hour == created_at.hour |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
from functools import partial | ||
|
||
from tortoise.contrib.fastapi import RegisterTortoise | ||
|
||
register_orm = partial( | ||
RegisterTortoise, | ||
db_url=os.getenv("DB_URL", "sqlite://db.sqlite3"), | ||
modules={"models": ["models"]}, | ||
generate_schemas=True, | ||
add_exception_handlers=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,124 +1,21 @@ | ||
# pylint: disable=E0611,E0401 | ||
from contextlib import asynccontextmanager | ||
from typing import TYPE_CHECKING, AsyncGenerator, List | ||
from typing import AsyncGenerator | ||
|
||
from fastapi import FastAPI, HTTPException | ||
from models import Users | ||
from pydantic import BaseModel | ||
|
||
from tortoise.contrib.fastapi import RegisterTortoise | ||
from tortoise.contrib.pydantic import PydanticModel | ||
|
||
if TYPE_CHECKING: # pragma: nocoverage | ||
|
||
class UserIn_Pydantic(Users, PydanticModel): # type:ignore[misc] | ||
pass | ||
|
||
class User_Pydantic(Users, PydanticModel): # type:ignore[misc] | ||
pass | ||
|
||
else: | ||
from models import User_Pydantic, UserIn_Pydantic | ||
from config import register_orm | ||
from fastapi import FastAPI | ||
from routers import router as users_router | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | ||
# app startup | ||
async with RegisterTortoise( | ||
app, | ||
db_url="sqlite://:memory:", | ||
modules={"models": ["models"]}, | ||
generate_schemas=True, | ||
add_exception_handlers=True, | ||
): | ||
# db connected | ||
yield | ||
# app teardown | ||
# db connections closed | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan_east(app: FastAPI) -> AsyncGenerator[None, None]: | ||
# app startup | ||
async with RegisterTortoise( | ||
app, | ||
db_url="sqlite://:memory:", | ||
modules={"models": ["models"]}, | ||
generate_schemas=True, | ||
add_exception_handlers=True, | ||
use_tz=False, | ||
timezone="Asia/Shanghai", | ||
): | ||
async with register_orm(app): | ||
# db connected | ||
yield | ||
# app teardown | ||
# db connections closed | ||
|
||
|
||
app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan) | ||
app_east = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan_east) | ||
|
||
|
||
class Status(BaseModel): | ||
message: str | ||
|
||
|
||
@app.get("/users", response_model=List[User_Pydantic]) | ||
async def get_users(): | ||
return await User_Pydantic.from_queryset(Users.all()) | ||
|
||
|
||
@app.post("/users", response_model=User_Pydantic) | ||
async def create_user(user: UserIn_Pydantic): | ||
user_obj = await Users.create(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_tortoise_orm(user_obj) | ||
|
||
|
||
@app.get("/user/{user_id}", response_model=User_Pydantic) | ||
async def get_user(user_id: int): | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app.put("/user/{user_id}", response_model=User_Pydantic) | ||
async def update_user(user_id: int, user: UserIn_Pydantic): | ||
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app.delete("/user/{user_id}", response_model=Status) | ||
async def delete_user(user_id: int): | ||
deleted_count = await Users.filter(id=user_id).delete() | ||
if not deleted_count: | ||
raise HTTPException(status_code=404, detail=f"User {user_id} not found") | ||
return Status(message=f"Deleted user {user_id}") | ||
|
||
|
||
############################ East 8 ############################ | ||
@app_east.get("/users_east", response_model=List[User_Pydantic]) | ||
async def get_users_east(): | ||
return await User_Pydantic.from_queryset(Users.all()) | ||
|
||
|
||
@app_east.post("/users_east", response_model=User_Pydantic) | ||
async def create_user_east(user: UserIn_Pydantic): | ||
user_obj = await Users.create(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_tortoise_orm(user_obj) | ||
|
||
|
||
@app_east.get("/user_east/{user_id}", response_model=User_Pydantic) | ||
async def get_user_east(user_id: int): | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app_east.put("/user_east/{user_id}", response_model=User_Pydantic) | ||
async def update_user_east(user_id: int, user: UserIn_Pydantic): | ||
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app_east.delete("/user_east/{user_id}", response_model=Status) | ||
async def delete_user_east(user_id: int): | ||
deleted_count = await Users.filter(id=user_id).delete() | ||
if not deleted_count: | ||
raise HTTPException(status_code=404, detail=f"User {user_id} not found") | ||
return Status(message=f"Deleted user {user_id}") | ||
app.include_router(users_router, prefix="") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# pylint: disable=E0611,E0401 | ||
from contextlib import asynccontextmanager | ||
from typing import AsyncGenerator | ||
|
||
from config import register_orm | ||
from fastapi import FastAPI | ||
from routers import router as users_router | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | ||
# app startup | ||
async with register_orm( | ||
app, | ||
use_tz=False, | ||
timezone="Asia/Shanghai", | ||
): | ||
# db connected | ||
yield | ||
# app teardown | ||
# db connections closed | ||
|
||
|
||
app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan) | ||
app.include_router(users_router, prefix="") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import List | ||
|
||
from fastapi import APIRouter, HTTPException | ||
from models import Users | ||
from schemas import Status, User_Pydantic, UserIn_Pydantic | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get("/users", response_model=List[User_Pydantic]) | ||
async def get_users(): | ||
return await User_Pydantic.from_queryset(Users.all()) | ||
|
||
|
||
@router.post("/users", response_model=User_Pydantic) | ||
async def create_user(user: UserIn_Pydantic): | ||
user_obj = await Users.create(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_tortoise_orm(user_obj) | ||
|
||
|
||
@router.get("/user/{user_id}", response_model=User_Pydantic) | ||
async def get_user(user_id: int): | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@router.put("/user/{user_id}", response_model=User_Pydantic) | ||
async def update_user(user_id: int, user: UserIn_Pydantic): | ||
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@router.delete("/user/{user_id}", response_model=Status) | ||
async def delete_user(user_id: int): | ||
deleted_count = await Users.filter(id=user_id).delete() | ||
if not deleted_count: | ||
raise HTTPException(status_code=404, detail=f"User {user_id} not found") | ||
return Status(message=f"Deleted user {user_id}") |
Oops, something went wrong.