Skip to content

Commit 6a3b779

Browse files
slessardsl255051
andauthored
Fix inconsistent processing of server variables in gorillamux router (#705)
Co-authored-by: Steve Lessard <[email protected]>
1 parent 6cbc1b0 commit 6a3b779

File tree

2 files changed

+196
-27
lines changed

2 files changed

+196
-27
lines changed

routers/gorillamux/router.go

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) {
5757
muxRouter := mux.NewRouter().UseEncodedPath()
5858
r := &Router{}
5959
for _, path := range orderedPaths(doc.Paths) {
60-
servers := servers
61-
6260
pathItem := doc.Paths[path]
6361
if len(pathItem.Servers) > 0 {
6462
if servers, err = makeServers(pathItem.Servers); err != nil {
@@ -140,19 +138,13 @@ func makeServers(in openapi3.Servers) ([]srv, error) {
140138
if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" {
141139
varsUpdater = func(vars map[string]string) { vars[sVar] = lhs }
142140
}
143-
servers = append(servers, srv{
144-
base: server.Variables[sVar].Default,
145-
server: server,
146-
varsUpdater: varsUpdater,
147-
})
148-
continue
149-
}
141+
svr, err := newSrv(serverURL, server, varsUpdater)
142+
if err != nil {
143+
return nil, err
144+
}
150145

151-
var schemes []string
152-
if strings.Contains(serverURL, "://") {
153-
scheme0 := strings.Split(serverURL, "://")[0]
154-
schemes = permutePart(scheme0, server)
155-
serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)
146+
servers = append(servers, svr)
147+
continue
156148
}
157149

158150
// If a variable represents the port "http://domain.tld:{port}/bla"
@@ -172,21 +164,11 @@ func makeServers(in openapi3.Servers) ([]srv, error) {
172164
}
173165
}
174166

175-
u, err := url.Parse(bEncode(serverURL))
167+
svr, err := newSrv(serverURL, server, varsUpdater)
176168
if err != nil {
177169
return nil, err
178170
}
179-
path := bDecode(u.EscapedPath())
180-
if len(path) > 0 && path[len(path)-1] == '/' {
181-
path = path[:len(path)-1]
182-
}
183-
servers = append(servers, srv{
184-
host: bDecode(u.Host), //u.Hostname()?
185-
base: path,
186-
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
187-
server: server,
188-
varsUpdater: varsUpdater,
189-
})
171+
servers = append(servers, svr)
190172
}
191173
if len(servers) == 0 {
192174
servers = append(servers, srv{})
@@ -195,6 +177,32 @@ func makeServers(in openapi3.Servers) ([]srv, error) {
195177
return servers, nil
196178
}
197179

180+
func newSrv(serverURL string, server *openapi3.Server, varsUpdater varsf) (srv, error) {
181+
var schemes []string
182+
if strings.Contains(serverURL, "://") {
183+
scheme0 := strings.Split(serverURL, "://")[0]
184+
schemes = permutePart(scheme0, server)
185+
serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)
186+
}
187+
188+
u, err := url.Parse(bEncode(serverURL))
189+
if err != nil {
190+
return srv{}, err
191+
}
192+
path := bDecode(u.EscapedPath())
193+
if len(path) > 0 && path[len(path)-1] == '/' {
194+
path = path[:len(path)-1]
195+
}
196+
svr := srv{
197+
host: bDecode(u.Host), //u.Hostname()?
198+
base: path,
199+
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
200+
server: server,
201+
varsUpdater: varsUpdater,
202+
}
203+
return svr, nil
204+
}
205+
198206
func orderedPaths(paths map[string]*openapi3.PathItem) []string {
199207
// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject
200208
// When matching URLs, concrete (non-templated) paths would be matched

routers/gorillamux/router_test.go

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"sort"
77
"testing"
88

9+
"github.com/stretchr/testify/assert"
910
"github.com/stretchr/testify/require"
1011

1112
"github.com/getkin/kin-openapi/openapi3"
@@ -249,7 +250,16 @@ func TestServerPath(t *testing.T) {
249250
"http://example.com:{port}/path",
250251
map[string]string{
251252
"port": "8088",
252-
})},
253+
}),
254+
newServerWithVariables(
255+
"{server}",
256+
map[string]string{
257+
"server": "/",
258+
}),
259+
newServerWithVariables(
260+
"/",
261+
nil,
262+
)},
253263
})
254264
require.NoError(t, err)
255265
}
@@ -325,6 +335,157 @@ func TestRelativeURL(t *testing.T) {
325335
require.Equal(t, "/hello", route.Path)
326336
}
327337

