diff --git a/.gitignore b/.gitignore index 1ab5db68..5cf69467 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,6 @@ target/ .mypy_cache pip-wheel-metadata/ .venv/ + + +.idea diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1ea1758b..0241b5b3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,8 @@ Changed jwt.encode({"payload":"abc"}, key=None, algorithm='none') ``` +- Added validation for 'sub' (subject) and 'jti' (JWT ID) claims in tokens + Fixed ~~~~~ diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index c82e6371..fa4d5e6f 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -15,6 +15,8 @@ InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError, + InvalidJTIError, + InvalidSubjectError, MissingRequiredClaimError, ) from .warnings import RemovedInPyjwt3Warning @@ -39,6 +41,8 @@ def _get_default_options() -> dict[str, bool | list[str]]: "verify_iat": True, "verify_aud": True, "verify_iss": True, + "verify_sub": True, + "verify_jti": True, "require": [], } @@ -112,6 +116,7 @@ def decode_complete( # consider putting in options audience: str | Iterable[str] | None = None, issuer: str | Sequence[str] | None = None, + subject: str | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -145,6 +150,8 @@ def decode_complete( options.setdefault("verify_iat", False) options.setdefault("verify_aud", False) options.setdefault("verify_iss", False) + options.setdefault("verify_sub", False) + options.setdefault("verify_jti", False) decoded = api_jws.decode_complete( jwt, @@ -158,7 +165,12 @@ def decode_complete( merged_options = {**self.options, **options} self._validate_claims( - payload, merged_options, audience=audience, issuer=issuer, leeway=leeway + payload, + merged_options, + audience=audience, + issuer=issuer, + leeway=leeway, + subject=subject, ) decoded["payload"] = payload @@ -193,6 +205,7 @@ def decode( # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, + subject: str | None = None, issuer: str | Sequence[str] | None = None, leeway: float | timedelta = 0, # kwargs @@ -214,6 +227,7 @@ def decode( verify=verify, detached_payload=detached_payload, audience=audience, + subject=subject, issuer=issuer, leeway=leeway, ) @@ -225,6 +239,7 @@ def _validate_claims( options: dict[str, Any], audience=None, issuer=None, + subject: str | None = None, leeway: float | timedelta = 0, ) -> None: if isinstance(leeway, timedelta): @@ -254,6 +269,12 @@ def _validate_claims( payload, audience, strict=options.get("strict_aud", False) ) + if options["verify_sub"]: + self._validate_sub(payload, subject) + + if options["verify_jti"]: + self._validate_jti(payload) + def _validate_required_claims( self, payload: dict[str, Any], @@ -263,6 +284,39 @@ def _validate_required_claims( if payload.get(claim) is None: raise MissingRequiredClaimError(claim) + def _validate_sub(self, payload: dict[str, Any], subject=None) -> None: + """ + Checks whether "sub" if in the payload is valid ot not. + This is an Optional claim + + :param payload(dict): The payload which needs to be validated + :param subject(str): The subject of the token + """ + + if "sub" not in payload: + return + + if not isinstance(payload["sub"], str): + raise InvalidSubjectError("Subject must be a string") + + if subject is not None: + if payload.get("sub") != subject: + raise InvalidSubjectError("Invalid subject") + + def _validate_jti(self, payload: dict[str, Any]) -> None: + """ + Checks whether "jti" if in the payload is valid ot not + This is an Optional claim + + :param payload(dict): The payload which needs to be validated + """ + + if "jti" not in payload: + return + + if not isinstance(payload.get("jti"), str): + raise InvalidJTIError("JWT ID must be a string") + def _validate_iat( self, payload: dict[str, Any], diff --git a/jwt/exceptions.py b/jwt/exceptions.py index 0d985882..9b45ae48 100644 --- a/jwt/exceptions.py +++ b/jwt/exceptions.py @@ -72,3 +72,11 @@ class PyJWKClientError(PyJWTError): class PyJWKClientConnectionError(PyJWKClientError): pass + + +class InvalidSubjectError(InvalidTokenError): + pass + + +class InvalidJTIError(InvalidTokenError): + pass diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 3c0a975b..43fa2e21 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -14,6 +14,8 @@ InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError, + InvalidJTIError, + InvalidSubjectError, MissingRequiredClaimError, ) from jwt.utils import base64url_decode @@ -816,3 +818,121 @@ def test_decode_strict_ok(self, jwt, payload): options={"strict_aud": True}, algorithms=["HS256"], ) + + # -------------------- Sub Claim Tests -------------------- + + def test_encode_decode_sub_claim(self, jwt): + payload = { + "sub": "user123", + } + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + decoded = jwt.decode(token, secret, algorithms=["HS256"]) + + assert decoded["sub"] == "user123" + + def test_decode_without_and_not_required_sub_claim(self, jwt): + payload = {} + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + decoded = jwt.decode(token, secret, algorithms=["HS256"]) + + assert "sub" not in decoded + + def test_decode_missing_sub_but_required_claim(self, jwt): + payload = {} + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + with pytest.raises(MissingRequiredClaimError): + jwt.decode( + token, secret, algorithms=["HS256"], options={"require": ["sub"]} + ) + + def test_decode_invalid_int_sub_claim(self, jwt): + payload = { + "sub": 1224344, + } + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + with pytest.raises(InvalidSubjectError): + jwt.decode(token, secret, algorithms=["HS256"]) + + def test_decode_with_valid_sub_claim(self, jwt): + payload = { + "sub": "user123", + } + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + decoded = jwt.decode(token, secret, algorithms=["HS256"], subject="user123") + + assert decoded["sub"] == "user123" + + def test_decode_with_invalid_sub_claim(self, jwt): + payload = { + "sub": "user123", + } + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + with pytest.raises(InvalidSubjectError) as exc_info: + jwt.decode(token, secret, algorithms=["HS256"], subject="user456") + + assert "Invalid subject" in str(exc_info.value) + + def test_decode_with_sub_claim_and_none_subject(self, jwt): + payload = { + "sub": "user789", + } + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + decoded = jwt.decode(token, secret, algorithms=["HS256"], subject=None) + assert decoded["sub"] == "user789" + + # -------------------- JTI Claim Tests -------------------- + + def test_encode_decode_with_valid_jti_claim(self, jwt): + payload = { + "jti": "unique-id-456", + } + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + decoded = jwt.decode(token, secret, algorithms=["HS256"]) + + assert decoded["jti"] == "unique-id-456" + + def test_decode_missing_jti_when_required_claim(self, jwt): + payload = {"name": "Bob", "admin": False} + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + with pytest.raises(MissingRequiredClaimError) as exc_info: + jwt.decode( + token, secret, algorithms=["HS256"], options={"require": ["jti"]} + ) + + assert "jti" in str(exc_info.value) + + def test_decode_missing_jti_claim(self, jwt): + payload = {} + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + decoded = jwt.decode(token, secret, algorithms=["HS256"]) + + assert decoded.get("jti") is None + + def test_jti_claim_with_invalid_int_value(self, jwt): + special_jti = 12223 + payload = { + "jti": special_jti, + } + secret = "your-256-bit-secret" + token = jwt.encode(payload, secret, algorithm="HS256") + + with pytest.raises(InvalidJTIError): + jwt.decode(token, secret, algorithms=["HS256"])