Skip to content

Commit 5bdc662

Browse files
punkwalkermjlshen
authored andcommitted
Migrate session package from v1 to v2
Signed-off-by: Pankaj Walke <[email protected]>
1 parent 606b13e commit 5bdc662

13 files changed

+591
-26
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ require (
2020
github.com/aws/aws-sdk-go v1.55.5
2121
github.com/aws/aws-sdk-go-v2 v1.30.3
2222
github.com/aws/aws-sdk-go-v2/config v1.27.11
23+
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
2324
github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1
2425
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6
2526
github.com/aws/smithy-go v1.20.3
@@ -80,7 +81,6 @@ require (
8081
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df // indirect
8182
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect
8283
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
83-
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect
8484
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect
8585
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect
8686
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect

pkg/cloud/identityv2/identity.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
Copyright 2021 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
// Package identity provides the AWSPrincipalTypeProvider interface and its implementations.
18+
package identityv2
19+
20+
import (
21+
"bytes"
22+
"context"
23+
"crypto/sha256"
24+
"encoding/gob"
25+
"time"
26+
27+
"github.com/aws/aws-sdk-go-v2/aws"
28+
"github.com/aws/aws-sdk-go-v2/config"
29+
"github.com/aws/aws-sdk-go-v2/credentials"
30+
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
31+
"github.com/aws/aws-sdk-go-v2/service/sts"
32+
corev1 "k8s.io/api/core/v1"
33+
34+
infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2"
35+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
36+
)
37+
38+
// AWSPrincipalTypeProvider defines the interface for AWS Principal Type Provider.
39+
type AWSPrincipalTypeProvider interface {
40+
aws.CredentialsProvider
41+
// Hash returns a unique hash of the data forming the V2 credentials
42+
// for this Principal
43+
Hash() (string, error)
44+
Name() string
45+
}
46+
47+
// NewAWSStaticPrincipalTypeProvider will create a new AWSStaticPrincipalTypeProvider from a given AWSClusterStaticIdentity.
48+
func NewAWSStaticPrincipalTypeProvider(identity *infrav1.AWSClusterStaticIdentity, secret *corev1.Secret) *AWSStaticPrincipalTypeProvider {
49+
accessKeyID := string(secret.Data["AccessKeyID"])
50+
secretAccessKey := string(secret.Data["SecretAccessKey"])
51+
sessionToken := string(secret.Data["SessionToken"])
52+
53+
credProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, sessionToken)
54+
credCache := aws.NewCredentialsCache(credProvider)
55+
return &AWSStaticPrincipalTypeProvider{
56+
Principal: identity,
57+
credentials: credCache,
58+
AccessKeyID: accessKeyID,
59+
SecretAccessKey: secretAccessKey,
60+
SessionToken: sessionToken,
61+
}
62+
}
63+
64+
// GetAssumeRoleCredentialsCache will return the CredentialsCache of a given AWSRolePrincipalTypeProvider.
65+
func GetAssumeRoleCredentialsCache(roleIdentityProvider *AWSRolePrincipalTypeProvider, optFns []func(*config.LoadOptions) error) (*aws.CredentialsCache, error) {
66+
cfg, err := config.LoadDefaultConfig(context.TODO(), optFns...)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
stsClient := sts.NewFromConfig(cfg)
72+
credsProvider := stscreds.NewAssumeRoleProvider(stsClient, roleIdentityProvider.Principal.Spec.RoleArn, func(o *stscreds.AssumeRoleOptions) {
73+
if roleIdentityProvider.Principal.Spec.ExternalID != "" {
74+
o.ExternalID = aws.String(roleIdentityProvider.Principal.Spec.ExternalID)
75+
}
76+
o.RoleSessionName = roleIdentityProvider.Principal.Spec.SessionName
77+
if roleIdentityProvider.Principal.Spec.InlinePolicy != "" {
78+
o.Policy = aws.String(roleIdentityProvider.Principal.Spec.InlinePolicy)
79+
}
80+
o.Duration = time.Duration(roleIdentityProvider.Principal.Spec.DurationSeconds) * time.Second
81+
// For testing
82+
if roleIdentityProvider.stsClient != nil {
83+
o.Client = roleIdentityProvider.stsClient
84+
}
85+
})
86+
87+
return aws.NewCredentialsCache(credsProvider), nil
88+
}
89+
90+
// NewAWSRolePrincipalTypeProvider will create a new AWSRolePrincipalTypeProvider from an AWSClusterRoleIdentity.
91+
func NewAWSRolePrincipalTypeProvider(identity *infrav1.AWSClusterRoleIdentity, sourceProvider AWSPrincipalTypeProvider, region string, log logger.Wrapper) *AWSRolePrincipalTypeProvider {
92+
return &AWSRolePrincipalTypeProvider{
93+
credentials: nil,
94+
stsClient: nil,
95+
region: region,
96+
Principal: identity,
97+
sourceProvider: sourceProvider,
98+
log: log.WithName("AWSRolePrincipalTypeProvider"),
99+
}
100+
}
101+
102+
// AWSStaticPrincipalTypeProvider defines the specs for a static AWSPrincipalTypeProvider.
103+
type AWSStaticPrincipalTypeProvider struct {
104+
Principal *infrav1.AWSClusterStaticIdentity
105+
credentials *aws.CredentialsCache
106+
// these are for tests :/
107+
AccessKeyID string
108+
SecretAccessKey string
109+
SessionToken string
110+
}
111+
112+
// Hash returns the byte encoded AWSStaticPrincipalTypeProvider.
113+
func (p *AWSStaticPrincipalTypeProvider) Hash() (string, error) {
114+
var roleIdentityValue bytes.Buffer
115+
err := gob.NewEncoder(&roleIdentityValue).Encode(p)
116+
if err != nil {
117+
return "", err
118+
}
119+
hash := sha256.New()
120+
return string(hash.Sum(roleIdentityValue.Bytes())), nil
121+
}
122+
123+
// Retrieve returns the credential values for the AWSStaticPrincipalTypeProvider.
124+
func (p *AWSStaticPrincipalTypeProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
125+
return p.credentials.Retrieve(ctx)
126+
}
127+
128+
// Name returns the name of the AWSStaticPrincipalTypeProvider.
129+
func (p *AWSStaticPrincipalTypeProvider) Name() string {
130+
return p.Principal.Name
131+
}
132+
133+
// AWSRolePrincipalTypeProvider defines the specs for a AWSPrincipalTypeProvider with a role.
134+
type AWSRolePrincipalTypeProvider struct {
135+
Principal *infrav1.AWSClusterRoleIdentity
136+
credentials *aws.CredentialsCache
137+
region string
138+
sourceProvider AWSPrincipalTypeProvider
139+
log logger.Wrapper
140+
stsClient stscreds.AssumeRoleAPIClient
141+
}
142+
143+
// Hash returns the byte encoded AWSRolePrincipalTypeProvider.
144+
func (p *AWSRolePrincipalTypeProvider) Hash() (string, error) {
145+
var roleIdentityValue bytes.Buffer
146+
err := gob.NewEncoder(&roleIdentityValue).Encode(p)
147+
if err != nil {
148+
return "", err
149+
}
150+
hash := sha256.New()
151+
return string(hash.Sum(roleIdentityValue.Bytes())), nil
152+
}
153+
154+
// Name returns the name of the AWSRolePrincipalTypeProvider.
155+
func (p *AWSRolePrincipalTypeProvider) Name() string {
156+
return p.Principal.Name
157+
}
158+
159+
// Retrieve returns the credential values for the AWSRolePrincipalTypeProvider.
160+
func (p *AWSRolePrincipalTypeProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
161+
162+
if p.credentials == nil {
163+
optFns := []func(*config.LoadOptions) error{config.WithRegion(p.region)}
164+
if p.sourceProvider != nil {
165+
sourceCreds, err := p.sourceProvider.Retrieve(ctx)
166+
if err != nil {
167+
return aws.Credentials{}, err
168+
}
169+
optFns = append(optFns, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(sourceCreds.AccessKeyID, sourceCreds.SecretAccessKey, sourceCreds.SessionToken)))
170+
}
171+
172+
creds, err := GetAssumeRoleCredentialsCache(p, optFns)
173+
if err != nil {
174+
return aws.Credentials{}, err
175+
}
176+
// Update credentials
177+
p.credentials = creds
178+
}
179+
return p.credentials.Retrieve(ctx)
180+
}

