Skip to content

Commit 8bc331d

Browse files
authored
support the lambda.norpc tag on the go1.x runtime (#456)
* support lambda.norpc tag on go1.x and rewrite invoke_loop.go to remove use of Function RPC type * restore pre-1.17 unix ms conversions * more test coverage * add quick benchmark comparision to show throughput improvement of norpc build flag * more * update bench.sh
1 parent d8bb932 commit 8bc331d

13 files changed

+414
-178
lines changed

lambda/entry.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package lambda
44

55
import (
66
"context"
7-
"errors"
87
"log"
98
"os"
109
)
@@ -70,20 +69,11 @@ type startFunction struct {
7069
}
7170

7271
var (
73-
// This allows users to save a little bit of coldstart time in the download, by the dependencies brought in for RPC support.
74-
// The tradeoff is dropping compatibility with the go1.x runtime, functions must be "Custom Runtime" instead.
75-
// To drop the rpc dependencies, compile with `-tags lambda.norpc`
76-
rpcStartFunction = &startFunction{
77-
env: "_LAMBDA_SERVER_PORT",
78-
f: func(_ string, _ Handler) error {
79-
return errors.New("_LAMBDA_SERVER_PORT was present but the function was compiled without RPC support")
80-
},
81-
}
8272
runtimeAPIStartFunction = &startFunction{
8373
env: "AWS_LAMBDA_RUNTIME_API",
8474
f: startRuntimeAPILoop,
8575
}
86-
startFunctions = []*startFunction{rpcStartFunction, runtimeAPIStartFunction}
76+
startFunctions = []*startFunction{runtimeAPIStartFunction}
8777

8878
// This allows end to end testing of the Start functions, by tests overwriting this function to keep the program alive
8979
logFatalf = log.Fatalf

lambda/entry_test.go

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,11 @@ package lambda
44

55
import (
66
"context"
7-
"fmt"
87
"log"
9-
"net"
10-
"net/rpc"
118
"os"
129
"strings"
1310
"testing"
1411

15-
"github.com/aws/aws-lambda-go/lambda/messages"
1612
"github.com/stretchr/testify/assert"
1713
)
1814

@@ -35,58 +31,3 @@ func TestStartRuntimeAPIWithContext(t *testing.T) {
3531

3632
assert.Equal(t, expected, actual)
3733
}
38-
39-
func TestStartRPCWithContext(t *testing.T) {
40-
expected := "expected"
41-
actual := "unexpected"
42-
port := getFreeTCPPort()
43-
os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port))
44-
defer os.Unsetenv("_LAMBDA_SERVER_PORT")
45-
go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error {
46-
actual, _ = ctx.Value(ctxTestKey{}).(string)
47-
return nil
48-
})
49-
50-
var client *rpc.Client
51-
var pingResponse messages.PingResponse
52-
var invokeResponse messages.InvokeResponse
53-
var err error
54-
for {
55-
client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port))
56-
if err != nil {
57-
continue
58-
}
59-
break
60-
}
61-
for {
62-
if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil {
63-
continue
64-
}
65-
break
66-
}
67-
if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil {
68-
t.Logf("error invoking function: %v", err)
69-
}
70-
71-
assert.Equal(t, expected, actual)
72-
}
73-
74-
func getFreeTCPPort() int {
75-
l, err := net.Listen("tcp", "localhost:0")
76-
if err != nil {
77-
log.Fatal("getFreeTCPPort failed: ", err)
78-
}
79-
defer l.Close()
80-
81-
return l.Addr().(*net.TCPAddr).Port
82-
}
83-
84-
func TestStartNotInLambda(t *testing.T) {
85-
actual := "unexpected"
86-
logFatalf = func(format string, v ...interface{}) {
87-
actual = fmt.Sprintf(format, v...)
88-
}
89-
90-
Start(func() error { return nil })
91-
assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)
92-
}

