Skip to content

Commit cca0657

Browse files
committed
Add support for retrieving multiple results.
If the special `NEXT` query is executed, the driver looks for another set of results on the connection. If no such results are found, ErrNoMoreResults is returned.
1 parent de04a24 commit cca0657

File tree

2 files changed

+121
-36
lines changed

2 files changed

+121
-36
lines changed

conn.go

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ var (
3232
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
3333
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
3434
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
35+
ErrNoMoreResults = errors.New("pq: no more results")
3536
)
3637

38+
const NextResults = "NEXT"
39+
3740
type drv struct{}
3841

3942
func (d *drv) Open(name string) (driver.Conn, error) {
@@ -619,56 +622,73 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
619622
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
620623
defer cn.errRecover(&err)
621624

622-
cn.waitReadyForQuery()
623-
624-
// Mark the connection has having sent a query.
625-
cn.readyForQuery = false
626-
b := cn.writeBuf('Q')
627-
b.string(q)
628-
cn.send(b)
625+
querySent := false
626+
nextResult := q == NextResults
629627

630628
for {
629+
if cn.readyForQuery && !querySent {
630+
if nextResult {
631+
return nil, ErrNoMoreResults
632+
}
633+
634+
// Mark the connection has having sent a query.
635+
cn.readyForQuery = false
636+
b := cn.writeBuf('Q')
637+
b.string(q)
638+
cn.send(b)
639+
querySent = true
640+
}
631641

632642
t, r := cn.recv1()
633643
switch t {
634644
case 'C', 'I':
635-
// We allow queries which don't return any results through Query as
636-
// well as Exec. We still have to give database/sql a rows object
637-
// the user can close, though, to avoid connections from being
638-
// leaked. A "rows" with done=true works fine for that purpose.
639-
if err != nil {
640-
cn.bad = true
641-
errorf("unexpected message %q in simple query execution", t)
642-
}
643-
if res == nil {
644-
res = &rows{
645-
cn: cn,
645+
if nextResult || querySent {
646+
// We allow queries which don't return any results through Query as
647+
// well as Exec. We still have to give database/sql a rows object
648+
// the user can close, though, to avoid connections from being
649+
// leaked. A "rows" with done=true works fine for that purpose.
650+
if err != nil {
651+
cn.bad = true
652+
errorf("unexpected message %q in simple query execution", t)
646653
}
654+
if res == nil {
655+
res = &rows{
656+
cn: cn,
657+
}
658+
}
659+
res.done = true
647660
}
648-
res.done = true
649661
case 'Z':
650662
cn.processReadyForQuery(r)
651-
// done
652-
return
663+
if querySent {
664+
// done
665+
return
666+
}
653667
case 'E':
654-
res = nil
655-
err = parseError(r)
668+
if nextResult || querySent {
669+
res = nil
670+
err = parseError(r)
671+
}
656672
case 'D':
657-
if res == nil {
658-
cn.bad = true
659-
errorf("unexpected DataRow in simple query execution")
673+
if nextResult || querySent {
674+
if res == nil {
675+
cn.bad = true
676+
errorf("unexpected DataRow in simple query execution")
677+
}
678+
// the query didn't fail; kick off to Next
679+
cn.saveMessage(t, r)
680+
return
660681
}
661-
// the query didn't fail; kick off to Next
662-
cn.saveMessage(t, r)
663-
return
664682
case 'T':
665-
// res might be non-nil here if we received a previous
666-
// CommandComplete, but that's fine; just overwrite it
667-
res = &rows{cn: cn}
668-
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
669-
670-
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
671-
// until the first DataRow has been received.
683+
if nextResult || querySent {
684+
// res might be non-nil here if we received a previous
685+
// CommandComplete, but that's fine; just overwrite it
686+
res = &rows{cn: cn}
687+
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
688+
689+
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
690+
// until the first DataRow has been received.
691+
}
672692
default:
673693
cn.bad = true
674694
errorf("unknown response for simple query: %q", t)

conn_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ func TestOpenURL(t *testing.T) {
136136
}
137137

138138
const pgpass_file = "/tmp/pqgotest_pgpass"
139+
139140
func TestPgpass(t *testing.T) {
140141
testAssert := func(conninfo string, expected string, reason string) {
141142
conn, err := openTestConnConninfo(conninfo)
@@ -339,6 +340,70 @@ func TestRowsCloseBeforeDone(t *testing.T) {
339340
}
340341
}
341342

343+
func TestMultipleResults(t *testing.T) {
344+
db := openTestConn(t)
345+
defer db.Close()
346+
347+
var val int
348+
if err := db.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
349+
t.Fatal(err)
350+
}
351+
if val != 1 {
352+
t.Fatalf("expected 1, but found %d", val)
353+
}
354+
if err := db.QueryRow(NextResults).Scan(&val); err != nil {
355+
t.Fatal(err)
356+
}
357+
if val != 2 {
358+
t.Fatalf("expected 2, but found %d", val)
359+
}
360+
if err := db.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
361+
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
362+
}
363+
364+
// Now test discarding the second result.
365+
if err := db.QueryRow("SELECT 3; SELECT 4").Scan(&val); err != nil {
366+
t.Fatal(err)
367+
}
368+
if val != 3 {
369+
t.Fatalf("expected 3, but found %d", val)
370+
}
371+
if err := db.QueryRow("SELECT 5").Scan(&val); err != nil {
372+
t.Fatal(err)
373+
}
374+
if val != 5 {
375+
t.Fatalf("expected 5, but found %d", val)
376+
}
377+
}
378+
379+
func TestTxnMultipleResults(t *testing.T) {
380+
db := openTestConn(t)
381+
defer db.Close()
382+
383+
tx, err := db.Begin()
384+
if err != nil {
385+
t.Fatal(err)
386+
}
387+
defer tx.Rollback()
388+
389+
var val int
390+
if err := tx.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
391+
t.Fatal(err)
392+
}
393+
if val != 1 {
394+
t.Fatalf("expected 1, but found %d", val)
395+
}
396+
if err := tx.QueryRow(NextResults).Scan(&val); err != nil {
397+
t.Fatal(err)
398+
}
399+
if val != 2 {
400+
t.Fatalf("expected 2, but found %d", val)
401+
}
402+
if err := tx.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
403+
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
404+
}
405+
}
406+
342407
func TestParameterCountMismatch(t *testing.T) {
343408
db := openTestConn(t)
344409
defer db.Close()

0 commit comments

Comments
 (0)