Skip to content

Commit fb9bc21

Browse files
committed
Make Write message type more flexible
Signed-off-by: Saswata Mukherjee <[email protected]>
1 parent 0b13b9f commit fb9bc21

File tree

6 files changed

+99
-42
lines changed

6 files changed

+99
-42
lines changed

Makefile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ fmt: common-format $(GOIMPORTS)
5555
proto: ## Regenerate Go from remote write proto.
5656
proto: $(BUF)
5757
@echo ">> regenerating Prometheus Remote Write proto"
58-
@cd api/prometheus/v1/genproto && $(BUF) generate
59-
@cd api/prometheus/v1 && find genproto/ -type f -exec sed -i '' 's/protohelpers "github.com\/planetscale\/vtprotobuf\/protohelpers"/protohelpers "github.com\/prometheus\/client_golang\/internal\/github.com\/planetscale\/vtprotobuf\/protohelpers"/g' {} \;
58+
@cd api/prometheus/v1/remote/genproto && $(BUF) generate
59+
@cd api/prometheus/v1/remote && find genproto/ -type f -exec sed -i '' 's/protohelpers "github.com\/planetscale\/vtprotobuf\/protohelpers"/protohelpers "github.com\/prometheus\/client_golang\/internal\/github.com\/planetscale\/vtprotobuf\/protohelpers"/g' {} \;
6060
# For some reasons buf generates this unused import, kill it manually for now and reformat.
61-
@cd api/prometheus/v1 && find genproto/ -type f -exec sed -i '' 's/_ "github.com\/gogo\/protobuf\/gogoproto"//g' {} \;
62-
@cd api/prometheus/v1 && go fmt ./genproto/...
61+
@cd api/prometheus/v1/remote && find genproto/ -type f -exec sed -i '' 's/_ "github.com\/gogo\/protobuf\/gogoproto"//g' {} \;
62+
@cd api/prometheus/v1/remote && go fmt ./genproto/...

api/prometheus/v1/remote/genproto/v2/types.pb.go

Lines changed: 2 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/prometheus/v1/remote/genproto/v2/types_vtproto.pb.go

Lines changed: 2 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/prometheus/v1/remote/remote_api.go

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ type apiOpts struct {
5151
logger *slog.Logger
5252
backoff backoff.Config
5353
compression Compression
54-
endpoint string
54+
path string
5555
retryOnRateLimit bool
5656
}
5757

@@ -64,6 +64,7 @@ var defaultAPIOpts = &apiOpts{
6464
// Hardcoded for now.
6565
retryOnRateLimit: true,
6666
compression: SnappyBlockCompression,
67+
path: "api/v1/write",
6768
}
6869

6970
// WithAPILogger returns APIOption that allows providing slog logger.
@@ -75,18 +76,18 @@ func WithAPILogger(logger *slog.Logger) APIOption {
7576
}
7677
}
7778

