Skip to content

Commit 75db9eb

Browse files
committed
fix: use req.Body instead of IOReaderFactory when possible
Changes request RPCs to use req.Body instead of reading into an in memory byte slice via IOReaderFactory. The IOReaderFactory logic is still available for the PATCH field_mask feature. fixes grpc-ecosystem#3714
1 parent bcd8db9 commit 75db9eb

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

protoc-gen-grpc-gateway/internal/gengateway/template.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,18 +333,26 @@ var (
333333
var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
334334
var metadata runtime.ServerMetadata
335335
{{if .Body}}
336+
{{- $isFieldMask := and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
337+
{{- if $isFieldMask }}
336338
newReader, berr := utilities.IOReaderFactory(req.Body)
337339
if berr != nil {
338340
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
339341
}
342+
{{- end}}
340343
{{- $protoReq := .Body.AssignableExprPrep "protoReq" .Method.Service.File.GoPkg.Path -}}
341344
{{- if ne "" $protoReq }}
342345
{{printf "%s" $protoReq }}
343346
{{- end}}
347+
{{- if not $isFieldMask }}
348+
if err := marshaler.NewDecoder(req.Body).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF {
349+
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
350+
}
351+
{{end}}
352+
{{- if $isFieldMask }}
344353
if err := marshaler.NewDecoder(newReader()).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF {
345354
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
346355
}
347-
{{- if and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
348356
if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 {
349357
if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), protoReq.{{.GetBodyFieldStructName}}); err != nil {
350358
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
@@ -502,18 +510,26 @@ func local_request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ct
502510
var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
503511
var metadata runtime.ServerMetadata
504512
{{if .Body}}
513+
{{- $isFieldMask := and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
514+
{{- if $isFieldMask }}
505515
newReader, berr := utilities.IOReaderFactory(req.Body)
506516
if berr != nil {
507517
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
508518
}
519+
{{- end}}
509520
{{- $protoReq := .Body.AssignableExprPrep "protoReq" .Method.Service.File.GoPkg.Path -}}
510521
{{- if ne "" $protoReq }}
511522
{{printf "%s" $protoReq }}
512523
{{- end}}
524+
{{- if not $isFieldMask }}
525+
if err := marshaler.NewDecoder(req.Body).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF {
526+
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
527+
}
528+
{{end}}
529+
{{- if $isFieldMask }}
513530
if err := marshaler.NewDecoder(newReader()).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF {
514531
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
515532
}
516-
{{- if and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
517533
if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 {
518534
if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), protoReq.{{.GetBodyFieldStructName}}); err != nil {
519535
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)

protoc-gen-grpc-gateway/internal/gengateway/template_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) {
239239
if want := spec.sigWant; !strings.Contains(got, want) {
240240
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
241241
}
242-
if want := `marshaler.NewDecoder(newReader()).Decode(&protoReq.GetNested().Bool)`; !strings.Contains(got, want) {
242+
if want := `marshaler.NewDecoder(req.Body).Decode(&protoReq.GetNested().Bool)`; !strings.Contains(got, want) {
243243
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
244244
}
245245
if want := `val, ok = pathParams["nested.int32"]`; !strings.Contains(got, want) {
@@ -659,6 +659,9 @@ func TestAllowPatchFeature(t *testing.T) {
659659
return
660660
}
661661
if allowPatchFeature {
662+
if want := `marshaler.NewDecoder(newReader()).Decode(&protoReq.Abe)`; !strings.Contains(got, want) {
663+
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
664+
}
662665
if !strings.Contains(got, want) {
663666
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
664667
}

0 commit comments

Comments
 (0)