@@ -854,8 +854,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
854
854
}
855
855
856
856
var (
857
- c * credentials.Credential
858
- exists bool
857
+ c * credentials.Credential
858
+ resultCredential credentials.Credential
859
+ exists bool
860
+ refresh bool
859
861
)
860
862
861
863
rm := runtimeWithLogger (callCtx , monitor , r .runtimeManager )
@@ -886,6 +888,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
886
888
if ! exists || c .IsExpired () {
887
889
// If the existing credential is expired, we need to provide it to the cred tool through the environment.
888
890
if exists && c .IsExpired () {
891
+ refresh = true
889
892
credJSON , err := json .Marshal (c )
890
893
if err != nil {
891
894
return nil , fmt .Errorf ("failed to marshal credential: %w" , err )
@@ -916,39 +919,56 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
916
919
continue
917
920
}
918
921
919
- if err := json .Unmarshal ([]byte (* res .Result ), & c ); err != nil {
922
+ if err := json .Unmarshal ([]byte (* res .Result ), & resultCredential ); err != nil {
920
923
return nil , fmt .Errorf ("failed to unmarshal credential tool %s response: %w" , ref .Reference , err )
921
924
}
922
- c .ToolName = credName
923
- c .Type = credentials .CredentialTypeTool
925
+ resultCredential .ToolName = credName
926
+ resultCredential .Type = credentials .CredentialTypeTool
927
+
928
+ if refresh {
929
+ // If this is a credential refresh, we need to make sure we use the same context.
930
+ resultCredential .Context = c .Context
931
+ } else {
932
+ // If it is a new credential, let the credential store determine the context.
933
+ resultCredential .Context = ""
934
+ }
924
935
925
936
isEmpty := true
926
- for _ , v := range c .Env {
937
+ for _ , v := range resultCredential .Env {
927
938
if v != "" {
928
939
isEmpty = false
929
940
break
930
941
}
931
942
}
932
943
933
- if ! c .Ephemeral {
944
+ if ! resultCredential .Ephemeral {
934
945
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
935
946
if (isGitHubTool (toolName ) && callCtx .Program .ToolSet [ref .ToolID ].Source .Repo != nil ) || credentialAlias != "" {
936
947
if isEmpty {
937
948
log .Warnf ("Not saving empty credential for tool %s" , toolName )
938
- } else if err := r .credStore .Add (callCtx .Ctx , * c ); err != nil {
939
- return nil , fmt .Errorf ("failed to add credential for tool %s: %w" , toolName , err )
949
+ } else {
950
+ if refresh {
951
+ err = r .credStore .Refresh (callCtx .Ctx , resultCredential )
952
+ } else {
953
+ err = r .credStore .Add (callCtx .Ctx , resultCredential )
954
+ }
955
+ if err != nil {
956
+ return nil , fmt .Errorf ("failed to save credential for tool %s: %w" , toolName , err )
957
+ }
940
958
}
941
959
} else {
942
960
log .Warnf ("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases." , toolName )
943
961
}
944
962
}
963
+ } else {
964
+ resultCredential = * c
945
965
}
946
966
947
- if c .ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration .After (* c .ExpiresAt )) {
948
- nearestExpiration = c .ExpiresAt
967
+ if resultCredential .ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration .After (* resultCredential .ExpiresAt )) {
968
+ nearestExpiration = resultCredential .ExpiresAt
949
969
}
950
970
951
- for k , v := range c .Env {
971
+ for k , v := range resultCredential .Env {
952
972
env = append (env , fmt .Sprintf ("%s=%s" , k , v ))
953
973
}
954
974
}
0 commit comments