23
23
from authlib .jose import JsonWebToken , jwt
24
24
from authlib .oauth2 .auth import ClientAuth
25
25
from authlib .oauth2 .rfc6749 .parameters import prepare_grant_uri
26
- from authlib .oidc .core import CodeIDToken , ImplicitIDToken , UserInfo
26
+ from authlib .oidc .core import CodeIDToken , UserInfo
27
27
from authlib .oidc .discovery import OpenIDProviderMetadata , get_well_known_url
28
28
from jinja2 import Environment , Template
29
29
from pymacaroons .exceptions import (
@@ -117,7 +117,8 @@ async def load_metadata(self) -> None:
117
117
for idp_id , p in self ._providers .items ():
118
118
try :
119
119
await p .load_metadata ()
120
- await p .load_jwks ()
120
+ if not p ._uses_userinfo :
121
+ await p .load_jwks ()
121
122
except Exception as e :
122
123
raise Exception (
123
124
"Error while initialising OIDC provider %r" % (idp_id ,)
@@ -498,10 +499,6 @@ async def load_jwks(self, force: bool = False) -> JWKS:
498
499
return await self ._jwks .get ()
499
500
500
501
async def _load_jwks (self ) -> JWKS :
501
- if self ._uses_userinfo :
502
- # We're not using jwt signing, return an empty jwk set
503
- return {"keys" : []}
504
-
505
502
metadata = await self .load_metadata ()
506
503
507
504
# Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ async def _fetch_userinfo(self, token: Token) -> UserInfo:
663
660
664
661
return UserInfo (resp )
665
662
666
- async def _parse_id_token (self , token : Token , nonce : str ) -> UserInfo :
663
+ async def _parse_id_token (self , token : Token , nonce : str ) -> CodeIDToken :
667
664
"""Return an instance of UserInfo from token's ``id_token``.
668
665
669
666
Args:
@@ -673,7 +670,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
673
670
request. This value should match the one inside the token.
674
671
675
672
Returns:
676
- An object representing the user .
673
+ The decoded claims in the ID token .
677
674
"""
678
675
metadata = await self .load_metadata ()
679
676
claims_params = {
@@ -684,9 +681,6 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
684
681
# If we got an `access_token`, there should be an `at_hash` claim
685
682
# in the `id_token` that we can check against.
686
683
claims_params ["access_token" ] = token ["access_token" ]
687
- claims_cls = CodeIDToken
688
- else :
689
- claims_cls = ImplicitIDToken
690
684
691
685
alg_values = metadata .get ("id_token_signing_alg_values_supported" , ["RS256" ])
692
686
jwt = JsonWebToken (alg_values )
@@ -703,7 +697,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
703
697
claims = jwt .decode (
704
698
id_token ,
705
699
key = jwk_set ,
706
- claims_cls = claims_cls ,
700
+ claims_cls = CodeIDToken ,
707
701
claims_options = claim_options ,
708
702
claims_params = claims_params ,
709
703
)
@@ -713,15 +707,16 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
713
707
claims = jwt .decode (
714
708
id_token ,
715
709
key = jwk_set ,
716
- claims_cls = claims_cls ,
710
+ claims_cls = CodeIDToken ,
717
711
claims_options = claim_options ,
718
712
claims_params = claims_params ,
719
713
)
720
714
721
715
logger .debug ("Decoded id_token JWT %r; validating" , claims )
722
716
723
717
claims .validate (leeway = 120 ) # allows 2 min of clock skew
724
- return UserInfo (claims )
718
+
719
+ return claims
725
720
726
721
async def handle_redirect_request (
727
722
self ,
@@ -837,22 +832,37 @@ async def handle_oidc_callback(
837
832
838
833
logger .debug ("Successfully obtained OAuth2 token data: %r" , token )
839
834
840
- # Now that we have a token, get the userinfo, either by decoding the
841
- # `id_token` or by fetching the `userinfo_endpoint`.
835
+ # If there is an id_token, it should be validated, regardless of the
836
+ # userinfo endpoint is used or not.
837
+ if token .get ("id_token" ) is not None :
838
+ try :
839
+ id_token = await self ._parse_id_token (token , nonce = session_data .nonce )
840
+ sid = id_token .get ("sid" )
841
+ except Exception as e :
842
+ logger .exception ("Invalid id_token" )
843
+ self ._sso_handler .render_error (request , "invalid_token" , str (e ))
844
+ return
845
+ else :
846
+ id_token = None
847
+ sid = None
848
+
849
+ # Now that we have a token, get the userinfo either from the `id_token`
850
+ # claims or by fetching the `userinfo_endpoint`.
842
851
if self ._uses_userinfo :
843
852
try :
844
853
userinfo = await self ._fetch_userinfo (token )
845
854
except Exception as e :
846
855
logger .exception ("Could not fetch userinfo" )
847
856
self ._sso_handler .render_error (request , "fetch_error" , str (e ))
848
857
return
858
+ elif id_token is not None :
859
+ userinfo = UserInfo (id_token )
849
860
else :
850
- try :
851
- userinfo = await self ._parse_id_token (token , nonce = session_data .nonce )
852
- except Exception as e :
853
- logger .exception ("Invalid id_token" )
854
- self ._sso_handler .render_error (request , "invalid_token" , str (e ))
855
- return
861
+ logger .error ("Missing id_token in token response" )
862
+ self ._sso_handler .render_error (
863
+ request , "invalid_token" , "Missing id_token in token response"
864
+ )
865
+ return
856
866
857
867
# first check if we're doing a UIA
858
868
if session_data .ui_auth_session_id :
@@ -884,7 +894,7 @@ async def handle_oidc_callback(
884
894
# Call the mapper to register/login the user
885
895
try :
886
896
await self ._complete_oidc_login (
887
- userinfo , token , request , session_data .client_redirect_url
897
+ userinfo , token , request , session_data .client_redirect_url , sid
888
898
)
889
899
except MappingException as e :
890
900
logger .exception ("Could not map user" )
@@ -896,6 +906,7 @@ async def _complete_oidc_login(
896
906
token : Token ,
897
907
request : SynapseRequest ,
898
908
client_redirect_url : str ,
909
+ sid : Optional [str ],
899
910
) -> None :
900
911
"""Given a UserInfo response, complete the login flow
901
912
@@ -1008,6 +1019,7 @@ async def grandfather_existing_users() -> Optional[str]:
1008
1019
oidc_response_to_user_attributes ,
1009
1020
grandfather_existing_users ,
1010
1021
extra_attributes ,
1022
+ auth_provider_session_id = sid ,
1011
1023
)
1012
1024
1013
1025
def _remote_id_from_userinfo (self , userinfo : UserInfo ) -> str :
0 commit comments