Skip to content

Commit 4e87856

Browse files
committed
Make Write message type more flexible
Signed-off-by: Saswata Mukherjee <[email protected]>
1 parent 19d78f0 commit 4e87856

File tree

3 files changed

+81
-19
lines changed

3 files changed

+81
-19
lines changed

api/prometheus/v1/remote/remote_api.go

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,30 +134,58 @@ type vtProtoEnabled interface {
134134
MarshalToSizedBufferVT(dAtA []byte) (int, error)
135135
}
136136

137+
type gogoProtoEnabled interface {
138+
Size() (n int)
139+
MarshalToSizedBuffer(dAtA []byte) (n int, err error)
140+
}
141+
142+
// Sort of a hack to identify v2 requests.
143+
// Under any marshaling scheme, v2 requests have a `Symbols` field of type []string.
144+
// So would always have a `GetSymbols()` method which doesn't rely on any other types.
145+
type v2Request interface {
146+
GetSymbols() []string
147+
}
148+
137149
// Write writes given, non-empty, protobuf message to a remote storage.
138150
// The https://github.com/planetscale/vtprotobuf methods will be used if your msg
139-
// supports those (e.g. SizeVT() and MarshalToSizedBufferVT(...)), for efficiency.
140-
func (r *API) Write(ctx context.Context, msg proto.Message) (_ WriteResponseStats, err error) {
151+
// supports those (e.g. SizeVT() and MarshalToSizedBufferVT(...)), for efficiency
152+
// or https://github.com/gogo/protobuf methods (e.g. Size() and MarshalToSizedBuffer(...))
153+
// will be used if your msg supports those.
154+
func (r *API) Write(ctx context.Context, msg any) (_ WriteResponseStats, err error) {
141155
// Detect content-type.
142-
cType := WriteProtoFullName(proto.MessageName(msg))
156+
cType := WriteProtoFullNameV1
157+
if _, ok := msg.(v2Request); ok {
158+
cType = WriteProtoFullNameV2
159+
}
160+
143161
if err := cType.Validate(); err != nil {
144162
return WriteResponseStats{}, err
145163
}
146164

147165
// Encode the payload.
148-
if emsg, ok := msg.(vtProtoEnabled); ok {
166+
switch m := msg.(type) {
167+
case vtProtoEnabled:
149168
// Use optimized vtprotobuf if supported.
150-
size := emsg.SizeVT()
169+
size := m.SizeVT()
151170
if len(r.reqBuf) < size {
152171
r.reqBuf = make([]byte, size)
153172
}
154-
if _, err := emsg.MarshalToSizedBufferVT(r.reqBuf[:size]); err != nil {
173+
if _, err := m.MarshalToSizedBufferVT(r.reqBuf[:size]); err != nil {
155174
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
156175
}
157-
} else {
176+
case gogoProtoEnabled:
177+
// Gogo proto if supported.
178+
size := m.Size()
179+
if len(r.reqBuf) < size {
180+
r.reqBuf = make([]byte, size)
181+
}
182+
if _, err := m.MarshalToSizedBuffer(r.reqBuf[:size]); err != nil {
183+
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
184+
}
185+
case proto.Message:
158186
// Generic proto.
159187
r.reqBuf = r.reqBuf[:0]
160-
r.reqBuf, err = (proto.MarshalOptions{}).MarshalAppend(r.reqBuf, msg)
188+
r.reqBuf, err = (proto.MarshalOptions{}).MarshalAppend(r.reqBuf, m)
161189
if err != nil {
162190
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
163191
}

api/prometheus/v1/remote/remote_api_test.go

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"google.golang.org/protobuf/testing/protocmp"
2828

2929
"github.com/prometheus/client_golang/api"
30+
writev1 "github.com/prometheus/client_golang/api/prometheus/v1/remote/genproto/v1"
3031
writev2 "github.com/prometheus/client_golang/api/prometheus/v1/remote/genproto/v2"
3132
)
3233

@@ -61,6 +62,7 @@ func TestRetryAfterDuration(t *testing.T) {
6162

6263
type mockStorage struct {
6364
v2Reqs []*writev2.Request
65+
v1Reqs []*writev1.WriteRequest
6466
protos []WriteProtoFullName
6567

6668
mockCode *int
@@ -72,14 +74,23 @@ func (m *mockStorage) Store(_ context.Context, msgFullName WriteProtoFullName, s
7274
return w, *m.mockCode, m.mockErr
7375
}
7476

75-
// This test expects v2 only.
76-
r := &writev2.Request{}
77-
if err := proto.Unmarshal(serializedRequest, r); err != nil {
78-
return WriteResponseStats{}, http.StatusInternalServerError, err
77+
if msgFullName == WriteProtoFullNameV1 {
78+
r := &writev1.WriteRequest{}
79+
if err := proto.Unmarshal(serializedRequest, r); err != nil {
80+
return WriteResponseStats{}, http.StatusInternalServerError, err
81+
}
82+
m.v1Reqs = append(m.v1Reqs, r)
83+
m.protos = append(m.protos, msgFullName)
84+
return WriteResponseStats{}, http.StatusOK, nil
85+
} else {
86+
r := &writev2.Request{}
87+
if err := proto.Unmarshal(serializedRequest, r); err != nil {
88+
return WriteResponseStats{}, http.StatusInternalServerError, err
89+
}
90+
m.v2Reqs = append(m.v2Reqs, r)
91+
m.protos = append(m.protos, msgFullName)
92+
return stats(r), http.StatusOK, nil
7993
}
80-
m.v2Reqs = append(m.v2Reqs, r)
81-
m.protos = append(m.protos, msgFullName)
82-
return stats(r), http.StatusOK, nil
8394
}
8495

8596
func testV2() *writev2.Request {
@@ -112,6 +123,20 @@ func testV2() *writev2.Request {
112123
}
113124
}
114125

126+
func testV1() *writev1.WriteRequest {
127+
return &writev1.WriteRequest{
128+
Timeseries: []*writev1.TimeSeries{
129+
{
130+
Labels: []*writev1.Label{
131+
{Name: "__name__", Value: "metric1"},
132+
{Name: "foo", Value: "bar1"},
133+
},
134+
Samples: []*writev1.Sample{{Value: 1.1, Timestamp: 1214141}, {Value: 1.5, Timestamp: 1214180}},
135+
},
136+
},
137+
}
138+
}
139+
115140
func stats(req *writev2.Request) (s WriteResponseStats) {
116141
s.Confirmed = true
117142
for _, ts := range req.Timeseries {
@@ -145,11 +170,22 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
145170
if err != nil {
146171
t.Fatal(err)
147172
}
173+
148174
if diff := cmp.Diff(stats(req), s); diff != "" {
149175
t.Fatal("unexpected stats", diff)
150176
}
177+
178+
req2 := testV1()
179+
_, err = client.Write(context.Background(), req2)
180+
if err != nil {
181+
t.Fatal(err)
182+
}
183+
151184
if len(mStore.v2Reqs) != 1 {
152-
t.Fatal("expected 1 request stored, got", mStore.v2Reqs)
185+
t.Fatal("expected 1 v2 request stored, got", mStore.v2Reqs)
186+
}
187+
if len(mStore.v1Reqs) != 1 {
188+
t.Fatal("expected 1 v1 request stored, got", mStore.v1Reqs)
153189
}
154190
if diff := cmp.Diff(req, mStore.v2Reqs[0], protocmp.Transform()); diff != "" {
155191
t.Fatal("unexpected request received", diff)

api/prometheus/v1/remote/remote_headers.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ import (
1919
"net/http"
2020
"strconv"
2121
"strings"
22-
23-
"google.golang.org/protobuf/reflect/protoreflect"
2422
)
2523

2624
const (
@@ -43,7 +41,7 @@ const (
4341
// WriteProtoFullName represents the fully qualified name of the protobuf message
4442
// to use in Remote write 1.0 and 2.0 protocols.
4543
// See https://prometheus.io/docs/specs/remote_write_spec_2_0/#protocol.
46-
type WriteProtoFullName protoreflect.FullName
44+
type WriteProtoFullName string
4745

4846
const (
4947
// WriteProtoFullNameV1 represents the `prometheus.WriteRequest` protobuf

0 commit comments

Comments
 (0)