Skip to content

Commit abde09f

Browse files
committed
Pass around ctx instead of a logger
1 parent 4c6d700 commit abde09f

File tree

8 files changed

+42
-25
lines changed

8 files changed

+42
-25
lines changed

pkg/ext-proc/handlers/request.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
package handlers
22

33
import (
4+
"context"
45
"encoding/json"
56
"errors"
67
"fmt"
78
"strconv"
89

910
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
1011
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
11-
"github.com/go-logr/logr"
1212
"google.golang.org/protobuf/types/known/structpb"
13+
"sigs.k8s.io/controller-runtime/pkg/log"
1314
"sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend"
1415
"sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/scheduling"
1516
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
@@ -19,10 +20,11 @@ import (
1920
// parameter.
2021
// Envoy sends the request body to ext proc before sending the request to the backend server.
2122
func (s *Server) HandleRequestBody(
22-
logger logr.Logger,
23+
ctx context.Context,
2324
reqCtx *RequestContext,
2425
req *extProcPb.ProcessingRequest,
2526
) (*extProcPb.ProcessingResponse, error) {
27+
logger := log.FromContext(ctx)
2628
loggerVerbose := logger.V(logutil.VERBOSE)
2729
loggerVerbose.Info("Handling request body")
2830

@@ -76,7 +78,7 @@ func (s *Server) HandleRequestBody(
7678
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBody))
7779
}
7880

79-
targetPod, err := s.scheduler.Schedule(logger, llmReq)
81+
targetPod, err := s.scheduler.Schedule(ctx, llmReq)
8082
if err != nil {
8183
return nil, fmt.Errorf("failed to find target pod: %w", err)
8284
}
@@ -141,13 +143,13 @@ func (s *Server) HandleRequestBody(
141143
}
142144

143145
func HandleRequestHeaders(
144-
logger logr.Logger,
146+
ctx context.Context,
145147
reqCtx *RequestContext,
146148
req *extProcPb.ProcessingRequest,
147149
) *extProcPb.ProcessingResponse {
148150
r := req.Request
149151
h := r.(*extProcPb.ProcessingRequest_RequestHeaders)
150-
logger.Info("Handling request headers", "headers", h)
152+
log.FromContext(ctx).Info("Handling request headers", "headers", h)
151153

152154
resp := &extProcPb.ProcessingResponse{
153155
Response: &extProcPb.ProcessingResponse_RequestHeaders{

pkg/ext-proc/handlers/response.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
package handlers
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67

78
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
89
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
9-
"github.com/go-logr/logr"
10+
"sigs.k8s.io/controller-runtime/pkg/log"
1011
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
1112
)
1213

1314
// HandleResponseHeaders processes response headers from the backend model server.
1415
func (s *Server) HandleResponseHeaders(
15-
logger logr.Logger,
16+
ctx context.Context,
1617
reqCtx *RequestContext,
1718
req *extProcPb.ProcessingRequest,
1819
) (*extProcPb.ProcessingResponse, error) {
19-
loggerVerbose := logger.V(logutil.VERBOSE)
20+
loggerVerbose := log.FromContext(ctx).V(logutil.VERBOSE)
2021
loggerVerbose.Info("Processing ResponseHeaders")
2122
h := req.Request.(*extProcPb.ProcessingRequest_ResponseHeaders)
2223
loggerVerbose.Info("Headers before", "headers", h)
@@ -71,10 +72,11 @@ func (s *Server) HandleResponseHeaders(
7172
}
7273
}*/
7374
func (s *Server) HandleResponseBody(
74-
logger logr.Logger,
75+
ctx context.Context,
7576
reqCtx *RequestContext,
7677
req *extProcPb.ProcessingRequest,
7778
) (*extProcPb.ProcessingResponse, error) {
79+
logger := log.FromContext(ctx)
7880
loggerVerbose := logger.V(logutil.VERBOSE)
7981
loggerVerbose.Info("Processing HandleResponseBody")
8082
body := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)

pkg/ext-proc/handlers/response_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package handlers
22

33
import (
4+
"context"
45
"testing"
56

67
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
@@ -35,7 +36,7 @@ const (
3536
)
3637

3738
func TestHandleResponseBody(t *testing.T) {
38-
logger := logutil.NewTestLogger()
39+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
3940

4041
tests := []struct {
4142
name string
@@ -73,7 +74,7 @@ func TestHandleResponseBody(t *testing.T) {
7374
t.Run(test.name, func(t *testing.T) {
7475
server := &Server{}
7576
reqCtx := &RequestContext{}
76-
_, err := server.HandleResponseBody(logger, reqCtx, &extProcPb.ProcessingRequest{Request: test.req})
77+
_, err := server.HandleResponseBody(ctx, reqCtx, &extProcPb.ProcessingRequest{Request: test.req})
7778
if err != nil {
7879
if !test.wantErr {
7980
t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr)

pkg/ext-proc/handlers/server.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package handlers
22

33
import (
4+
"context"
45
"io"
56
"time"
67

78
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
89
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
9-
"github.com/go-logr/logr"
1010
"google.golang.org/grpc/codes"
1111
"google.golang.org/grpc/status"
1212
"sigs.k8s.io/controller-runtime/pkg/log"
@@ -38,7 +38,7 @@ type Server struct {
3838
}
3939

4040
type Scheduler interface {
41-
Schedule(logger logr.Logger, b *scheduling.LLMRequest) (targetPod backend.Pod, err error)
41+
Schedule(ctx context.Context, b *scheduling.LLMRequest) (targetPod backend.Pod, err error)
4242
}
4343

4444
// PodProvider is an interface to provide set of pods in the backend and information such as metrics.
@@ -83,23 +83,23 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
8383
switch v := req.Request.(type) {
8484
case *extProcPb.ProcessingRequest_RequestHeaders:
8585
reqCtx.RequestReceivedTimestamp = time.Now()
86-
resp = HandleRequestHeaders(logger, reqCtx, req)
86+
resp = HandleRequestHeaders(ctx, reqCtx, req)
8787
loggerVerbose.Info("Request context after HandleRequestHeaders", "context", reqCtx)
8888
case *extProcPb.ProcessingRequest_RequestBody:
89-
resp, err = s.HandleRequestBody(logger, reqCtx, req)
89+
resp, err = s.HandleRequestBody(ctx, reqCtx, req)
9090
if err == nil {
9191
metrics.RecordRequestCounter(reqCtx.Model, reqCtx.ResolvedTargetModel)
9292
metrics.RecordRequestSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestSize)
9393
}
9494
loggerVerbose.Info("Request context after HandleRequestBody", "context", reqCtx)
9595
case *extProcPb.ProcessingRequest_ResponseHeaders:
96-
resp, err = s.HandleResponseHeaders(logger, reqCtx, req)
96+
resp, err = s.HandleResponseHeaders(ctx, reqCtx, req)
9797
loggerVerbose.Info("Request context after HandleResponseHeaders", "context", reqCtx)
9898
case *extProcPb.ProcessingRequest_ResponseBody:
99-
resp, err = s.HandleResponseBody(logger, reqCtx, req)
99+
resp, err = s.HandleResponseBody(ctx, reqCtx, req)
100100
if err == nil && reqCtx.ResponseComplete {
101101
reqCtx.ResponseCompleteTimestamp = time.Now()
102-
metrics.RecordRequestLatencies(logger, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
102+
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
103103
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
104104
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.PromptTokens)
105105
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.CompletionTokens)

pkg/ext-proc/metrics/metrics.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package metrics
22

33
import (
4+
"context"
45
"sync"
56
"time"
67

7-
"github.com/go-logr/logr"
88
compbasemetrics "k8s.io/component-base/metrics"
99
"k8s.io/component-base/metrics/legacyregistry"
10+
"sigs.k8s.io/controller-runtime/pkg/log"
1011
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
1112
)
1213

@@ -144,9 +145,9 @@ func RecordRequestSizes(modelName, targetModelName string, reqSize int) {
144145
}
145146

146147
// RecordRequestLatencies records duration of request.
147-
func RecordRequestLatencies(logger logr.Logger, modelName, targetModelName string, received time.Time, complete time.Time) bool {
148+
func RecordRequestLatencies(ctx context.Context, modelName, targetModelName string, received time.Time, complete time.Time) bool {
148149
if !complete.After(received) {
149-
logger.V(logutil.DEFAULT).Error(nil, "Request latency values are invalid",
150+
log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Request latency values are invalid",
150151
"modelName", modelName, "targetModelName", targetModelName, "completeTime", complete, "receivedTime", received)
151152
return false
152153
}

pkg/ext-proc/metrics/metrics_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package metrics
22

33
import (
4+
"context"
45
"os"
56
"testing"
67
"time"
@@ -91,7 +92,7 @@ func TestRecordRequestCounterandSizes(t *testing.T) {
9192
}
9293

9394
func TestRecordRequestLatencies(t *testing.T) {
94-
logger := logutil.NewTestLogger()
95+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
9596
timeBaseline := time.Now()
9697
type requests struct {
9798
modelName string
@@ -150,7 +151,7 @@ func TestRecordRequestLatencies(t *testing.T) {
150151
for _, scenario := range scenarios {
151152
t.Run(scenario.name, func(t *testing.T) {
152153
for _, req := range scenario.reqs {
153-
success := RecordRequestLatencies(logger, req.modelName, req.targetModelName, req.receivedTime, req.completeTime)
154+
success := RecordRequestLatencies(ctx, req.modelName, req.targetModelName, req.receivedTime, req.completeTime)
154155
if success == scenario.invalid {
155156
t.Errorf("got record success(%v), but the request expects invalid(%v)", success, scenario.invalid)
156157
}

pkg/ext-proc/scheduling/scheduler.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
package scheduling
33

44
import (
5+
"context"
56
"fmt"
67
"math/rand"
78

89
"github.com/go-logr/logr"
910
"google.golang.org/grpc/codes"
1011
"google.golang.org/grpc/status"
12+
"sigs.k8s.io/controller-runtime/pkg/log"
1113
"sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend"
1214
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
1315
)
@@ -110,8 +112,8 @@ type PodMetricsProvider interface {
110112
}
111113

112114
// Schedule finds the target pod based on metrics and the requested lora adapter.
113-
func (s *Scheduler) Schedule(logger logr.Logger, req *LLMRequest) (targetPod backend.Pod, err error) {
114-
logger = logger.WithValues("request", req)
115+
func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backend.Pod, err error) {
116+
logger := log.FromContext(ctx).WithValues("request", req)
115117
logger.V(logutil.VERBOSE).Info("Scheduling a request", "metrics", s.podMetricsProvider.AllPodMetrics())
116118
pods, err := s.filter.Filter(logger, req, s.podMetricsProvider.AllPodMetrics())
117119
if err != nil || len(pods) == 0 {

pkg/ext-proc/util/logging/logger.go

+8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
package logging
22

33
import (
4+
"context"
5+
46
"github.com/go-logr/logr"
57
uberzap "go.uber.org/zap"
8+
"sigs.k8s.io/controller-runtime/pkg/log"
69
"sigs.k8s.io/controller-runtime/pkg/log/zap"
710
)
811

912
// NewTestLogger creates a new Zap logger using the dev mode.
1013
func NewTestLogger() logr.Logger {
1114
return zap.New(zap.UseDevMode(true), zap.RawZapOpts(uberzap.AddCaller()))
1215
}
16+
17+
// NewTestLoggerIntoContext creates a new Zap logger using the dev mode and inserts it into the given context.
18+
func NewTestLoggerIntoContext(ctx context.Context) context.Context {
19+
return log.IntoContext(ctx, zap.New(zap.UseDevMode(true), zap.RawZapOpts(uberzap.AddCaller())))
20+
}

0 commit comments

Comments
 (0)