pkg/cloud/interfaces.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
package cloud
1919

2020
import (
21+
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
2122
awsclient "github.com/aws/aws-sdk-go/aws/client"
2223
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
2324
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -32,6 +33,7 @@ import (
3233
// Session represents an AWS session.
3334
type Session interface {
3435
Session() awsclient.ConfigProvider
36+
SessionV2() awsv2.Config
3537
ServiceLimiter(service string) *throttle.ServiceLimiter
3638
}
3739

pkg/cloud/logs/logs.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package logs
1919

2020
import (
2121
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
22-
"github.com/aws/aws-sdk-go-v2/config"
2322
"github.com/aws/aws-sdk-go/aws"
2423
"github.com/go-logr/logr"
2524
)
@@ -43,16 +42,16 @@ func GetAWSLogLevel(logger logr.Logger) aws.LogLevelType {
4342
}
4443

4544
// GetAWSLogLevelV2 will return the log level of an AWS Logger.
46-
func GetAWSLogLevelV2(logger logr.Logger) config.LoadOptionsFunc {
45+
func GetAWSLogLevelV2(logger logr.Logger) awsv2.ClientLogMode {
4746
if logger.V(logWithHTTPBody).Enabled() {
48-
return config.WithClientLogMode(awsv2.LogRequestWithBody | awsv2.LogResponseWithBody)
47+
return awsv2.LogRequestWithBody | awsv2.LogResponseWithBody
4948
}
5049

5150
if logger.V(logWithHTTPHeader).Enabled() {
52-
return config.WithClientLogMode(awsv2.LogRequest | awsv2.LogResponse)
51+
return awsv2.LogRequest | awsv2.LogResponse
5352
}
5453

55-
return nil
54+
return awsv2.LogRequestEventMessage
5655
}
5756

5857
// NewWrapLogr will create an AWS Logger wrapper.

pkg/cloud/scope/clients.go

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"fmt"
2323

2424
awsv2middleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
25-
"github.com/aws/aws-sdk-go-v2/config"
2625
"github.com/aws/aws-sdk-go-v2/service/s3"
2726
"github.com/aws/aws-sdk-go/aws"
2827
"github.com/aws/aws-sdk-go/aws/awserr"
@@ -209,31 +208,30 @@ func NewSSMClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logg
209208

210209
// NewS3Client creates a new S3 API client for a given session.
211210
func NewS3Client(scopeUser cloud.ScopeUsage, session cloud.Session, logger logger.Wrapper, target runtime.Object) *s3.Client {
212-
// TODO: Incorporate session into loading config
213-
optFns := []func(*config.LoadOptions) error{config.WithLogger(logger.GetAWSLogger())}
214-
if awslogs.GetAWSLogLevelV2(logger.GetLogger()) != nil {
215-
optFns = append(optFns, awslogs.GetAWSLogLevelV2(logger.GetLogger()))
216-
}
217-
cfg, err := config.LoadDefaultConfig(context.TODO(), optFns...)
218-
if err != nil {
219-
panic(err)
220-
}
221-
222-
cfg.APIOptions = append(
223-
cfg.APIOptions,
224-
func(stack *middleware.Stack) error {
225-
return stack.Build.Add(getUserAgentHandlerV2(), middleware.Before)
211+
// TODO: Implement EndpointResolverV2 for Service Endpoints
212+
cfg := session.SessionV2()
213+
s3Opts := []func(*s3.Options){
214+
func(o *s3.Options) {
215+
o.Logger = logger.GetAWSLogger()
226216
},
227-
func(stack *middleware.Stack) error {
228-
return stack.Deserialize.Add(recordAWSPermissionsIssueV2(target), middleware.After)
217+
func(o *s3.Options) {
218+
o.ClientLogMode = awslogs.GetAWSLogLevelV2(logger.GetLogger())
229219
},
230-
)
220+
s3.WithAPIOptions(
221+
func(stack *middleware.Stack) error {
222+
return stack.Build.Add(getUserAgentHandlerV2(), middleware.Before)
223+
},
224+
func(stack *middleware.Stack) error {
225+
return stack.Deserialize.Add(recordAWSPermissionsIssueV2(target), middleware.After)
226+
},
227+
),
228+
}
231229
// TODO: https://docs.aws.amazon.com/sdk-for-go/v2/developer-guide/sdk-timing.html
232230
// cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error {
233-
// return stack.Deserialize.Add(awsmetrics.CaptureRequestMetricsV2(scopeUser.ControllerName()), middleware.Before)
231+
// return stack.Deserialize.Add(awsmetrics.CaptureRequestMetricsV2(scopeUser.ControllerName()), middleware.Before)
234232
// })
235233

236-
return s3.NewFromConfig(cfg)
234+
return s3.NewFromConfig(cfg, s3Opts...)
237235
}
238236

239237
func recordAWSPermissionsIssue(target runtime.Object) func(r *request.Request) {

pkg/cloud/scope/cluster.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"fmt"
2222

23+
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
2324
awsclient "github.com/aws/aws-sdk-go/aws/client"
2425
"github.com/pkg/errors"
2526
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -45,6 +46,7 @@ type ClusterScopeParams struct {
4546
ControllerName string
4647
Endpoints []ServiceEndpoint
4748
Session awsclient.ConfigProvider
49+
SessionV2 awsv2.Config
4850
TagUnmanagedNetworkResources bool
4951
}
5052

@@ -77,13 +79,19 @@ func NewClusterScope(params ClusterScopeParams) (*ClusterScope, error) {
7779
return nil, errors.Errorf("failed to create aws session: %v", err)
7880
}
7981

82+
sessionv2, _, err := sessionForClusterWithRegionV2(params.Client, clusterScope, params.AWSCluster.Spec.Region, params.Endpoints, params.Logger)
83+
if err != nil {
84+
return nil, errors.Errorf("failed to create aws V2 session: %v", err)
85+
}
86+
8087
helper, err := patch.NewHelper(params.AWSCluster, params.Client)
8188
if err != nil {
8289
return nil, errors.Wrap(err, "failed to init patch helper")
8390
}
8491

8592
clusterScope.patchHelper = helper
8693
clusterScope.session = session
94+
clusterScope.sessionV2 = *sessionv2
8795
clusterScope.serviceLimiters = serviceLimiters
8896

8997
return clusterScope, nil
@@ -99,6 +107,7 @@ type ClusterScope struct {
99107
AWSCluster *infrav1.AWSCluster
100108

101109
session awsclient.ConfigProvider
110+
sessionV2 awsv2.Config
102111
serviceLimiters throttle.ServiceLimiters
103112
controllerName string
104113

@@ -351,6 +360,11 @@ func (s *ClusterScope) Session() awsclient.ConfigProvider {
351360
return s.session
352361
}
353362

363+
// Session returns the AWS SDK V2 session. Used for creating clients.
364+
func (s *ClusterScope) SessionV2() awsv2.Config {
365+
return s.sessionV2
366+
}
367+
354368
// ServiceLimiter returns the AWS SDK session. Used for creating clients.
355369
func (s *ClusterScope) ServiceLimiter(service string) *throttle.ServiceLimiter {
356370
if sl, ok := s.serviceLimiters[service]; ok {

0 commit comments

Comments
 (0)