diff --git a/interceptors/retry/README.md b/interceptors/retry/README.md new file mode 100644 index 00000000..06f0aaa0 --- /dev/null +++ b/interceptors/retry/README.md @@ -0,0 +1,49 @@ +## Retry Interceptor + +The `retry` interceptor is a client-side middleware for gRPC that provides a generic mechanism to retry requests based on gRPC response codes. + +### Build Flags + +The `retry` interceptor supports a build flag `retrynotrace` to disable tracing for retry attempts. +This can be useful to avoid importing `golang.org/x/net/trace`, which allows for more aggressive deadcode elimination. This can yield improvements in binary size when tracing is not needed. + +To build your application with the `retrynotrace` flag, use the following command: + +```shell +go build -tags retrynotrace -o your_application +``` + +### Usage + +To use the `retry` interceptor, you need to add it to your gRPC client interceptor chain: + +```go +import ( + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" + "google.golang.org/grpc" +) + +func main() { + opts := []grpc.DialOption{ + grpc.WithUnaryInterceptor(retry.UnaryClientInterceptor( + retry.WithMax(3), // Maximum number of retries + retry.WithPerRetryTimeout(2*time.Second), // Timeout per retry + )), + } + + conn, err := grpc.NewClient("your_grpc_server_address", opts...) + if err != nil { + log.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Your gRPC client code here +} +``` + +### Configuration Options + +- `retry.WithMax(maxRetries int)`: Sets the maximum number of retry attempts. +- `retry.WithPerRetryTimeout(timeout time.Duration)`: Sets the timeout for each retry attempt. +- `retry.WithBackoff(backoffFunc retry.BackoffFunc)`: Sets a custom backoff strategy. +- `retry.WithCodes(codes ...codes.Code)`: Specifies the gRPC response codes that should trigger a retry. \ No newline at end of file diff --git a/interceptors/retry/retry.go b/interceptors/retry/retry.go index 368f1a52..f7970d0a 100644 --- a/interceptors/retry/retry.go +++ b/interceptors/retry/retry.go @@ -11,7 +11,6 @@ import ( "time" "github.com/grpc-ecosystem/go-grpc-middleware/v2/metadata" - "golang.org/x/net/trace" "google.golang.org/grpc" "google.golang.org/grpc/codes" grpcMetadata "google.golang.org/grpc/metadata" @@ -320,7 +319,7 @@ func contextErrToGrpcErr(err error) error { } func logTrace(ctx context.Context, format string, a ...any) { - tr, ok := trace.FromContext(ctx) + tr, ok := traceFromCtx(ctx) if !ok { return } diff --git a/interceptors/retry/trace_notrace.go b/interceptors/retry/trace_notrace.go new file mode 100644 index 00000000..2bfa68cd --- /dev/null +++ b/interceptors/retry/trace_notrace.go @@ -0,0 +1,29 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +//go:build retrynotrace + +package retry + +// retrynotrace can be used to avoid importing golang.org/x/net/trace, +// which allows for more aggressive deadcode elimination, which can +// yield improvements in binary size when tracing is not needed. + +import ( + "context" + "fmt" +) + +type notrace struct{} + +func (notrace) LazyLog(x fmt.Stringer, sensitive bool) {} +func (notrace) LazyPrintf(format string, a ...any) {} +func (notrace) SetError() {} +func (notrace) SetRecycler(f func(any)) {} +func (notrace) SetTraceInfo(traceID, spanID uint64) {} +func (notrace) SetMaxEvents(m int) {} +func (notrace) Finish() {} + +func traceFromCtx(ctx context.Context) (notrace, bool) { + return notrace{}, true +} diff --git a/interceptors/retry/trace_notrace_test.go b/interceptors/retry/trace_notrace_test.go new file mode 100644 index 00000000..60a5cd37 --- /dev/null +++ b/interceptors/retry/trace_notrace_test.go @@ -0,0 +1,42 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +//go:build retrynotrace + +package retry + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_traceFromCtx(t *testing.T) { + tr := notrace{} + ctx := context.Background() + + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want notrace + want1 bool + }{ + { + name: "should return notrace", + args: args{ctx: ctx}, + want: tr, + want1: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := traceFromCtx(tt.args.ctx) + assert.Equalf(t, tt.want, got, "traceFromCtx(%v)", tt.args.ctx) + assert.Equalf(t, tt.want1, got1, "traceFromCtx(%v)", tt.args.ctx) + }) + } +} diff --git a/interceptors/retry/trace_withtrace.go b/interceptors/retry/trace_withtrace.go new file mode 100644 index 00000000..8e86c8c0 --- /dev/null +++ b/interceptors/retry/trace_withtrace.go @@ -0,0 +1,16 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +//go:build !retrynotrace + +package retry + +import ( + "context" + + t "golang.org/x/net/trace" +) + +func traceFromCtx(ctx context.Context) (t.Trace, bool) { + return t.FromContext(ctx) +} diff --git a/interceptors/retry/trace_withtrace_test.go b/interceptors/retry/trace_withtrace_test.go new file mode 100644 index 00000000..9cf64d9a --- /dev/null +++ b/interceptors/retry/trace_withtrace_test.go @@ -0,0 +1,49 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +//go:build !retrynotrace + +package retry + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/trace" +) + +func Test_traceFromCtx(t *testing.T) { + tr := trace.New("test", "with trace") + ctx := trace.NewContext(context.Background(), tr) + + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want trace.Trace + want1 bool + }{ + { + name: "should return trace", + args: args{ctx: ctx}, + want: tr, + want1: true, + }, + { + name: "should return false if trace not found in ctx", + args: args{ctx: context.Background()}, + want: nil, + want1: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := traceFromCtx(tt.args.ctx) + assert.Equalf(t, tt.want, got, "traceFromCtx(%v)", tt.args.ctx) + assert.Equalf(t, tt.want1, got1, "traceFromCtx(%v)", tt.args.ctx) + }) + } +}