diff --git a/p2p/http/auth/auth_test.go b/p2p/http/auth/auth_test.go index 9d1b4688d4..293f70de03 100644 --- a/p2p/http/auth/auth_test.go +++ b/p2p/http/auth/auth_test.go @@ -137,6 +137,7 @@ func TestMutualAuth(t *testing.T) { req.Host = "example.com" serverID, resp, err = clientAuth.AuthenticatedDo(client, req) require.NotEmpty(t, req.Header.Get("Authorization")) + require.True(t, HasAuthHeader(req)) require.NoError(t, err) require.Equal(t, expectedServerID, serverID) require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) diff --git a/p2p/http/auth/client.go b/p2p/http/auth/client.go index a6bdece61a..3bef9d6d8b 100644 --- a/p2p/http/auth/client.go +++ b/p2p/http/auth/client.go @@ -20,6 +20,14 @@ type ClientPeerIDAuth struct { tm tokenMap } +type clientAsRoundTripper struct { + *http.Client +} + +func (c clientAsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return c.Client.Do(req) +} + // AuthenticatedDo is like http.Client.Do, but it does the libp2p peer ID auth // handshake if needed. // @@ -27,6 +35,10 @@ type ClientPeerIDAuth struct { // method can retry sending the request in case a previously used token has // expired. func (a *ClientPeerIDAuth) AuthenticatedDo(client *http.Client, req *http.Request) (peer.ID, *http.Response, error) { + return a.AuthenticateWithRoundTripper(clientAsRoundTripper{client}, req) +} + +func (a *ClientPeerIDAuth) AuthenticateWithRoundTripper(rt http.RoundTripper, req *http.Request) (peer.ID, *http.Response, error) { hostname := req.Host ti, hasToken := a.tm.get(hostname, a.TokenTTL) handshake := handshake.PeerIDAuthHandshakeClient{ @@ -36,7 +48,7 @@ func (a *ClientPeerIDAuth) AuthenticatedDo(client *http.Client, req *http.Reques if hasToken { // We have a token. Attempt to use that, but fallback to server initiated challenge if it fails. - peer, resp, err := a.doWithToken(client, req, ti) + peer, resp, err := a.doWithToken(rt, req, ti) switch { case err == nil: return peer, resp, nil @@ -62,7 +74,7 @@ func (a *ClientPeerIDAuth) AuthenticatedDo(client *http.Client, req *http.Reques handshake.SetInitiateChallenge() } - serverPeerID, resp, err := a.runHandshake(client, req, clearBody(req), &handshake) + serverPeerID, resp, err := a.runHandshake(rt, req, clearBody(req), &handshake) if err != nil { return "", nil, fmt.Errorf("failed to run handshake: %w", err) } @@ -74,7 +86,12 @@ func (a *ClientPeerIDAuth) AuthenticatedDo(client *http.Client, req *http.Reques return serverPeerID, resp, nil } -func (a *ClientPeerIDAuth) runHandshake(client *http.Client, req *http.Request, b bodyMeta, hs *handshake.PeerIDAuthHandshakeClient) (peer.ID, *http.Response, error) { +func (a *ClientPeerIDAuth) HasToken(hostname string) bool { + _, hasToken := a.tm.get(hostname, a.TokenTTL) + return hasToken +} + +func (a *ClientPeerIDAuth) runHandshake(rt http.RoundTripper, req *http.Request, b bodyMeta, hs *handshake.PeerIDAuthHandshakeClient) (peer.ID, *http.Response, error) { maxSteps := 5 // Avoid infinite loops in case of buggy handshake. Shouldn't happen. var resp *http.Response @@ -92,7 +109,7 @@ func (a *ClientPeerIDAuth) runHandshake(client *http.Client, req *http.Request, b.setBody(req) } - resp, err = client.Do(req) + resp, err = rt.RoundTrip(req) if err != nil { return "", nil, err } @@ -119,10 +136,10 @@ func (a *ClientPeerIDAuth) runHandshake(client *http.Client, req *http.Request, var errTokenRejected = errors.New("token rejected") -func (a *ClientPeerIDAuth) doWithToken(client *http.Client, req *http.Request, ti tokenInfo) (peer.ID, *http.Response, error) { +func (a *ClientPeerIDAuth) doWithToken(rt http.RoundTripper, req *http.Request, ti tokenInfo) (peer.ID, *http.Response, error) { // Try to make the request with the token req.Header.Set("Authorization", ti.token) - resp, err := client.Do(req) + resp, err := rt.RoundTrip(req) if err != nil { return "", nil, err } diff --git a/p2p/http/auth/server.go b/p2p/http/auth/server.go index b17c3fccf1..f9d5076b0c 100644 --- a/p2p/http/auth/server.go +++ b/p2p/http/auth/server.go @@ -7,6 +7,7 @@ import ( "errors" "hash" "net/http" + "strings" "sync" "time" @@ -60,6 +61,10 @@ type ServerPeerIDAuth struct { // scheme. If a Next handler is set, it will be called on authenticated // requests. func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.ServeHTTPWithNextHandler(w, r, a.Next) +} + +func (a *ServerPeerIDAuth) ServeHTTPWithNextHandler(w http.ResponseWriter, r *http.Request, next func(peer.ID, http.ResponseWriter, *http.Request)) { a.initHmac.Do(func() { if a.HmacKey == nil { key := make([]byte, 32) @@ -130,7 +135,7 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { TokenTTL: a.TokenTTL, Hmac: hmac, } - hs.Run() + _ = hs.Run() // First run will never err hs.SetHeader(w.Header()) w.WriteHeader(http.StatusUnauthorized) @@ -149,9 +154,16 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if a.Next == nil { + if next == nil { w.WriteHeader(http.StatusOK) return } - a.Next(peer, w, r) + next(peer, w, r) +} + +// HasAuthHeader checks if the HTTP request contains an Authorization header +// that starts with the PeerIDAuthScheme prefix. +func HasAuthHeader(r *http.Request) bool { + h := r.Header.Get("Authorization") + return h != "" && strings.HasPrefix(h, handshake.PeerIDAuthScheme) } diff --git a/p2p/http/example_test.go b/p2p/http/example_test.go index 8e94f6e7e0..3b83341003 100644 --- a/p2p/http/example_test.go +++ b/p2p/http/example_test.go @@ -8,13 +8,86 @@ import ( "net/http" "regexp" "strings" + "time" "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" libp2phttp "github.com/libp2p/go-libp2p/p2p/http" + httpauth "github.com/libp2p/go-libp2p/p2p/http/auth" ma "github.com/multiformats/go-multiaddr" ) +func ExampleHost_authenticatedHTTP() { + clientKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 0) + if err != nil { + log.Fatal(err) + } + client := libp2phttp.Host{ + ClientPeerIDAuth: &httpauth.ClientPeerIDAuth{ + TokenTTL: time.Hour, + PrivKey: clientKey, + }, + } + + serverKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 0) + if err != nil { + log.Fatal(err) + } + server := libp2phttp.Host{ + ServerPeerIDAuth: &httpauth.ServerPeerIDAuth{ + PrivKey: serverKey, + // No TLS for this example. In practice you want to use TLS. + NoTLS: true, + ValidHostnameFn: func(hostname string) bool { + return strings.HasPrefix(hostname, "127.0.0.1") + }, + TokenTTL: time.Hour, + }, + // No TLS for this example. In practice you want to use TLS. + InsecureAllowHTTP: true, + ListenAddrs: []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/0/http")}, + } + + observedClientID := "" + server.SetHTTPHandler("/echo-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + observedClientID = libp2phttp.ClientPeerID(r).String() + w.WriteHeader(http.StatusOK) + })) + + go server.Serve() + defer server.Close() + + expectedServerID, err := peer.IDFromPrivateKey(serverKey) + if err != nil { + log.Fatal(err) + } + + httpClient := http.Client{Transport: &client} + url := fmt.Sprintf("multiaddr:%s/p2p/%s/http-path/echo-id", server.Addrs()[0], expectedServerID) + resp, err := httpClient.Get(url) + if err != nil { + log.Fatal(err) + } + resp.Body.Close() + + expectedClientID, err := peer.IDFromPrivateKey(clientKey) + if err != nil { + log.Fatal(err) + } + if observedClientID != expectedClientID.String() { + log.Fatal("observedClientID does not match expectedClientID") + } + + observedServerID := libp2phttp.ServerPeerID(resp) + if observedServerID != expectedServerID { + log.Fatal("observedServerID does not match expectedServerID") + } + + fmt.Println("Successfully authenticated HTTP request") + // Output: Successfully authenticated HTTP request +} + func ExampleHost_withAStockGoHTTPClient() { server := libp2phttp.Host{ InsecureAllowHTTP: true, // For our example, we'll allow insecure HTTP diff --git a/p2p/http/libp2phttp.go b/p2p/http/libp2phttp.go index f61eda4d63..e6558a8ab4 100644 --- a/p2p/http/libp2phttp.go +++ b/p2p/http/libp2phttp.go @@ -25,6 +25,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" + httpauth "github.com/libp2p/go-libp2p/p2p/http/auth" gostream "github.com/libp2p/go-libp2p/p2p/net/gostream" ma "github.com/multiformats/go-multiaddr" ) @@ -44,6 +45,23 @@ const LegacyWellKnownProtocols = "/.well-known/libp2p" const peerMetadataLimit = 8 << 10 // 8KB const peerMetadataLRUSize = 256 // How many different peer's metadata to keep in our LRU cache +type clientPeerIDContextKey struct{} +type serverPeerIDContextKey struct{} + +func ClientPeerID(r *http.Request) peer.ID { + if id, ok := r.Context().Value(clientPeerIDContextKey{}).(peer.ID); ok { + return id + } + return "" +} + +func ServerPeerID(r *http.Response) peer.ID { + if id, ok := r.Request.Context().Value(serverPeerIDContextKey{}).(peer.ID); ok { + return id + } + return "" +} + // ProtocolMeta is metadata about a protocol. type ProtocolMeta struct { // Path defines the HTTP Path prefix used for this protocol @@ -134,6 +152,14 @@ type Host struct { // InsecureAllowHTTP indicates if the server is allowed to serve unencrypted // HTTP requests over TCP. InsecureAllowHTTP bool + + // ServerPeerIDAuth sets the Server's signing key and TTL for server + // provided tokens. + ServerPeerIDAuth *httpauth.ServerPeerIDAuth + // ClientPeerIDAuth sets the Client's signing key and TTL for our stored + // tokens. + ClientPeerIDAuth *httpauth.ClientPeerIDAuth + // ServeMux is the http.ServeMux used by the server to serve requests. If // nil, a new serve mux will be created. Users may manually add handlers to // this mux instead of using `SetHTTPHandler`, but if they do, they should @@ -264,7 +290,7 @@ func (h *Host) setupListeners(listenerErrCh chan error) error { if parsedAddr.useHTTPS { go func() { srv := http.Server{ - Handler: h.ServeMux, + Handler: maybeDecorateContextWithAuthMiddleware(h.ServerPeerIDAuth, h.ServeMux), TLSConfig: h.TLSConfig, } listenerErrCh <- srv.ServeTLS(l, "", "") @@ -272,7 +298,10 @@ func (h *Host) setupListeners(listenerErrCh chan error) error { h.httpTransport.listenAddrs = append(h.httpTransport.listenAddrs, listenAddr) } else if h.InsecureAllowHTTP { go func() { - listenerErrCh <- http.Serve(l, h.ServeMux) + srv := http.Server{ + Handler: maybeDecorateContextWithAuthMiddleware(h.ServerPeerIDAuth, h.ServeMux), + } + listenerErrCh <- srv.Serve(l) }() h.httpTransport.listenAddrs = append(h.httpTransport.listenAddrs, listenAddr) } else { @@ -332,7 +361,20 @@ func (h *Host) Serve() error { h.httpTransport.listenAddrs = append(h.httpTransport.listenAddrs, h.StreamHost.Addrs()...) go func() { - errCh <- http.Serve(listener, connectionCloseHeaderMiddleware(h.ServeMux)) + srv := &http.Server{ + Handler: connectionCloseHeaderMiddleware(h.ServeMux), + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + remote := c.RemoteAddr() + if remote.Network() == gostream.Network { + remoteID, err := peer.Decode(remote.String()) + if err == nil { + return context.WithValue(ctx, clientPeerIDContextKey{}, remoteID) + } + } + return ctx + }, + } + errCh <- srv.Serve(listener) }() } @@ -497,6 +539,8 @@ func (rt *streamRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) } } + ctxWithServerID := context.WithValue(r.Context(), serverPeerIDContextKey{}, rt.server) + resp.Request = resp.Request.WithContext(ctxWithServerID) return resp, nil } @@ -529,10 +573,14 @@ func relativeMultiaddrURIToAbs(original *url.URL, relative *url.URL) (*url.URL, return nil, errors.New("relative path is not a valid http-path") } - withoutPath, _ := ma.SplitFunc(originalMa, func(c ma.Component) bool { + withoutPath, afterAndIncludingPath := ma.SplitFunc(originalMa, func(c ma.Component) bool { return c.Protocol().Code == ma.P_HTTP_PATH }) withNewPath := withoutPath.AppendComponent(relativePathComponent) + if len(afterAndIncludingPath) > 1 { + // Include after path since it may include other parts + withNewPath = append(withNewPath, afterAndIncludingPath[1:]...) + } return url.Parse("multiaddr:" + withNewPath.String()) } @@ -701,6 +749,18 @@ func (h *Host) RoundTrip(r *http.Request) (*http.Response, error) { switch r.URL.Scheme { case "http", "https": h.initDefaultRT() + if r.Host == "" { + r.Host = r.URL.Host + } + if h.ClientPeerIDAuth != nil && h.ClientPeerIDAuth.HasToken(r.Host) { + serverID, resp, err := h.ClientPeerIDAuth.AuthenticateWithRoundTripper(h.DefaultClientRoundTripper, r) + if err != nil { + return nil, err + } + ctxWithServerID := context.WithValue(r.Context(), serverPeerIDContextKey{}, serverID) + resp.Request = resp.Request.WithContext(ctxWithServerID) + return resp, nil + } return h.DefaultClientRoundTripper.RoundTrip(r) case "multiaddr": break @@ -732,7 +792,12 @@ func (h *Host) RoundTrip(r *http.Request) (*http.Response, error) { h.initDefaultRT() rt := h.DefaultClientRoundTripper - if parsed.sni != parsed.host { + sni := parsed.sni + if sni == "" { + sni = parsed.host + } + + if sni != parsed.host { // We have a different host and SNI (e.g. using an IP address but specifying a SNI) // We need to make our own transport to support this. // @@ -743,6 +808,33 @@ func (h *Host) RoundTrip(r *http.Request) (*http.Response, error) { rt.TLSClientConfig.ServerName = parsed.sni } + if parsed.peer != "" { + // The peer ID is present. We are making an authenticated request + if h.ClientPeerIDAuth == nil { + return nil, fmt.Errorf("can not authenticate server. Host.ClientPeerIDAuth field is not set") + } + + if r.Host == "" { + // Missing a host header. Default to what we parsed earlier + r.Host = u.Host + } + + serverID, resp, err := h.ClientPeerIDAuth.AuthenticateWithRoundTripper(rt, r) + if err != nil { + return nil, err + } + + if serverID != parsed.peer { + resp.Body.Close() + return nil, fmt.Errorf("authenticated server ID does not match expected server ID") + } + + ctxWithServerID := context.WithValue(r.Context(), serverPeerIDContextKey{}, serverID) + resp.Request = resp.Request.WithContext(ctxWithServerID) + + return resp, nil + } + return rt.RoundTrip(r) } @@ -1088,3 +1180,22 @@ func connectionCloseHeaderMiddleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +// maybeDecorateContextWithAuth decorates the request context with +// authentication information if serverAuth is provided. +func maybeDecorateContextWithAuthMiddleware(serverAuth *httpauth.ServerPeerIDAuth, next http.Handler) http.Handler { + if next == nil { + return nil + } + if serverAuth == nil { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if httpauth.HasAuthHeader(r) { + serverAuth.ServeHTTPWithNextHandler(w, r, func(p peer.ID, w http.ResponseWriter, r *http.Request) { + r = r.WithContext(context.WithValue(r.Context(), clientPeerIDContextKey{}, p)) + next.ServeHTTP(w, r) + }) + } + }) +} diff --git a/p2p/http/libp2phttp_test.go b/p2p/http/libp2phttp_test.go index ae3285b9a3..ed69bf130b 100644 --- a/p2p/http/libp2phttp_test.go +++ b/p2p/http/libp2phttp_test.go @@ -24,9 +24,11 @@ import ( "time" "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/crypto" host "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" libp2phttp "github.com/libp2p/go-libp2p/p2p/http" + httpauth "github.com/libp2p/go-libp2p/p2p/http/auth" httpping "github.com/libp2p/go-libp2p/p2p/http/ping" libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" ma "github.com/multiformats/go-multiaddr" @@ -1014,3 +1016,117 @@ func TestErrServerClosed(t *testing.T) { server.Close() <-done } + +func TestHTTPOverStreamsGetClientID(t *testing.T) { + serverHost, err := libp2p.New( + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), + ) + require.NoError(t, err) + + httpHost := libp2phttp.Host{StreamHost: serverHost} + + httpHost.SetHTTPHandler("/echo-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clientID := libp2phttp.ClientPeerID(r) + w.Write([]byte(clientID.String())) + })) + + // Start server + go httpHost.Serve() + defer httpHost.Close() + + // Start client + clientHost, err := libp2p.New(libp2p.NoListenAddrs) + require.NoError(t, err) + clientHost.Connect(context.Background(), peer.AddrInfo{ + ID: serverHost.ID(), + Addrs: serverHost.Addrs(), + }) + + client := http.Client{ + Transport: &libp2phttp.Host{StreamHost: clientHost}, + } + require.NoError(t, err) + + resp, err := client.Get("multiaddr:" + serverHost.Addrs()[0].String() + "/p2p/" + serverHost.ID().String() + "/http-path/echo-id") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, clientHost.ID().String(), string(body)) +} + +func TestAuthenticatedRequest(t *testing.T) { + serverSK, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + serverID, err := peer.IDFromPrivateKey(serverSK) + require.NoError(t, err) + + serverStreamHost, err := libp2p.New( + libp2p.Identity(serverSK), + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), + libp2p.Transport(libp2pquic.NewTransport), + ) + require.NoError(t, err) + + server := libp2phttp.Host{ + InsecureAllowHTTP: true, + StreamHost: serverStreamHost, + ListenAddrs: []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/0/http")}, + ServerPeerIDAuth: &httpauth.ServerPeerIDAuth{ + TokenTTL: time.Hour, + PrivKey: serverSK, + NoTLS: true, + ValidHostnameFn: func(hostname string) bool { + return strings.HasPrefix(hostname, "127.0.0.1") + }, + }, + } + server.SetHTTPHandler("/echo-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clientID := libp2phttp.ClientPeerID(r) + w.Write([]byte(clientID.String())) + })) + + go server.Serve() + + clientSK, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + + clientStreamHost, err := libp2p.New( + libp2p.Identity(clientSK), + libp2p.NoListenAddrs, + libp2p.Transport(libp2pquic.NewTransport)) + require.NoError(t, err) + + client := &http.Client{ + Transport: &libp2phttp.Host{ + StreamHost: clientStreamHost, + ClientPeerIDAuth: &httpauth.ClientPeerIDAuth{ + TokenTTL: time.Hour, + PrivKey: clientSK, + }, + }, + } + + clientID, err := peer.IDFromPrivateKey(clientSK) + require.NoError(t, err) + + for _, serverAddr := range server.Addrs() { + _, tpt := ma.SplitLast(serverAddr) + t.Run(tpt.String(), func(t *testing.T) { + url := fmt.Sprintf("multiaddr:%s/p2p/%s/http-path/echo-id", serverAddr, serverID) + t.Log("Making a GET request to:", url) + resp, err := client.Get(url) + require.NoError(t, err) + + observedServerID := libp2phttp.ServerPeerID(resp) + require.Equal(t, serverID, observedServerID) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, clientID.String(), string(body)) + }) + } +}