Skip to content

Commit 1e64068

Browse files
committed
Fix AWS credential expiration
Use the AWS cached credential provider to automatically handle credentials. The CredentialsCache will automatically handle refreshing expired credentials and keeping them cached as long as necessary. Replaces prometheus-community#634 as this offloads more of the work to the AWS SDK Signed-off-by: Joe Adams <[email protected]>
1 parent 70152fe commit 1e64068

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

pkg/roundtripper/roundtripper.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ const (
3636

3737
type AWSSigningTransport struct {
3838
t http.RoundTripper
39-
creds aws.Credentials
39+
creds aws.CredentialsProvider
4040
region string
4141
log log.Logger
4242
}
@@ -48,12 +48,17 @@ func NewAWSSigningTransport(transport http.RoundTripper, region string, log log.
4848
return nil, err
4949
}
5050

51-
creds, err := cfg.Credentials.Retrieve(context.Background())
51+
// Run a single fetch credentials operation to ensure that the credentials
52+
// are valid before returning the transport.
53+
_, err = cfg.Credentials.Retrieve(context.Background())
5254
if err != nil {
5355
_ = level.Error(log).Log("msg", "fail to retrive aws credentials", "err", err)
5456
return nil, err
5557
}
5658

59+
// Build a cached credentials provider to manage the credentials and prevent new credentials on every request.
60+
creds := aws.NewCredentialsCache(cfg.Credentials)
61+
5762
return &AWSSigningTransport{
5863
t: transport,
5964
region: region,
@@ -69,8 +74,15 @@ func (a *AWSSigningTransport) RoundTrip(req *http.Request) (*http.Response, erro
6974
_ = level.Error(a.log).Log("msg", "fail to hash request body", "err", err)
7075
return nil, err
7176
}
77+
78+
creds, err := a.creds.Retrieve(context.Background())
79+
if err != nil {
80+
_ = level.Error(a.log).Log("msg", "fail to retrive aws credentials", "err", err)
81+
return nil, err
82+
}
83+
7284
req.Body = newReader
73-
err = signer.SignHTTP(context.Background(), a.creds, req, payloadHash, service, a.region, time.Now())
85+
err = signer.SignHTTP(context.Background(), creds, req, payloadHash, service, a.region, time.Now())
7486
if err != nil {
7587
_ = level.Error(a.log).Log("msg", "fail to sign request body", "err", err)
7688
return nil, err

0 commit comments

Comments
 (0)