Skip to content

Commit 673ab68

Browse files
easyCZcorneliusludmann
authored andcommitted
Use context to store and populate origin
1 parent 76896ac commit 673ab68

File tree

10 files changed

+191
-20
lines changed

10 files changed

+191
-20
lines changed

components/public-api-server/pkg/auth/context.go

+3-6
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ const (
2525
type Token struct {
2626
Type TokenType
2727
Value string
28-
// Only relevant for CookieTokenType
29-
OriginHeader string
3028
}
3129

3230
func NewAccessToken(token string) Token {
@@ -36,11 +34,10 @@ func NewAccessToken(token string) Token {
3634
}
3735
}
3836

39-
func NewCookieToken(cookie string, origin string) Token {
37+
func NewCookieToken(cookie string) Token {
4038
return Token{
41-
Type: CookieTokenType,
42-
Value: cookie,
43-
OriginHeader: origin,
39+
Type: CookieTokenType,
40+
Value: cookie,
4441
}
4542
}
4643

components/public-api-server/pkg/auth/context_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func TestTokenToAndFromContext_AccessToken(t *testing.T) {
2020
}
2121

2222
func TestTokenToAndFromContext_CookieToken(t *testing.T) {
23-
token := NewCookieToken("my_token", "gitpod.io")
23+
token := NewCookieToken("my_token")
2424

2525
extracted, err := TokenFromContext(TokenToContext(context.Background(), token))
2626
require.NoError(t, err)

components/public-api-server/pkg/auth/middleware.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error
4141
}
4242

4343
cookie := req.Header().Get("Cookie")
44-
origin := req.Header().Get("Origin")
4544
if cookie != "" {
46-
return NewCookieToken(cookie, origin), nil
45+
return NewCookieToken(cookie), nil
4746
}
4847

4948
return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package origin
6+
7+
import (
8+
"context"
9+
)
10+
11+
type contextKey int
12+
13+
const (
14+
originContextKey contextKey = iota
15+
)
16+
17+
func ToContext(ctx context.Context, origin string) context.Context {
18+
return context.WithValue(ctx, originContextKey, origin)
19+
}
20+
21+
func FromContext(ctx context.Context) string {
22+
if val, ok := ctx.Value(originContextKey).(string); ok {
23+
return val
24+
}
25+
26+
return ""
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package origin
6+
7+
import (
8+
"context"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestToFromContext(t *testing.T) {
15+
require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted")
16+
require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty")
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package origin
6+
7+
import (
8+
"context"
9+
10+
"github.com/bufbuild/connect-go"
11+
)
12+
13+
func NewInterceptor() *Interceptor {
14+
return &Interceptor{}
15+
}
16+
17+
type Interceptor struct{}
18+
19+
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
20+
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
21+
if req.Spec().IsClient {
22+
req.Header().Add("Origin", FromContext(ctx))
23+
} else {
24+
origin := req.Header().Get("Origin")
25+
ctx = ToContext(ctx, origin)
26+
}
27+
28+
return next(ctx, req)
29+
})
30+
}
31+
32+
func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
33+
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
34+
conn := next(ctx, s)
35+
conn.RequestHeader().Add("Origin", FromContext(ctx))
36+
37+
return conn
38+
}
39+
}
40+
41+
func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
42+
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
43+
origin := conn.RequestHeader().Get("Origin")
44+
ctx = ToContext(ctx, origin)
45+
46+
return next(ctx, conn)
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package origin
6+
7+
import (
8+
"context"
9+
"testing"
10+
11+
"github.com/bufbuild/connect-go"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestInterceptor_Unary(t *testing.T) {
16+
requestPaylaod := "request"
17+
origin := "my-origin"
18+
19+
type response struct {
20+
origin string
21+
}
22+
23+
handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
24+
origin := FromContext(ctx)
25+
return connect.NewResponse(&response{origin: origin}), nil
26+
})
27+
28+
ctx := context.Background()
29+
request := connect.NewRequest(&requestPaylaod)
30+
request.Header().Add("Origin", origin)
31+
32+
interceptor := NewInterceptor()
33+
resp, err := interceptor.WrapUnary(handler)(ctx, request)
34+
require.NoError(t, err)
35+
require.Equal(t, &response{origin: origin}, resp.Any())
36+
}

components/public-api-server/pkg/proxy/conn.go

+23-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/gitpod-io/gitpod/common-go/log"
1515
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
1616
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
17+
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
1718
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
1819

1920
lru "github.com/hashicorp/golang-lru"
@@ -41,14 +42,14 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP
4142
opts := gitpod.ConnectToServerOpts{
4243
Context: ctx,
4344
Log: logger,
45+
Origin: origin.FromContext(ctx),
4446
}
4547

4648
switch token.Type {
4749
case auth.AccessTokenType:
4850
opts.Token = token.Value
4951
case auth.CookieTokenType:
5052
opts.Cookie = token.Value
51-
opts.Origin = token.OriginHeader
5253
default:
5354
return nil, errors.New("unknown token type")
5455
}
@@ -83,11 +84,12 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
8384

