Skip to content

Commit

Permalink
Pack permissions in session and access token
Browse files Browse the repository at this point in the history
  • Loading branch information
gmazoyer committed Jul 22, 2024
1 parent 8e13c2a commit f6d1899
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
5 changes: 3 additions & 2 deletions backend/infrahub/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ async def get_db(request: Request) -> InfrahubDatabase:

async def get_access_token(
request: Request,
db: InfrahubDatabase = Depends(get_db),
jwt_header: HTTPAuthorizationCredentials = Depends(jwt_scheme),
) -> AccountSession:
if jwt_header:
return await validate_jwt_access_token(token=jwt_header.credentials)
return await validate_jwt_access_token(db=db, token=jwt_header.credentials)
if token := request.cookies.get("access_token"):
return await validate_jwt_access_token(token=token)
return await validate_jwt_access_token(db=db, token=token)

raise AuthorizationError("A JWT access token is required to perform this operation.")

Expand Down
24 changes: 16 additions & 8 deletions backend/infrahub/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel

from infrahub import config, models
from infrahub.core.account import validate_token
from infrahub.core.account import fetch_permissions, validate_token
from infrahub.core.constants import InfrahubKind
from infrahub.core.manager import NodeManager
from infrahub.core.node import Node
Expand All @@ -35,6 +35,7 @@ class AccountSession(BaseModel):
account_id: str
session_id: Optional[str] = None
role: str = "read-only"
permissions: Optional[list[str]] = None
auth_type: AuthType

@property
Expand Down Expand Up @@ -75,7 +76,9 @@ async def authenticate_with_password(

refresh_expires = now + timedelta(seconds=config.SETTINGS.security.refresh_token_lifetime)
session_id = await create_db_refresh_token(db=db, account_id=account.id, expiration=refresh_expires)
access_token = generate_access_token(account_id=account.id, role=account.role.value.value, session_id=session_id)
access_token = await generate_access_token(
db=db, account_id=account.id, role=account.role.value.value, session_id=session_id
)
refresh_token = generate_refresh_token(account_id=account.id, session_id=session_id, expiration=refresh_expires)

return models.UserToken(access_token=access_token, refresh_token=refresh_token)
Expand Down Expand Up @@ -106,14 +109,15 @@ async def create_fresh_access_token(
message="That login user doesn't exist in the system",
)

access_token = generate_access_token(
account_id=account.id, role=account.role.value.value, session_id=refresh_data.session_id
access_token = await generate_access_token(
db=db, account_id=account.id, role=account.role.value.value, session_id=refresh_data.session_id
)

return models.AccessTokenResponse(access_token=access_token)


def generate_access_token(account_id: str, role: str, session_id: uuid.UUID) -> str:
async def generate_access_token(db: InfrahubDatabase, account_id: str, role: str, session_id: uuid.UUID) -> str:
selected_branch = await registry.get_branch(db=db)
now = datetime.now(tz=timezone.utc)

access_expires = now + timedelta(seconds=config.SETTINGS.security.access_token_lifetime)
Expand All @@ -126,6 +130,7 @@ def generate_access_token(account_id: str, role: str, session_id: uuid.UUID) ->
"type": "access",
"session_id": str(session_id),
"user_claims": {"role": role},
"permissions": await fetch_permissions(db=db, branch=selected_branch, account_id=account_id),
}
access_token = jwt.encode(access_data, config.SETTINGS.security.secret_key, algorithm="HS256")
return access_token
Expand Down Expand Up @@ -153,24 +158,27 @@ async def authentication_token(
if api_key:
return await validate_api_key(db=db, token=api_key)
if jwt_token:
return await validate_jwt_access_token(token=jwt_token)
return await validate_jwt_access_token(db=db, token=jwt_token)

return AccountSession(authenticated=False, account_id="anonymous", auth_type=AuthType.NONE)


async def validate_jwt_access_token(token: str) -> AccountSession:
async def validate_jwt_access_token(db: InfrahubDatabase, token: str) -> AccountSession:
try:
payload = jwt.decode(token, config.SETTINGS.security.secret_key, algorithms=["HS256"])
account_id = payload["sub"]
role = payload["user_claims"]["role"]
session_id = payload["session_id"]
permissions = await fetch_permissions(db=db, account_id=account_id)
except jwt.ExpiredSignatureError:
raise AuthorizationError("Expired Signature") from None
except Exception:
raise AuthorizationError("Invalid token") from None

if payload["type"] == "access":
return AccountSession(account_id=account_id, role=role, session_id=session_id, auth_type=AuthType.JWT)
return AccountSession(
account_id=account_id, role=role, session_id=session_id, permissions=permissions, auth_type=AuthType.JWT
)

raise AuthorizationError("Invalid token, current token is not an access token")

Expand Down

0 comments on commit f6d1899

Please sign in to comment.