|
4 | 4 | # ------------------------------------
|
5 | 5 | from azure.core.exceptions import ClientAuthenticationError
|
6 | 6 | from azure.core.pipeline.policies import SansIOHTTPPolicy
|
7 |
| -from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential |
| 7 | +from azure.identity import ( |
| 8 | + AuthenticationRecord, |
| 9 | + CredentialUnavailableError, |
| 10 | + SharedTokenCacheCredential, |
| 11 | +) |
8 | 12 | from azure.identity._constants import EnvironmentVariables
|
9 | 13 | from azure.identity._internal.shared_token_cache import (
|
10 | 14 | KNOWN_ALIASES,
|
@@ -502,6 +506,112 @@ def test_authority_environment_variable():
|
502 | 506 | assert token.token == expected_access_token
|
503 | 507 |
|
504 | 508 |
|
| 509 | +def test_authentication_record_empty_cache(): |
| 510 | + record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") |
| 511 | + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) |
| 512 | + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) |
| 513 | + |
| 514 | + with pytest.raises(CredentialUnavailableError): |
| 515 | + credential.get_token("scope") |
| 516 | + |
| 517 | + |
| 518 | +def test_authentication_record_no_match(): |
| 519 | + tenant_id = "tenant-id" |
| 520 | + client_id = "client-id" |
| 521 | + authority = "localhost" |
| 522 | + object_id = "object-id" |
| 523 | + home_account_id = object_id + "." + tenant_id |
| 524 | + username = "me" |
| 525 | + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) |
| 526 | + |
| 527 | + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) |
| 528 | + cache = populated_cache( |
| 529 | + get_account_event( |
| 530 | + "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, |
| 531 | + ), |
| 532 | + ) |
| 533 | + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) |
| 534 | + |
| 535 | + with pytest.raises(CredentialUnavailableError): |
| 536 | + credential.get_token("scope") |
| 537 | + |
| 538 | + |
| 539 | +def test_authentication_record(): |
| 540 | + tenant_id = "tenant-id" |
| 541 | + client_id = "client-id" |
| 542 | + authority = "localhost" |
| 543 | + object_id = "object-id" |
| 544 | + home_account_id = object_id + "." + tenant_id |
| 545 | + username = "me" |
| 546 | + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) |
| 547 | + |
| 548 | + expected_access_token = "****" |
| 549 | + expected_refresh_token = "**" |
| 550 | + account = get_account_event( |
| 551 | + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token |
| 552 | + ) |
| 553 | + cache = populated_cache(account) |
| 554 | + |
| 555 | + transport = validating_transport( |
| 556 | + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], |
| 557 | + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], |
| 558 | + ) |
| 559 | + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) |
| 560 | + |
| 561 | + token = credential.get_token("scope") |
| 562 | + assert token.token == expected_access_token |
| 563 | + |
| 564 | + |
| 565 | +def test_auth_record_multiple_accounts_for_username(): |
| 566 | + tenant_id = "tenant-id" |
| 567 | + client_id = "client-id" |
| 568 | + authority = "localhost" |
| 569 | + object_id = "object-id" |
| 570 | + home_account_id = object_id + "." + tenant_id |
| 571 | + username = "me" |
| 572 | + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) |
| 573 | + |
| 574 | + expected_access_token = "****" |
| 575 | + expected_refresh_token = "**" |
| 576 | + expected_account = get_account_event( |
| 577 | + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token |
| 578 | + ) |
| 579 | + cache = populated_cache( |
| 580 | + expected_account, |
| 581 | + get_account_event( # this account matches all but the record's tenant |
| 582 | + username, |
| 583 | + object_id, |
| 584 | + "different-" + tenant_id, |
| 585 | + authority=authority, |
| 586 | + client_id=client_id, |
| 587 | + refresh_token="not-" + expected_refresh_token, |
| 588 | + ), |
| 589 | + ) |
| 590 | + |
| 591 | + transport = validating_transport( |
| 592 | + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], |
| 593 | + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], |
| 594 | + ) |
| 595 | + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) |
| 596 | + |
| 597 | + token = credential.get_token("scope") |
| 598 | + assert token.token == expected_access_token |
| 599 | + |
| 600 | + |
| 601 | +def test_authentication_record_authenticating_tenant(): |
| 602 | + """when given a record and 'tenant_id', the credential should authenticate in the latter""" |
| 603 | + |
| 604 | + expected_tenant_id = "tenant-id" |
| 605 | + record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") |
| 606 | + |
| 607 | + with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: |
| 608 | + SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) |
| 609 | + |
| 610 | + assert get_auth_client.call_count == 1 |
| 611 | + _, kwargs = get_auth_client.call_args |
| 612 | + assert kwargs["tenant_id"] == expected_tenant_id |
| 613 | + |
| 614 | + |
505 | 615 | def get_account_event(
|
506 | 616 | username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None
|
507 | 617 | ):
|
|
0 commit comments