8485
return &ConnectionPool{
8586
cache: cache,
86-
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
87+
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
8788
opts := gitpod.ConnectToServerOpts{
8889
// We're using Background context as we want the connection to persist beyond the lifecycle of a single request
8990
Context: context.Background(),
9091
Log: log.Log,
92+
Origin: origin.FromContext(ctx),
9193
CloseHandler: func(_ error) {
9294
cache.Remove(token)
9395
connectionPoolSize.Dec()
@@ -99,7 +101,6 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
99101
opts.Token = token.Value
100102
case auth.CookieTokenType:
101103
opts.Cookie = token.Value
102-
opts.Origin = token.OriginHeader
103104
default:
104105
return nil, errors.New("unknown token type")
105106
}
@@ -120,15 +121,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
120121

121122
}
122123

124+
type conenctionPoolCacheKey struct {
125+
token auth.Token
126+
origin string
127+
}
128+
123129
type ConnectionPool struct {
124-
connConstructor func(token auth.Token) (gitpod.APIInterface, error)
130+
connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error)
125131

126132
// cache stores token to connection mapping
127133
cache *lru.Cache
128134
}
129135

130136
func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
131-
cached, found := p.cache.Get(token)
137+
origin := origin.FromContext(ctx)
138+
139+
cacheKey := p.cacheKey(token, origin)
140+
cached, found := p.cache.Get(cacheKey)
132141
reportCacheOutcome(found)
133142
if found {
134143
conn, ok := cached.(*gitpod.APIoverJSONRPC)
@@ -137,17 +146,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII
137146
}
138147
}
139148

140-
conn, err := p.connConstructor(token)
149+
conn, err := p.connConstructor(ctx, token)
141150
if err != nil {
142151
return nil, fmt.Errorf("failed to create new connection to server: %w", err)
143152
}
144153

145-
p.cache.Add(token, conn)
154+
p.cache.Add(cacheKey, conn)
146155
connectionPoolSize.Inc()
147156

148157
return conn, nil
149158
}
150159

160+
func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey {
161+
return conenctionPoolCacheKey{
162+
token: token,
163+
origin: origin,
164+
}
165+
}
166+
151167
func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) {
152168
switch t.Type {
153169
case auth.AccessTokenType:

components/public-api-server/pkg/proxy/conn_test.go

+33-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
1313
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
14+
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
1415
"github.com/golang/mock/gomock"
1516
lru "github.com/hashicorp/golang-lru"
1617
"github.com/stretchr/testify/require"
@@ -25,7 +26,7 @@ func TestConnectionPool(t *testing.T) {
2526
require.NoError(t, err)
2627
pool := &ConnectionPool{
2728
cache: cache,
28-
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
29+
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
2930
return srv, nil
3031
},
3132
}
@@ -45,8 +46,36 @@ func TestConnectionPool(t *testing.T) {
4546
_, err = pool.Get(context.Background(), bazToken)
4647
require.NoError(t, err)
4748
require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons")
48-
require.True(t, pool.cache.Contains(barToken))
49-
require.True(t, pool.cache.Contains(bazToken))
49+
require.True(t, pool.cache.Contains(pool.cacheKey(barToken, "")))
50+
require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, "")))
51+
}
52+
53+
func TestConnectionPool_ByDistinctOrigins(t *testing.T) {
54+
ctrl := gomock.NewController(t)
55+
defer ctrl.Finish()
56+
srv := gitpod.NewMockAPIInterface(ctrl)
57+
58+
cache, err := lru.New(2)
59+
require.NoError(t, err)
60+
pool := &ConnectionPool{
61+
cache: cache,
62+
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
63+
return srv, nil
64+
},
65+
}
66+
67+
token := auth.NewAccessToken("foo")
68+
69+
ctxWithOriginA := origin.ToContext(context.Background(), "originA")
70+
ctxWithOriginB := origin.ToContext(context.Background(), "originB")
71+
72+
_, err = pool.Get(ctxWithOriginA, token)
73+
require.NoError(t, err)
74+
require.Equal(t, 1, pool.cache.Len())
75+
76+
_, err = pool.Get(ctxWithOriginB, token)
77+
require.NoError(t, err)
78+
require.Equal(t, 2, pool.cache.Len())
5079
}
5180

5281
func TestEndpointBasedOnToken(t *testing.T) {
@@ -57,7 +86,7 @@ func TestEndpointBasedOnToken(t *testing.T) {
5786
require.NoError(t, err)
5887
require.Equal(t, "wss://gitpod.io/api/v1", endpointForAccessToken)
5988

60-
endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo", "server"), u)
89+
endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo"), u)
6190
require.NoError(t, err)
6291
require.Equal(t, "wss://gitpod.io/api/gitpod", endpointForCookie)
6392
}

components/public-api-server/pkg/server/server.go

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/gitpod-io/gitpod/public-api-server/pkg/apiv1"
2727
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
2828
"github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice"
29+
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
2930
"github.com/gitpod-io/gitpod/public-api-server/pkg/proxy"
3031
"github.com/gitpod-io/gitpod/public-api-server/pkg/webhooks"
3132
"github.com/sirupsen/logrus"
@@ -112,6 +113,7 @@ func register(srv *baseserver.Server, connPool proxy.ServerConnectionPool, expCl
112113
NewMetricsInterceptor(connectMetrics),
113114
NewLogInterceptor(log.Log),
114115
auth.NewServerInterceptor(),
116+
origin.NewInterceptor(),
115117
),
116118
}
117119

0 commit comments

Comments
 (0)