Skip to content

Commit 642a013

Browse files
authored
Merge pull request #254 from nhooyr/netconn-readlimit
netconn.go: Disable read limit on WebSocket
2 parents 085d46c + 11af7f8 commit 642a013

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

Diff for: conn_test.go

+31
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,37 @@ func TestConn(t *testing.T) {
208208
}
209209
})
210210

211+
t.Run("netConn/readLimit", func(t *testing.T) {
212+
tt, c1, c2 := newConnTest(t, nil, nil)
213+
214+
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
215+
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
216+
217+
s := strings.Repeat("papa", 1 << 20)
218+
errs := xsync.Go(func() error {
219+
_, err := n2.Write([]byte(s))
220+
if err != nil {
221+
return err
222+
}
223+
return n2.Close()
224+
})
225+
226+
b, err := ioutil.ReadAll(n1)
227+
assert.Success(t, err)
228+
229+
_, err = n1.Read(nil)
230+
assert.Equal(t, "read error", err, io.EOF)
231+
232+
select {
233+
case err := <-errs:
234+
assert.Success(t, err)
235+
case <-tt.ctx.Done():
236+
t.Fatal(tt.ctx.Err())
237+
}
238+
239+
assert.Equal(t, "read msg", s, string(b))
240+
})
241+
211242
t.Run("wsjson", func(t *testing.T) {
212243
tt, c1, c2 := newConnTest(t, nil, nil)
213244

Diff for: netconn.go

+4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ import (
3838
//
3939
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
4040
// io.EOF when reading.
41+
//
42+
// Furthermore, the ReadLimit is set to -1 to disable it.
4143
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
44+
c.SetReadLimit(-1)
45+
4246
nc := &netConn{
4347
c: c,
4448
msgType: msgType,

Diff for: read.go

+17-4
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,16 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
7474
// By default, the connection has a message read limit of 32768 bytes.
7575
//
7676
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
77+
//
78+
// Set to -1 to disable.
7779
func (c *Conn) SetReadLimit(n int64) {
78-
// We add read one more byte than the limit in case
79-
// there is a fin frame that needs to be read.
80-
c.msgReader.limitReader.limit.Store(n + 1)
80+
if n >= 0 {
81+
// We read one more byte than the limit in case
82+
// there is a fin frame that needs to be read.
83+
n++
84+
}
85+
86+
c.msgReader.limitReader.limit.Store(n)
8187
}
8288

8389
const defaultReadLimit = 32768
@@ -455,7 +461,11 @@ func (lr *limitReader) reset(r io.Reader) {
455461
}
456462

457463
func (lr *limitReader) Read(p []byte) (int, error) {
458-
if lr.n <= 0 {
464+
if lr.n < 0 {
465+
return lr.r.Read(p)
466+
}
467+
468+
if lr.n == 0 {
459469
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
460470
lr.c.writeError(StatusMessageTooBig, err)
461471
return 0, err
@@ -466,6 +476,9 @@ func (lr *limitReader) Read(p []byte) (int, error) {
466476
}
467477
n, err := lr.r.Read(p)
468478
lr.n -= int64(n)
479+
if lr.n < 0 {
480+
lr.n = 0
481+
}
469482
return n, err
470483
}
471484

0 commit comments

Comments
 (0)