Skip to content

Commit

Permalink
Added display_name_claim in jwt_config which sets the user's display …
Browse files Browse the repository at this point in the history
…name upon registration (#17708)
  • Loading branch information
EnneS authored Oct 9, 2024
1 parent 60aebdb commit 05576f0
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 6 deletions.
1 change: 1 addition & 0 deletions changelog.d/17708.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added the `display_name_claim` option to the JWT configuration. This option allows specifying the claim key that contains the user's display name in the JWT payload.
3 changes: 3 additions & 0 deletions docs/usage/configuration/config_documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -3722,6 +3722,8 @@ Additional sub-options for this setting include:
Required if `enabled` is set to true.
* `subject_claim`: Name of the claim containing a unique identifier for the user.
Optional, defaults to `sub`.
* `display_name_claim`: Name of the claim containing the display name for the user. Optional.
If provided, the display name will be set to the value of this claim upon first login.
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
"iss" claim will be required and validated for all JSON web tokens.
* `audiences`: A list of audiences to validate the "aud" claim against. Optional.
Expand All @@ -3736,6 +3738,7 @@ jwt_config:
secret: "provided-by-your-issuer"
algorithm: "provided-by-your-issuer"
subject_claim: "name_of_claim"
display_name_claim: "name_of_claim"
issuer: "provided-by-your-issuer"
audiences:
- "provided-by-your-issuer"
Expand Down
2 changes: 2 additions & 0 deletions synapse/config/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.jwt_algorithm = jwt_config["algorithm"]

self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
self.jwt_display_name_claim = jwt_config.get("display_name_claim")

# The issuer and audiences are optional, if provided, it is asserted
# that the claims exist on the JWT.
Expand All @@ -49,5 +50,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.jwt_secret = None
self.jwt_algorithm = None
self.jwt_subject_claim = None
self.jwt_display_name_claim = None
self.jwt_issuer = None
self.jwt_audiences = None
16 changes: 12 additions & 4 deletions synapse/handlers/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple

from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
Expand All @@ -36,11 +36,12 @@ def __init__(self, hs: "HomeServer"):

self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences

def validate_login(self, login_submission: JsonDict) -> str:
def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]:
"""
Authenticates the user for the /login API
Expand All @@ -49,7 +50,8 @@ def validate_login(self, login_submission: JsonDict) -> str:
(including 'type' and other relevant fields)
Returns:
The user ID that is logging in.
A tuple of (user_id, display_name) of the user that is logging in.
If the JWT does not contain a display name, the second element of the tuple will be None.
Raises:
LoginError if there was an authentication problem.
Expand Down Expand Up @@ -109,4 +111,10 @@ def validate_login(self, login_submission: JsonDict) -> str:
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

return UserID(user, self.hs.hostname).to_string()
default_display_name = None
if self.jwt_display_name_claim:
display_name_claim = claims.get(self.jwt_display_name_claim)
if display_name_claim is not None:
default_display_name = display_name_claim

return UserID(user, self.hs.hostname).to_string(), default_display_name
9 changes: 7 additions & 2 deletions synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ async def _complete_login(
login_submission: JsonDict,
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
default_display_name: Optional[str] = None,
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
Expand Down Expand Up @@ -410,7 +411,8 @@ async def _complete_login(
canonical_uid = await self.auth_handler.check_user_exists(user_id)
if not canonical_uid:
canonical_uid = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart
localpart=UserID.from_string(user_id).localpart,
default_display_name=default_display_name,
)
user_id = canonical_uid

Expand Down Expand Up @@ -546,11 +548,14 @@ async def _do_jwt_login(
Returns:
The body of the JSON response.
"""
user_id = self.hs.get_jwt_handler().validate_login(login_submission)
user_id, default_display_name = self.hs.get_jwt_handler().validate_login(
login_submission
)
return await self._complete_login(
user_id,
login_submission,
create_non_existent_users=True,
default_display_name=default_display_name,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
Expand Down
25 changes: 25 additions & 0 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
profile.register_servlets,
]

jwt_secret = "secret"
Expand Down Expand Up @@ -1202,6 +1203,30 @@ def test_login_custom_sub(self) -> None:
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")

@override_config(
{"jwt_config": {**base_config, "display_name_claim": "display_name"}}
)
def test_login_custom_display_name(self) -> None:
"""Test setting a custom display name."""
localpart = "pinkie"
user_id = f"@{localpart}:test"
display_name = "Pinkie Pie"

# Perform the login, specifying a custom display name.
channel = self.jwt_login({"sub": localpart, "display_name": display_name})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], user_id)

# Fetch the user's display name and check that it was set correctly.
access_token = channel.json_body["access_token"]
channel = self.make_request(
"GET",
f"/_matrix/client/v3/profile/{user_id}/displayname",
access_token=access_token,
)
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["displayname"], display_name)

def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
Expand Down

0 comments on commit 05576f0

Please sign in to comment.