lambda/entry_with_no_rpc_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
3+
//go:build lambda.norpc
4+
// +build lambda.norpc
5+
6+
package lambda
7+
8+
import (
9+
"fmt"
10+
"testing"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestStartNotInLambda(t *testing.T) {
16+
actual := "unexpected"
17+
logFatalf = func(format string, v ...interface{}) {
18+
actual = fmt.Sprintf(format, v...)
19+
}
20+
21+
Start(func() error { return nil })
22+
assert.Equal(t, "expected AWS Lambda environment variables [AWS_LAMBDA_RUNTIME_API] are not defined", actual)
23+
}

lambda/entry_with_rpc_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved
2+
3+
//go:build !lambda.norpc
4+
// +build !lambda.norpc
5+
6+
package lambda
7+
8+
import (
9+
"context"
10+
"fmt"
11+
"log"
12+
"net"
13+
"net/rpc"
14+
"os"
15+
"testing"
16+
17+
"github.com/aws/aws-lambda-go/lambda/messages"
18+
"github.com/stretchr/testify/assert"
19+
)
20+
21+
func TestStartRPCWithContext(t *testing.T) {
22+
expected := "expected"
23+
actual := "unexpected"
24+
port := getFreeTCPPort()
25+
os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port))
26+
defer os.Unsetenv("_LAMBDA_SERVER_PORT")
27+
go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error {
28+
actual, _ = ctx.Value(ctxTestKey{}).(string)
29+
return nil
30+
})
31+
32+
var client *rpc.Client
33+
var pingResponse messages.PingResponse
34+
var invokeResponse messages.InvokeResponse
35+
var err error
36+
for {
37+
client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port))
38+
if err != nil {
39+
continue
40+
}
41+
break
42+
}
43+
for {
44+
if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil {
45+
continue
46+
}
47+
break
48+
}
49+
if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil {
50+
t.Logf("error invoking function: %v", err)
51+
}
52+
53+
assert.Equal(t, expected, actual)
54+
}
55+
56+
func getFreeTCPPort() int {
57+
l, err := net.Listen("tcp", "localhost:0")
58+
if err != nil {
59+
log.Fatal("getFreeTCPPort failed: ", err)
60+
}
61+
defer l.Close()
62+
63+
return l.Addr().(*net.TCPAddr).Port
64+
}
65+
66+
func TestStartNotInLambda(t *testing.T) {
67+
actual := "unexpected"
68+
logFatalf = func(format string, v ...interface{}) {
69+
actual = fmt.Sprintf(format, v...)
70+
}
71+
72+
Start(func() error { return nil })
73+
assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)
74+
}

lambda/invoke_loop.go

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,101 +3,138 @@
33
package lambda
44

55
import (
6+
"context"
67
"encoding/json"
78
"fmt"
89
"log"
10+
"os"
911
"strconv"
1012
"time"
1113

1214
"github.com/aws/aws-lambda-go/lambda/messages"
15+
"github.com/aws/aws-lambda-go/lambdacontext"
1316
)
1417

1518
const (
1619
msPerS = int64(time.Second / time.Millisecond)
1720
nsPerMS = int64(time.Millisecond / time.Nanosecond)
1821
)
1922

