|
3 | 3 | package lambda
|
4 | 4 |
|
5 | 5 | import (
|
| 6 | + "context" |
6 | 7 | "encoding/json"
|
7 | 8 | "fmt"
|
8 | 9 | "log"
|
| 10 | + "os" |
9 | 11 | "strconv"
|
10 | 12 | "time"
|
11 | 13 |
|
12 | 14 | "github.com/aws/aws-lambda-go/lambda/messages"
|
| 15 | + "github.com/aws/aws-lambda-go/lambdacontext" |
13 | 16 | )
|
14 | 17 |
|
15 | 18 | const (
|
16 | 19 | msPerS = int64(time.Second / time.Millisecond)
|
17 | 20 | nsPerMS = int64(time.Millisecond / time.Nanosecond)
|
18 | 21 | )
|
19 | 22 |
|
| 23 | +// TODO: replace with time.UnixMillis after dropping version <1.17 from CI workflows |
| 24 | +func unixMS(ms int64) time.Time { |
| 25 | + return time.Unix(ms/msPerS, (ms%msPerS)*nsPerMS) |
| 26 | +} |
| 27 | + |
20 | 28 | // startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error
|
21 | 29 | func startRuntimeAPILoop(api string, handler Handler) error {
|
22 | 30 | client := newRuntimeAPIClient(api)
|
23 |
| - function := NewFunction(handler) |
| 31 | + h := newHandler(handler) |
24 | 32 | for {
|
25 | 33 | invoke, err := client.next()
|
26 | 34 | if err != nil {
|
27 | 35 | return err
|
28 | 36 | }
|
29 |
| - |
30 |
| - err = handleInvoke(invoke, function) |
31 |
| - if err != nil { |
| 37 | + if err = handleInvoke(invoke, h); err != nil { |
32 | 38 | return err
|
33 | 39 | }
|
34 | 40 | }
|
35 | 41 | }
|
36 | 42 |
|
37 | 43 | // handleInvoke returns an error if the function panics, or some other non-recoverable error occurred
|
38 |
| -func handleInvoke(invoke *invoke, function *Function) error { |
39 |
| - functionRequest, err := convertInvokeRequest(invoke) |
| 44 | +func handleInvoke(invoke *invoke, handler *handlerOptions) error { |
| 45 | + // set the deadline |
| 46 | + deadline, err := parseDeadline(invoke) |
40 | 47 | if err != nil {
|
41 |
| - return fmt.Errorf("unexpected error occurred when parsing the invoke: %v", err) |
| 48 | + return reportFailure(invoke, lambdaErrorResponse(err)) |
42 | 49 | }
|
| 50 | + ctx, cancel := context.WithDeadline(handler.baseContext, deadline) |
| 51 | + defer cancel() |
43 | 52 |
|
44 |
| - functionResponse := &messages.InvokeResponse{} |
45 |
| - if err := function.Invoke(functionRequest, functionResponse); err != nil { |
46 |
| - return fmt.Errorf("unexpected error occurred when invoking the handler: %v", err) |
| 53 | + // set the invoke metadata values |
| 54 | + lc := lambdacontext.LambdaContext{ |
| 55 | + AwsRequestID: invoke.id, |
| 56 | + InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN), |
47 | 57 | }
|
48 |
| - |
49 |
| - if functionResponse.Error != nil { |
50 |
| - errorPayload := safeMarshal(functionResponse.Error) |
51 |
| - log.Printf("%s", errorPayload) |
52 |
| - if err := invoke.failure(errorPayload, contentTypeJSON); err != nil { |
53 |
| - return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err) |
| 58 | + if err := parseClientContext(invoke, &lc.ClientContext); err != nil { |
| 59 | + return reportFailure(invoke, lambdaErrorResponse(err)) |
| 60 | + } |
| 61 | + if err := parseCognitoIdentity(invoke, &lc.Identity); err != nil { |
| 62 | + return reportFailure(invoke, lambdaErrorResponse(err)) |
| 63 | + } |
| 64 | + ctx = lambdacontext.NewContext(ctx, &lc) |
| 65 | + |
| 66 | + // set the trace id |
| 67 | + traceID := invoke.headers.Get(headerTraceID) |
| 68 | + os.Setenv("_X_AMZN_TRACE_ID", traceID) |
| 69 | + // nolint:staticcheck |
| 70 | + ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) |
| 71 | + |
| 72 | + // call the handler, marshal any returned error |
| 73 | + response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke) |
| 74 | + if invokeErr != nil { |
| 75 | + if err := reportFailure(invoke, invokeErr); err != nil { |
| 76 | + return err |
54 | 77 | }
|
55 |
| - if functionResponse.Error.ShouldExit { |
| 78 | + if invokeErr.ShouldExit { |
56 | 79 | return fmt.Errorf("calling the handler function resulted in a panic, the process should exit")
|
57 | 80 | }
|
58 | 81 | return nil
|
59 | 82 | }
|
60 |
| - |
61 |
| - if err := invoke.success(functionResponse.Payload, contentTypeJSON); err != nil { |
| 83 | + if err := invoke.success(response, contentTypeJSON); err != nil { |
62 | 84 | return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
|
63 | 85 | }
|
64 | 86 |
|
65 | 87 | return nil
|
66 | 88 | }
|
67 | 89 |
|
68 |
| -// convertInvokeRequest converts an invoke from the Runtime API, and unpacks it to be compatible with the shape of a `lambda.Function` InvokeRequest. |
69 |
| -func convertInvokeRequest(invoke *invoke) (*messages.InvokeRequest, error) { |
70 |
| - deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64) |
71 |
| - if err != nil { |
72 |
| - return nil, fmt.Errorf("failed to parse contents of header: %s", headerDeadlineMS) |
| 90 | +func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error { |
| 91 | + errorPayload := safeMarshal(invokeErr) |
| 92 | + log.Printf("%s", errorPayload) |
| 93 | + if err := invoke.failure(errorPayload, contentTypeJSON); err != nil { |
| 94 | + return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err) |
73 | 95 | }
|
74 |
| - deadlineS := deadlineEpochMS / msPerS |
75 |
| - deadlineNS := (deadlineEpochMS % msPerS) * nsPerMS |
| 96 | + return nil |
| 97 | +} |
76 | 98 |
|
77 |
| - res := &messages.InvokeRequest{ |
78 |
| - InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN), |
79 |
| - XAmznTraceId: invoke.headers.Get(headerTraceID), |
80 |
| - RequestId: invoke.id, |
81 |
| - Deadline: messages.InvokeRequest_Timestamp{ |
82 |
| - Seconds: deadlineS, |
83 |
| - Nanos: deadlineNS, |
84 |
| - }, |
85 |
| - Payload: invoke.payload, |
| 99 | +func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) { |
| 100 | + defer func() { |
| 101 | + if err := recover(); err != nil { |
| 102 | + invokeErr = lambdaPanicResponse(err) |
| 103 | + } |
| 104 | + }() |
| 105 | + response, err := handler(ctx, payload) |
| 106 | + if err != nil { |
| 107 | + return nil, lambdaErrorResponse(err) |
86 | 108 | }
|
| 109 | + return response, nil |
| 110 | +} |
87 | 111 |
|
88 |
| - clientContextJSON := invoke.headers.Get(headerClientContext) |
89 |
| - if clientContextJSON != "" { |
90 |
| - res.ClientContext = []byte(clientContextJSON) |
| 112 | +func parseDeadline(invoke *invoke) (time.Time, error) { |
| 113 | + deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64) |
| 114 | + if err != nil { |
| 115 | + return time.Time{}, fmt.Errorf("failed to parse deadline: %v", err) |
91 | 116 | }
|
| 117 | + return unixMS(deadlineEpochMS), nil |
| 118 | +} |
92 | 119 |
|
| 120 | +func parseCognitoIdentity(invoke *invoke, out *lambdacontext.CognitoIdentity) error { |
93 | 121 | cognitoIdentityJSON := invoke.headers.Get(headerCognitoIdentity)
|
94 | 122 | if cognitoIdentityJSON != "" {
|
95 |
| - if err := json.Unmarshal([]byte(invoke.headers.Get(headerCognitoIdentity)), res); err != nil { |
96 |
| - return nil, fmt.Errorf("failed to unmarshal cognito identity json: %v", err) |
| 123 | + if err := json.Unmarshal([]byte(cognitoIdentityJSON), out); err != nil { |
| 124 | + return fmt.Errorf("failed to unmarshal cognito identity json: %v", err) |
97 | 125 | }
|
98 | 126 | }
|
| 127 | + return nil |
| 128 | +} |
99 | 129 |
|
100 |
| - return res, nil |
| 130 | +func parseClientContext(invoke *invoke, out *lambdacontext.ClientContext) error { |
| 131 | + clientContextJSON := invoke.headers.Get(headerClientContext) |
| 132 | + if clientContextJSON != "" { |
| 133 | + if err := json.Unmarshal([]byte(clientContextJSON), out); err != nil { |
| 134 | + return fmt.Errorf("failed to unmarshal client context json: %v", err) |
| 135 | + } |
| 136 | + } |
| 137 | + return nil |
101 | 138 | }
|
102 | 139 |
|
103 | 140 | func safeMarshal(v interface{}) []byte {
|
|
0 commit comments