1
1
package pubsub
2
2
3
3
import (
4
- "bufio"
5
4
"context"
5
+ "encoding/binary"
6
6
"io"
7
7
"time"
8
8
9
+ "github.com/gogo/protobuf/proto"
10
+ pool "github.com/libp2p/go-buffer-pool"
11
+ "github.com/multiformats/go-varint"
12
+
9
13
"github.com/libp2p/go-libp2p/core/network"
10
14
"github.com/libp2p/go-libp2p/core/peer"
15
+ "github.com/libp2p/go-msgio"
11
16
12
17
pb "github.com/libp2p/go-libp2p-pubsub/pb"
13
-
14
- "github.com/libp2p/go-msgio/protoio"
15
-
16
- "github.com/gogo/protobuf/proto"
17
18
)
18
19
19
20
// get the initial RPC containing all of our subscriptions to send to new peers
@@ -60,11 +61,11 @@ func (p *PubSub) handleNewStream(s network.Stream) {
60
61
p .inboundStreamsMx .Unlock ()
61
62
}()
62
63
63
- r := protoio . NewDelimitedReader (s , p .maxMessageSize )
64
+ r := msgio . NewVarintReaderSize (s , p .maxMessageSize )
64
65
for {
65
- rpc := new (RPC )
66
- err := r .ReadMsg (& rpc .RPC )
66
+ msgbytes , err := r .ReadMsg ()
67
67
if err != nil {
68
+ r .ReleaseMsg (msgbytes )
68
69
if err != io .EOF {
69
70
s .Reset ()
70
71
log .Debugf ("error reading rpc from %s: %s" , s .Conn ().RemotePeer (), err )
@@ -77,6 +78,15 @@ func (p *PubSub) handleNewStream(s network.Stream) {
77
78
return
78
79
}
79
80
81
+ rpc := new (RPC )
82
+ err = rpc .Unmarshal (msgbytes )
83
+ r .ReleaseMsg (msgbytes )
84
+ if err != nil {
85
+ s .Reset ()
86
+ log .Warnf ("bogus rpc from %s: %s" , s .Conn ().RemotePeer (), err )
87
+ return
88
+ }
89
+
80
90
rpc .from = peer
81
91
select {
82
92
case p .incoming <- rpc :
@@ -115,7 +125,7 @@ func (p *PubSub) handleNewPeer(ctx context.Context, pid peer.ID, outgoing <-chan
115
125
}
116
126
117
127
go p .handleSendingMessages (ctx , s , outgoing )
118
- go p .handlePeerEOF ( ctx , s )
128
+ go p .handlePeerDead ( s )
119
129
select {
120
130
case p .newPeerStream <- s :
121
131
case <- ctx .Done ():
@@ -131,32 +141,33 @@ func (p *PubSub) handleNewPeerWithBackoff(ctx context.Context, pid peer.ID, back
131
141
}
132
142
}
133
143
134
- func (p * PubSub ) handlePeerEOF ( ctx context. Context , s network.Stream ) {
144
+ func (p * PubSub ) handlePeerDead ( s network.Stream ) {
135
145
pid := s .Conn ().RemotePeer ()
136
- r := protoio .NewDelimitedReader (s , p .maxMessageSize )
137
- rpc := new (RPC )
138
- for {
139
- err := r .ReadMsg (& rpc .RPC )
140
- if err != nil {
141
- p .notifyPeerDead (pid )
142
- return
143
- }
144
146
147
+ _ , err := s .Read ([]byte {0 })
148
+ if err == nil {
145
149
log .Debugf ("unexpected message from %s" , pid )
146
150
}
151
+
152
+ s .Reset ()
153
+ p .notifyPeerDead (pid )
147
154
}
148
155
149
156
func (p * PubSub ) handleSendingMessages (ctx context.Context , s network.Stream , outgoing <- chan * RPC ) {
150
- bufw := bufio .NewWriter (s )
151
- wc := protoio .NewDelimitedWriter (bufw )
157
+ writeRpc := func (rpc * RPC ) error {
158
+ size := uint64 (rpc .Size ())
159
+
160
+ buf := pool .Get (varint .UvarintSize (size ) + int (size ))
161
+ defer pool .Put (buf )
152
162
153
- writeMsg := func ( msg proto. Message ) error {
154
- err := wc . WriteMsg ( msg )
163
+ n := binary . PutUvarint ( buf , size )
164
+ _ , err := rpc . MarshalTo ( buf [ n :] )
155
165
if err != nil {
156
166
return err
157
167
}
158
168
159
- return bufw .Flush ()
169
+ _ , err = s .Write (buf )
170
+ return err
160
171
}
161
172
162
173
defer s .Close ()
@@ -167,7 +178,7 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou
167
178
return
168
179
}
169
180
170
- err := writeMsg ( & rpc . RPC )
181
+ err := writeRpc ( rpc )
171
182
if err != nil {
172
183
s .Reset ()
173
184
log .Debugf ("writing message to %s: %s" , s .Conn ().RemotePeer (), err )
0 commit comments