23+
// TODO: replace with time.UnixMillis after dropping version <1.17 from CI workflows
24+
func unixMS(ms int64) time.Time {
25+
return time.Unix(ms/msPerS, (ms%msPerS)*nsPerMS)
26+
}
27+
2028
// startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error
2129
func startRuntimeAPILoop(api string, handler Handler) error {
2230
client := newRuntimeAPIClient(api)
23-
function := NewFunction(handler)
31+
h := newHandler(handler)
2432
for {
2533
invoke, err := client.next()
2634
if err != nil {
2735
return err
2836
}
29-
30-
err = handleInvoke(invoke, function)
31-
if err != nil {
37+
if err = handleInvoke(invoke, h); err != nil {
3238
return err
3339
}
3440
}
3541
}
3642

3743
// handleInvoke returns an error if the function panics, or some other non-recoverable error occurred
38-
func handleInvoke(invoke *invoke, function *Function) error {
39-
functionRequest, err := convertInvokeRequest(invoke)
44+
func handleInvoke(invoke *invoke, handler *handlerOptions) error {
45+
// set the deadline
46+
deadline, err := parseDeadline(invoke)
4047
if err != nil {
41-
return fmt.Errorf("unexpected error occurred when parsing the invoke: %v", err)
48+
return reportFailure(invoke, lambdaErrorResponse(err))
4249
}
50+
ctx, cancel := context.WithDeadline(handler.baseContext, deadline)
51+
defer cancel()
4352

44-
functionResponse := &messages.InvokeResponse{}
45-
if err := function.Invoke(functionRequest, functionResponse); err != nil {
46-
return fmt.Errorf("unexpected error occurred when invoking the handler: %v", err)
53+
// set the invoke metadata values
54+
lc := lambdacontext.LambdaContext{
55+
AwsRequestID: invoke.id,
56+
InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN),
4757
}
48-
49-
if functionResponse.Error != nil {
50-
errorPayload := safeMarshal(functionResponse.Error)
51-
log.Printf("%s", errorPayload)
52-
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
53-
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
58+
if err := parseClientContext(invoke, &lc.ClientContext); err != nil {
59+
return reportFailure(invoke, lambdaErrorResponse(err))
60+
}
61+
if err := parseCognitoIdentity(invoke, &lc.Identity); err != nil {
62+
return reportFailure(invoke, lambdaErrorResponse(err))
63+
}
64+
ctx = lambdacontext.NewContext(ctx, &lc)
65+
66+
// set the trace id
67+
traceID := invoke.headers.Get(headerTraceID)
68+
os.Setenv("_X_AMZN_TRACE_ID", traceID)
69+
// nolint:staticcheck
70+
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)
71+
72+
// call the handler, marshal any returned error
73+
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke)
74+
if invokeErr != nil {
75+
if err := reportFailure(invoke, invokeErr); err != nil {
76+
return err
5477
}
55-
if functionResponse.Error.ShouldExit {
78+
if invokeErr.ShouldExit {
5679
return fmt.Errorf("calling the handler function resulted in a panic, the process should exit")
5780
}
5881
return nil
5982
}
60-
61-
if err := invoke.success(functionResponse.Payload, contentTypeJSON); err != nil {
83+
if err := invoke.success(response, contentTypeJSON); err != nil {
6284
return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
6385
}
6486

6587
return nil
6688
}
6789

68-
// convertInvokeRequest converts an invoke from the Runtime API, and unpacks it to be compatible with the shape of a `lambda.Function` InvokeRequest.
69-
func convertInvokeRequest(invoke *invoke) (*messages.InvokeRequest, error) {
70-
deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64)
71-
if err != nil {
72-
return nil, fmt.Errorf("failed to parse contents of header: %s", headerDeadlineMS)
90+
func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error {
91+
errorPayload := safeMarshal(invokeErr)
92+
log.Printf("%s", errorPayload)
93+
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
94+
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
7395
}
74-
deadlineS := deadlineEpochMS / msPerS
75-
deadlineNS := (deadlineEpochMS % msPerS) * nsPerMS
96+
return nil
97+
}
7698

77-
res := &messages.InvokeRequest{
78-
InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN),
79-
XAmznTraceId: invoke.headers.Get(headerTraceID),
80-
RequestId: invoke.id,
81-
Deadline: messages.InvokeRequest_Timestamp{
82-
Seconds: deadlineS,
83-
Nanos: deadlineNS,
84-
},
85-
Payload: invoke.payload,
99+
func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) {
100+
defer func() {
101+
if err := recover(); err != nil {
102+
invokeErr = lambdaPanicResponse(err)
103+
}
104+
}()
105+
response, err := handler(ctx, payload)
106+
if err != nil {
107+
return nil, lambdaErrorResponse(err)
86108
}
109+
return response, nil
110+
}
87111

88-
clientContextJSON := invoke.headers.Get(headerClientContext)
89-
if clientContextJSON != "" {
90-
res.ClientContext = []byte(clientContextJSON)
112+
func parseDeadline(invoke *invoke) (time.Time, error) {
113+
deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64)
114+
if err != nil {
115+
return time.Time{}, fmt.Errorf("failed to parse deadline: %v", err)
91116
}
117+
return unixMS(deadlineEpochMS), nil
118+
}
92119

120+
func parseCognitoIdentity(invoke *invoke, out *lambdacontext.CognitoIdentity) error {
93121
cognitoIdentityJSON := invoke.headers.Get(headerCognitoIdentity)
94122
if cognitoIdentityJSON != "" {
95-
if err := json.Unmarshal([]byte(invoke.headers.Get(headerCognitoIdentity)), res); err != nil {
96-
return nil, fmt.Errorf("failed to unmarshal cognito identity json: %v", err)
123+
if err := json.Unmarshal([]byte(cognitoIdentityJSON), out); err != nil {
124+
return fmt.Errorf("failed to unmarshal cognito identity json: %v", err)
97125
}
98126
}
127+
return nil
128+
}
99129

100-
return res, nil
130+
func parseClientContext(invoke *invoke, out *lambdacontext.ClientContext) error {
131+
clientContextJSON := invoke.headers.Get(headerClientContext)
132+
if clientContextJSON != "" {
133+
if err := json.Unmarshal([]byte(clientContextJSON), out); err != nil {
134+
return fmt.Errorf("failed to unmarshal client context json: %v", err)
135+
}
136+
}
137+
return nil
101138
}
102139

103140
func safeMarshal(v interface{}) []byte {

0 commit comments

Comments
 (0)