Skip to content

Commit e81421d

Browse files
committed
[WIP] Refactor token source code into the token_source.go file
1 parent 0ad159a commit e81421d

File tree

3 files changed

+108
-11
lines changed

3 files changed

+108
-11
lines changed

pkg/gce-cloud-provider/compute/gce-compute.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,14 @@ func (cloud *CloudProvider) listDisksInternal(ctx context.Context, fields []goog
164164
disks := []*computev1.Disk{}
165165

166166
// listing out regional disks in the region
167+
klog.Infof("Getting regional disks for default project: %s", cloud.project)
167168
rDisks, err := listRegionalDisksForProject(cloud.service, cloud.project, region, fields, filter)
168169
if err != nil {
169170
return nil, "", err
170171
}
171172
disks = append(disks, rDisks...)
172173
for tProject, tService := range cloud.tenantServiceMap {
174+
klog.Infof("Getting regional disks for tenant: %s", tProject)
173175
rDisks, err := listRegionalDisksForProject(tService, tProject, region, fields, filter)
174176
if err != nil {
175177
return nil, "", err
@@ -178,12 +180,14 @@ func (cloud *CloudProvider) listDisksInternal(ctx context.Context, fields []goog
178180
}
179181

180182
// listing out zonal disks in all zones of the region
183+
klog.Infof("Getting zonal disks for default project: %s", cloud.project)
181184
zDisks, err := listZonalDisksForProject(cloud.service, cloud.project, zones, fields, filter)
182185
if err != nil {
183186
return nil, "", err
184187
}
185188
disks = append(disks, zDisks...)
186189
for tProject, tService := range cloud.tenantServiceMap {
190+
klog.Infof("Getting zonal disks for tenant: %s", tProject)
187191
zDisks, err := listZonalDisksForProject(tService, tProject, zones, fields, filter)
188192
if err != nil {
189193
return nil, "", err
@@ -913,6 +917,7 @@ func (cloud *CloudProvider) AttachDisk(ctx context.Context, project string, volK
913917

914918
service := cloud.service
915919
if _, ok := cloud.tenantServiceMap[project]; ok {
920+
klog.Infof("Using tenant service in AttachDisk for project: %s", project)
916921
service = cloud.tenantServiceMap[project]
917922
}
918923
op, err := service.Instances.AttachDisk(project, instanceZone, instanceName, attachedDiskV1).Context(ctx).ForceAttach(forceAttach).Do()
@@ -1211,6 +1216,7 @@ func (cloud *CloudProvider) GetInstanceOrError(ctx context.Context, project, ins
12111216
klog.V(5).Infof("Getting instance %v from zone %v", instanceName, instanceZone)
12121217
service := cloud.service
12131218
if _, ok := cloud.tenantServiceMap[project]; ok {
1219+
klog.Infof("Using tenant service in GetInstanceOrError for project: %s", project)
12141220
service = cloud.tenantServiceMap[project]
12151221
}
12161222
instance, err := service.Instances.Get(project, instanceZone, instanceName).Do()

pkg/gce-cloud-provider/compute/gce.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ const (
6969
// globalComputeParentPathFmt is the string format for the full path of global compute resource.
7070
globalComputeParentPathFmt = "//compute.googleapis.com/projects/%s/global/%s/%d"
7171

72-
// tenantAuthenticationPathFmt is the string format for the full URL needed to generate an access token for tenants
73-
tenantAuthenticationPathFmt = "https://preprod-gkeauth.sandbox.googleapis.com/v1/projects/%s/locations/%s/tenants/%s:generateTenantToken"
74-
7572
// gcpTagsRequestRateLimit is the tag request rate limit per second.
7673
gcpTagsRequestRateLimit = 8
7774

@@ -216,16 +213,24 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s
216213
return
217214
}
218215

219-
tenantTokenUrl := fmt.Sprintf(tenantAuthenticationPathFmt, tenantMeta.ProjectNumber, configFile.Global.Zone, tenantMeta.TenantName)
220-
tokenSource := NewAltTokenSource(tenantTokenUrl, "")
221-
222-
testToken, err := tokenSource.Token()
216+
region, err := common.GetRegionFromZones([]string{zone})
217+
if err != nil {
218+
klog.Errorf("error getting region from zone(%s): %v", zone, err)
219+
return
220+
}
221+
tokenSource, err := NewTenantTokenSource(tenantMeta, region, configFile.Global.TokenURL, configFile.Global.TokenBody)
223222
if err != nil {
224-
klog.Errorf("error fetching initial token during test: %v", err.Error())
223+
klog.Errorf("error during tenant token generation: %v", err.Error())
225224
}
226-
klog.Infof("Token type: %v", testToken.TokenType)
227-
klog.Infof("Token access token: %v", testToken.AccessToken)
228-
klog.Infof("%+v", testToken)
225+
226+
// testToken, err := tokenSource.Token()
227+
// if err != nil {
228+
// klog.Errorf("error fetching initial token during test: %v", err.Error())
229+
// return
230+
// }
231+
// klog.Infof("Token type: %v", testToken.TokenType)
232+
// klog.Infof("Token access token: %v", testToken.AccessToken)
233+
// klog.Infof("%+v", testToken)
229234

230235
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
231236
if err != nil {

pkg/gce-cloud-provider/compute/token_source.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ package gcecloudprovider
1818

1919
import (
2020
"encoding/json"
21+
"fmt"
2122
"net/http"
23+
"strconv"
2224
"strings"
2325
"time"
2426

2527
"k8s.io/client-go/util/flowcontrol"
28+
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/gce-cloud-provider/compute/tenancy"
2629

2730
"golang.org/x/oauth2"
2831
"golang.org/x/oauth2/google"
@@ -34,6 +37,8 @@ const (
3437
tokenURLQPS = .05 // back off to once every 20 seconds when failing
3538
// Maximum burst of requests to token URL before limiting.
3639
tokenURLBurst = 3
40+
// tenantAuthenticationPathFmt is the string format for the full URL needed to generate an access token for tenants
41+
tenantAuthenticationPathFmt = "https://preprod-gkeauth.sandbox.googleapis.com/v1/projects/%s/locations/%s/tenants/%s:generateTenantToken"
3742
)
3843

3944
// TODO(#276) add metrics around token requests once the driver integrates with Prometheus.
@@ -89,3 +94,84 @@ func NewAltTokenSource(tokenURL, tokenBody string) oauth2.TokenSource {
8994
}
9095
return oauth2.ReuseTokenSource(nil, a)
9196
}
97+
98+
func NewTenantTokenSource(tenantMeta tenancy.Metadata, region, existingTokenURL, existingTokenBody string) (oauth2.TokenSource, error) {
99+
tenantTokenUrl, err := getTenantTokenURL(tenantMeta, existingTokenURL)
100+
if err != nil {
101+
return nil, err
102+
}
103+
tenantTokenBody, err := getTenantTokenBody(tenantMeta, existingTokenBody)
104+
if err != nil {
105+
return nil, err
106+
}
107+
return NewAltTokenSource(tenantTokenUrl, tenantTokenBody), nil
108+
}
109+
110+
func getTenantTokenURL(tenantMeta tenancy.Metadata, existingTokenURL string) (string, error) {
111+
location := extractLocationFromTokenURL(existingTokenURL)
112+
if location == "" {
113+
return "", fmt.Errorf("could not extract location from existing token URL: %s", existingTokenURL)
114+
}
115+
116+
tokenURLParts := strings.SplitN(existingTokenURL, "/projects/", 2)
117+
if len(tokenURLParts) != 2 {
118+
return "", fmt.Errorf("invalid existing token URL format: %s, cannot extract base URL", existingTokenURL)
119+
}
120+
baseURL := tokenURLParts[0]
121+
122+
// Format: {BASE_URL}/projects/{TENANT_PROJECT_NUMBER}/locations/{TENANT_LOCATION}/tenants/{TENANT_ID}:generateTenantToken
123+
formatString := "%s/projects/%s/locations/%s/tenants/%s:generateTenantToken"
124+
tokenURL := fmt.Sprintf(formatString, baseURL, tenantMeta.ProjectNumber, location, tenantMeta.TenantName)
125+
return tokenURL, nil
126+
}
127+
128+
// extractLocationFromTokenURL extracts the location from a GKE token URL.
129+
// Example input: https://gkeauth.googleapis.com/v1/projects/654321/locations/us-central1/clusters/example-cluster:generateToken
130+
// Returns: us-central1
131+
func extractLocationFromTokenURL(tokenURL string) string {
132+
parts := strings.Split(tokenURL, "/")
133+
for i, part := range parts {
134+
if part == "locations" && i+1 < len(parts) {
135+
return parts[i+1]
136+
}
137+
}
138+
return ""
139+
}
140+
141+
func getTenantTokenBody(tenantMeta tenancy.Metadata, existingTokenBody string) (string, error) {
142+
// Check if the token body is a quoted JSON string
143+
// Quoted example: "{\"projectNumber\":12345,\"clusterId\":\"example-cluster\"}"
144+
// Non-quoted example: {"projectNumber":12345,"clusterId":"example-cluster"}
145+
isQuoted := len(existingTokenBody) > 0 && existingTokenBody[0] == '"' && existingTokenBody[len(existingTokenBody)-1] == '"'
146+
147+
var jsonStr string
148+
if isQuoted {
149+
var err error
150+
jsonStr, err = strconv.Unquote(existingTokenBody)
151+
if err != nil {
152+
return "", fmt.Errorf("error unquoting TokenBody: %v", err)
153+
}
154+
} else {
155+
jsonStr = existingTokenBody
156+
}
157+
158+
var bodyMap map[string]any
159+
160+
if err := json.Unmarshal([]byte(jsonStr), &bodyMap); err != nil {
161+
return "", fmt.Errorf("error unmarshaling TokenBody: %v", err)
162+
}
163+
164+
bodyMap["projectNumber"] = tenantMeta.ProjectNumber
165+
166+
newTokenBodyBytes, err := json.Marshal(bodyMap)
167+
if err != nil {
168+
return "", fmt.Errorf("error marshaling TokenBody: %v", err)
169+
}
170+
171+
if isQuoted {
172+
// Re-quote the JSON string if the original was quoted
173+
return strconv.Quote(string(newTokenBodyBytes)), nil
174+
}
175+
176+
return string(newTokenBodyBytes), nil
177+
}

0 commit comments

Comments
 (0)