diff --git a/context_test.go b/context_test.go index 417d4a749..d71848cee 100644 --- a/context_test.go +++ b/context_test.go @@ -462,9 +462,13 @@ func TestContextCookie(t *testing.T) { func TestContextPath(t *testing.T) { e := New() - r := e.Router() + b := NewRouter(e) + + b.Add(http.MethodGet, "/users/:id", "", nil) + b.Add(http.MethodGet, "/users/:uid/files/:fid", "", nil) + + r, _ := b.Build() - r.Add(http.MethodGet, "/users/:id", nil) c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1", c) @@ -472,7 +476,6 @@ func TestContextPath(t *testing.T) { assert.Equal("/users/:id", c.Path()) - r.Add(http.MethodGet, "/users/:uid/files/:fid", nil) c = e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1/files/1", c) assert.Equal("/users/:uid/files/:fid", c.Path()) @@ -498,8 +501,7 @@ func TestContextPathParam(t *testing.T) { func TestContextGetAndSetParam(t *testing.T) { e := New() - r := e.Router() - r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) + e.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) c.SetParamNames("foo") @@ -672,17 +674,20 @@ func BenchmarkContext_Store(b *testing.B) { func TestContextHandler(t *testing.T) { e := New() - r := e.Router() - b := new(bytes.Buffer) + b := NewRouter(e) + buff := new(bytes.Buffer) - r.Add(http.MethodGet, "/handler", func(Context) error { - _, err := b.Write([]byte("handler")) + b.Add(http.MethodGet, "/handler", "", func(Context) error { + _, err := buff.Write([]byte("handler")) return err }) + + r, _ := b.Build() + c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/handler", c) err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) + testify.Equal(t, "handler", buff.String()) testify.NoError(t, err) } diff --git a/echo.go b/echo.go index 7f1c83998..b9bb68c6e 100644 --- a/echo.go +++ b/echo.go @@ -37,7 +37,6 @@ Learn more at https://echo.labstack.com package echo import ( - "bytes" stdContext "context" "crypto/tls" "errors" @@ -75,8 +74,10 @@ type ( premiddleware []MiddlewareFunc middleware []MiddlewareFunc maxParam *int - router *Router - routers map[string]*Router + routerBuilder RouteBuilder + routersBuilder map[string]RouteBuilder + router Router + routers map[string]Router notFoundHandler HandlerFunc pool sync.Pool Server *http.Server @@ -97,13 +98,6 @@ type ( ListenerNetwork string } - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } - // HTTPError represents an error that occurred while handling a request. HTTPError struct { Code int `json:"-"` @@ -320,8 +314,8 @@ func New() (e *Echo) { e.pool.New = func() interface{} { return e.NewContext(nil, nil) } - e.router = NewRouter(e) - e.routers = map[string]*Router{} + e.routerBuilder = NewRouter(e) + e.routersBuilder = map[string]RouteBuilder{} return } @@ -338,15 +332,33 @@ func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { } // Router returns the default router. -func (e *Echo) Router() *Router { +func (e *Echo) Router() Router { return e.router } // Routers returns the map of host => router. -func (e *Echo) Routers() map[string]*Router { +func (e *Echo) Routers() map[string]Router { return e.routers } +//BuildRouters builds the internal Routers +func (e *Echo) BuildRouters() error { + var err error + if e.router, err = e.routerBuilder.Build(); err != nil { + return err + } + e.routers = make(map[string]Router) + for host, routeBuilder := range e.routersBuilder { + router, err := routeBuilder.Build() + if err == nil { + e.routers[host] = router + } else { + return err + } + } + return nil +} + // DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response // with status code. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { @@ -530,17 +542,11 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { name := handlerName(handler) - router := e.findRouter(host) - router.Add(method, path, func(c Context) error { + router := e.findRouterBuilder(host) + r, _ := router.Add(method, path, name, func(c Context) error { h := applyMiddleware(handler, middleware...) return h(c) }) - r := &Route{ - Method: method, - Path: path, - Name: name, - } - e.router.routes[method+path] = r return r } @@ -552,7 +558,7 @@ func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...Middl // Host creates a new router group for the provided host and optional host-level middleware. func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { - e.routers[name] = NewRouter(e) + e.routersBuilder[name] = NewRouter(e) g = &Group{host: name, echo: e} g.Use(m...) return @@ -578,34 +584,12 @@ func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { // Reverse generates an URL from route name and provided parameters. func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() + return e.routerBuilder.Reverse(name, params...) } // Routes returns the registered routes. func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } + routes, _ := e.routerBuilder.Routes() return routes } @@ -757,6 +741,11 @@ func (e *Echo) configureServer(s *http.Server) (err error) { e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) } + //Build Routers + if err := e.BuildRouters(); err != nil { + return err + } + if s.TLSConfig == nil { if e.Listener == nil { e.Listener, err = newListener(s.Addr, e.ListenerNetwork) @@ -912,7 +901,7 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -func (e *Echo) findRouter(host string) *Router { +func (e *Echo) findRouter(host string) Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { return r @@ -921,6 +910,15 @@ func (e *Echo) findRouter(host string) *Router { return e.router } +func (e *Echo) findRouterBuilder(host string) RouteBuilder { + if len(e.routersBuilder) > 0 { + if r, ok := e.routersBuilder[host]; ok { + return r + } + } + return e.routerBuilder +} + func handlerName(h HandlerFunc) string { t := reflect.ValueOf(h).Type() if t.Kind() == reflect.Func { diff --git a/echo_test.go b/echo_test.go index 781b901fa..b2e9018f6 100644 --- a/echo_test.go +++ b/echo_test.go @@ -50,6 +50,7 @@ const userXMLPretty = ` func TestEcho(t *testing.T) { e := New() + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -185,6 +186,7 @@ func TestEchoStatic(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := New() e.Static(tc.givenPrefix, tc.givenRoot) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -243,6 +245,7 @@ func TestEchoStaticRedirectIndex(t *testing.T) { func TestEchoFile(t *testing.T) { e := New() e.File("/walle", "_fixture/images/walle.png") + e.BuildRouters() c, b := request(http.MethodGet, "/walle", e) assert.Equal(t, http.StatusOK, c) assert.NotEmpty(t, b) @@ -286,6 +289,8 @@ func TestEchoMiddleware(t *testing.T) { return c.String(http.StatusOK, "OK") }) + e.BuildRouters() + c, b := request(http.MethodGet, "/", e) assert.Equal(t, "-1123", buf.String()) assert.Equal(t, http.StatusOK, c) @@ -300,6 +305,7 @@ func TestEchoMiddlewareError(t *testing.T) { } }) e.GET("/", NotFoundHandler) + e.BuildRouters() c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -312,6 +318,8 @@ func TestEchoHandler(t *testing.T) { return c.String(http.StatusOK, "OK") }) + e.BuildRouters() + c, b := request(http.MethodGet, "/ok", e) assert.Equal(t, http.StatusOK, c) assert.Equal(t, "OK", b) @@ -473,6 +481,7 @@ func TestEchoEncodedPath(t *testing.T) { e.GET("/:id", func(c Context) error { return c.NoContent(http.StatusOK) }) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/with%2Fslash", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -525,6 +534,8 @@ func TestEchoGroup(t *testing.T) { }) g3.GET("", h) + e.BuildRouters() + request(http.MethodGet, "/users", e) assert.Equal(t, "0", buf.String()) @@ -539,6 +550,7 @@ func TestEchoGroup(t *testing.T) { func TestEchoNotFound(t *testing.T) { e := New() + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/files", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -550,6 +562,7 @@ func TestEchoMethodNotAllowed(t *testing.T) { e.GET("/", func(c Context) error { return c.String(http.StatusOK, "Echo!") }) + e.BuildRouters() req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -840,6 +853,7 @@ func testMethod(t *testing.T, method, path string, e *Echo) { }) i := interface{}(e) reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) + e.BuildRouters() _, body := request(method, path, e) assert.Equal(t, method, body) } @@ -884,6 +898,7 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { "error": "stackinfo", }) }) + e.BuildRouters() // With Debug=true plain response contains error message c, b := request(http.MethodGet, "/plain", e) assert.Equal(t, http.StatusInternalServerError, c) diff --git a/group_test.go b/group_test.go index c51fd91eb..db1f5cedd 100644 --- a/group_test.go +++ b/group_test.go @@ -32,6 +32,7 @@ func TestGroupFile(t *testing.T) { e := New() g := e.Group("/group") g.File("/walle", "_fixture/images/walle.png") + e.BuildRouters() expectedData, err := ioutil.ReadFile("_fixture/images/walle.png") assert.Nil(t, err) req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) @@ -75,6 +76,8 @@ func TestGroupRouteMiddleware(t *testing.T) { g.GET("/404", h, m4) g.GET("/405", h, m5) + e.BuildRouters() + c, _ := request(http.MethodGet, "/group/404", e) assert.Equal(t, 404, c) c, _ = request(http.MethodGet, "/group/405", e) @@ -105,6 +108,8 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { e.GET("unrelated", h, m2) e.GET("*", h, m2) + e.BuildRouters() + _, m := request(http.MethodGet, "/group/help", e) assert.Equal(t, "/group/help", m) _, m = request(http.MethodGet, "/group/help/other", e) diff --git a/middleware/compress_test.go b/middleware/compress_test.go index d16ffca43..edaabff93 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -112,6 +112,7 @@ func TestGzipErrorReturned(t *testing.T) { e.GET("/", func(c echo.Context) error { return echo.ErrNotFound }) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() @@ -128,6 +129,7 @@ func TestGzipErrorReturnedInvalidConfig(t *testing.T) { c.Response().Write([]byte("test")) return nil }) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() @@ -141,6 +143,7 @@ func TestGzipWithStatic(t *testing.T) { e := echo.New() e.Use(Gzip()) e.Static("/test", "../_fixture/images") + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 51fa6b0f1..4533e773c 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -76,6 +76,7 @@ func TestDecompressDefaultConfig(t *testing.T) { func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e := echo.New() + e.BuildRouters() body := `{"name":"echo"}` gz, _ := gzipString(body) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) @@ -112,6 +113,7 @@ func TestDecompressErrorReturned(t *testing.T) { e.GET("/", func(c echo.Context) error { return echo.ErrNotFound }) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() @@ -127,6 +129,7 @@ func TestDecompressSkipper(t *testing.T) { return c.Request().URL.Path == "/skip" }, })) + e.BuildRouters() body := `{"name": "echo"}` req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) @@ -156,6 +159,7 @@ func TestDecompressPoolError(t *testing.T) { Skipper: DefaultSkipper, GzipDecompressPool: &TestDecompressPoolWithError{}, })) + e.BuildRouters() body := `{"name": "echo"}` req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) diff --git a/middleware/logger_test.go b/middleware/logger_test.go index b196bc6c8..4a82b50ae 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -101,6 +101,8 @@ func TestLoggerTemplate(t *testing.T) { return c.String(http.StatusOK, "Header Logged") }) + e.BuildRouters() + req := httptest.NewRequest(http.MethodGet, "/?username=apagano-param&password=secret", nil) req.RequestURI = "/" req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") @@ -159,6 +161,8 @@ func TestLoggerCustomTimestamp(t *testing.T) { return c.String(http.StatusOK, "custom time stamp test") }) + e.BuildRouters() + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) diff --git a/middleware/proxy_1_11_test.go b/middleware/proxy_1_11_test.go index 26feaabaa..23a046461 100644 --- a/middleware/proxy_1_11_test.go +++ b/middleware/proxy_1_11_test.go @@ -41,6 +41,7 @@ func TestProxy_1_11(t *testing.T) { // Random e := echo.New() e.Use(Proxy(rb)) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index ec6f1925b..44461a095 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -53,6 +53,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() e.Use(Proxy(rb)) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -75,6 +76,7 @@ func TestProxy(t *testing.T) { rrb := NewRoundRobinBalancer(targets) e = echo.New() e.Use(Proxy(rrb)) + e.BuildRouters() rec = httptest.NewRecorder() e.ServeHTTP(rec, req) body = rec.Body.String() @@ -94,6 +96,7 @@ func TestProxy(t *testing.T) { return nil }, })) + e.BuildRouters() rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "modified", rec.Body.String()) @@ -111,6 +114,7 @@ func TestProxy(t *testing.T) { e = echo.New() e.Use(contextObserver) e.Use(Proxy(rrb1)) + e.BuildRouters() rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } @@ -123,6 +127,7 @@ func TestProxyRealIPHeader(t *testing.T) { rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() e.Use(Proxy(rrb)) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -178,6 +183,7 @@ func TestProxyRewrite(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) + e.BuildRouters() req.URL, _ = url.Parse("/api/users") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -233,6 +239,8 @@ func TestProxyRewriteRegex(t *testing.T) { }, })) + e.BuildRouters() + testCases := []struct { requestPath string statusCode int @@ -247,7 +255,6 @@ func TestProxyRewriteRegex(t *testing.T) { {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, } - for _, tc := range testCases { t.Run(tc.requestPath, func(t *testing.T) { req.URL, _ = url.Parse(tc.requestPath) diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 351b7313c..eabed5f6f 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -23,6 +23,7 @@ func TestRewrite(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) + e.BuildRouters() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() req.URL, _ = url.Parse("/api/users") @@ -52,7 +53,6 @@ func TestRewrite(t *testing.T) { // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() - r := e.Router() // Rewrite old url to new one e.Pre(Rewrite(map[string]string{ @@ -61,10 +61,12 @@ func TestEchoRewritePreMiddleware(t *testing.T) { )) // Route - r.Add(http.MethodGet, "/new", func(c echo.Context) error { + e.Add(http.MethodGet, "/new", func(c echo.Context) error { return c.NoContent(http.StatusOK) }) + e.BuildRouters() + req := httptest.NewRequest(http.MethodGet, "/old", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -75,7 +77,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() - r := e.Router() e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ @@ -84,13 +85,15 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { return c.String(http.StatusOK, "hosts") }) - r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { return c.String(http.StatusOK, "eng") }) + e.BuildRouters() + for i := 0; i < 100; i++ { req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil) rec := httptest.NewRecorder() @@ -114,6 +117,8 @@ func TestEchoRewriteWithCaret(t *testing.T) { }, })) + e.BuildRouters() + rec := httptest.NewRecorder() var req *http.Request @@ -147,6 +152,8 @@ func TestEchoRewriteWithRegexRules(t *testing.T) { }, })) + e.BuildRouters() + var rec *httptest.ResponseRecorder var req *http.Request @@ -163,12 +170,12 @@ func TestEchoRewriteWithRegexRules(t *testing.T) { {"/y/foo/bar", "/v5/bar/foo"}, } - for _, tc := range testCases { - t.Run(tc.requestPath, func(t *testing.T) { - req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) - }) - } + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + }) + } } diff --git a/middleware/static_test.go b/middleware/static_test.go index 8c0c97ded..0b0239f97 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -117,6 +117,8 @@ func TestStatic(t *testing.T) { e.Use(middlewareFunc) } + e.BuildRouters() + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() @@ -257,6 +259,8 @@ func TestStatic_GroupWithStatic(t *testing.T) { g := e.Group(group) g.Static(tc.givenPrefix, tc.givenRoot) + e.BuildRouters() + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) diff --git a/router.go b/router.go index 5010659a6..aaf31d752 100644 --- a/router.go +++ b/router.go @@ -1,14 +1,36 @@ package echo import ( + "bytes" + "errors" + "fmt" "net/http" + "reflect" "strings" ) type ( - // Router is the registry of all registered routes for an `Echo` instance for + // Route contains a information for matching against requests. + Route struct { + Method string `json:"method"` + Path string `json:"path"` + Name string `json:"name"` + } + + Router interface { + Find(method, path string, c Context) error + } + + RouteBuilder interface { + Add(method, path, name string, h HandlerFunc) (*Route, error) + Reverse(name string, params ...interface{}) string + Routes() ([]*Route, error) + Build() (Router, error) + } + + // router is the registry of all registered routes for an `Echo` instance for // request matching and URL path parameter parsing. - Router struct { + router struct { tree *node routes map[string]*Route echo *Echo @@ -51,9 +73,13 @@ const ( anyLabel = byte('*') ) +var ( + NoRouteFound = errors.New("no route found") +) + // NewRouter returns a new Router instance. -func NewRouter(e *Echo) *Router { - return &Router{ +func NewRouter(e *Echo) RouteBuilder { + return &router{ tree: &node{ methodHandler: new(methodHandler), }, @@ -62,8 +88,18 @@ func NewRouter(e *Echo) *Router { } } +func (r *router) registerRoute(method, path, name string) *Route { + route := &Route{ + Method: method, + Path: path, + Name: name, + } + r.routes[method+path] = route + return route +} + // Add registers a new route for method and path with matching handler. -func (r *Router) Add(method, path string, h HandlerFunc) { +func (r *router) Add(method, path, name string, h HandlerFunc) (*Route, error) { // Validate path if path == "" { path = "/" @@ -78,7 +114,9 @@ func (r *Router) Add(method, path string, h HandlerFunc) { if path[i] == ':' { j := i + 1 - r.insert(method, path[:i], nil, skind, "", nil) + if err := r.insert(method, path[:i], nil, skind, "", nil); err != nil { + return nil, err + } for ; i < l && path[i] != '/'; i++ { } @@ -87,21 +125,32 @@ func (r *Router) Add(method, path string, h HandlerFunc) { i, l = j, len(path) if i == l { - r.insert(method, path[:i], h, pkind, ppath, pnames) + if err := r.insert(method, path[:i], h, pkind, ppath, pnames); err != nil { + return nil, err + } } else { - r.insert(method, path[:i], nil, pkind, "", nil) + if err := r.insert(method, path[:i], nil, pkind, "", nil); err != nil { + return nil, err + } } } else if path[i] == '*' { - r.insert(method, path[:i], nil, skind, "", nil) + if err := r.insert(method, path[:i], nil, skind, "", nil); err != nil { + return nil, err + } pnames = append(pnames, "*") - r.insert(method, path[:i+1], h, akind, ppath, pnames) + if err := r.insert(method, path[:i+1], h, akind, ppath, pnames); err != nil { + return nil, err + } } } - r.insert(method, path, h, skind, ppath, pnames) + if err := r.insert(method, path, h, skind, ppath, pnames); err != nil { + return nil, err + } + return r.registerRoute(method, ppath, name), nil } -func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) { +func (r *router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) error { // Adjust max param l := len(pnames) if *r.echo.maxParam < l { @@ -110,7 +159,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn := r.tree // Current node as root if cn == nil { - panic("echo: invalid method") + errors.New("invalid root node") } search := path @@ -201,14 +250,19 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string } else { // Node already exists if h != nil { - cn.addHandler(method, h) - cn.ppath = ppath if len(cn.pnames) == 0 { // Issue #729 cn.pnames = pnames + } else { + if !reflect.DeepEqual(cn.pnames, pnames) { + return fmt.Errorf("route params are different for %s - %s != %s", + cn.ppath, cn.pnames, pnames) + } } + cn.addHandler(method, h) + cn.ppath = ppath } } - return + return nil } } @@ -328,7 +382,7 @@ func (n *node) checkMethodNotAllowed() HandlerFunc { // - Get context from `Echo#AcquireContext()` // - Reset it `Context#Reset()` // - Return it `Echo#ReleaseContext()`. -func (r *Router) Find(method, path string, c Context) { +func (r *router) Find(method, path string, c Context) error { ctx := c.(*context) ctx.path = path cn := r.tree // Current node as root @@ -381,7 +435,7 @@ func (r *Router) Find(method, path string, c Context) { // Attempt to go back up the tree on no matching prefix or no remaining search if l != pl || search == "" { if nn == nil { // Issue #1348 - return // Not found + return NoRouteFound } cn = nn search = ns @@ -481,7 +535,7 @@ func (r *Router) Find(method, path string, c Context) { break } } - return // Not found + return NoRouteFound } @@ -496,7 +550,7 @@ func (r *Router) Find(method, path string, c Context) { // Dig further for any, might have an empty value for *, e.g. // serving a directory. Issue #207. if cn = cn.anyChildren; cn == nil { - return + return NoRouteFound } if h := cn.findHandler(method); h != nil { ctx.handler = h @@ -508,5 +562,40 @@ func (r *Router) Find(method, path string, c Context) { pvalues[len(cn.pnames)-1] = "" } - return + return nil +} + +func (r *router) Reverse(name string, params ...interface{}) string { + uri := new(bytes.Buffer) + ln := len(params) + n := 0 + for _, r := range r.routes { + if r.Name == name { + for i, l := 0, len(r.Path); i < l; i++ { + if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { + for ; i < l && r.Path[i] != '/'; i++ { + } + uri.WriteString(fmt.Sprintf("%v", params[n])) + n++ + } + if i < l { + uri.WriteByte(r.Path[i]) + } + } + break + } + } + return uri.String() +} + +func (r *router) Routes() ([]*Route, error) { + routes := make([]*Route, 0, len(r.routes)) + for _, v := range r.routes { + routes = append(routes, v) + } + return routes, nil +} + +func (r *router) Build() (Router, error) { + return r, nil } diff --git a/router_test.go b/router_test.go index a5e53c05b..4cea74c8d 100644 --- a/router_test.go +++ b/router_test.go @@ -644,13 +644,14 @@ var ( func TestRouterStatic(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) path := "/folders/a/files/echo.gif" - r.Add(http.MethodGet, path, func(c Context) error { + b.Add(http.MethodGet, path, "", func(c Context) error { c.Set("path", path) return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() r.Find(http.MethodGet, path, c) c.handler(c) assert.Equal(t, path, c.Get("path")) @@ -658,23 +659,25 @@ func TestRouterStatic(t *testing.T) { func TestRouterParam(t *testing.T) { e := New() - r := e.router - r.Add(http.MethodGet, "/users/:id", func(c Context) error { + b := NewRouter(e) + b.Add(http.MethodGet, "/users/:id", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() r.Find(http.MethodGet, "/users/1", c) assert.Equal(t, "1", c.Param("id")) } func TestRouterTwoParam(t *testing.T) { e := New() - r := e.router - r.Add(http.MethodGet, "/users/:uid/files/:fid", func(Context) error { + b := NewRouter(e) + b.Add(http.MethodGet, "/users/:uid/files/:fid", "", func(Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() r.Find(http.MethodGet, "/users/1/files/1", c) assert.Equal(t, "1", c.Param("uid")) assert.Equal(t, "1", c.Param("fid")) @@ -683,17 +686,18 @@ func TestRouterTwoParam(t *testing.T) { // Issue #378 func TestRouterParamWithSlash(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) - r.Add(http.MethodGet, "/a/:b/c/d/:e", func(c Context) error { + b.Add(http.MethodGet, "/a/:b/c/d/:e", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/a/:b/c/:d/:f", func(c Context) error { + b.Add(http.MethodGet, "/a/:b/c/:d/:f", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() assert.NotPanics(t, func() { r.Find(http.MethodGet, "/a/1/c/d/2/3", c) }) @@ -702,7 +706,6 @@ func TestRouterParamWithSlash(t *testing.T) { // Issue #1509 func TestRouterParamStaticConflict(t *testing.T) { e := New() - r := e.router handler := func(c Context) error { c.Set("path", c.Path()) return nil @@ -713,6 +716,9 @@ func TestRouterParamStaticConflict(t *testing.T) { g.GET("/status", handler) g.GET("/:name", handler) + e.BuildRouters() + r := e.Router() + c := e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/g/s", c) c.handler(c) @@ -727,19 +733,21 @@ func TestRouterParamStaticConflict(t *testing.T) { func TestRouterMatchAny(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) // Routes - r.Add(http.MethodGet, "/", func(Context) error { + b.Add(http.MethodGet, "/", "", func(Context) error { return nil }) - r.Add(http.MethodGet, "/*", func(Context) error { + b.Add(http.MethodGet, "/*", "", func(Context) error { return nil }) - r.Add(http.MethodGet, "/users/*", func(Context) error { + b.Add(http.MethodGet, "/users/*", "", func(Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + + r, _ := b.Build() r.Find(http.MethodGet, "/", c) assert.Equal(t, "", c.Param("*")) @@ -753,18 +761,20 @@ func TestRouterMatchAny(t *testing.T) { // Issue #1739 func TestRouterMatchAnyPrefixIssue(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) // Routes - r.Add(http.MethodGet, "/*", func(c Context) error { + b.Add(http.MethodGet, "/*", "", func(c Context) error { c.Set("path", c.Path()) return nil }) - r.Add(http.MethodGet, "/users/*", func(c Context) error { + b.Add(http.MethodGet, "/users/*", "", func(c Context) error { c.Set("path", c.Path()) return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() + r.Find(http.MethodGet, "/", c) c.handler(c) assert.Equal(t, "/*", c.Get("path")) @@ -795,7 +805,7 @@ func TestRouterMatchAnyPrefixIssue(t *testing.T) { // for any routes with trailing slash requests func TestRouterMatchAnySlash(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) handler := func(c Context) error { c.Set("path", c.Path()) @@ -803,14 +813,17 @@ func TestRouterMatchAnySlash(t *testing.T) { } // Routes - r.Add(http.MethodGet, "/users", handler) - r.Add(http.MethodGet, "/users/*", handler) - r.Add(http.MethodGet, "/img/*", handler) - r.Add(http.MethodGet, "/img/load", handler) - r.Add(http.MethodGet, "/img/load/*", handler) - r.Add(http.MethodGet, "/assets/*", handler) + b.Add(http.MethodGet, "/users", "", handler) + b.Add(http.MethodGet, "/users/*", "", handler) + b.Add(http.MethodGet, "/img/*", "", handler) + b.Add(http.MethodGet, "/img/load", "", handler) + b.Add(http.MethodGet, "/img/load/*", "", handler) + b.Add(http.MethodGet, "/assets/*", "", handler) c := e.NewContext(nil, nil).(*context) + + r, _ := b.Build() + r.Find(http.MethodGet, "/", c) assert.Equal(t, "", c.Param("*")) @@ -865,21 +878,24 @@ func TestRouterMatchAnySlash(t *testing.T) { func TestRouterMatchAnyMultiLevel(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) handler := func(c Context) error { c.Set("path", c.Path()) return nil } // Routes - r.Add(http.MethodGet, "/api/users/jack", handler) - r.Add(http.MethodGet, "/api/users/jill", handler) - r.Add(http.MethodGet, "/api/users/*", handler) - r.Add(http.MethodGet, "/api/*", handler) - r.Add(http.MethodGet, "/other/*", handler) - r.Add(http.MethodGet, "/*", handler) + b.Add(http.MethodGet, "/api/users/jack", "", handler) + b.Add(http.MethodGet, "/api/users/jill", "", handler) + b.Add(http.MethodGet, "/api/users/*", "", handler) + b.Add(http.MethodGet, "/api/*", "", handler) + b.Add(http.MethodGet, "/other/*", "", handler) + b.Add(http.MethodGet, "/*", "", handler) c := e.NewContext(nil, nil).(*context) + + r, _ := b.Build() + r.Find(http.MethodGet, "/api/users/jack", c) c.handler(c) assert.Equal(t, "/api/users/jack", c.Get("path")) @@ -912,7 +928,6 @@ func TestRouterMatchAnyMultiLevel(t *testing.T) { } func TestRouterMatchAnyMultiLevelWithPost(t *testing.T) { e := New() - r := e.router handler := func(c Context) error { c.Set("path", c.Path()) return nil @@ -924,6 +939,9 @@ func TestRouterMatchAnyMultiLevelWithPost(t *testing.T) { e.Any("/api/*", handler) e.Any("/*", handler) + e.BuildRouters() + r := e.Router() + // POST /api/auth/login shall choose login method c := e.NewContext(nil, nil).(*context) r.Find(http.MethodPost, "/api/auth/login", c) @@ -963,11 +981,12 @@ func TestRouterMatchAnyMultiLevelWithPost(t *testing.T) { func TestRouterMicroParam(t *testing.T) { e := New() - r := e.router - r.Add(http.MethodGet, "/:a/:b/:c", func(c Context) error { + b := NewRouter(e) + b.Add(http.MethodGet, "/:a/:b/:c", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() r.Find(http.MethodGet, "/1/2/3", c) assert.Equal(t, "1", c.Param("a")) assert.Equal(t, "2", c.Param("b")) @@ -976,13 +995,14 @@ func TestRouterMicroParam(t *testing.T) { func TestRouterMixParamMatchAny(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) // Route - r.Add(http.MethodGet, "/users/:id/*", func(c Context) error { + b.Add(http.MethodGet, "/users/:id/*", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() r.Find(http.MethodGet, "/users/joe/comments", c) c.handler(c) @@ -991,17 +1011,18 @@ func TestRouterMixParamMatchAny(t *testing.T) { func TestRouterMultiRoute(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) // Routes - r.Add(http.MethodGet, "/users", func(c Context) error { + b.Add(http.MethodGet, "/users", "", func(c Context) error { c.Set("path", "/users") return nil }) - r.Add(http.MethodGet, "/users/:id", func(c Context) error { + b.Add(http.MethodGet, "/users/:id", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() // Route > /users r.Find(http.MethodGet, "/users", c) @@ -1021,18 +1042,20 @@ func TestRouterMultiRoute(t *testing.T) { func TestRouterPriority(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) // Routes - r.Add(http.MethodGet, "/users", handlerHelper("a", 1)) - r.Add(http.MethodGet, "/users/new", handlerHelper("b", 2)) - r.Add(http.MethodGet, "/users/:id", handlerHelper("c", 3)) - r.Add(http.MethodGet, "/users/dew", handlerHelper("d", 4)) - r.Add(http.MethodGet, "/users/:id/files", handlerHelper("e", 5)) - r.Add(http.MethodGet, "/users/newsee", handlerHelper("f", 6)) - r.Add(http.MethodGet, "/users/*", handlerHelper("g", 7)) - r.Add(http.MethodGet, "/users/new/*", handlerHelper("h", 8)) - r.Add(http.MethodGet, "/*", handlerHelper("i", 9)) + b.Add(http.MethodGet, "/users", "", handlerHelper("a", 1)) + b.Add(http.MethodGet, "/users/new", "", handlerHelper("b", 2)) + b.Add(http.MethodGet, "/users/:id", "", handlerHelper("c", 3)) + b.Add(http.MethodGet, "/users/dew", "", handlerHelper("d", 4)) + b.Add(http.MethodGet, "/users/:id/files", "", handlerHelper("e", 5)) + b.Add(http.MethodGet, "/users/newsee", "", handlerHelper("f", 6)) + b.Add(http.MethodGet, "/users/*", "", handlerHelper("g", 7)) + b.Add(http.MethodGet, "/users/new/*", "", handlerHelper("h", 8)) + b.Add(http.MethodGet, "/*", "", handlerHelper("i", 9)) + + r, _ := b.Build() // Route > /users c := e.NewContext(nil, nil).(*context) @@ -1145,12 +1168,12 @@ func TestRouterPriority(t *testing.T) { func TestRouterIssue1348(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) - r.Add(http.MethodGet, "/:lang/", func(c Context) error { + b.Add(http.MethodGet, "/:lang/", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/:lang/dupa", func(c Context) error { + b.Add(http.MethodGet, "/:lang/dupa", "", func(c Context) error { return nil }) } @@ -1158,19 +1181,21 @@ func TestRouterIssue1348(t *testing.T) { // Issue #372 func TestRouterPriorityNotFound(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) c := e.NewContext(nil, nil).(*context) // Add - r.Add(http.MethodGet, "/a/foo", func(c Context) error { + b.Add(http.MethodGet, "/a/foo", "", func(c Context) error { c.Set("a", 1) return nil }) - r.Add(http.MethodGet, "/a/bar", func(c Context) error { + b.Add(http.MethodGet, "/a/bar", "", func(c Context) error { c.Set("b", 2) return nil }) + r, _ := b.Build() + // Find r.Find(http.MethodGet, "/a/foo", c) c.handler(c) @@ -1188,20 +1213,21 @@ func TestRouterPriorityNotFound(t *testing.T) { func TestRouterParamNames(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) // Routes - r.Add(http.MethodGet, "/users", func(c Context) error { + b.Add(http.MethodGet, "/users", "", func(c Context) error { c.Set("path", "/users") return nil }) - r.Add(http.MethodGet, "/users/:id", func(c Context) error { + b.Add(http.MethodGet, "/users/:id", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/users/:uid/files/:fid", func(c Context) error { + b.Add(http.MethodGet, "/users/:uid/files/:fid", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() // Route > /users r.Find(http.MethodGet, "/users", c) @@ -1224,14 +1250,16 @@ func TestRouterParamNames(t *testing.T) { // Issue #623 and #1406 func TestRouterStaticDynamicConflict(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) + + b.Add(http.MethodGet, "/dictionary/skills", "", handlerHelper("a", 1)) + b.Add(http.MethodGet, "/dictionary/:name", "", handlerHelper("b", 2)) + b.Add(http.MethodGet, "/users/new", "", handlerHelper("d", 4)) + b.Add(http.MethodGet, "/users/:name", "", handlerHelper("e", 5)) + b.Add(http.MethodGet, "/server", "", handlerHelper("c", 3)) + b.Add(http.MethodGet, "/", "", handlerHelper("f", 6)) - r.Add(http.MethodGet, "/dictionary/skills", handlerHelper("a", 1)) - r.Add(http.MethodGet, "/dictionary/:name", handlerHelper("b", 2)) - r.Add(http.MethodGet, "/users/new", handlerHelper("d", 4)) - r.Add(http.MethodGet, "/users/:name", handlerHelper("e", 5)) - r.Add(http.MethodGet, "/server", handlerHelper("c", 3)) - r.Add(http.MethodGet, "/", handlerHelper("f", 6)) + r, _ := b.Build() c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/dictionary/skills", c) @@ -1279,24 +1307,26 @@ func TestRouterStaticDynamicConflict(t *testing.T) { // Issue #1348 func TestRouterParamBacktraceNotFound(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) // Add - r.Add(http.MethodGet, "/:param1", func(c Context) error { + b.Add(http.MethodGet, "/:param1", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/:param1/foo", func(c Context) error { + b.Add(http.MethodGet, "/:param1/foo", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/:param1/bar", func(c Context) error { + b.Add(http.MethodGet, "/:param1/bar", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/:param1/bar/:param2", func(c Context) error { + b.Add(http.MethodGet, "/:param1/bar/:param2", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() + //Find r.Find(http.MethodGet, "/a", c) assert.Equal(t, "a", c.Param("param1")) @@ -1322,14 +1352,15 @@ func TestRouterParamBacktraceNotFound(t *testing.T) { func testRouterAPI(t *testing.T, api []*Route) { e := New() - r := e.router + b := NewRouter(e) for _, route := range api { - r.Add(route.Method, route.Path, func(c Context) error { + b.Add(route.Method, route.Path, "", func(c Context) error { return nil }) } c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() for _, route := range api { r.Find(route.Method, route.Path, c) tokens := strings.Split(route.Path[1:], "/") @@ -1394,39 +1425,41 @@ func TestRouterMixedParams(t *testing.T) { // Issue #1466 func TestRouterParam1466(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) - r.Add(http.MethodPost, "/users/signup", func(c Context) error { + b.Add(http.MethodPost, "/users/signup", "", func(c Context) error { return nil }) - r.Add(http.MethodPost, "/users/signup/bulk", func(c Context) error { + b.Add(http.MethodPost, "/users/signup/bulk", "", func(c Context) error { return nil }) - r.Add(http.MethodPost, "/users/survey", func(c Context) error { + b.Add(http.MethodPost, "/users/survey", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/users/:username", func(c Context) error { + b.Add(http.MethodGet, "/users/:username", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/interests/:name/users", func(c Context) error { + b.Add(http.MethodGet, "/interests/:name/users", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/skills/:name/users", func(c Context) error { + b.Add(http.MethodGet, "/skills/:name/users", "", func(c Context) error { return nil }) // Additional routes for Issue 1479 - r.Add(http.MethodGet, "/users/:username/likes/projects/ids", func(c Context) error { + b.Add(http.MethodGet, "/users/:username/likes/projects/ids", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/users/:username/profile", func(c Context) error { + b.Add(http.MethodGet, "/users/:username/profile", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/users/:username/uploads/:type", func(c Context) error { + b.Add(http.MethodGet, "/users/:username/uploads/:type", "", func(c Context) error { return nil }) c := e.NewContext(nil, nil).(*context) + r, _ := b.Build() + r.Find(http.MethodGet, "/users/ajitem", c) assert.Equal(t, "ajitem", c.Param("username")) @@ -1474,7 +1507,6 @@ func TestRouterParam1466(t *testing.T) { // Issue #1655 func TestRouterFindNotPanicOrLoopsWhenContextSetParamValuesIsCalledWithLessValuesThanEchoMaxParam(t *testing.T) { e := New() - r := e.router v0 := e.Group("/:version") v0.GET("/admin", func(c Context) error { @@ -1487,6 +1519,9 @@ func TestRouterFindNotPanicOrLoopsWhenContextSetParamValuesIsCalledWithLessValue v0.GET("/images/:id", handlerHelper("i", 1)) v0.GET("/view/*", handlerHelper("v", 1)) + e.BuildRouters() + r := e.Router() + //If this API is called before the next two one panic the other loops ( of course without my fix ;) ) c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/v1/admin", c) @@ -1511,16 +1546,18 @@ func TestRouterFindNotPanicOrLoopsWhenContextSetParamValuesIsCalledWithLessValue // Issue #1653 func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { e := New() - r := e.router + b := NewRouter(e) - r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1)) - r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error { + b.Add(http.MethodGet, "/users/create", "", handlerHelper("create", 1)) + b.Add(http.MethodGet, "/users/:id/edit", "", func(c Context) error { return nil }) - r.Add(http.MethodGet, "/users/:id/active", func(c Context) error { + b.Add(http.MethodGet, "/users/:id/active", "", func(c Context) error { return nil }) + r, _ := b.Build() + c := e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/alice/edit", c) assert.Equal(t, "alice", c.Param("id")) @@ -1542,18 +1579,77 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { assert.Equal(t, http.StatusNotFound, he.Code) } +// Issue #1726 +func TestRouterParamDifferentNamesPerHTTPMethod(t *testing.T) { + e := New() + b := NewRouter(e) + + b.Add(http.MethodGet, "/translation/:lang", "", func(Context) error { + return nil + }) + b.Add(http.MethodPut, "/translation/:lang", "", func(Context) error { + return nil + }) + + c := e.NewContext(nil, nil).(*context) + + r, _ := b.Build() + + r.Find(http.MethodGet, "/translation/en-US", c) + assert.Equal(t, "en-US", c.Param("lang")) + + r.Find(http.MethodPut, "/translation/es-AR", c) + assert.Equal(t, "es-AR", c.Param("lang")) + + _, err := b.Add(http.MethodDelete, "/translation/:id", "", func(Context) error { + return nil + }) + assert.Error(t, err) +} + +func TestRouterReverse(t *testing.T) { + assert := assert.New(t) + + b := NewRouter(New()) + dummyHandler := func(Context) error { return nil } + + route, _ := b.Add(http.MethodGet, "/static", "", dummyHandler) + route.Name = "/static" + + route, _ = b.Add(http.MethodGet, "/static/*", "", dummyHandler) + route.Name = "/static/*" + + route, _ = b.Add(http.MethodGet, "/params/:foo", "/params/:foo", dummyHandler) + route, _ = b.Add(http.MethodGet, "/params/:foo/bar/:qux", "/params/:foo/bar/:qux", dummyHandler) + route, _ = b.Add(http.MethodGet, "/params/:foo/bar/:qux/*", "/params/:foo/bar/:qux/*", dummyHandler) + + assert.Equal("/static", b.Reverse("/static")) + assert.Equal("/static", b.Reverse("/static", "missing param")) + assert.Equal("/static/*", b.Reverse("/static/*")) + assert.Equal("/static/foo.txt", b.Reverse("/static/*", "foo.txt")) + + assert.Equal("/params/:foo", b.Reverse("/params/:foo")) + assert.Equal("/params/one", b.Reverse("/params/:foo", "one")) + assert.Equal("/params/:foo/bar/:qux", b.Reverse("/params/:foo/bar/:qux")) + assert.Equal("/params/one/bar/:qux", b.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal("/params/one/bar/two", b.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal("/params/one/bar/two/three", b.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { e := New() - r := e.router + builder := NewRouter(e) b.ReportAllocs() // Add routes for _, route := range routes { - r.Add(route.Method, route.Path, func(c Context) error { + builder.Add(route.Method, route.Path, "", func(c Context) error { return nil }) } + r, _ := builder.Build() + // Routes adding are performed just once, so it doesn't make sense to see that in the benchmark b.ResetTimer()