Skip to content

Commit 95bc43c

Browse files
committed
Add support for retrieving multiple results.
Reworked how the msg-ready command (`Z`) is processed. Previously execution of a query would look for the msg-ready command before completing the operation. Now, when executing a query, the driver places the connection into a state where it knows there may be more results. If another query is subsequently executed, the driver waits for the msg-ready command to arrive, discarding any other commands, before sending the new query. But if the special `NEXT` query is executed, the driver looks for another set of results on the connection. If no such results are found, ErrNoMoreResults is returned.
1 parent 3e033bd commit 95bc43c

File tree

2 files changed

+161
-33
lines changed

2 files changed

+161
-33
lines changed

conn.go

Lines changed: 96 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ var (
3232
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
3333
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
3434
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
35+
ErrNoMoreResults = errors.New("pq: no more results")
3536
)
3637

38+
const NextResults = "NEXT"
39+
3740
type drv struct{}
3841

3942
func (d *drv) Open(name string) (driver.Conn, error) {
@@ -115,6 +118,9 @@ type conn struct {
115118
// Whether to always send []byte parameters over as binary. Enables single
116119
// round-trip mode for non-prepared Query calls.
117120
binaryParameters bool
121+
122+
// Whether the connection is ready to execute a query.
123+
readyForQuery bool
118124
}
119125

120126
// Handle driver-side settings in parsed connection string.
@@ -587,6 +593,8 @@ func (cn *conn) gname() string {
587593
}
588594

589595
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
596+
cn.waitReadyForQuery()
597+
590598
b := cn.writeBuf('Q')
591599
b.string(q)
592600
cn.send(b)
@@ -614,51 +622,73 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
614622
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
615623
defer cn.errRecover(&err)
616624

617-
b := cn.writeBuf('Q')
618-
b.string(q)
619-
cn.send(b)
625+
querySent := false
626+
nextResult := q == NextResults
620627

621628
for {
629+
if cn.readyForQuery && !querySent {
630+
if nextResult {
631+
return nil, ErrNoMoreResults
632+
}
633+
634+
// Mark the connection has having sent a query.
635+
cn.readyForQuery = false
636+
b := cn.writeBuf('Q')
637+
b.string(q)
638+
cn.send(b)
639+
querySent = true
640+
}
641+
622642
t, r := cn.recv1()
623643
switch t {
624644
case 'C', 'I':
625-
// We allow queries which don't return any results through Query as
626-
// well as Exec. We still have to give database/sql a rows object
627-
// the user can close, though, to avoid connections from being
628-
// leaked. A "rows" with done=true works fine for that purpose.
629-
if err != nil {
630-
cn.bad = true
631-
errorf("unexpected message %q in simple query execution", t)
632-
}
633-
if res == nil {
634-
res = &rows{
635-
cn: cn,
645+
if nextResult || querySent {
646+
// We allow queries which don't return any results through Query as
647+
// well as Exec. We still have to give database/sql a rows object
648+
// the user can close, though, to avoid connections from being
649+
// leaked. A "rows" with done=true works fine for that purpose.
650+
if err != nil {
651+
cn.bad = true
652+
errorf("unexpected message %q in simple query execution", t)
653+
}
654+
if res == nil {
655+
res = &rows{
656+
cn: cn,
657+
}
636658
}
659+
res.done = true
637660
}
638-
res.done = true
639661
case 'Z':
640662
cn.processReadyForQuery(r)
641-
// done
642-
return
663+
if querySent {
664+
// done
665+
return
666+
}
643667
case 'E':
644-
res = nil
645-
err = parseError(r)
668+
if nextResult || querySent {
669+
res = nil
670+
err = parseError(r)
671+
}
646672
case 'D':
647-
if res == nil {
648-
cn.bad = true
649-
errorf("unexpected DataRow in simple query execution")
673+
if nextResult || querySent {
674+
if res == nil {
675+
cn.bad = true
676+
errorf("unexpected DataRow in simple query execution")
677+
}
678+
// the query didn't fail; kick off to Next
679+
cn.saveMessage(t, r)
680+
return
650681
}
651-
// the query didn't fail; kick off to Next
652-
cn.saveMessage(t, r)
653-
return
654682
case 'T':
655-
// res might be non-nil here if we received a previous
656-
// CommandComplete, but that's fine; just overwrite it
657-
res = &rows{cn: cn}
658-
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
659-
660-
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
661-
// until the first DataRow has been received.
683+
if nextResult || querySent {
684+
// res might be non-nil here if we received a previous
685+
// CommandComplete, but that's fine; just overwrite it
686+
res = &rows{cn: cn}
687+
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
688+
689+
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
690+
// until the first DataRow has been received.
691+
}
662692
default:
663693
cn.bad = true
664694
errorf("unknown response for simple query: %q", t)
@@ -742,6 +772,8 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
742772
}
743773
defer cn.errRecover(&err)
744774

775+
cn.waitReadyForQuery()
776+
745777
if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
746778
return cn.prepareCopyIn(q)
747779
}
@@ -777,6 +809,8 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err
777809
return cn.simpleQuery(query)
778810
}
779811

812+
cn.waitReadyForQuery()
813+
780814
if cn.binaryParameters {
781815
cn.sendBinaryModeQuery(query, args)
782816

@@ -813,6 +847,8 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
813847
return r, err
814848
}
815849

850+
cn.waitReadyForQuery()
851+
816852
if cn.binaryParameters {
817853
cn.sendBinaryModeQuery(query, args)
818854

@@ -1301,6 +1337,10 @@ func (st *stmt) exec(v []driver.Value) {
13011337
}
13021338

13031339
cn := st.cn
1340+
cn.waitReadyForQuery()
1341+
// Mark the connection has having sent a query.
1342+
cn.readyForQuery = false
1343+
13041344
w := cn.writeBuf('B')
13051345
w.byte(0) // unnamed portal
13061346
w.string(st.name)
@@ -1431,7 +1471,11 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
14311471
case 'E':
14321472
err = parseError(&rs.rb)
14331473
case 'C', 'I':
1434-
continue
1474+
rs.done = true
1475+
if err != nil {
1476+
return err
1477+
}
1478+
return io.EOF
14351479
case 'Z':
14361480
conn.processReadyForQuery(&rs.rb)
14371481
rs.done = true
@@ -1527,6 +1571,9 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
15271571
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
15281572
}
15291573

1574+
// Mark the connection has having sent a query.
1575+
cn.readyForQuery = false
1576+
15301577
b := cn.writeBuf('P')
15311578
b.byte(0) // unnamed statement
15321579
b.string(query)
@@ -1576,6 +1623,7 @@ func (c *conn) processParameterStatus(r *readBuf) {
15761623

15771624
func (c *conn) processReadyForQuery(r *readBuf) {
15781625
c.txnStatus = transactionStatus(r.byte())
1626+
c.readyForQuery = true
15791627
}
15801628

15811629
func (cn *conn) readReadyForQuery() {
@@ -1590,6 +1638,21 @@ func (cn *conn) readReadyForQuery() {
15901638
}
15911639
}
15921640

1641+
func (cn *conn) waitReadyForQuery() {
1642+
// The postgres server sends a 'Z' command when it is ready to receive a
1643+
// query. We use this as a sync marker to skip over commands we're not
1644+
// handling in our current state. For example, we might be skipping over
1645+
// subsequent results when a query contained multiple statements and only the
1646+
// first result was retrieved.
1647+
for !cn.readyForQuery {
1648+
t, r := cn.recv1()
1649+
switch t {
1650+
case 'Z':
1651+
cn.processReadyForQuery(r)
1652+
}
1653+
}
1654+
}
1655+
15931656
func (cn *conn) readParseResponse() {
15941657
t, r := cn.recv1()
15951658
switch t {

conn_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ func TestOpenURL(t *testing.T) {
136136
}
137137

138138
const pgpass_file = "/tmp/pqgotest_pgpass"
139+
139140
func TestPgpass(t *testing.T) {
140141
testAssert := func(conninfo string, expected string, reason string) {
141142
conn, err := openTestConnConninfo(conninfo)
@@ -339,6 +340,70 @@ func TestRowsCloseBeforeDone(t *testing.T) {
339340
}
340341
}
341342

343+
func TestMultipleResults(t *testing.T) {
344+
db := openTestConn(t)
345+
defer db.Close()
346+
347+
var val int
348+
if err := db.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
349+
t.Fatal(err)
350+
}
351+
if val != 1 {
352+
t.Fatalf("expected 1, but found %d", val)
353+
}
354+
if err := db.QueryRow(NextResults).Scan(&val); err != nil {
355+
t.Fatal(err)
356+
}
357+
if val != 2 {
358+
t.Fatalf("expected 2, but found %d", val)
359+
}
360+
if err := db.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
361+
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
362+
}
363+
364+
// Now test discarding the second result.
365+
if err := db.QueryRow("SELECT 3; SELECT 4").Scan(&val); err != nil {
366+
t.Fatal(err)
367+
}
368+
if val != 3 {
369+
t.Fatalf("expected 3, but found %d", val)
370+
}
371+
if err := db.QueryRow("SELECT 5").Scan(&val); err != nil {
372+
t.Fatal(err)
373+
}
374+
if val != 5 {
375+
t.Fatalf("expected 5, but found %d", val)
376+
}
377+
}
378+
379+
func TestTxnMultipleResults(t *testing.T) {
380+
db := openTestConn(t)
381+
defer db.Close()
382+
383+
tx, err := db.Begin()
384+
if err != nil {
385+
t.Fatal(err)
386+
}
387+
defer tx.Rollback()
388+
389+
var val int
390+
if err := tx.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
391+
t.Fatal(err)
392+
}
393+
if val != 1 {
394+
t.Fatalf("expected 1, but found %d", val)
395+
}
396+
if err := tx.QueryRow(NextResults).Scan(&val); err != nil {
397+
t.Fatal(err)
398+
}
399+
if val != 2 {
400+
t.Fatalf("expected 2, but found %d", val)
401+
}
402+
if err := tx.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
403+
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
404+
}
405+
}
406+
342407
func TestParameterCountMismatch(t *testing.T) {
343408
db := openTestConn(t)
344409
defer db.Close()

0 commit comments

Comments
 (0)