diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 4418a034d..27dd5c7ce 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -336,6 +336,8 @@ def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc valid_since: An integer signifying the seconds since the epoch (optional). This field is set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + providers_to_delete: The list of provider IDs to unlink, + eg: 'google.com', 'password', etc. Returns: UserRecord: An updated UserRecord instance for the user. diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 02d32b659..7aece495e 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -266,6 +266,15 @@ def validate_action_type(action_type): Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) return action_type +def validate_provider_ids(provider_ids, required=False): + if not provider_ids: + if required: + raise ValueError('Invalid provider IDs. Provider ids should be provided') + return [] + for provider_id in provider_ids: + validate_provider_id(provider_id, True) + return provider_ids + def build_update_mask(params): """Creates an update mask list from the given dictionary.""" mask = [] diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index b60c4d100..c77c4d40d 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -688,7 +688,7 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=None): + valid_since=None, custom_claims=None, providers_to_delete=None): """Updates an existing user account with the specified properties""" payload = { 'localId': _auth_utils.validate_uid(uid, required=True), @@ -700,6 +700,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, } remove = [] + remove_provider = _auth_utils.validate_provider_ids(providers_to_delete) if display_name is not None: if display_name is DELETE_ATTRIBUTE: remove.append('DISPLAY_NAME') @@ -715,7 +716,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, if phone_number is not None: if phone_number is DELETE_ATTRIBUTE: - payload['deleteProvider'] = ['phone'] + remove_provider.append('phone') else: payload['phoneNumber'] = _auth_utils.validate_phone(phone_number) @@ -726,6 +727,9 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, custom_claims, dict) else custom_claims payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) + if remove_provider: + payload['deleteProvider'] = list(set(remove_provider)) + payload = {k: v for k, v in payload.items() if v is not None} body, http_resp = self._make_request('post', '/accounts:update', json=payload) if not body or not body.get('localId'): diff --git a/integration/test_auth.py b/integration/test_auth.py index 55ddbb0a0..2dd2cb639 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -496,6 +496,14 @@ def test_disable_user(new_user_with_params): assert user.disabled is True assert len(user.provider_data) == 1 +def test_remove_provider(new_user_with_provider): + provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] + assert 'google.com' in provider_ids + user = auth.update_user(new_user_with_provider, providers_to_delete=['google.com']) + assert user.uid == new_user_with_params.uid + new_provider_ids = [provider.provider_id for provider in user.provider_data] + assert 'google.com' not in new_provider_ids + def test_delete_user(): user = auth.create_user() auth.delete_user(user.uid) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 10dfe698f..67447c6ba 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -663,6 +663,23 @@ def test_update_user_valid_since(self, user_mgt_app, arg): request = json.loads(recorder[0].body.decode()) assert request == {'localId': 'testuser', 'validSince': int(arg)} + @pytest.mark.parametrize('arg', [['phone'], ['google.com', 'phone']]) + def test_update_user_delete_provider(self, user_mgt_app, arg): + user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') + user_mgt.update_user('testuser', providers_to_delete=arg) + request = json.loads(recorder[0].body.decode()) + assert set(request['deleteProvider']) == set(arg) + + @pytest.mark.parametrize('arg', [[], ['phone'], ['google.com'], ['google.com', 'phone']]) + def test_update_user_delete_provider_and_phone(self, user_mgt_app, arg): + user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') + user_mgt.update_user('testuser', + providers_to_delete=arg, + phone_number=auth.DELETE_ATTRIBUTE) + request = json.loads(recorder[0].body.decode()) + assert 'phone' in request['deleteProvider'] + assert len(set(request['deleteProvider'])) == len(request['deleteProvider']) + assert set(arg) - set(request['deleteProvider']) == set() class TestSetCustomUserClaims: