diff --git a/conn.go b/conn.go index 34de352..3bf8570 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ package jsonrpc2 import ( + "bytes" "context" "fmt" "sync" @@ -141,7 +142,9 @@ func (c *conn) Call(ctx context.Context, method string, params, result interface return id, nil } - if err := json.Unmarshal(resp.result, result); err != nil { + dec := json.NewDecoder(bytes.NewReader(resp.result)) + dec.ZeroCopy() + if err := dec.Decode(result); err != nil { return id, fmt.Errorf("unmarshaling result: %w", err) } diff --git a/go.mod b/go.mod index 5e2a8ff..4d7ddb2 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module go.lsp.dev/jsonrpc2 -go 1.15 +go 1.16 require ( - github.com/segmentio/encoding v0.2.7 + github.com/segmentio/encoding v0.2.17 go.lsp.dev/pkg v0.0.0-20210125030640-b6310ac75a91 ) diff --git a/go.sum b/go.sum index 444d703..5462feb 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ -github.com/segmentio/encoding v0.2.7 h1:TKxEiKbernCFCTFW5wnSlE21kIQpqcY/ABXjhc9YeJU= -github.com/segmentio/encoding v0.2.7/go.mod h1:MJjRE6bMDocliO2FyFC2Dusp+uYdBfHWh5Bw7QyExto= +github.com/klauspost/cpuid/v2 v2.0.5 h1:qnfhwbFriwDIX51QncuNU5mEMf+6KE3t7O8V2KQl3Dg= +github.com/klauspost/cpuid/v2 v2.0.5/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/segmentio/encoding v0.2.17 h1:cgfmPc44u1po1lz5bSgF00gLCROBjDNc7h+H7I20zpc= +github.com/segmentio/encoding v0.2.17/go.mod h1:7E68jTSWMnNoYhHi1JbLd7NBSB6XfE4vzqhR88hDBQc= go.lsp.dev/pkg v0.0.0-20210125030640-b6310ac75a91 h1:JPKNt/RzBcOc89rhZ4Vl6U05Y1nN37FAc8PTKE3hssk= go.lsp.dev/pkg v0.0.0-20210125030640-b6310ac75a91/go.mod h1:gtSHRuYfbCT0qnbLnovpie/WEmqyJ7T4n6VXiFMBtcw= diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index c604c88..405ee14 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -4,6 +4,7 @@ package jsonrpc2_test import ( + "bytes" "context" "fmt" "io" @@ -141,21 +142,27 @@ func testHandler() jsonrpc2.Handler { case methodOneString: var v string - if err := json.Unmarshal(req.Params(), &v); err != nil { + dec := json.NewDecoder(bytes.NewReader(req.Params())) + dec.ZeroCopy() + if err := dec.Decode(&v); err != nil { return reply(ctx, nil, fmt.Errorf("%s: %w", jsonrpc2.ErrParse, err)) } return reply(ctx, "got:"+v, nil) case methodOneNumber: var v int - if err := json.Unmarshal(req.Params(), &v); err != nil { + dec := json.NewDecoder(bytes.NewReader(req.Params())) + dec.ZeroCopy() + if err := dec.Decode(&v); err != nil { return reply(ctx, nil, fmt.Errorf("%s: %w", jsonrpc2.ErrParse, err)) } return reply(ctx, fmt.Sprintf("got:%d", v), nil) case methodJoin: var v []string - if err := json.Unmarshal(req.Params(), &v); err != nil { + dec := json.NewDecoder(bytes.NewReader(req.Params())) + dec.ZeroCopy() + if err := dec.Decode(&v); err != nil { return reply(ctx, nil, fmt.Errorf("%s: %w", jsonrpc2.ErrParse, err)) } return reply(ctx, path.Join(v...), nil) diff --git a/message.go b/message.go index 2fdc817..ba89a09 100644 --- a/message.go +++ b/message.go @@ -4,6 +4,7 @@ package jsonrpc2 import ( + "bytes" "errors" "fmt" @@ -102,7 +103,9 @@ func (c Call) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (c *Call) UnmarshalJSON(data []byte) error { var req wireRequest - if err := json.Unmarshal(data, &req); err != nil { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.ZeroCopy() + if err := dec.Decode(&req); err != nil { return fmt.Errorf("unmarshaling call: %w", err) } @@ -181,7 +184,9 @@ func (r Response) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (r *Response) UnmarshalJSON(data []byte) error { var resp wireResponse - if err := json.Unmarshal(data, &resp); err != nil { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.ZeroCopy() + if err := dec.Decode(&resp); err != nil { return fmt.Errorf("unmarshaling jsonrpc response: %w", err) } @@ -276,7 +281,9 @@ func (n Notification) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (n *Notification) UnmarshalJSON(data []byte) error { var req wireRequest - if err := json.Unmarshal(data, &req); err != nil { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.ZeroCopy() + if err := dec.Decode(&req); err != nil { return fmt.Errorf("unmarshaling notification: %w", err) } @@ -291,7 +298,9 @@ func (n *Notification) UnmarshalJSON(data []byte) error { // DecodeMessage decodes data to Message. func DecodeMessage(data []byte) (Message, error) { var msg combined - if err := json.Unmarshal(data, &msg); err != nil { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.ZeroCopy() + if err := dec.Decode(&msg); err != nil { return nil, fmt.Errorf("unmarshaling jsonrpc message: %w", err) } diff --git a/wire_test.go b/wire_test.go index 9052646..66c6026 100644 --- a/wire_test.go +++ b/wire_test.go @@ -85,7 +85,9 @@ func TestIDDecode(t *testing.T) { t.Parallel() var got *jsonrpc2.ID - if err := json.Unmarshal(tt.encoded, &got); err != nil { + dec := json.NewDecoder(bytes.NewReader(tt.encoded)) + dec.ZeroCopy() + if err := dec.Decode(&got); err != nil { t.Fatal(err) }