diff --git a/components/ws-proxy/pkg/proxy/infoprovider.go b/components/ws-proxy/pkg/proxy/infoprovider.go index 9403a51a6f881c..2d438a41fc6c5c 100644 --- a/components/ws-proxy/pkg/proxy/infoprovider.go +++ b/components/ws-proxy/pkg/proxy/infoprovider.go @@ -30,6 +30,7 @@ import ( const ( workspaceIndex = "workspaceIndex" + ipAddressIndex = "ipAddressIndex" ) // getPortStr extracts the port part from a given URL string. Returns "" if parsing fails or port is not specified. @@ -77,6 +78,15 @@ func NewCRDWorkspaceInfoProvider(client client.Client, scheme *runtime.Scheme) ( return nil, xerrors.Errorf("object is not a WorkspaceInfo") }, + ipAddressIndex: func(obj interface{}) ([]string, error) { + if workspaceInfo, ok := obj.(*common.WorkspaceInfo); ok { + if workspaceInfo.IPAddress == "" { + return nil, nil + } + return []string{workspaceInfo.IPAddress}, nil + } + return nil, xerrors.Errorf("object is not a WorkspaceInfo") + }, } contextIndexers := cache.Indexers{ workspaceIndex: func(obj interface{}) ([]string, error) { @@ -96,29 +106,73 @@ func NewCRDWorkspaceInfoProvider(client client.Client, scheme *runtime.Scheme) ( }, nil } -// WorkspaceInfo return the WorkspaceInfo available for the given workspaceID. +// WorkspaceInfo returns the WorkspaceInfo for the given workspaceID. +// It performs validation to ensure the workspace is unique and properly associated with its IP address. func (r *CRDWorkspaceInfoProvider) WorkspaceInfo(workspaceID string) *common.WorkspaceInfo { workspaces, err := r.store.ByIndex(workspaceIndex, workspaceID) if err != nil { return nil } - if len(workspaces) >= 1 { - if len(workspaces) != 1 { - log.Warnf("multiple instances (%d) for workspace %s", len(workspaces), workspaceID) - } + if len(workspaces) == 0 { + return nil + } - sort.Slice(workspaces, func(i, j int) bool { - a := workspaces[i].(*common.WorkspaceInfo) - b := workspaces[j].(*common.WorkspaceInfo) + if len(workspaces) > 1 { + log.WithField("workspaceID", workspaceID).WithField("instanceCount", len(workspaces)).Warn("multiple workspace instances found") + } - return a.StartedAt.After(b.StartedAt) - }) + sort.Slice(workspaces, func(i, j int) bool { + a := workspaces[i].(*common.WorkspaceInfo) + b := workspaces[j].(*common.WorkspaceInfo) + return a.StartedAt.After(b.StartedAt) + }) + + wsInfo := workspaces[0].(*common.WorkspaceInfo) + + if wsInfo.IPAddress == "" { + return wsInfo + } + + if conflict, err := r.validateIPAddressConflict(workspaceID, wsInfo.IPAddress); conflict || err != nil { + return nil + } - return workspaces[0].(*common.WorkspaceInfo) + return wsInfo +} + +func (r *CRDWorkspaceInfoProvider) validateIPAddressConflict(workspaceID, ipAddress string) (bool, error) { + wsInfos, err := r.workspacesInfoByIPAddress(ipAddress) + if err != nil { + log.WithError(err).WithField("workspaceID", workspaceID).WithField("ipAddress", ipAddress).Error("failed to get workspaces by IP address") + return true, err + } + + if len(wsInfos) > 1 { + log.WithField("workspaceID", workspaceID).WithField("ipAddress", ipAddress).WithField("workspaceCount", len(wsInfos)).Warn("multiple workspaces found for IP address") + return true, nil + } + + if len(wsInfos) == 1 && wsInfos[0].WorkspaceID != workspaceID { + log.WithField("workspaceID", workspaceID).WithField("ipAddress", ipAddress).WithField("foundWorkspaceID", wsInfos[0].WorkspaceID).Warn("workspace IP address conflict detected") + return true, nil + } + + return false, nil +} + +func (r *CRDWorkspaceInfoProvider) workspacesInfoByIPAddress(ipAddress string) ([]*common.WorkspaceInfo, error) { + workspaces := make([]*common.WorkspaceInfo, 0) + + objs, err := r.store.ByIndex(ipAddressIndex, ipAddress) + if err != nil { + return nil, err + } + for _, w := range objs { + workspaces = append(workspaces, w.(*common.WorkspaceInfo)) } - return nil + return workspaces, nil } func (r *CRDWorkspaceInfoProvider) AcquireContext(ctx context.Context, workspaceID string, port string) (context.Context, string, error) {