diff --git a/core/request.go b/core/request.go index 6e2e941..bd4102e 100644 --- a/core/request.go +++ b/core/request.go @@ -35,6 +35,16 @@ const ( // API Gateway stage variables. To access the stage variable values // use the GetAPIGatewayStageVars method of the RequestAccessor object. APIGwStageVarsHeader = "X-GoLambdaProxy-ApiGw-StageVars" + + // APIGwPathParamVarsHeader is the custom header key used to store the + // API Gateway path param variables. To access the path param variable values + // use the GetAPIGatewayPathParamVars method of the RequestAccessor object. + APIGwPathParamVarsHeader = "X-GoLambdaProxy-ApiGw-PathParamVars" + + // APIGwQueryStringVarsHeader is the custom header key used to store the + // API Gateway query string variables. To access the query string param variable values + // use the GetAPIGatewayQueryStringParamVars method of the RequestAccessor object. + APIGwQueryStringVarsHeader = "X-GoLambdaProxy-ApiGw-QueryStringParamVars" ) // RequestAccessor objects give access to custom API Gateway properties @@ -79,6 +89,42 @@ func (r *RequestAccessor) GetAPIGatewayStageVars(req *http.Request) (map[string] return stageVars, nil } +// GetAPIGatewayPathParamVars extracts the API Gateway path param variables from a +// request's custom header. +// Returns a map[string]string of the path param variables and their values from +// the request. +func (r *RequestAccessor) GetAPIGatewayPathParamVars(req *http.Request) (map[string]string, error) { + pathVars := make(map[string]string) + if req.Header.Get(APIGwPathParamVarsHeader) == "" { + return pathVars, errors.New("No path param vars header in request") + } + err := json.Unmarshal([]byte(req.Header.Get(APIGwPathParamVarsHeader)), &pathVars) + if err != nil { + log.Println("Error while unmarshalling stage variables") + log.Println(err) + return pathVars, err + } + return pathVars, nil +} + +// GetAPIGatewayQueryStringParamVars extracts the API Gateway query string param variables from a +// request's custom header. +// Returns a map[string]string of the query string param variables and their values from +// the request. +func (r *RequestAccessor) GetAPIGatewayQueryStringParamVars(req *http.Request) (map[string]string, error) { + pathVars := make(map[string]string) + if req.Header.Get(APIGwQueryStringVarsHeader) == "" { + return pathVars, errors.New("No query string vars header in request") + } + err := json.Unmarshal([]byte(req.Header.Get(APIGwQueryStringVarsHeader)), &pathVars) + if err != nil { + log.Println("Error while unmarshalling query string param variables") + log.Println(err) + return pathVars, err + } + return pathVars, nil +} + // StripBasePath instructs the RequestAccessor object that the given base // path should be removed from the request path before sending it to the // framework for routing. This is used when API Gateway is configured with @@ -214,6 +260,21 @@ func addToHeader(req *http.Request, apiGwRequest events.APIGatewayProxyRequest) return nil, err } req.Header.Set(APIGwStageVarsHeader, string(stageVars)) + + pathParamsVars, err := json.Marshal(apiGwRequest.PathParameters) + if err != nil { + log.Println("Could not marshal path param variables for custom header") + return nil, err + } + req.Header.Set(APIGwPathParamVarsHeader, string(pathParamsVars)) + + queryStringParamsVars, err := json.Marshal(apiGwRequest.QueryStringParameters) + if err != nil { + log.Println("Could not marshal query string param variables for custom header") + return nil, err + } + req.Header.Set(APIGwQueryStringVarsHeader, string(queryStringParamsVars)) + apiGwContext, err := json.Marshal(apiGwRequest.RequestContext) if err != nil { log.Println("Could not Marshal API GW context for custom header") @@ -225,7 +286,13 @@ func addToHeader(req *http.Request, apiGwRequest events.APIGatewayProxyRequest) func addToContext(ctx context.Context, req *http.Request, apiGwRequest events.APIGatewayProxyRequest) *http.Request { lc, _ := lambdacontext.FromContext(ctx) - rc := requestContext{lambdaContext: lc, gatewayProxyContext: apiGwRequest.RequestContext, stageVars: apiGwRequest.StageVariables} + rc := requestContext{ + lambdaContext: lc, + gatewayProxyContext: apiGwRequest.RequestContext, + stageVars: apiGwRequest.StageVariables, + pathParamVars: apiGwRequest.PathParameters, + queryStringParams: apiGwRequest.QueryStringParameters, + } ctx = context.WithValue(ctx, ctxKey{}, rc) return req.WithContext(ctx) } @@ -248,10 +315,24 @@ func GetStageVarsFromContext(ctx context.Context) (map[string]string, bool) { return v.stageVars, ok } +// GetPathParamVarsFromContext retrieve path param variables from context +func GetPathParamVarsFromContext(ctx context.Context) (map[string]string, bool) { + v, ok := ctx.Value(ctxKey{}).(requestContext) + return v.pathParamVars, ok +} + +// GetQueryStringParamsVarsFromContext retrieve query string param variables from context +func GetQueryStringParamsVarsFromContext(ctx context.Context) (map[string]string, bool) { + v, ok := ctx.Value(ctxKey{}).(requestContext) + return v.queryStringParams, ok +} + type ctxKey struct{} type requestContext struct { lambdaContext *lambdacontext.LambdaContext gatewayProxyContext events.APIGatewayProxyRequestContext stageVars map[string]string + pathParamVars map[string]string + queryStringParams map[string]string } diff --git a/core/request_test.go b/core/request_test.go index b91563b..aa7d233 100644 --- a/core/request_test.go +++ b/core/request_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambdacontext" + "github.com/awslabs/aws-lambda-go-api-proxy/core" . "github.com/onsi/ginkgo" @@ -174,7 +175,7 @@ var _ = Describe("RequestAccessor tests", func() { // calling old method to verify reverse compatibility httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest) Expect(err).To(BeNil()) - Expect(2).To(Equal(len(httpReq.Header))) + Expect(4).To(Equal(len(httpReq.Header))) Expect(httpReq.Header.Get(core.APIGwContextHeader)).ToNot(BeNil()) }) }) @@ -304,6 +305,53 @@ var _ = Describe("RequestAccessor tests", func() { Expect("value2").To(Equal(stageVars["var2"])) }) + It("Populates query and path param variables correctly", func() { + varsRequest := getProxyRequest("orders", "GET") + varsRequest.PathParameters = getPathParamVariables() + varsRequest.QueryStringParameters = getQueryStringParamVariables() + + accessor := core.RequestAccessor{} + httpReq, err := accessor.ProxyEventToHTTPRequest(varsRequest) + Expect(err).To(BeNil()) + + pathVars, err := accessor.GetAPIGatewayPathParamVars(httpReq) + Expect(err).To(BeNil()) + Expect(2).To(Equal(len(pathVars))) + Expect(pathVars["var1"]).ToNot(BeNil()) + Expect(pathVars["var2"]).ToNot(BeNil()) + Expect("value1").To(Equal(pathVars["var1"])) + Expect("value2").To(Equal(pathVars["var2"])) + + // overwrite existing pathvars header + varsRequestWithHeaders := getProxyRequest("orders", "GET") + varsRequestWithHeaders.PathParameters = getPathParamVariables() + varsRequestWithHeaders.Headers = map[string]string{core.APIGwPathParamVarsHeader: `{"var1":"abc123"}`} + httpReq, err = accessor.ProxyEventToHTTPRequest(varsRequestWithHeaders) + Expect(err).To(BeNil()) + pathVars, err = accessor.GetAPIGatewayPathParamVars(httpReq) + Expect(err).To(BeNil()) + Expect(pathVars["var1"]).To(Equal("value1")) + + pathVars, ok := core.GetPathParamVarsFromContext(httpReq.Context()) + // not present in context + Expect(ok).To(BeFalse()) + + httpReq, err = accessor.EventToRequestWithContext(context.Background(), varsRequest) + Expect(err).To(BeNil()) + + pathVars, err = accessor.GetAPIGatewayPathParamVars(httpReq) + // should not be in headers + Expect(err).ToNot(BeNil()) + + pathVars, ok = core.GetPathParamVarsFromContext(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect(2).To(Equal(len(pathVars))) + Expect(pathVars["var1"]).ToNot(BeNil()) + Expect(pathVars["var2"]).ToNot(BeNil()) + Expect("value1").To(Equal(pathVars["var1"])) + Expect("value2").To(Equal(pathVars["var2"])) + }) + It("Populates the default hostname correctly", func() { basicRequest := getProxyRequest("orders", "GET") @@ -367,3 +415,17 @@ func getStageVariables() map[string]string { "var2": "value2", } } + +func getPathParamVariables() map[string]string { + return map[string]string{ + "var1": "value1", + "var2": "value2", + } +} + +func getQueryStringParamVariables() map[string]string { + return map[string]string{ + "var1": "value1", + "var2": "value2", + } +} diff --git a/core/requestv2.go b/core/requestv2.go index e3c8d56..96d6051 100644 --- a/core/requestv2.go +++ b/core/requestv2.go @@ -123,7 +123,7 @@ func (r *RequestAccessorV2) EventToRequest(req events.APIGatewayV2HTTPRequest) ( path := req.RawPath - //if RawPath empty is, populate from request context + // if RawPath empty is, populate from request context if len(path) == 0 { path = req.RequestContext.HTTP.Path } @@ -186,6 +186,21 @@ func addToHeaderV2(req *http.Request, apiGwRequest events.APIGatewayV2HTTPReques return nil, err } req.Header.Add(APIGwStageVarsHeader, string(stageVars)) + + pathParamVars, err := json.Marshal(apiGwRequest.PathParameters) + if err != nil { + log.Println("Could not marshal path params variables for custom header") + return nil, err + } + req.Header.Add(APIGwPathParamVarsHeader, string(pathParamVars)) + + queryStringParamVars, err := json.Marshal(apiGwRequest.QueryStringParameters) + if err != nil { + log.Println("Could not marshal query string params variables for custom header") + return nil, err + } + req.Header.Add(APIGwQueryStringVarsHeader, string(queryStringParamVars)) + apiGwContext, err := json.Marshal(apiGwRequest.RequestContext) if err != nil { log.Println("Could not Marshal API GW context for custom header") @@ -197,7 +212,13 @@ func addToHeaderV2(req *http.Request, apiGwRequest events.APIGatewayV2HTTPReques func addToContextV2(ctx context.Context, req *http.Request, apiGwRequest events.APIGatewayV2HTTPRequest) *http.Request { lc, _ := lambdacontext.FromContext(ctx) - rc := requestContextV2{lambdaContext: lc, gatewayProxyContext: apiGwRequest.RequestContext, stageVars: apiGwRequest.StageVariables} + rc := requestContextV2{ + lambdaContext: lc, + gatewayProxyContext: apiGwRequest.RequestContext, + stageVars: apiGwRequest.StageVariables, + queryStringParamVars: apiGwRequest.QueryStringParameters, + pathParamVars: apiGwRequest.PathParameters, + } ctx = context.WithValue(ctx, ctxKey{}, rc) return req.WithContext(ctx) } @@ -221,7 +242,9 @@ func GetStageVarsFromContextV2(ctx context.Context) (map[string]string, bool) { } type requestContextV2 struct { - lambdaContext *lambdacontext.LambdaContext - gatewayProxyContext events.APIGatewayV2HTTPRequestContext - stageVars map[string]string + lambdaContext *lambdacontext.LambdaContext + gatewayProxyContext events.APIGatewayV2HTTPRequestContext + stageVars map[string]string + pathParamVars map[string]string + queryStringParamVars map[string]string } diff --git a/core/requestv2_test.go b/core/requestv2_test.go index e42370d..d1b1601 100644 --- a/core/requestv2_test.go +++ b/core/requestv2_test.go @@ -3,14 +3,16 @@ package core_test import ( "context" "encoding/base64" - "github.com/onsi/gomega/gstruct" "io/ioutil" "math/rand" "os" "strings" + "github.com/onsi/gomega/gstruct" + "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambdacontext" + "github.com/awslabs/aws-lambda-go-api-proxy/core" . "github.com/onsi/ginkgo" @@ -173,7 +175,7 @@ var _ = Describe("RequestAccessorV2 tests", func() { // calling old method to verify reverse compatibility httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest) Expect(err).To(BeNil()) - Expect(2).To(Equal(len(httpReq.Header))) + Expect(4).To(Equal(len(httpReq.Header))) Expect(httpReq.Header.Get(core.APIGwContextHeader)).ToNot(BeNil()) }) }) @@ -283,6 +285,98 @@ var _ = Describe("RequestAccessorV2 tests", func() { Expect("value2").To(Equal(stageVars["var2"])) }) + It("Populates path variables correctly", func() { + varsRequest := getProxyRequest("orders", "GET") + varsRequest.PathParameters = getPathParamVariables() + + accessor := core.RequestAccessor{} + httpReq, err := accessor.ProxyEventToHTTPRequest(varsRequest) + Expect(err).To(BeNil()) + + pathVars, err := accessor.GetAPIGatewayPathParamVars(httpReq) + Expect(err).To(BeNil()) + Expect(2).To(Equal(len(pathVars))) + Expect(pathVars["var1"]).ToNot(BeNil()) + Expect(pathVars["var2"]).ToNot(BeNil()) + Expect("value1").To(Equal(pathVars["var1"])) + Expect("value2").To(Equal(pathVars["var2"])) + + // overwrite existing pathvars header + varsRequestWithHeaders := getProxyRequest("orders", "GET") + varsRequestWithHeaders.PathParameters = getPathParamVariables() + varsRequestWithHeaders.Headers = map[string]string{core.APIGwPathParamVarsHeader: `{"var1":"abc123"}`} + httpReq, err = accessor.ProxyEventToHTTPRequest(varsRequestWithHeaders) + Expect(err).To(BeNil()) + pathVars, err = accessor.GetAPIGatewayPathParamVars(httpReq) + Expect(err).To(BeNil()) + Expect(pathVars["var1"]).To(Equal("value1")) + + pathVars, ok := core.GetPathParamVarsFromContext(httpReq.Context()) + // not present in context + Expect(ok).To(BeFalse()) + + httpReq, err = accessor.EventToRequestWithContext(context.Background(), varsRequest) + Expect(err).To(BeNil()) + + pathVars, err = accessor.GetAPIGatewayPathParamVars(httpReq) + // should not be in headers + Expect(err).ToNot(BeNil()) + + pathVars, ok = core.GetPathParamVarsFromContext(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect(2).To(Equal(len(pathVars))) + Expect(pathVars["var1"]).ToNot(BeNil()) + Expect(pathVars["var2"]).ToNot(BeNil()) + Expect("value1").To(Equal(pathVars["var1"])) + Expect("value2").To(Equal(pathVars["var2"])) + }) + + It("Populates query string param variables correctly", func() { + varsRequest := getProxyRequest("orders", "GET") + varsRequest.QueryStringParameters = getQueryStringParamVariables() + + accessor := core.RequestAccessor{} + httpReq, err := accessor.ProxyEventToHTTPRequest(varsRequest) + Expect(err).To(BeNil()) + + queryStringVars, err := accessor.GetAPIGatewayQueryStringParamVars(httpReq) + Expect(err).To(BeNil()) + Expect(2).To(Equal(len(queryStringVars))) + Expect(queryStringVars["var1"]).ToNot(BeNil()) + Expect(queryStringVars["var2"]).ToNot(BeNil()) + Expect("value1").To(Equal(queryStringVars["var1"])) + Expect("value2").To(Equal(queryStringVars["var2"])) + + // overwrite existing query string param vars header + varsRequestWithHeaders := getProxyRequest("orders", "GET") + varsRequestWithHeaders.QueryStringParameters = getQueryStringParamVariables() + varsRequestWithHeaders.Headers = map[string]string{core.APIGwQueryStringVarsHeader: `{"var1":"abc123"}`} + httpReq, err = accessor.ProxyEventToHTTPRequest(varsRequestWithHeaders) + Expect(err).To(BeNil()) + queryStringVars, err = accessor.GetAPIGatewayQueryStringParamVars(httpReq) + Expect(err).To(BeNil()) + Expect(queryStringVars["var1"]).To(Equal("value1")) + + queryStringVars, ok := core.GetQueryStringParamsVarsFromContext(httpReq.Context()) + // not present in context + Expect(ok).To(BeFalse()) + + httpReq, err = accessor.EventToRequestWithContext(context.Background(), varsRequest) + Expect(err).To(BeNil()) + + queryStringVars, err = accessor.GetAPIGatewayQueryStringParamVars(httpReq) + // should not be in headers + Expect(err).ToNot(BeNil()) + + queryStringVars, ok = core.GetQueryStringParamsVarsFromContext(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect(2).To(Equal(len(queryStringVars))) + Expect(queryStringVars["var1"]).ToNot(BeNil()) + Expect(queryStringVars["var2"]).ToNot(BeNil()) + Expect("value1").To(Equal(queryStringVars["var1"])) + Expect("value2").To(Equal(queryStringVars["var2"])) + }) + It("Populates the default hostname correctly", func() { basicRequest := getProxyRequestV2("orders", "GET")