Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit a15a893

Browse files
authored
Save the OIDC session ID (sid) with the device on login (#11482)
As a step towards allowing back-channel logout for OIDC.
1 parent 8b4b153 commit a15a893

File tree

15 files changed

+370
-65
lines changed

15 files changed

+370
-65
lines changed

changelog.d/11482.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Save the OpenID Connect session ID on login.

synapse/handlers/auth.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import bcrypt
4040
import pymacaroons
4141
import unpaddedbase64
42+
from pymacaroons.exceptions import MacaroonVerificationFailedException
4243

4344
from twisted.web.server import Request
4445

@@ -182,8 +183,11 @@ class LoginTokenAttributes:
182183

183184
user_id = attr.ib(type=str)
184185

185-
# the SSO Identity Provider that the user authenticated with, to get this token
186186
auth_provider_id = attr.ib(type=str)
187+
"""The SSO Identity Provider that the user authenticated with, to get this token."""
188+
189+
auth_provider_session_id = attr.ib(type=Optional[str])
190+
"""The session ID advertised by the SSO Identity Provider."""
187191

188192

189193
class AuthHandler:
@@ -1650,6 +1654,7 @@ async def complete_sso_login(
16501654
client_redirect_url: str,
16511655
extra_attributes: Optional[JsonDict] = None,
16521656
new_user: bool = False,
1657+
auth_provider_session_id: Optional[str] = None,
16531658
) -> None:
16541659
"""Having figured out a mxid for this user, complete the HTTP request
16551660
@@ -1665,6 +1670,7 @@ async def complete_sso_login(
16651670
during successful login. Must be JSON serializable.
16661671
new_user: True if we should use wording appropriate to a user who has just
16671672
registered.
1673+
auth_provider_session_id: The session ID from the SSO IdP received during login.
16681674
"""
16691675
# If the account has been deactivated, do not proceed with the login
16701676
# flow.
@@ -1685,6 +1691,7 @@ async def complete_sso_login(
16851691
extra_attributes,
16861692
new_user=new_user,
16871693
user_profile_data=profile,
1694+
auth_provider_session_id=auth_provider_session_id,
16881695
)
16891696

16901697
def _complete_sso_login(
@@ -1696,6 +1703,7 @@ def _complete_sso_login(
16961703
extra_attributes: Optional[JsonDict] = None,
16971704
new_user: bool = False,
16981705
user_profile_data: Optional[ProfileInfo] = None,
1706+
auth_provider_session_id: Optional[str] = None,
16991707
) -> None:
17001708
"""
17011709
The synchronous portion of complete_sso_login.
@@ -1717,7 +1725,9 @@ def _complete_sso_login(
17171725

17181726
# Create a login token
17191727
login_token = self.macaroon_gen.generate_short_term_login_token(
1720-
registered_user_id, auth_provider_id=auth_provider_id
1728+
registered_user_id,
1729+
auth_provider_id=auth_provider_id,
1730+
auth_provider_session_id=auth_provider_session_id,
17211731
)
17221732

17231733
# Append the login token to the original redirect URL (i.e. with its query
@@ -1822,6 +1832,7 @@ def generate_short_term_login_token(
18221832
self,
18231833
user_id: str,
18241834
auth_provider_id: str,
1835+
auth_provider_session_id: Optional[str] = None,
18251836
duration_in_ms: int = (2 * 60 * 1000),
18261837
) -> str:
18271838
macaroon = self._generate_base_macaroon(user_id)
@@ -1830,6 +1841,10 @@ def generate_short_term_login_token(
18301841
expiry = now + duration_in_ms
18311842
macaroon.add_first_party_caveat("time < %d" % (expiry,))
18321843
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
1844+
if auth_provider_session_id is not None:
1845+
macaroon.add_first_party_caveat(
1846+
"auth_provider_session_id = %s" % (auth_provider_session_id,)
1847+
)
18331848
return macaroon.serialize()
18341849

18351850
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1851,15 +1866,28 @@ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
18511866
user_id = get_value_from_macaroon(macaroon, "user_id")
18521867
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
18531868

1869+
auth_provider_session_id: Optional[str] = None
1870+
try:
1871+
auth_provider_session_id = get_value_from_macaroon(
1872+
macaroon, "auth_provider_session_id"
1873+
)
1874+
except MacaroonVerificationFailedException:
1875+
pass
1876+
18541877
v = pymacaroons.Verifier()
18551878
v.satisfy_exact("gen = 1")
18561879
v.satisfy_exact("type = login")
18571880
v.satisfy_general(lambda c: c.startswith("user_id = "))
18581881
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
1882+
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
18591883
satisfy_expiry(v, self.hs.get_clock().time_msec)
18601884
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
18611885

1862-
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
1886+
return LoginTokenAttributes(
1887+
user_id=user_id,
1888+
auth_provider_id=auth_provider_id,
1889+
auth_provider_session_id=auth_provider_session_id,
1890+
)
18631891

18641892
def generate_delete_pusher_token(self, user_id: str) -> str:
18651893
macaroon = self._generate_base_macaroon(user_id)

synapse/handlers/device.py

+8
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ async def check_device_registered(
301301
user_id: str,
302302
device_id: Optional[str],
303303
initial_device_display_name: Optional[str] = None,
304+
auth_provider_id: Optional[str] = None,
305+
auth_provider_session_id: Optional[str] = None,
304306
) -> str:
305307
"""
306308
If the given device has not been registered, register it with the
@@ -312,6 +314,8 @@ async def check_device_registered(
312314
user_id: @user:id
313315
device_id: device id supplied by client
314316
initial_device_display_name: device display name from client
317+
auth_provider_id: The SSO IdP the user used, if any.
318+
auth_provider_session_id: The session ID (sid) got from the SSO IdP.
315319
Returns:
316320
device id (generated if none was supplied)
317321
"""
@@ -323,6 +327,8 @@ async def check_device_registered(
323327
user_id=user_id,
324328
device_id=device_id,
325329
initial_device_display_name=initial_device_display_name,
330+
auth_provider_id=auth_provider_id,
331+
auth_provider_session_id=auth_provider_session_id,
326332
)
327333
if new_device:
328334
await self.notify_device_update(user_id, [device_id])
@@ -337,6 +343,8 @@ async def check_device_registered(
337343
user_id=user_id,
338344
device_id=new_device_id,
339345
initial_device_display_name=initial_device_display_name,
346+
auth_provider_id=auth_provider_id,
347+
auth_provider_session_id=auth_provider_session_id,
340348
)
341349
if new_device:
342350
await self.notify_device_update(user_id, [new_device_id])

synapse/handlers/oidc.py

+35-23
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from authlib.jose import JsonWebToken, jwt
2424
from authlib.oauth2.auth import ClientAuth
2525
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
26-
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
26+
from authlib.oidc.core import CodeIDToken, UserInfo
2727
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
2828
from jinja2 import Environment, Template
2929
from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ async def load_metadata(self) -> None:
117117
for idp_id, p in self._providers.items():
118118
try:
119119
await p.load_metadata()
120-
await p.load_jwks()
120+
if not p._uses_userinfo:
121+
await p.load_jwks()
121122
except Exception as e:
122123
raise Exception(
123124
"Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ async def load_jwks(self, force: bool = False) -> JWKS:
498499
return await self._jwks.get()
499500

500501
async def _load_jwks(self) -> JWKS:
501-
if self._uses_userinfo:
502-
# We're not using jwt signing, return an empty jwk set
503-
return {"keys": []}
504-
505502
metadata = await self.load_metadata()
506503

507504
# Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ async def _fetch_userinfo(self, token: Token) -> UserInfo:
663660

664661
return UserInfo(resp)
665662

666-
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
663+
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
667664
"""Return an instance of UserInfo from token's ``id_token``.
668665
669666
Args:
@@ -673,7 +670,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
673670
request. This value should match the one inside the token.
674671
675672
Returns:
676-
An object representing the user.
673+
The decoded claims in the ID token.
677674
"""
678675
metadata = await self.load_metadata()
679676
claims_params = {
@@ -684,9 +681,6 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
684681
# If we got an `access_token`, there should be an `at_hash` claim
685682
# in the `id_token` that we can check against.
686683
claims_params["access_token"] = token["access_token"]
687-
claims_cls = CodeIDToken
688-
else:
689-
claims_cls = ImplicitIDToken
690684

691685
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
692686
jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
703697
claims = jwt.decode(
704698
id_token,
705699
key=jwk_set,
706-
claims_cls=claims_cls,
700+
claims_cls=CodeIDToken,
707701
claims_options=claim_options,
708702
claims_params=claims_params,
709703
)
@@ -713,15 +707,16 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
713707
claims = jwt.decode(
714708
id_token,
715709
key=jwk_set,
716-
claims_cls=claims_cls,
710+
claims_cls=CodeIDToken,
717711
claims_options=claim_options,
718712
claims_params=claims_params,
719713
)
720714

721715
logger.debug("Decoded id_token JWT %r; validating", claims)
722716

723717
claims.validate(leeway=120) # allows 2 min of clock skew
724-
return UserInfo(claims)
718+
719+
return claims
725720

726721
async def handle_redirect_request(
727722
self,
@@ -837,22 +832,37 @@ async def handle_oidc_callback(
837832

838833
logger.debug("Successfully obtained OAuth2 token data: %r", token)
839834

840-
# Now that we have a token, get the userinfo, either by decoding the
841-
# `id_token` or by fetching the `userinfo_endpoint`.
835+
# If there is an id_token, it should be validated, regardless of the
836+
# userinfo endpoint is used or not.
837+
if token.get("id_token") is not None:
838+
try:
839+
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
840+
sid = id_token.get("sid")
841+
except Exception as e:
842+
logger.exception("Invalid id_token")
843+
self._sso_handler.render_error(request, "invalid_token", str(e))
844+
return
845+
else:
846+
id_token = None
847+
sid = None
848+
849+
# Now that we have a token, get the userinfo either from the `id_token`
850+
# claims or by fetching the `userinfo_endpoint`.
842851
if self._uses_userinfo:
843852
try:
844853
userinfo = await self._fetch_userinfo(token)
845854
except Exception as e:
846855
logger.exception("Could not fetch userinfo")
847856
self._sso_handler.render_error(request, "fetch_error", str(e))
848857
return
858+
elif id_token is not None:
859+
userinfo = UserInfo(id_token)
849860
else:
850-
try:
851-
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
852-
except Exception as e:
853-
logger.exception("Invalid id_token")
854-
self._sso_handler.render_error(request, "invalid_token", str(e))
855-
return
861+
logger.error("Missing id_token in token response")
862+
self._sso_handler.render_error(
863+
request, "invalid_token", "Missing id_token in token response"
864+
)
865+
return
856866

857867
# first check if we're doing a UIA
858868
if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ async def handle_oidc_callback(
884894
# Call the mapper to register/login the user
885895
try:
886896
await self._complete_oidc_login(
887-
userinfo, token, request, session_data.client_redirect_url
897+
userinfo, token, request, session_data.client_redirect_url, sid
888898
)
889899
except MappingException as e:
890900
logger.exception("Could not map user")
@@ -896,6 +906,7 @@ async def _complete_oidc_login(
896906
token: Token,
897907
request: SynapseRequest,
898908
client_redirect_url: str,
909+
sid: Optional[str],
899910
) -> None:
900911
"""Given a UserInfo response, complete the login flow
901912
@@ -1008,6 +1019,7 @@ async def grandfather_existing_users() -> Optional[str]:
10081019
oidc_response_to_user_attributes,
10091020
grandfather_existing_users,
10101021
extra_attributes,
1022+
auth_provider_session_id=sid,
10111023
)
10121024

10131025
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:

synapse/handlers/register.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ async def register_device(
746746
is_appservice_ghost: bool = False,
747747
auth_provider_id: Optional[str] = None,
748748
should_issue_refresh_token: bool = False,
749+
auth_provider_session_id: Optional[str] = None,
749750
) -> Tuple[str, str, Optional[int], Optional[str]]:
750751
"""Register a device for a user and generate an access token.
751752
@@ -756,9 +757,9 @@ async def register_device(
756757
device_id: The device ID to check, or None to generate a new one.
757758
initial_display_name: An optional display name for the device.
758759
is_guest: Whether this is a guest account
759-
auth_provider_id: The SSO IdP the user used, if any (just used for the
760-
prometheus metrics).
760+
auth_provider_id: The SSO IdP the user used, if any.
761761
should_issue_refresh_token: Whether it should also issue a refresh token
762+
auth_provider_session_id: The session ID received during login from the SSO IdP.
762763
Returns:
763764
Tuple of device ID, access token, access token expiration time and refresh token
764765
"""
@@ -769,6 +770,8 @@ async def register_device(
769770
is_guest=is_guest,
770771
is_appservice_ghost=is_appservice_ghost,
771772
should_issue_refresh_token=should_issue_refresh_token,
773+
auth_provider_id=auth_provider_id,
774+
auth_provider_session_id=auth_provider_session_id,
772775
)
773776

774777
login_counter.labels(
@@ -791,6 +794,8 @@ async def register_device_inner(
791794
is_guest: bool = False,
792795
is_appservice_ghost: bool = False,
793796
should_issue_refresh_token: bool = False,
797+
auth_provider_id: Optional[str] = None,
798+
auth_provider_session_id: Optional[str] = None,
794799
) -> LoginDict:
795800
"""Helper for register_device
796801
@@ -822,7 +827,11 @@ class and RegisterDeviceReplicationServlet.
822827
refresh_token_id = None
823828

824829
registered_device_id = await self.device_handler.check_device_registered(
825-
user_id, device_id, initial_display_name
830+
user_id,
831+
device_id,
832+
initial_display_name,
833+
auth_provider_id=auth_provider_id,
834+
auth_provider_session_id=auth_provider_session_id,
826835
)
827836
if is_guest:
828837
assert access_token_expiry is None

synapse/handlers/sso.py

+4
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ async def complete_sso_login_request(
365365
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
366366
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
367367
extra_login_attributes: Optional[JsonDict] = None,
368+
auth_provider_session_id: Optional[str] = None,
368369
) -> None:
369370
"""
370371
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -415,6 +416,8 @@ async def complete_sso_login_request(
415416
extra_login_attributes: An optional dictionary of extra
416417
attributes to be provided to the client in the login response.
417418
419+
auth_provider_session_id: An optional session ID from the IdP.
420+
418421
Raises:
419422
MappingException if there was a problem mapping the response to a user.
420423
RedirectException: if the mapping provider needs to redirect the user
@@ -490,6 +493,7 @@ async def complete_sso_login_request(
490493
client_redirect_url,
491494
extra_login_attributes,
492495
new_user=new_user,
496+
auth_provider_session_id=auth_provider_session_id,
493497
)
494498

495499
async def _call_attribute_mapper(

synapse/module_api/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def generate_short_term_login_token(
626626
user_id: str,
627627
duration_in_ms: int = (2 * 60 * 1000),
628628
auth_provider_id: str = "",
629+
auth_provider_session_id: Optional[str] = None,
629630
) -> str:
630631
"""Generate a login token suitable for m.login.token authentication
631632
@@ -643,6 +644,7 @@ def generate_short_term_login_token(
643644
return self._hs.get_macaroon_generator().generate_short_term_login_token(
644645
user_id,
645646
auth_provider_id,
647+
auth_provider_session_id,
646648
duration_in_ms,
647649
)
648650

0 commit comments

Comments
 (0)