Skip to content

Commit 5999752

Browse files
Minimize allocations on header access (connectrpc#445)
The standard library's `http.Header.{Get,Set,Add,Del}` automatically canonicalize header keys. This is nice, but expensive for us - we're using these APIs in many hot paths, and our constants are already in canonical form. This PR adds functions that bypass canonicalization and uses them wherever possible. --------- Co-authored-by: Akshay Shah <[email protected]>
1 parent e488bce commit 5999752

File tree

6 files changed

+87
-45
lines changed

6 files changed

+87
-45
lines changed

error_writer.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
8383
// IsSupported checks whether a request is using one of the ErrorWriter's
8484
// supported RPC protocols.
8585
func (w *ErrorWriter) IsSupported(request *http.Request) bool {
86-
ctype := canonicalizeContentType(request.Header.Get(headerContentType))
86+
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
8787
_, ok := w.allContentTypes[ctype]
8888
return ok
8989
}
@@ -94,22 +94,22 @@ func (w *ErrorWriter) IsSupported(request *http.Request) bool {
9494
//
9595
// Write does not read or close the request body.
9696
func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, err error) error {
97-
ctype := canonicalizeContentType(request.Header.Get(headerContentType))
97+
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
9898
if _, ok := w.unaryConnectContentTypes[ctype]; ok {
9999
// Unary errors are always JSON.
100-
response.Header().Set(headerContentType, connectUnaryContentTypeJSON)
100+
setHeaderCanonical(response.Header(), headerContentType, connectUnaryContentTypeJSON)
101101
return w.writeConnectUnary(response, err)
102102
}
103103
if _, ok := w.streamingConnectContentTypes[ctype]; ok {
104-
response.Header().Set(headerContentType, ctype)
104+
setHeaderCanonical(response.Header(), headerContentType, ctype)
105105
return w.writeConnectStreaming(response, err)
106106
}
107107
if _, ok := w.grpcContentTypes[ctype]; ok {
108-
response.Header().Set(headerContentType, ctype)
108+
setHeaderCanonical(response.Header(), headerContentType, ctype)
109109
return w.writeGRPC(response, err)
110110
}
111111
if _, ok := w.grpcWebContentTypes[ctype]; ok {
112-
response.Header().Set(headerContentType, ctype)
112+
setHeaderCanonical(response.Header(), headerContentType, ctype)
113113
return w.writeGRPCWeb(response, err)
114114
}
115115
return fmt.Errorf("unsupported Content-Type %q", ctype)
@@ -153,7 +153,7 @@ func (w *ErrorWriter) writeGRPC(response http.ResponseWriter, err error) error {
153153
for k := range trailers {
154154
keys = append(keys, k)
155155
}
156-
response.Header().Set("Trailer", strings.Join(keys, ","))
156+
setHeaderCanonical(response.Header(), headerTrailer, strings.Join(keys, ","))
157157
response.WriteHeader(http.StatusOK)
158158
mergeHeaders(response.Header(), trailers)
159159
return nil

handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
190190
}
191191

192192
// Find our implementation of the RPC protocol in use.
193-
contentType := canonicalizeContentType(request.Header.Get("Content-Type"))
193+
contentType := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
194194
var protocolHandler protocolHandler
195195
for _, handler := range h.protocolHandlers {
196196
if _, ok := handler.ContentTypes()[contentType]; ok {
@@ -205,7 +205,7 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
205205
}
206206

207207
// Establish a stream and serve the RPC.
208-
request.Header.Set("Content-Type", contentType) // prefer canonicalized value
208+
setHeaderCanonical(request.Header, headerContentType, contentType)
209209
ctx, cancel, timeoutErr := protocolHandler.SetTimeout(request) //nolint: contextcheck
210210
if timeoutErr != nil {
211211
ctx = request.Context()

header.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,38 @@ func mergeHeaders(into, from http.Header) {
5050
into[k] = append(into[k], vals...)
5151
}
5252
}
53+
54+
// getCanonicalHeader is a shortcut for Header.Get() which
55+
// bypasses the CanonicalMIMEHeaderKey operation when we
56+
// know the key is already in canonical form.
57+
func getHeaderCanonical(h http.Header, key string) string {
58+
if h == nil {
59+
return ""
60+
}
61+
v := h[key]
62+
if len(v) == 0 {
63+
return ""
64+
}
65+
return v[0]
66+
}
67+
68+
// setHeaderCanonical is a shortcut for Header.Set() which
69+
// bypasses the CanonicalMIMEHeaderKey operation when we
70+
// know the key is already in canonical form.
71+
func setHeaderCanonical(h http.Header, key, value string) {
72+
h[key] = []string{value}
73+
}
74+
75+
// delHeaderCanonical is a shortcut for Header.Del() which
76+
// bypasses the CanonicalMIMEHeaderKey operation when we
77+
// know the key is already in canonical form.
78+
func delHeaderCanonical(h http.Header, key string) {
79+
delete(h, key)
80+
}
81+
82+
// addHeaderCanonical is a shortcut for Header.Add() which
83+
// bypasses the CanonicalMIMEHeaderKey operation when we
84+
// know the key is already in canonical form.
85+
func addHeaderCanonical(h http.Header, key, value string) {
86+
h[key] = append(h[key], value)
87+
}

protocol.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const (
3737
const (
3838
headerContentType = "Content-Type"
3939
headerUserAgent = "User-Agent"
40+
headerTrailer = "Trailer"
4041

4142
discardLimit = 1024 * 1024 * 4 // 4MiB
4243
)

protocol_connect.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func (h *connectHandler) ContentTypes() map[string]struct{} {
9191
}
9292

9393
func (*connectHandler) SetTimeout(request *http.Request) (context.Context, context.CancelFunc, error) {
94-
timeout := request.Header.Get(connectHeaderTimeout)
94+
timeout := getHeaderCanonical(request.Header, connectHeaderTimeout)
9595
if timeout == "" {
9696
return request.Context(), nil, nil
9797
}
@@ -117,11 +117,11 @@ func (h *connectHandler) NewConn(
117117
// send the error to the client later on.
118118
var contentEncoding, acceptEncoding string
119119
if h.Spec.StreamType == StreamTypeUnary {
120-
contentEncoding = request.Header.Get(connectUnaryHeaderCompression)
121-
acceptEncoding = request.Header.Get(connectUnaryHeaderAcceptCompression)
120+
contentEncoding = getHeaderCanonical(request.Header, connectUnaryHeaderCompression)
121+
acceptEncoding = getHeaderCanonical(request.Header, connectUnaryHeaderAcceptCompression)
122122
} else {
123-
contentEncoding = request.Header.Get(connectStreamingHeaderCompression)
124-
acceptEncoding = request.Header.Get(connectStreamingHeaderAcceptCompression)
123+
contentEncoding = getHeaderCanonical(request.Header, connectStreamingHeaderCompression)
124+
acceptEncoding = getHeaderCanonical(request.Header, connectStreamingHeaderAcceptCompression)
125125
}
126126
requestCompression, responseCompression, failed := negotiateCompression(
127127
h.CompressionPools,
@@ -132,7 +132,7 @@ func (h *connectHandler) NewConn(
132132
failed = checkServerStreamsCanFlush(h.Spec, responseWriter)
133133
}
134134
if failed == nil {
135-
version := request.Header.Get(connectHeaderProtocolVersion)
135+
version := getHeaderCanonical(request.Header, connectHeaderProtocolVersion)
136136
if version == "" && h.RequireConnectProtocolHeader {
137137
failed = errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion)
138138
} else if version != "" && version != connectProtocolVersion {
@@ -148,7 +148,7 @@ func (h *connectHandler) NewConn(
148148
// Since we know that these header keys are already in canonical form, we can
149149
// skip the normalization in Header.Set.
150150
header := responseWriter.Header()
151-
header[headerContentType] = []string{request.Header.Get(headerContentType)}
151+
header[headerContentType] = []string{getHeaderCanonical(request.Header, headerContentType)}
152152
acceptCompressionHeader := connectUnaryHeaderAcceptCompression
153153
if h.Spec.StreamType != StreamTypeUnary {
154154
acceptCompressionHeader = connectStreamingHeaderAcceptCompression
@@ -164,7 +164,7 @@ func (h *connectHandler) NewConn(
164164

165165
codecName := connectCodecFromContentType(
166166
h.Spec.StreamType,
167-
request.Header.Get(headerContentType),
167+
getHeaderCanonical(request.Header, headerContentType),
168168
)
169169
codec := h.Codecs.Get(codecName) // handler.go guarantees this is not nil
170170

@@ -247,7 +247,7 @@ func (c *connectClient) Peer() Peer {
247247
func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.Header) {
248248
// We know these header keys are in canonical form, so we can bypass all the
249249
// checks in Header.Set.
250-
if header.Get(headerUserAgent) == "" {
250+
if getHeaderCanonical(header, headerUserAgent) == "" {
251251
header[headerUserAgent] = []string{defaultConnectUserAgent}
252252
}
253253
header[connectHeaderProtocolVersion] = []string{connectProtocolVersion}
@@ -418,7 +418,7 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err
418418
}
419419
cc.responseTrailer[strings.TrimPrefix(k, connectUnaryTrailerPrefix)] = v
420420
}
421-
compression := response.Header.Get(connectUnaryHeaderCompression)
421+
compression := getHeaderCanonical(response.Header, connectUnaryHeaderCompression)
422422
if compression != "" &&
423423
compression != compressionIdentity &&
424424
!cc.compressionPools.Contains(compression) {
@@ -534,7 +534,7 @@ func (cc *connectStreamingClientConn) validateResponse(response *http.Response)
534534
if response.StatusCode != http.StatusOK {
535535
return errorf(connectHTTPToCode(response.StatusCode), "HTTP status %v", response.Status)
536536
}
537-
compression := response.Header.Get(connectStreamingHeaderCompression)
537+
compression := getHeaderCanonical(response.Header, connectStreamingHeaderCompression)
538538
if compression != "" &&
539539
compression != compressionIdentity &&
540540
!cc.compressionPools.Contains(compression) {
@@ -605,7 +605,7 @@ func (hc *connectUnaryHandlerConn) Close(err error) error {
605605
return hc.request.Body.Close()
606606
}
607607
// In unary Connect, errors always use application/json.
608-
hc.responseWriter.Header().Set(headerContentType, connectUnaryContentTypeJSON)
608+
setHeaderCanonical(hc.responseWriter.Header(), headerContentType, connectUnaryContentTypeJSON)
609609
hc.responseWriter.WriteHeader(connectCodeToHTTP(CodeOf(err)))
610610
data, marshalErr := json.Marshal(newConnectWireError(err))
611611
if marshalErr != nil {
@@ -795,7 +795,7 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error {
795795
if m.sendMaxBytes > 0 && compressed.Len() > m.sendMaxBytes {
796796
return NewError(CodeResourceExhausted, fmt.Errorf("compressed message size %d exceeds sendMaxBytes %d", compressed.Len(), m.sendMaxBytes))
797797
}
798-
m.header.Set(connectUnaryHeaderCompression, m.compressionName)
798+
setHeaderCanonical(m.header, connectUnaryHeaderCompression, m.compressionName)
799799
return m.write(compressed.Bytes())
800800
}
801801

protocol_grpc.go

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func (g *grpcHandler) ContentTypes() map[string]struct{} {
131131
}
132132

133133
func (*grpcHandler) SetTimeout(request *http.Request) (context.Context, context.CancelFunc, error) {
134-
timeout, err := grpcParseTimeout(request.Header.Get(grpcHeaderTimeout))
134+
timeout, err := grpcParseTimeout(getHeaderCanonical(request.Header, grpcHeaderTimeout))
135135
if err != nil && !errors.Is(err, errNoTimeout) {
136136
// Errors here indicate that the client sent an invalid timeout header, so
137137
// the error text is safe to send back.
@@ -152,8 +152,8 @@ func (g *grpcHandler) NewConn(
152152
// send the error to the client later on.
153153
requestCompression, responseCompression, failed := negotiateCompression(
154154
g.CompressionPools,
155-
request.Header.Get(grpcHeaderCompression),
156-
request.Header.Get(grpcHeaderAcceptCompression),
155+
getHeaderCanonical(request.Header, grpcHeaderCompression),
156+
getHeaderCanonical(request.Header, grpcHeaderAcceptCompression),
157157
)
158158
if failed == nil {
159159
failed = checkServerStreamsCanFlush(g.Spec, responseWriter)
@@ -167,13 +167,13 @@ func (g *grpcHandler) NewConn(
167167
// Since we know that these header keys are already in canonical form, we can
168168
// skip the normalization in Header.Set.
169169
header := responseWriter.Header()
170-
header[headerContentType] = []string{request.Header.Get(headerContentType)}
170+
header[headerContentType] = []string{getHeaderCanonical(request.Header, headerContentType)}
171171
header[grpcHeaderAcceptCompression] = []string{g.CompressionPools.CommaSeparatedNames()}
172172
if responseCompression != compressionIdentity {
173173
header[grpcHeaderCompression] = []string{responseCompression}
174174
}
175175

176-
codecName := grpcCodecFromContentType(g.web, request.Header.Get(headerContentType))
176+
codecName := grpcCodecFromContentType(g.web, getHeaderCanonical(request.Header, headerContentType))
177177
codec := g.Codecs.Get(codecName) // handler.go guarantees this is not nil
178178
protocolName := ProtocolGRPC
179179
if g.web {
@@ -237,7 +237,7 @@ func (g *grpcClient) Peer() Peer {
237237
func (g *grpcClient) WriteRequestHeader(_ StreamType, header http.Header) {
238238
// We know these header keys are in canonical form, so we can bypass all the
239239
// checks in Header.Set.
240-
if header.Get(headerUserAgent) == "" {
240+
if getHeaderCanonical(header, headerUserAgent) == "" {
241241
header[headerUserAgent] = []string{defaultGrpcUserAgent}
242242
}
243243
header[headerContentType] = []string{grpcContentTypeFromCodecName(g.web, g.Codec.Name())}
@@ -365,7 +365,7 @@ func (cc *grpcClientConn) Receive(msg any) error {
365365
if err == nil {
366366
return nil
367367
}
368-
if cc.responseHeader.Get(grpcHeaderStatus) != "" {
368+
if getHeaderCanonical(cc.responseHeader, grpcHeaderStatus) != "" {
369369
// We got what gRPC calls a trailers-only response, which puts the trailing
370370
// metadata (including errors) into HTTP headers. validateResponse has
371371
// already extracted the error.
@@ -423,7 +423,7 @@ func (cc *grpcClientConn) validateResponse(response *http.Response) *Error {
423423
); err != nil {
424424
return err
425425
}
426-
compression := response.Header.Get(grpcHeaderCompression)
426+
compression := getHeaderCanonical(response.Header, grpcHeaderCompression)
427427
cc.unmarshaler.envelopeReader.compressionPool = cc.compressionPools.Get(compression)
428428
return nil
429429
}
@@ -542,11 +542,13 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) {
542542
// implement http.Flusher, we must pre-declare our HTTP trailers. We can
543543
// remove this when Go 1.21 ships and we drop support for Go 1.19.
544544
for key := range mergedTrailers {
545-
hc.responseWriter.Header().Add("Trailer", key)
545+
addHeaderCanonical(hc.responseWriter.Header(), headerTrailer, key)
546546
}
547547
hc.responseWriter.WriteHeader(http.StatusOK)
548548
for key, values := range mergedTrailers {
549549
for _, value := range values {
550+
// These are potentially user-supplied, so we can't assume they're in
551+
// canonical form. Don't use addHeaderCanonical.
550552
hc.responseWriter.Header().Add(key, value)
551553
}
552554
}
@@ -561,6 +563,8 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) {
561563
// logic breaks Envoy's gRPC-Web translation.
562564
for key, values := range mergedTrailers {
563565
for _, value := range values {
566+
// These are potentially user-supplied, so we can't assume they're in
567+
// canonical form. Don't use addHeaderCanonical.
564568
hc.responseWriter.Header().Add(http.TrailerPrefix+key, value)
565569
}
566570
}
@@ -636,7 +640,7 @@ func grpcValidateResponse(
636640
if response.StatusCode != http.StatusOK {
637641
return errorf(grpcHTTPToCode(response.StatusCode), "HTTP status %v", response.Status)
638642
}
639-
if compression := response.Header.Get(grpcHeaderCompression); compression != "" &&
643+
if compression := getHeaderCanonical(response.Header, grpcHeaderCompression); compression != "" &&
640644
compression != compressionIdentity &&
641645
!availableCompressors.Contains(compression) {
642646
// Per https://github.com/grpc/grpc/blob/master/doc/compression.md, we
@@ -658,11 +662,11 @@ func grpcValidateResponse(
658662
); err != nil && !errors.Is(err, errTrailersWithoutGRPCStatus) {
659663
// Per the specification, only the HTTP status code and Content-Type should
660664
// be treated as headers. The rest should be treated as trailing metadata.
661-
if contentType := response.Header.Get(headerContentType); contentType != "" {
662-
header.Set(headerContentType, contentType)
665+
if contentType := getHeaderCanonical(response.Header, headerContentType); contentType != "" {
666+
setHeaderCanonical(header, headerContentType, contentType)
663667
}
664668
mergeHeaders(trailer, response.Header)
665-
trailer.Del(headerContentType)
669+
delHeaderCanonical(trailer, headerContentType)
666670
// Also set the error metadata
667671
err.meta = header.Clone()
668672
mergeHeaders(err.meta, trailer)
@@ -699,7 +703,7 @@ func grpcHTTPToCode(httpCode int) Code {
699703
// use a different codec. Consequently, this function needs a Protobuf codec to
700704
// unmarshal error information in the headers.
701705
func grpcErrorFromTrailer(bufferPool *bufferPool, protobuf Codec, trailer http.Header) *Error {
702-
codeHeader := trailer.Get(grpcHeaderStatus)
706+
codeHeader := getHeaderCanonical(trailer, grpcHeaderStatus)
703707
if codeHeader == "" {
704708
return NewError(CodeInternal, errTrailersWithoutGRPCStatus)
705709
}
@@ -711,10 +715,10 @@ func grpcErrorFromTrailer(bufferPool *bufferPool, protobuf Codec, trailer http.H
711715
if err != nil {
712716
return errorf(CodeInternal, "gRPC protocol error: invalid error code %q", codeHeader)
713717
}
714-
message := grpcPercentDecode(bufferPool, trailer.Get(grpcHeaderMessage))
718+
message := grpcPercentDecode(bufferPool, getHeaderCanonical(trailer, grpcHeaderMessage))
715719
retErr := NewWireError(Code(code), errors.New(message))
716720

717-
detailsBinaryEncoded := trailer.Get(grpcHeaderDetails)
721+
detailsBinaryEncoded := getHeaderCanonical(trailer, grpcHeaderDetails)
718722
if len(detailsBinaryEncoded) > 0 {
719723
detailsBinary, err := DecodeBinaryHeader(detailsBinaryEncoded)
720724
if err != nil {
@@ -794,19 +798,21 @@ func grpcContentTypeFromCodecName(web bool, name string) string {
794798

795799
func grpcErrorToTrailer(bufferPool *bufferPool, trailer http.Header, protobuf Codec, err error) {
796800
if err == nil {
797-
trailer.Set(grpcHeaderStatus, "0") // zero is the gRPC OK status
798-
trailer.Set(grpcHeaderMessage, "")
801+
setHeaderCanonical(trailer, grpcHeaderStatus, "0") // zero is the gRPC OK status
802+
setHeaderCanonical(trailer, grpcHeaderMessage, "")
799803
return
800804
}
801805
status := grpcStatusFromError(err)
802806
code := strconv.Itoa(int(status.Code))
803807
bin, binErr := protobuf.Marshal(status)
804808
if binErr != nil {
805-
trailer.Set(
809+
setHeaderCanonical(
810+
trailer,
806811
grpcHeaderStatus,
807812
strconv.FormatInt(int64(CodeInternal), 10 /* base */),
808813
)
809-
trailer.Set(
814+
setHeaderCanonical(
815+
trailer,
810816
grpcHeaderMessage,
811817
grpcPercentEncode(
812818
bufferPool,
@@ -818,9 +824,9 @@ func grpcErrorToTrailer(bufferPool *bufferPool, trailer http.Header, protobuf Co
818824
if connectErr, ok := asError(err); ok {
819825
mergeHeaders(trailer, connectErr.meta)
820826
}
821-
trailer.Set(grpcHeaderStatus, code)
822-
trailer.Set(grpcHeaderMessage, grpcPercentEncode(bufferPool, status.Message))
823-
trailer.Set(grpcHeaderDetails, EncodeBinaryHeader(bin))
827+
setHeaderCanonical(trailer, grpcHeaderStatus, code)
828+
setHeaderCanonical(trailer, grpcHeaderMessage, grpcPercentEncode(bufferPool, status.Message))
829+
setHeaderCanonical(trailer, grpcHeaderDetails, EncodeBinaryHeader(bin))
824830
}
825831

826832
func grpcStatusFromError(err error) *statusv1.Status {

0 commit comments

Comments
 (0)