338+
func Test_makeServers(t *testing.T) {
339+
type testStruct struct {
340+
name string
341+
servers openapi3.Servers
342+
want []srv
343+
wantErr bool
344+
initFn func(tt *testStruct)
345+
}
346+
tests := []testStruct{
347+
{
348+
name: "server is root path",
349+
servers: openapi3.Servers{
350+
newServerWithVariables("/", nil),
351+
},
352+
want: []srv{{
353+
schemes: nil,
354+
host: "",
355+
base: "",
356+
server: nil,
357+
varsUpdater: nil,
358+
}},
359+
wantErr: false,
360+
initFn: func(tt *testStruct) {
361+
for i, server := range tt.servers {
362+
tt.want[i].server = server
363+
}
364+
},
365+
},
366+
{
367+
name: "server with single variable that evaluates to root path",
368+
servers: openapi3.Servers{
369+
newServerWithVariables("{server}", map[string]string{"server": "/"}),
370+
},
371+
want: []srv{{
372+
schemes: nil,
373+
host: "",
374+
base: "",
375+
server: nil,
376+
varsUpdater: nil,
377+
}},
378+
wantErr: false,
379+
initFn: func(tt *testStruct) {
380+
for i, server := range tt.servers {
381+
tt.want[i].server = server
382+
}
383+
},
384+
},
385+
{
386+
name: "server is http://localhost:28002",
387+
servers: openapi3.Servers{
388+
newServerWithVariables("http://localhost:28002", nil),
389+
},
390+
want: []srv{{
391+
schemes: []string{"http"},
392+
host: "localhost:28002",
393+
base: "",
394+
server: nil,
395+
varsUpdater: nil,
396+
}},
397+
wantErr: false,
398+
initFn: func(tt *testStruct) {
399+
for i, server := range tt.servers {
400+
tt.want[i].server = server
401+
}
402+
},
403+
},
404+
{
405+
name: "server with single variable that evaluates to http://localhost:28002",
406+
servers: openapi3.Servers{
407+
newServerWithVariables("{server}", map[string]string{"server": "http://localhost:28002"}),
408+
},
409+
want: []srv{{
410+
schemes: []string{"http"},
411+
host: "localhost:28002",
412+
base: "",
413+
server: nil,
414+
varsUpdater: nil,
415+
}},
416+
wantErr: false,
417+
initFn: func(tt *testStruct) {
418+
for i, server := range tt.servers {
419+
tt.want[i].server = server
420+
}
421+
},
422+
},
423+
{
424+
name: "server with multiple variables that evaluates to http://localhost:28002",
425+
servers: openapi3.Servers{
426+
newServerWithVariables("{scheme}://{host}:{port}", map[string]string{"scheme": "http", "host": "localhost", "port": "28002"}),
427+
},
428+
want: []srv{{
429+
schemes: []string{"http"},
430+
host: "{host}:28002",
431+
base: "",
432+
server: nil,
433+
varsUpdater: func(vars map[string]string) { vars["port"] = "28002" },
434+
}},
435+
wantErr: false,
436+
initFn: func(tt *testStruct) {
437+
for i, server := range tt.servers {
438+
tt.want[i].server = server
439+
}
440+
},
441+
},
442+
{
443+
name: "server with unparsable URL fails",
444+
servers: openapi3.Servers{
445+
newServerWithVariables("exam^ple.com:443", nil),
446+
},
447+
want: nil,
448+
wantErr: true,
449+
initFn: nil,
450+
},
451+
{
452+
name: "server with single variable that evaluates to unparsable URL fails",
453+
servers: openapi3.Servers{
454+
newServerWithVariables("{server}", map[string]string{"server": "exam^ple.com:443"}),
455+
},
456+
want: nil,
457+
wantErr: true,
458+
initFn: nil,
459+
},
460+
}
461+
for _, tt := range tests {
462+
t.Run(tt.name, func(t *testing.T) {
463+
if tt.initFn != nil {
464+
tt.initFn(&tt)
465+
}
466+
got, err := makeServers(tt.servers)
467+
if (err != nil) != tt.wantErr {
468+
t.Errorf("makeServers() error = %v, wantErr %v", err, tt.wantErr)
469+
return
470+
}
471+
assert.Equal(t, len(tt.want), len(got), "expected and actual servers lengths are not equal")
472+
for i := 0; i < len(tt.want); i++ {
473+
// Unfortunately using assert.Equals or reflect.DeepEquals isn't
474+
// an option because function pointers cannot be compared
475+
assert.Equal(t, tt.want[i].schemes, got[i].schemes)
476+
assert.Equal(t, tt.want[i].host, got[i].host)
477+
assert.Equal(t, tt.want[i].host, got[i].host)
478+
assert.Equal(t, tt.want[i].server, got[i].server)
479+
if tt.want[i].varsUpdater == nil {
480+
assert.Nil(t, got[i].varsUpdater, "expected and actual varsUpdater should point to same function")
481+
} else {
482+
assert.NotNil(t, got[i].varsUpdater, "expected and actual varsUpdater should point to same function")
483+
}
484+
}
485+
})
486+
}
487+
}
488+
328489
func newServerWithVariables(url string, variables map[string]string) *openapi3.Server {
329490
var serverVariables = map[string]*openapi3.ServerVariable{}
330491

0 commit comments

Comments
 (0)