diff --git a/pkg/blob/blob.go b/pkg/blob/blob.go index 098f0d9e7..9044d5604 100644 --- a/pkg/blob/blob.go +++ b/pkg/blob/blob.go @@ -283,10 +283,10 @@ func GetContainerInfo(id string) (string, string, string, string, error) { } // A container name must be a valid DNS name, conforming to the following naming rules: -// 1. Container names must start with a letter or number, and can contain only letters, numbers, and the dash (-) character. -// 2. Every dash (-) character must be immediately preceded and followed by a letter or number; consecutive dashes are not permitted in container names. -// 3. All letters in a container name must be lowercase. -// 4. Container names must be from 3 through 63 characters long. +// 1. Container names must start with a letter or number, and can contain only letters, numbers, and the dash (-) character. +// 2. Every dash (-) character must be immediately preceded and followed by a letter or number; consecutive dashes are not permitted in container names. +// 3. All letters in a container name must be lowercase. +// 4. Container names must be from 3 through 63 characters long. // // See https://docs.microsoft.com/en-us/rest/api/storageservices/naming-and-referencing-containers--blobs--and-metadata#container-names func getValidContainerName(volumeName, protocol string) string { @@ -315,9 +315,10 @@ func checkContainerNameBeginAndEnd(containerName string) bool { return false } -// isSASToken checks if the key contains the patterns. Because a SAS Token must have these strings, use them to judge. +// isSASToken checks if the key contains the patterns. +// SAS token format could refer to https://docs.microsoft.com/en-us/rest/api/eventhub/generate-sas-token func isSASToken(key string) bool { - return strings.Contains(key, "?sv=") + return strings.HasPrefix(key, "?") } // GetAuthEnv return @@ -681,9 +682,9 @@ func setAzureCredentials(kubeClient kubernetes.Interface, accountName, accountKe } // GetStorageAccesskey get Azure storage account key from -// 1. secrets (if not empty) -// 2. use k8s client identity to read from k8s secret -// 3. use cluster identity to get from storage account directly +// 1. secrets (if not empty) +// 2. use k8s client identity to read from k8s secret +// 3. use cluster identity to get from storage account directly func (d *Driver) GetStorageAccesskey(ctx context.Context, accountOptions *azure.AccountOptions, secrets map[string]string, secretName, secretNamespace string) (string, string, error) { if len(secrets) > 0 { return getStorageAccount(secrets) diff --git a/test/e2e/pre_provisioning_test.go b/test/e2e/pre_provisioning_test.go index dcb6172e2..6771db7b2 100644 --- a/test/e2e/pre_provisioning_test.go +++ b/test/e2e/pre_provisioning_test.go @@ -24,6 +24,7 @@ import ( "sigs.k8s.io/blob-csi-driver/test/e2e/driver" "sigs.k8s.io/blob-csi-driver/test/e2e/testsuites" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/onsi/ginkgo" v1 "k8s.io/api/core/v1" @@ -290,6 +291,42 @@ var _ = ginkgo.Describe("[blob-csi-e2e] Pre-Provisioned", func() { } test.Run(cs, ns) }) + + ginkgo.It("should use SAS token", func() { + req := makeCreateVolumeReq("pre-provisioned-sas-token", ns.Name) + resp, err := blobDriver.CreateVolume(context.Background(), req) + if err != nil { + ginkgo.Fail(fmt.Sprintf("create volume error: %v", err)) + } + volumeID = resp.Volume.VolumeId + ginkgo.By(fmt.Sprintf("Successfully provisioned blob volume: %q\n", volumeID)) + + pods := []testsuites.PodDetails{ + { + Cmd: "echo 'hello world' > /mnt/test-1/data && grep 'hello world' /mnt/test-1/data", + Volumes: []testsuites.VolumeDetails{ + { + VolumeID: volumeID, + FSType: "ext4", + ClaimSize: fmt.Sprintf("%dGi", defaultVolumeSize), + ReclaimPolicy: to.Ptr(v1.PersistentVolumeReclaimRetain), + VolumeBindingMode: to.Ptr(storagev1.VolumeBindingImmediate), + VolumeMount: testsuites.VolumeMountDetails{ + NameGenerate: "test-volume-", + MountPathGenerate: "/mnt/test-", + }, + }, + }, + }, + } + + test := testsuites.PreProvisionedSASTokenTest{ + CSIDriver: testDriver, + Pods: pods, + Driver: blobDriver, + } + test.Run(cs, ns) + }) }) func makeCreateVolumeReq(volumeName, secretNamespace string) *csi.CreateVolumeRequest { diff --git a/test/e2e/testsuites/pre_provisioned_keyvault_tester.go b/test/e2e/testsuites/pre_provisioned_keyvault_tester.go index 95e515e2b..9dd866c65 100644 --- a/test/e2e/testsuites/pre_provisioned_keyvault_tester.go +++ b/test/e2e/testsuites/pre_provisioned_keyvault_tester.go @@ -19,43 +19,19 @@ package testsuites import ( "context" "fmt" - "net/url" - "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/adal" - "github.com/Azure/go-autorest/autorest/azure" "github.com/onsi/ginkgo" v1 "k8s.io/api/core/v1" - "k8s.io/apiserver/pkg/storage/names" clientset "k8s.io/client-go/kubernetes" "k8s.io/kubernetes/test/e2e/framework" "sigs.k8s.io/blob-csi-driver/pkg/blob" "sigs.k8s.io/blob-csi-driver/test/e2e/driver" - "sigs.k8s.io/blob-csi-driver/test/utils/credentials" -) - -var ( - subscriptionID string - resourceGroupName string - location string - vaultName string - TenantID string - cloud string - clientID string - clientSecret string + "sigs.k8s.io/blob-csi-driver/test/utils/azure" ) // PreProvisionedKeyVaultTest will provision required PV(s), PVC(s) and Pod(s) // Testing that the Pod(s) can be created successfully with provided Key Vault -// which is used to store storage account name and key(or sastoken) +// which is used to store storage account name and key type PreProvisionedKeyVaultTest struct { CSIDriver driver.PreProvisionedVolumeTestDriver Pods []PodDetails @@ -63,17 +39,8 @@ type PreProvisionedKeyVaultTest struct { } func (t *PreProvisionedKeyVaultTest) Run(client clientset.Interface, namespace *v1.Namespace) { - e2eCred, err := credentials.ParseAzureCredentialFile() - framework.ExpectNoError(err, fmt.Sprintf("Error ParseAzureCredentialFile: %v", err)) - - subscriptionID = e2eCred.SubscriptionID - resourceGroupName = e2eCred.ResourceGroup - location = e2eCred.Location - TenantID = e2eCred.TenantID - cloud = e2eCred.Cloud - clientID = e2eCred.AADClientID - clientSecret = e2eCred.AADClientSecret - vaultName = names.SimpleNameGenerator.GenerateName("blob-csi-test-kv-") + keyVaultClient, err := azure.NewKeyVaultClient() + framework.ExpectNoError(err) for _, pod := range t.Pods { for n, volume := range pod.Volumes { @@ -82,291 +49,36 @@ func (t *PreProvisionedKeyVaultTest) Run(client clientset.Interface, namespace * accountName, accountKey, _, containerName, err := t.Driver.GetStorageAccountAndContainer(context.TODO(), volume.VolumeID, nil, nil) framework.ExpectNoError(err, fmt.Sprintf("Error GetStorageAccountAndContainer from volumeID(%s): %v", volume.VolumeID, err)) - azureCred, err := azidentity.NewDefaultAzureCredential(nil) - framework.ExpectNoError(err) - ginkgo.By("creating KeyVault...") - vault, err := createVault(context.TODO(), azureCred) + vault, err := keyVaultClient.CreateVault(context.TODO()) framework.ExpectNoError(err) - defer cleanVault(context.TODO(), azureCred) + defer func() { + err := keyVaultClient.CleanVault(context.TODO()) + framework.ExpectNoError(err) + }() ginkgo.By("creating secret for storage account key...") - accountKeySecret, err := createSecret(context.TODO(), azureCred, accountName+"-key", accountKey) + accountKeySecret, err := keyVaultClient.CreateSecret(context.TODO(), accountName+"-key", accountKey) framework.ExpectNoError(err) pod.Volumes[n].ContainerName = containerName pod.Volumes[n].StorageAccountname = accountName pod.Volumes[n].KeyVaultURL = *vault.Properties.VaultURI pod.Volumes[n].KeyVaultSecretName = *accountKeySecret.Name - // test for Account key - ginkgo.By("test storage account key...") - run(pod, client, namespace, t.CSIDriver) - - ginkgo.By("generate SAS token...") - sasToken := generateSASToken(accountName, accountKey) - - ginkgo.By("creating secret for SAS token...") - accountSASSecret, err := createSecret(context.TODO(), azureCred, accountName+"-sas", sasToken) - framework.ExpectNoError(err) - pod.Volumes[n].KeyVaultSecretName = *accountSASSecret.Name - // TODO: test for SAS token - // ginkgo.By("test SAS token...") - // run(pod, client, namespace, t.CSIDriver) + ginkgo.By("test storage account key...") + tpod, cleanup := pod.SetupWithPreProvisionedVolumes(client, namespace, t.CSIDriver) + // defer must be called here for resources not get removed before using them + for i := range cleanup { + defer cleanup[i]() + } + + ginkgo.By("deploying the pod") + tpod.Create() + defer tpod.Cleanup() + + ginkgo.By("checking that the pods command exits with no error") + tpod.WaitForSuccess() } } } - -func run(pod PodDetails, client clientset.Interface, namespace *v1.Namespace, csidriver driver.PreProvisionedVolumeTestDriver) { - tpod, cleanup := pod.SetupWithPreProvisionedVolumes(client, namespace, csidriver) - // defer must be called here for resources not get removed before using them - for i := range cleanup { - defer cleanup[i]() - } - - ginkgo.By("deploying the pod") - tpod.Create() - defer tpod.Cleanup() - - ginkgo.By("checking that the pods command exits with no error") - tpod.WaitForSuccess() -} - -func generateSASToken(accountName, accountKey string) string { - credential, err := azblob.NewSharedKeyCredential(accountName, accountKey) - framework.ExpectNoError(err) - serviceClient, err := azblob.NewServiceClientWithSharedKey(fmt.Sprintf("https://%s.blob.core.windows.net/", accountName), credential, nil) - framework.ExpectNoError(err) - sasURL, err := serviceClient.GetSASURL( - azblob.AccountSASResourceTypes{Object: true, Service: true, Container: true}, - azblob.AccountSASPermissions{Read: true, List: true, Write: true, Delete: true, Add: true, Create: true, Update: true}, - time.Now(), time.Now().Add(12*time.Hour)) - framework.ExpectNoError(err) - ginkgo.By("sas URL: " + sasURL) - u, err := url.Parse(sasURL) - framework.ExpectNoError(err) - queryUnescape, err := url.QueryUnescape(u.RawQuery) - framework.ExpectNoError(err) - sasToken := "?" + queryUnescape - ginkgo.By("sas Token: " + sasToken) - return sasToken -} - -func createVault(ctx context.Context, cred azcore.TokenCredential) (*armkeyvault.Vault, error) { - vaultsClient, err := armkeyvault.NewVaultsClient(subscriptionID, cred, nil) - if err != nil { - return nil, err - } - - pollerResp, err := vaultsClient.BeginCreateOrUpdate( - ctx, - resourceGroupName, - vaultName, - armkeyvault.VaultCreateOrUpdateParameters{ - Location: to.Ptr(location), - Properties: &armkeyvault.VaultProperties{ - SKU: &armkeyvault.SKU{ - Family: to.Ptr(armkeyvault.SKUFamilyA), - Name: to.Ptr(armkeyvault.SKUNameStandard), - }, - TenantID: to.Ptr(TenantID), - AccessPolicies: getAccessPolicy(ctx), - }, - }, - nil, - ) - if err != nil { - return nil, err - } - - resp, err := pollerResp.PollUntilDone(ctx, nil) - if err != nil { - return nil, err - } - return &resp.Vault, nil -} - -func getAccessPolicy(ctx context.Context) []*armkeyvault.AccessPolicyEntry { - accessPolicyEntry := []*armkeyvault.AccessPolicyEntry{} - - // vault secret permission for upstream e2e test, which uses application service principal - clientObjectID, err := getServicePrincipalObjectID(ctx, clientID) - if err == nil { - ginkgo.By("client object ID: " + clientObjectID) - accessPolicyEntry = append(accessPolicyEntry, &armkeyvault.AccessPolicyEntry{ - TenantID: to.Ptr(TenantID), - ObjectID: to.Ptr(clientObjectID), - Permissions: &armkeyvault.Permissions{ - Secrets: []*armkeyvault.SecretPermissions{ - to.Ptr(armkeyvault.SecretPermissionsGet), - }, - }, - }) - } - - // vault secret permission for upstream e2e-vmss test, which uses msi blobfuse-csi-driver-e2e-test-id - msiObjectID, err := getMSIObjectID(ctx, "blobfuse-csi-driver-e2e-test-id") - if err == nil { - ginkgo.By("MSI object ID: " + msiObjectID) - accessPolicyEntry = append(accessPolicyEntry, &armkeyvault.AccessPolicyEntry{ - TenantID: to.Ptr(TenantID), - ObjectID: to.Ptr(msiObjectID), - Permissions: &armkeyvault.Permissions{ - Secrets: []*armkeyvault.SecretPermissions{ - to.Ptr(armkeyvault.SecretPermissionsGet), - }, - }, - }) - } - - return accessPolicyEntry -} - -func cleanVault(ctx context.Context, cred azcore.TokenCredential) { - err := deleteVault(ctx, cred) - framework.ExpectNoError(err) - - err = purgeDeleted(ctx, cred) - framework.ExpectNoError(err) -} - -func deleteVault(ctx context.Context, cred azcore.TokenCredential) error { - vaultsClient, err := armkeyvault.NewVaultsClient(subscriptionID, cred, nil) - if err != nil { - return err - } - - _, err = vaultsClient.Delete(ctx, resourceGroupName, vaultName, nil) - if err != nil { - return err - } - return nil -} - -func purgeDeleted(ctx context.Context, cred azcore.TokenCredential) error { - vaultsClient, err := armkeyvault.NewVaultsClient(subscriptionID, cred, nil) - if err != nil { - return err - } - - pollerResp, err := vaultsClient.BeginPurgeDeleted(ctx, vaultName, location, nil) - if err != nil { - return err - } - - _, err = pollerResp.PollUntilDone(ctx, nil) - if err != nil { - return err - } - - return nil -} - -func createSecret(ctx context.Context, cred azcore.TokenCredential, secretName, secretValue string) (*armkeyvault.Secret, error) { - secretsClient, err := armkeyvault.NewSecretsClient(subscriptionID, cred, nil) - if err != nil { - return nil, err - } - - secretResp, err := secretsClient.CreateOrUpdate( - ctx, - resourceGroupName, - vaultName, - secretName, - armkeyvault.SecretCreateOrUpdateParameters{ - Properties: &armkeyvault.SecretProperties{ - Attributes: &armkeyvault.SecretAttributes{ - Enabled: to.Ptr(true), - }, - Value: to.Ptr(secretValue), - }, - }, - nil, - ) - if err != nil { - return nil, err - } - - return &secretResp.Secret, nil -} - -func getServicePrincipalObjectID(ctx context.Context, clientID string) (string, error) { - spClient, err := getServicePrincipalsClient() - if err != nil { - return "", err - } - - page, err := spClient.List(ctx, fmt.Sprintf("servicePrincipalNames/any(c:c eq '%s')", clientID)) - if err != nil { - return "", err - } - servicePrincipals := page.Values() - if len(servicePrincipals) == 0 { - return "", fmt.Errorf("didn't find any service principals for client ID %s", clientID) - } - return *servicePrincipals[0].ObjectID, nil -} - -func getServicePrincipalsClient() (*graphrbac.ServicePrincipalsClient, error) { - spClient := graphrbac.NewServicePrincipalsClient(TenantID) - - env, err := azure.EnvironmentFromName(cloud) - if err != nil { - return nil, err - } - - oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, TenantID) - if err != nil { - return nil, err - } - - token, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, env.GraphEndpoint) - if err != nil { - return nil, err - } - - authorizer := autorest.NewBearerAuthorizer(token) - - spClient.Authorizer = authorizer - - return &spClient, nil -} - -func getMSIUserAssignedIDClient() (*msi.UserAssignedIdentitiesClient, error) { - msiClient := msi.NewUserAssignedIdentitiesClient(subscriptionID) - - env, err := azure.EnvironmentFromName(cloud) - if err != nil { - return nil, err - } - - oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, TenantID) - if err != nil { - return nil, err - } - - token, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, env.ResourceManagerEndpoint) - if err != nil { - return nil, err - } - - authorizer := autorest.NewBearerAuthorizer(token) - - msiClient.Authorizer = authorizer - - return &msiClient, nil -} - -func getMSIObjectID(ctx context.Context, identityName string) (string, error) { - msiClient, err := getMSIUserAssignedIDClient() - if err != nil { - return "", err - } - - id, err := msiClient.Get(ctx, resourceGroupName, identityName) - if err != nil { - return "", err - } - - return id.UserAssignedIdentityProperties.PrincipalID.String(), err -} diff --git a/test/e2e/testsuites/pre_provisioned_sastoken_tester.go b/test/e2e/testsuites/pre_provisioned_sastoken_tester.go new file mode 100644 index 000000000..567e99988 --- /dev/null +++ b/test/e2e/testsuites/pre_provisioned_sastoken_tester.go @@ -0,0 +1,105 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testsuites + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/onsi/ginkgo" + v1 "k8s.io/api/core/v1" + clientset "k8s.io/client-go/kubernetes" + "k8s.io/kubernetes/test/e2e/framework" + "sigs.k8s.io/blob-csi-driver/pkg/blob" + "sigs.k8s.io/blob-csi-driver/test/e2e/driver" + "sigs.k8s.io/blob-csi-driver/test/utils/azure" +) + +// PreProvisionedSASTokenTest will provision required PV(s), PVC(s) and Pod(s) +// Testing that the Pod(s) can be created successfully with provided Key Vault +// which is used to store storage SAS token +type PreProvisionedSASTokenTest struct { + CSIDriver driver.PreProvisionedVolumeTestDriver + Pods []PodDetails + Driver *blob.Driver +} + +func (t *PreProvisionedSASTokenTest) Run(client clientset.Interface, namespace *v1.Namespace) { + keyVaultClient, err := azure.NewKeyVaultClient() + framework.ExpectNoError(err) + + for _, pod := range t.Pods { + for n, volume := range pod.Volumes { + // In the method GetStorageAccountAndContainer, we can get an account key of the blob volume + // by calling azure API, but not the sas token... + accountName, accountKey, _, containerName, err := t.Driver.GetStorageAccountAndContainer(context.TODO(), volume.VolumeID, nil, nil) + framework.ExpectNoError(err, fmt.Sprintf("Error GetStorageAccountAndContainer from volumeID(%s): %v", volume.VolumeID, err)) + + ginkgo.By("creating KeyVault...") + vault, err := keyVaultClient.CreateVault(context.TODO()) + framework.ExpectNoError(err) + defer func() { + err := keyVaultClient.CleanVault(context.TODO()) + framework.ExpectNoError(err) + }() + + ginkgo.By("generating SAS token...") + sasToken := generateSASToken(accountName, accountKey) + + ginkgo.By("creating secret for SAS token...") + accountSASSecret, err := keyVaultClient.CreateSecret(context.TODO(), accountName+"-sas", sasToken) + framework.ExpectNoError(err) + + pod.Volumes[n].ContainerName = containerName + pod.Volumes[n].StorageAccountname = accountName + pod.Volumes[n].KeyVaultURL = *vault.Properties.VaultURI + pod.Volumes[n].KeyVaultSecretName = *accountSASSecret.Name + + tpod, cleanup := pod.SetupWithPreProvisionedVolumes(client, namespace, t.CSIDriver) + // defer must be called here for resources not get removed before using them + for i := range cleanup { + defer cleanup[i]() + } + + ginkgo.By("deploying the pod") + tpod.Create() + defer tpod.Cleanup() + + ginkgo.By("checking that the pods command exits with no error") + tpod.WaitForSuccess() + } + } +} + +func generateSASToken(accountName, accountKey string) string { + credential, err := azblob.NewSharedKeyCredential(accountName, accountKey) + framework.ExpectNoError(err) + serviceClient, err := azblob.NewServiceClientWithSharedKey(fmt.Sprintf("https://%s.blob.core.windows.net/", accountName), credential, nil) + framework.ExpectNoError(err) + sasURL, err := serviceClient.GetSASURL( + azblob.AccountSASResourceTypes{Object: true, Service: true, Container: true}, + azblob.AccountSASPermissions{Read: true, List: true, Write: true, Delete: true, Add: true, Create: true, Update: true}, + time.Now(), time.Now().Add(10*time.Hour)) + framework.ExpectNoError(err) + u, err := url.Parse(sasURL) + framework.ExpectNoError(err) + sasToken := "?" + u.RawQuery + return sasToken +} diff --git a/test/utils/azure/keyvault_helper.go b/test/utils/azure/keyvault_helper.go new file mode 100644 index 000000000..ef27055fe --- /dev/null +++ b/test/utils/azure/keyvault_helper.go @@ -0,0 +1,300 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault" + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/onsi/ginkgo" + "k8s.io/apiserver/pkg/storage/names" + "sigs.k8s.io/blob-csi-driver/test/utils/credentials" +) + +type KeyVaultClient struct { + cred *credentials.Credentials + vaultName string +} + +func NewKeyVaultClient() (*KeyVaultClient, error) { + e2eCred, err := credentials.ParseAzureCredentialFile() + if err != nil { + return nil, err + } + + return &KeyVaultClient{ + cred: &credentials.Credentials{ + SubscriptionID: e2eCred.SubscriptionID, + ResourceGroup: e2eCred.ResourceGroup, + Location: e2eCred.Location, + TenantID: e2eCred.TenantID, + Cloud: e2eCred.Cloud, + AADClientID: e2eCred.AADClientID, + AADClientSecret: e2eCred.AADClientSecret, + }, + vaultName: names.SimpleNameGenerator.GenerateName("blob-csi-test-kv-"), + }, nil +} + +func (kvc *KeyVaultClient) CreateVault(ctx context.Context) (*armkeyvault.Vault, error) { + azureCred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, err + } + + vaultsClient, err := armkeyvault.NewVaultsClient(kvc.cred.SubscriptionID, azureCred, nil) + if err != nil { + return nil, err + } + + pollerResp, err := vaultsClient.BeginCreateOrUpdate( + ctx, + kvc.cred.ResourceGroup, + kvc.vaultName, + armkeyvault.VaultCreateOrUpdateParameters{ + Location: to.Ptr(kvc.cred.Location), + Properties: &armkeyvault.VaultProperties{ + SKU: &armkeyvault.SKU{ + Family: to.Ptr(armkeyvault.SKUFamilyA), + Name: to.Ptr(armkeyvault.SKUNameStandard), + }, + TenantID: to.Ptr(kvc.cred.TenantID), + AccessPolicies: kvc.getAccessPolicy(ctx), + }, + }, + nil, + ) + if err != nil { + return nil, err + } + + resp, err := pollerResp.PollUntilDone(ctx, nil) + if err != nil { + return nil, err + } + return &resp.Vault, nil +} + +func (kvc *KeyVaultClient) CleanVault(ctx context.Context) error { + azureCred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return err + } + + err = kvc.deleteVault(ctx, azureCred) + if err != nil { + return err + } + + err = kvc.purgeDeleted(ctx, azureCred) + if err != nil { + return err + } + + return nil +} + +func (kvc *KeyVaultClient) CreateSecret(ctx context.Context, secretName, secretValue string) (*armkeyvault.Secret, error) { + azureCred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, err + } + + secretsClient, err := armkeyvault.NewSecretsClient(kvc.cred.SubscriptionID, azureCred, nil) + if err != nil { + return nil, err + } + + secretResp, err := secretsClient.CreateOrUpdate( + ctx, + kvc.cred.ResourceGroup, + kvc.vaultName, + secretName, + armkeyvault.SecretCreateOrUpdateParameters{ + Properties: &armkeyvault.SecretProperties{ + Attributes: &armkeyvault.SecretAttributes{ + Enabled: to.Ptr(true), + }, + Value: to.Ptr(secretValue), + }, + }, + nil, + ) + if err != nil { + return nil, err + } + + return &secretResp.Secret, nil +} + +func (kvc *KeyVaultClient) getAccessPolicy(ctx context.Context) []*armkeyvault.AccessPolicyEntry { + accessPolicyEntry := []*armkeyvault.AccessPolicyEntry{} + + // vault secret permission for upstream e2e test, which uses application service principal + clientObjectID, err := kvc.getServicePrincipalObjectID(ctx, kvc.cred.AADClientID) + if err == nil { + ginkgo.By("client object ID: " + clientObjectID) + accessPolicyEntry = append(accessPolicyEntry, &armkeyvault.AccessPolicyEntry{ + TenantID: to.Ptr(kvc.cred.TenantID), + ObjectID: to.Ptr(clientObjectID), + Permissions: &armkeyvault.Permissions{ + Secrets: []*armkeyvault.SecretPermissions{ + to.Ptr(armkeyvault.SecretPermissionsGet), + }, + }, + }) + } + + // vault secret permission for upstream e2e-vmss test, which uses msi blobfuse-csi-driver-e2e-test-id + msiObjectID, err := kvc.getMSIObjectID(ctx, "blobfuse-csi-driver-e2e-test-id") + if err == nil { + ginkgo.By("MSI object ID: " + msiObjectID) + accessPolicyEntry = append(accessPolicyEntry, &armkeyvault.AccessPolicyEntry{ + TenantID: to.Ptr(kvc.cred.TenantID), + ObjectID: to.Ptr(msiObjectID), + Permissions: &armkeyvault.Permissions{ + Secrets: []*armkeyvault.SecretPermissions{ + to.Ptr(armkeyvault.SecretPermissionsGet), + }, + }, + }) + } + + return accessPolicyEntry +} + +func (kvc *KeyVaultClient) deleteVault(ctx context.Context, cred azcore.TokenCredential) error { + vaultsClient, err := armkeyvault.NewVaultsClient(kvc.cred.SubscriptionID, cred, nil) + if err != nil { + return err + } + + _, err = vaultsClient.Delete(ctx, kvc.cred.ResourceGroup, kvc.vaultName, nil) + if err != nil { + return err + } + return nil +} + +func (kvc *KeyVaultClient) purgeDeleted(ctx context.Context, cred azcore.TokenCredential) error { + vaultsClient, err := armkeyvault.NewVaultsClient(kvc.cred.SubscriptionID, cred, nil) + if err != nil { + return err + } + + pollerResp, err := vaultsClient.BeginPurgeDeleted(ctx, kvc.vaultName, kvc.cred.Location, nil) + if err != nil { + return err + } + + _, err = pollerResp.PollUntilDone(ctx, nil) + if err != nil { + return err + } + + return nil +} + +func (kvc *KeyVaultClient) getServicePrincipalObjectID(ctx context.Context, clientID string) (string, error) { + spClient, err := kvc.getServicePrincipalsClient() + if err != nil { + return "", err + } + + page, err := spClient.List(ctx, fmt.Sprintf("servicePrincipalNames/any(c:c eq '%s')", clientID)) + if err != nil { + return "", err + } + servicePrincipals := page.Values() + if len(servicePrincipals) == 0 { + return "", fmt.Errorf("didn't find any service principals for client ID %s", clientID) + } + return *servicePrincipals[0].ObjectID, nil +} + +func (kvc *KeyVaultClient) getServicePrincipalsClient() (*graphrbac.ServicePrincipalsClient, error) { + spClient := graphrbac.NewServicePrincipalsClient(kvc.cred.TenantID) + + env, err := azure.EnvironmentFromName(kvc.cred.Cloud) + if err != nil { + return nil, err + } + + oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, kvc.cred.TenantID) + if err != nil { + return nil, err + } + + token, err := adal.NewServicePrincipalToken(*oauthConfig, kvc.cred.AADClientID, kvc.cred.AADClientSecret, env.GraphEndpoint) + if err != nil { + return nil, err + } + + authorizer := autorest.NewBearerAuthorizer(token) + + spClient.Authorizer = authorizer + + return &spClient, nil +} + +func (kvc *KeyVaultClient) getMSIUserAssignedIDClient() (*msi.UserAssignedIdentitiesClient, error) { + msiClient := msi.NewUserAssignedIdentitiesClient(kvc.cred.SubscriptionID) + + env, err := azure.EnvironmentFromName(kvc.cred.Cloud) + if err != nil { + return nil, err + } + + oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, kvc.cred.TenantID) + if err != nil { + return nil, err + } + + token, err := adal.NewServicePrincipalToken(*oauthConfig, kvc.cred.AADClientID, kvc.cred.AADClientSecret, env.ResourceManagerEndpoint) + if err != nil { + return nil, err + } + + authorizer := autorest.NewBearerAuthorizer(token) + + msiClient.Authorizer = authorizer + + return &msiClient, nil +} + +func (kvc *KeyVaultClient) getMSIObjectID(ctx context.Context, identityName string) (string, error) { + msiClient, err := kvc.getMSIUserAssignedIDClient() + if err != nil { + return "", err + } + + id, err := msiClient.Get(ctx, kvc.cred.ResourceGroup, identityName) + if err != nil { + return "", err + } + + return id.UserAssignedIdentityProperties.PrincipalID.String(), err +}