Skip to content

Commit 8399c68

Browse files
committed
Remove bytecopy from GetBody
req.Body is already populated, and this, even on the first call, will write the payload again. Also adds a test guarding against the bug. Closes #130 Signed-off-by: Bret Comnes <[email protected]>
1 parent 41e24cc commit 8399c68

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

client/request.go

-3
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,6 @@ func (r *request) buildHTTP(mediaType, basePath string, producers map[string]run
213213
return nil, err
214214
}
215215

216-
if _, err := r.buf.Write(b.Bytes()); err != nil {
217-
return nil, err
218-
}
219216
return ioutil.NopCloser(&b), nil
220217
}
221218

client/request_test.go

+84
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ import (
1818
"bytes"
1919
"encoding/json"
2020
"encoding/xml"
21+
"errors"
22+
"io"
2123
"io/ioutil"
2224
"mime"
2325
"mime/multipart"
26+
"net/http"
27+
"net/http/httptest"
28+
"net/url"
2429
"os"
2530
"path/filepath"
2631
"strings"
@@ -29,6 +34,7 @@ import (
2934
"github.com/go-openapi/runtime"
3035
"github.com/go-openapi/strfmt"
3136
"github.com/stretchr/testify/assert"
37+
"github.com/stretchr/testify/require"
3238
)
3339

3440
var testProducers = map[string]runtime.Producer{
@@ -509,3 +515,81 @@ func TestBuildRequest_BuildHTTP_EscapedPath(t *testing.T) {
509515
assert.Equal(t, runtime.JSONMime, req.Header.Get(runtime.HeaderContentType))
510516
}
511517
}
518+
519+
type testReqFn func(*testing.T, *http.Request)
520+
521+
type testRoundTripper struct {
522+
tr http.RoundTripper
523+
testFn testReqFn
524+
testHarness *testing.T
525+
}
526+
527+
func (t *testRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
528+
t.testFn(t.testHarness, req)
529+
return t.tr.RoundTrip(req)
530+
}
531+
532+
func TestGetBodyCallsBeforeRoundTrip(t *testing.T) {
533+
534+
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
535+
rw.WriteHeader(http.StatusCreated)
536+
_, err := rw.Write([]byte("test result"))
537+
require.NoError(t, err)
538+
}))
539+
defer server.Close()
540+
hu, _ := url.Parse(server.URL)
541+
542+
client := http.DefaultClient
543+
transport := http.DefaultTransport
544+
545+
client.Transport = &testRoundTripper{
546+
tr: transport,
547+
testHarness: t,
548+
testFn: func(t *testing.T, req *http.Request) {
549+
// Read the body once before sending the request
550+
body, err := req.GetBody()
551+
require.NoError(t, err)
552+
bodyContent, err := ioutil.ReadAll(io.Reader(body))
553+
require.EqualValues(t, req.ContentLength, len(bodyContent))
554+
require.NoError(t, err)
555+
require.EqualValues(t, "\"test body\"\n", string(bodyContent))
556+
557+
// Read the body a second time before sending the request
558+
body, err = req.GetBody()
559+
require.NoError(t, err)
560+
bodyContent, err = ioutil.ReadAll(io.Reader(body))
561+
require.NoError(t, err)
562+
require.EqualValues(t, req.ContentLength, len(bodyContent))
563+
require.EqualValues(t, "\"test body\"\n", string(bodyContent))
564+
},
565+
}
566+
567+
rwrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, _ strfmt.Registry) error {
568+
return req.SetBodyParam("test body")
569+
})
570+
571+
operation := &runtime.ClientOperation{
572+
ID: "getSites",
573+
Method: "POST",
574+
PathPattern: "/",
575+
Params: rwrtr,
576+
Client: client,
577+
Reader: runtime.ClientResponseReaderFunc(func(response runtime.ClientResponse, consumer runtime.Consumer) (interface{}, error) {
578+
if response.Code() == http.StatusCreated {
579+
var result string
580+
if err := consumer.Consume(response.Body(), &result); err != nil {
581+
return nil, err
582+
}
583+
return result, nil
584+
}
585+
return nil, errors.New("Unexpected error code")
586+
}),
587+
}
588+
589+
openAPIClient := New(hu.Host, "/", []string{"http"})
590+
res, err := openAPIClient.Submit(operation)
591+
592+
require.NoError(t, err)
593+
actual := res.(string)
594+
require.EqualValues(t, "test result", actual)
595+
}

0 commit comments

Comments
 (0)