Skip to content

transport: new stream with actual server name #5748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type http2Client struct {
cancel context.CancelFunc
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
userAgent string
address resolver.Address // Record the used resolver address of client, and replace :authority to resolver address if serverName is not empty.
md metadata.MD
conn net.Conn // underlying communication channel
loopy *loopyWriter
Expand Down Expand Up @@ -314,6 +315,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
cancel: cancel,
userAgent: opts.UserAgent,
registeredCompressors: grpcutil.RegisteredCompressors(),
address: addr, // resolver address
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
Expand Down Expand Up @@ -454,6 +456,11 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return t, nil
}

// Address return the resolver address meta info
func (t *http2Client) Address() resolver.Address {
return t.address
}

func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
Expand Down Expand Up @@ -702,7 +709,14 @@ func (e NewStreamError) Error() string {
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) {
ctx = peer.NewContext(ctx, t.getPeer())
headerFields, err := t.createHeaderFields(ctx, callHdr)

dupCallHdr := *callHdr
// replace host with the actual server name, if it is exist and unmatch
if t.address.ServerName != "" && t.address.ServerName != dupCallHdr.Host {
dupCallHdr.Host = t.address.ServerName
}

headerFields, err := t.createHeaderFields(ctx, &dupCallHdr)
if err != nil {
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
}
Expand Down
3 changes: 3 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ type ClientTransport interface {
// with a human readable string with debug info.
GetGoAwayReason() (GoAwayReason, string)

// Address return the transport used resolver address meta info
Address() resolver.Address

// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr

Expand Down
67 changes: 66 additions & 1 deletion internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"testing"
"time"

"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -91,6 +92,7 @@ const (
invalidHeaderField
delayRead
pingpong
returnHeaderAuthority
)

func (h *testStreamHandler) handleStreamAndNotify(s *Stream) {
Expand Down Expand Up @@ -203,6 +205,31 @@ func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) {
})
}

func (h *testStreamHandler) handleStreamReturnValueOfAuthority(t *testing.T, s *Stream) {
var (
md, exist = metadata.FromIncomingContext(s.ctx)
resp string
)

if !exist || len(md.Get(":authority")) == 0 {
h.handleStreamInvalidHeaderField(s)
return
}

resp = md.Get(":authority")[0]

req := expectedRequest
p := make([]byte, len(req))
_, err := s.Read(p)
if err != nil {
return
}
// send a response back to the client.
h.t.Write(s, nil, []byte(resp), &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}

// handleStreamDelayRead delays reads so that the other side has to halt on
// stream-level flow control.
// This handler assumes dynamic flow control is turned off and assumes window
Expand Down Expand Up @@ -379,6 +406,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case returnHeaderAuthority:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamReturnValueOfAuthority(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case delayRead:
h.notify = make(chan struct{})
h.getNotified = make(chan struct{})
Expand Down Expand Up @@ -448,7 +481,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2

func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
server := setUpServerOnly(t, port, sc, ht)
addr := resolver.Address{Addr: "localhost:" + server.port}
addr := resolver.Address{Addr: "localhost:" + server.port, ServerName: server.addr()}
copts.ChannelzParentID = channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)

connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
Expand Down Expand Up @@ -1431,6 +1464,38 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
server.stop()
}

func (s) TestHeaderHostReplacedWithResolverAddress(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, returnHeaderAuthority)
defer cancel()
callHdr := &CallHdr{
Host: "scheme://testSrv.com/testPath",
Method: "foo",
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
s, err := ct.NewStream(ctx, callHdr)
if err != nil {
return
}

opts := Options{Last: true}
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone {
t.Fatalf("Failed to write the request: %v", err)
}
respOfAuthority := make([]byte, http2MaxFrameLen)
len, recvErr := s.Read(respOfAuthority)
if err, ok := status.FromError(recvErr); ok {
t.Fatalf("Read got error %v, headers are unexpected", err)
}

if string(respOfAuthority[:len]) != server.addr() {
t.Fatalf("Read got a unexpected :authority value %v, want %v", string(respOfAuthority), server.addr())
}

ct.Close(fmt.Errorf("closed manually by test"))
server.stop()
}

func (s) TestInvalidHeaderField(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
defer cancel()
Expand Down