Skip to content

Commit fcd3d71

Browse files
holdnojronak
authored andcommitted
transport: new stream with actual server name (grpc#5748)
1 parent f7ba15d commit fcd3d71

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

internal/transport/http2_client.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,15 @@ var clientConnectionCounter uint64
5959

6060
// http2Client implements the ClientTransport interface with HTTP2.
6161
type http2Client struct {
62-
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
63-
ctx context.Context
64-
cancel context.CancelFunc
65-
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
66-
userAgent string
62+
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
63+
ctx context.Context
64+
cancel context.CancelFunc
65+
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
66+
userAgent string
67+
// address contains the resolver returned address for this transport.
68+
// If the `ServerName` field is set, it takes precedence over `CallHdr.Host`
69+
// passed to `NewStream`, when determining the :authority header.
70+
address resolver.Address
6771
md metadata.MD
6872
conn net.Conn // underlying communication channel
6973
loopy *loopyWriter
@@ -314,6 +318,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
314318
cancel: cancel,
315319
userAgent: opts.UserAgent,
316320
registeredCompressors: grpcutil.RegisteredCompressors(),
321+
address: addr,
317322
conn: conn,
318323
remoteAddr: conn.RemoteAddr(),
319324
localAddr: conn.LocalAddr(),
@@ -702,6 +707,18 @@ func (e NewStreamError) Error() string {
702707
// streams. All non-nil errors returned will be *NewStreamError.
703708
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) {
704709
ctx = peer.NewContext(ctx, t.getPeer())
710+
711+
// ServerName field of the resolver returned address takes precedence over
712+
// Host field of CallHdr to determine the :authority header. This is because,
713+
// the ServerName field takes precedence for server authentication during
714+
// TLS handshake, and the :authority header should match the value used
715+
// for server authentication.
716+
if t.address.ServerName != "" {
717+
newCallHdr := *callHdr
718+
newCallHdr.Host = t.address.ServerName
719+
callHdr = &newCallHdr
720+
}
721+
705722
headerFields, err := t.createHeaderFields(ctx, callHdr)
706723
if err != nil {
707724
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}

test/authority_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import (
3636
"google.golang.org/grpc/credentials/insecure"
3737
"google.golang.org/grpc/internal/stubserver"
3838
"google.golang.org/grpc/metadata"
39+
"google.golang.org/grpc/resolver"
40+
"google.golang.org/grpc/resolver/manual"
3941
"google.golang.org/grpc/status"
4042
testpb "google.golang.org/grpc/test/grpc_testing"
4143
)
@@ -205,3 +207,35 @@ func (s) TestColonPortAuthority(t *testing.T) {
205207
t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
206208
}
207209
}
210+
211+
// TestAuthorityReplacedWithResolverAddress tests the scenario where the resolver
212+
// returned address contains a ServerName override. The test verifies that the the
213+
// :authority header value sent to the server as part of the http/2 HEADERS frame
214+
// is set to the value specified in the resolver returned address.
215+
func (s) TestAuthorityReplacedWithResolverAddress(t *testing.T) {
216+
const expectedAuthority = "test.server.name"
217+
218+
ss := &stubserver.StubServer{
219+
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
220+
return authorityChecker(ctx, expectedAuthority)
221+
},
222+
}
223+
if err := ss.Start(nil); err != nil {
224+
t.Fatalf("Error starting endpoint server: %v", err)
225+
}
226+
defer ss.Stop()
227+
228+
r := manual.NewBuilderWithScheme("whatever")
229+
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address, ServerName: expectedAuthority}}})
230+
cc, err := grpc.Dial(r.Scheme()+":///whatever", grpc.WithInsecure(), grpc.WithResolvers(r))
231+
if err != nil {
232+
t.Fatalf("grpc.Dial(%q) = %v", ss.Address, err)
233+
}
234+
defer cc.Close()
235+
236+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
237+
defer cancel()
238+
if _, err = testpb.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}); err != nil {
239+
t.Fatalf("EmptyCall() rpc failed: %v", err)
240+
}
241+
}

0 commit comments

Comments
 (0)