78-
// WithAPIEndpoint returns APIOption that allows providing endpoint.
79-
func WithAPIEndpoint(endpoint string) APIOption {
79+
// WithAPIPath returns APIOption that allows providing path to send remote write requests to.
80+
func WithAPIPath(path string) APIOption {
8081
return func(o *apiOpts) error {
81-
o.endpoint = endpoint
82+
o.path = path
8283
return nil
8384
}
8485
}
8586

86-
// WithAPIRetryOnRateLimit returns APIOption that allows providing retry on rate limit.
87-
func WithAPIRetryOnRateLimit(retry bool) APIOption {
87+
// WithAPIRetryOnRateLimit returns APIOption that disables retrying on rate limit status code.
88+
func WithAPINoRetryOnRateLimit() APIOption {
8889
return func(o *apiOpts) error {
89-
o.retryOnRateLimit = retry
90+
o.retryOnRateLimit = false
9091
return nil
9192
}
9293
}
@@ -134,33 +135,67 @@ type vtProtoEnabled interface {
134135
MarshalToSizedBufferVT(dAtA []byte) (int, error)
135136
}
136137

138+
type gogoProtoEnabled interface {
139+
Size() (n int)
140+
MarshalToSizedBuffer(dAtA []byte) (n int, err error)
141+
}
142+
143+
// Sort of a hack to identify v2 requests.
144+
// Under any marshaling scheme, v2 requests have a `Symbols` field of type []string.
145+
// So would always have a `GetSymbols()` method which doesn't rely on any other types.
146+
type v2Request interface {
147+
GetSymbols() []string
148+
}
149+
137150
// Write writes given, non-empty, protobuf message to a remote storage.
138-
// The https://github.com/planetscale/vtprotobuf methods will be used if your msg
139-
// supports those (e.g. SizeVT() and MarshalToSizedBufferVT(...)), for efficiency.
140-
func (r *API) Write(ctx context.Context, msg proto.Message) (_ WriteResponseStats, err error) {
151+
//
152+
// Depending on serialization methods,
153+
// - https://github.com/planetscale/vtprotobuf methods will be used if your msg
154+
// supports those (e.g. SizeVT() and MarshalToSizedBufferVT(...)), for efficiency
155+
// - Otherwise https://github.com/gogo/protobuf methods (e.g. Size() and MarshalToSizedBuffer(...))
156+
// will be used
157+
// - If neither is supported, it will marshaled using generic google.golang.org/protobuf methods and
158+
// error out on unknown scheme.
159+
func (r *API) Write(ctx context.Context, msg any) (_ WriteResponseStats, err error) {
141160
// Detect content-type.
142-
cType := WriteProtoFullName(proto.MessageName(msg))
161+
cType := WriteProtoFullNameV1
162+
if _, ok := msg.(v2Request); ok {
163+
cType = WriteProtoFullNameV2
164+
}
165+
143166
if err := cType.Validate(); err != nil {
144167
return WriteResponseStats{}, err
145168
}
146169

147170
// Encode the payload.
148-
if emsg, ok := msg.(vtProtoEnabled); ok {
171+
switch m := msg.(type) {
172+
case vtProtoEnabled:
149173
// Use optimized vtprotobuf if supported.
150-
size := emsg.SizeVT()
174+
size := m.SizeVT()
151175
if len(r.reqBuf) < size {
152176
r.reqBuf = make([]byte, size)
153177
}
154-
if _, err := emsg.MarshalToSizedBufferVT(r.reqBuf[:size]); err != nil {
178+
if _, err := m.MarshalToSizedBufferVT(r.reqBuf[:size]); err != nil {
155179
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
156180
}
157-
} else {
181+
case gogoProtoEnabled:
182+
// Gogo proto if supported.
183+
size := m.Size()
184+
if len(r.reqBuf) < size {
185+
r.reqBuf = make([]byte, size)
186+
}
187+
if _, err := m.MarshalToSizedBuffer(r.reqBuf[:size]); err != nil {
188+
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
189+
}
190+
case proto.Message:
158191
// Generic proto.
159192
r.reqBuf = r.reqBuf[:0]
160-
r.reqBuf, err = (proto.MarshalOptions{}).MarshalAppend(r.reqBuf, msg)
193+
r.reqBuf, err = (proto.MarshalOptions{}).MarshalAppend(r.reqBuf, m)
161194
if err != nil {
162195
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
163196
}
197+
default:
198+
return WriteResponseStats{}, fmt.Errorf("unknown message type %T", m)
164199
}
165200

166201
payload, err := compressPayload(&r.comprBuf, r.opts.compression, r.reqBuf)
@@ -231,7 +266,7 @@ func compressPayload(tmpbuf *[]byte, enc Compression, inp []byte) (compressed []
231266
}
232267

233268
func (r *API) attemptWrite(ctx context.Context, compr Compression, proto WriteProtoFullName, payload []byte, attempt int) (WriteResponseStats, error) {
234-
u := r.client.URL(r.opts.endpoint, nil)
269+
u := r.client.URL(r.opts.path, nil)
235270
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(payload))
236271
if err != nil {
237272
// Errors from NewRequest are from unparsable URLs, so are not
@@ -305,15 +340,42 @@ type remoteWriteDecompressor interface {
305340
}
306341

307342
type handler struct {
343+
store writeStorage
344+
opts handlerOpts
345+
}
346+
347+
type handlerOpts struct {
308348
logger *slog.Logger
309-
store writeStorage
310349
decompressor remoteWriteDecompressor
311350
}
312351

352+
// HandlerOption represents an option for the handler.
353+
type HandlerOption func(o *handlerOpts)
354+
355+
// WithHandlerLogger returns HandlerOption that allows providing slog logger.
356+
// By default, nothing is logged.
357+
func WithHandlerLogger(logger *slog.Logger) HandlerOption {
358+
return func(o *handlerOpts) {
359+
o.logger = logger
360+
}
361+
}
362+
363+
// WithHandlerDecompressor returns HandlerOption that allows providing remoteWriteDecompressor.
364+
// By default, SimpleSnappyDecompressor is used.
365+
func WithHandlerDecompressor(decompressor remoteWriteDecompressor) HandlerOption {
366+
return func(o *handlerOpts) {
367+
o.decompressor = decompressor
368+
}
369+
}
370+
313371
// NewRemoteWriteHandler returns HTTP handler that receives Remote Write 2.0
314372
// protocol https://prometheus.io/docs/specs/remote_write_spec_2_0/.
315-
func NewRemoteWriteHandler(logger *slog.Logger, store writeStorage, decompressor remoteWriteDecompressor) http.Handler {
316-
return &handler{logger: logger, store: store, decompressor: decompressor}
373+
func NewRemoteWriteHandler(store writeStorage, opts ...HandlerOption) http.Handler {
374+
o := handlerOpts{logger: slog.New(nopSlogHandler{}), decompressor: &SimpleSnappyDecompressor{}}
375+
for _, opt := range opts {
376+
opt(&o)
377+
}
378+
return &handler{opts: o, store: store}
317379
}
318380

319381
// ParseProtoMsg parses the content-type header and returns the proto message type.
@@ -359,7 +421,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
359421

360422
msgType, err := ParseProtoMsg(contentType)
361423
if err != nil {
362-
h.logger.Error("Error decoding remote write request", "err", err)
424+
h.opts.logger.Error("Error decoding remote write request", "err", err)
363425
http.Error(w, err.Error(), http.StatusUnsupportedMediaType)
364426
return
365427
}
@@ -371,14 +433,14 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
371433
// We could give http.StatusUnsupportedMediaType, but let's assume snappy by default.
372434
} else if enc != string(SnappyBlockCompression) {
373435
err := fmt.Errorf("%v encoding (compression) is not accepted by this server; only %v is acceptable", enc, SnappyBlockCompression)
374-
h.logger.Error("Error decoding remote write request", "err", err)
436+
h.opts.logger.Error("Error decoding remote write request", "err", err)
375437
http.Error(w, err.Error(), http.StatusUnsupportedMediaType)
376438
}
377439

378440
// Decompress the request body.
379-
decompressed, err := h.decompressor.Decompress(r.Context(), r.Body)
441+
decompressed, err := h.opts.decompressor.Decompress(r.Context(), r.Body)
380442
if err != nil {
381-
h.logger.Error("Error decompressing remote write request", "err", err.Error())
443+
h.opts.logger.Error("Error decompressing remote write request", "err", err.Error())
382444
http.Error(w, err.Error(), http.StatusBadRequest)
383445
return
384446
}
@@ -393,7 +455,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
393455
code = http.StatusInternalServerError
394456
}
395457
if code/5 == 100 { // 5xx
396-
h.logger.Error("Error while storing the remote write request", "err", storeErr.Error())
458+
h.opts.logger.Error("Error while storing the remote write request", "err", storeErr.Error())
397459
}
398460
http.Error(w, storeErr.Error(), code)
399461
return

api/prometheus/v1/remote/remote_api_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func stats(req *writev2.Request) (s WriteResponseStats) {
125125
func TestRemoteAPI_Write_WithHandler(t *testing.T) {
126126
tLogger := slog.Default()
127127
mStore := &mockStorage{}
128-
srv := httptest.NewServer(NewRemoteWriteHandler(tLogger, mStore, &SimpleSnappyDecompressor{}))
128+
srv := httptest.NewServer(NewRemoteWriteHandler(mStore, WithHandlerLogger(tLogger)))
129129
t.Cleanup(srv.Close)
130130

131131
cl, err := api.NewClient(api.Config{
@@ -135,7 +135,7 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
135135
if err != nil {
136136
t.Fatal(err)
137137
}
138-
client, err := NewAPI(cl, WithAPILogger(tLogger), WithAPIEndpoint("api/v1/write"))
138+
client, err := NewAPI(cl, WithAPILogger(tLogger), WithAPIPath("api/v1/write"))
139139
if err != nil {
140140
t.Fatal(err)
141141
}
@@ -149,7 +149,7 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
149149
t.Fatal("unexpected stats", diff)
150150
}
151151
if len(mStore.v2Reqs) != 1 {
152-
t.Fatal("expected 1 request stored, got", mStore.v2Reqs)
152+
t.Fatal("expected 1 v2 request stored, got", mStore.v2Reqs)
153153
}
154154
if diff := cmp.Diff(req, mStore.v2Reqs[0], protocmp.Transform()); diff != "" {
155155
t.Fatal("unexpected request received", diff)

api/prometheus/v1/remote/remote_headers.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ import (
1919
"net/http"
2020
"strconv"
2121
"strings"
22-
23-
"google.golang.org/protobuf/reflect/protoreflect"
2422
)
2523

2624
const (
@@ -43,7 +41,7 @@ const (
4341
// WriteProtoFullName represents the fully qualified name of the protobuf message
4442
// to use in Remote write 1.0 and 2.0 protocols.
4543
// See https://prometheus.io/docs/specs/remote_write_spec_2_0/#protocol.
46-
type WriteProtoFullName protoreflect.FullName
44+
type WriteProtoFullName string
4745

4846
const (
4947
// WriteProtoFullNameV1 represents the `prometheus.WriteRequest` protobuf

0 commit comments

Comments
 (0)