Skip to content

Commit

Permalink
Validate sub and jti claims for the token (#1005)
Browse files Browse the repository at this point in the history
* feat(jwt): Both JTI and sub are now being validated, test cases added, changelog updated

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Divan009 and pre-commit-ci[bot] authored Oct 13, 2024
1 parent 310962b commit 6dc23b3
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ target/
.mypy_cache
pip-wheel-metadata/
.venv/


.idea
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Changed
jwt.encode({"payload":"abc"}, key=None, algorithm='none')
```

- Added validation for 'sub' (subject) and 'jti' (JWT ID) claims in tokens

Fixed
~~~~~

Expand Down
56 changes: 55 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidJTIError,
InvalidSubjectError,
MissingRequiredClaimError,
)
from .warnings import RemovedInPyjwt3Warning
Expand All @@ -39,6 +41,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
"verify_sub": True,
"verify_jti": True,
"require": [],
}

Expand Down Expand Up @@ -112,6 +116,7 @@ def decode_complete(
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | Sequence[str] | None = None,
subject: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
Expand Down Expand Up @@ -145,6 +150,8 @@ def decode_complete(
options.setdefault("verify_iat", False)
options.setdefault("verify_aud", False)
options.setdefault("verify_iss", False)
options.setdefault("verify_sub", False)
options.setdefault("verify_jti", False)

decoded = api_jws.decode_complete(
jwt,
Expand All @@ -158,7 +165,12 @@ def decode_complete(

merged_options = {**self.options, **options}
self._validate_claims(
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
payload,
merged_options,
audience=audience,
issuer=issuer,
leeway=leeway,
subject=subject,
)

decoded["payload"] = payload
Expand Down Expand Up @@ -193,6 +205,7 @@ def decode(
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
subject: str | None = None,
issuer: str | Sequence[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
Expand All @@ -214,6 +227,7 @@ def decode(
verify=verify,
detached_payload=detached_payload,
audience=audience,
subject=subject,
issuer=issuer,
leeway=leeway,
)
Expand All @@ -225,6 +239,7 @@ def _validate_claims(
options: dict[str, Any],
audience=None,
issuer=None,
subject: str | None = None,
leeway: float | timedelta = 0,
) -> None:
if isinstance(leeway, timedelta):
Expand Down Expand Up @@ -254,6 +269,12 @@ def _validate_claims(
payload, audience, strict=options.get("strict_aud", False)
)

if options["verify_sub"]:
self._validate_sub(payload, subject)

if options["verify_jti"]:
self._validate_jti(payload)

def _validate_required_claims(
self,
payload: dict[str, Any],
Expand All @@ -263,6 +284,39 @@ def _validate_required_claims(
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)

def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
"""
Checks whether "sub" if in the payload is valid ot not.
This is an Optional claim
:param payload(dict): The payload which needs to be validated
:param subject(str): The subject of the token
"""

if "sub" not in payload:
return

if not isinstance(payload["sub"], str):
raise InvalidSubjectError("Subject must be a string")

if subject is not None:
if payload.get("sub") != subject:
raise InvalidSubjectError("Invalid subject")

def _validate_jti(self, payload: dict[str, Any]) -> None:
"""
Checks whether "jti" if in the payload is valid ot not
This is an Optional claim
:param payload(dict): The payload which needs to be validated
"""

if "jti" not in payload:
return

if not isinstance(payload.get("jti"), str):
raise InvalidJTIError("JWT ID must be a string")

def _validate_iat(
self,
payload: dict[str, Any],
Expand Down
8 changes: 8 additions & 0 deletions jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ class PyJWKClientError(PyJWTError):

class PyJWKClientConnectionError(PyJWKClientError):
pass


class InvalidSubjectError(InvalidTokenError):
pass


class InvalidJTIError(InvalidTokenError):
pass
120 changes: 120 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidJTIError,
InvalidSubjectError,
MissingRequiredClaimError,
)
from jwt.utils import base64url_decode
Expand Down Expand Up @@ -816,3 +818,121 @@ def test_decode_strict_ok(self, jwt, payload):
options={"strict_aud": True},
algorithms=["HS256"],
)

# -------------------- Sub Claim Tests --------------------

def test_encode_decode_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")
decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded["sub"] == "user123"

def test_decode_without_and_not_required_sub_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert "sub" not in decoded

def test_decode_missing_sub_but_required_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(MissingRequiredClaimError):
jwt.decode(
token, secret, algorithms=["HS256"], options={"require": ["sub"]}
)

def test_decode_invalid_int_sub_claim(self, jwt):
payload = {
"sub": 1224344,
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidSubjectError):
jwt.decode(token, secret, algorithms=["HS256"])

def test_decode_with_valid_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"], subject="user123")

assert decoded["sub"] == "user123"

def test_decode_with_invalid_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidSubjectError) as exc_info:
jwt.decode(token, secret, algorithms=["HS256"], subject="user456")

assert "Invalid subject" in str(exc_info.value)

def test_decode_with_sub_claim_and_none_subject(self, jwt):
payload = {
"sub": "user789",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"], subject=None)
assert decoded["sub"] == "user789"

# -------------------- JTI Claim Tests --------------------

def test_encode_decode_with_valid_jti_claim(self, jwt):
payload = {
"jti": "unique-id-456",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")
decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded["jti"] == "unique-id-456"

def test_decode_missing_jti_when_required_claim(self, jwt):
payload = {"name": "Bob", "admin": False}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(MissingRequiredClaimError) as exc_info:
jwt.decode(
token, secret, algorithms=["HS256"], options={"require": ["jti"]}
)

assert "jti" in str(exc_info.value)

def test_decode_missing_jti_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded.get("jti") is None

def test_jti_claim_with_invalid_int_value(self, jwt):
special_jti = 12223
payload = {
"jti": special_jti,
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidJTIError):
jwt.decode(token, secret, algorithms=["HS256"])

0 comments on commit 6dc23b3

Please sign in to comment.