Skip to content

Commit 7ceeea6

Browse files
committed
Fix explicitly prepared statements with describe statement cache mode
fixes #1196
1 parent c6335a3 commit 7ceeea6

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

Diff for: conn.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ optionLoop:
645645
resultFormats = c.eqb.resultFormats
646646
}
647647

648-
if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe {
648+
if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe && !ok {
649649
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
650650
} else {
651651
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)

Diff for: conn_test.go

+44
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,50 @@ func TestPrepareIdempotency(t *testing.T) {
496496
}
497497
}
498498

499+
func TestPrepareStatementCacheModes(t *testing.T) {
500+
t.Parallel()
501+
502+
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
503+
504+
tests := []struct {
505+
name string
506+
buildStatementCache pgx.BuildStatementCacheFunc
507+
}{
508+
{
509+
name: "disabled",
510+
buildStatementCache: nil,
511+
},
512+
{
513+
name: "prepare",
514+
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
515+
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
516+
},
517+
},
518+
{
519+
name: "describe",
520+
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
521+
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
522+
},
523+
},
524+
}
525+
526+
for _, tt := range tests {
527+
t.Run(tt.name, func(t *testing.T) {
528+
config.BuildStatementCache = tt.buildStatementCache
529+
conn := mustConnect(t, config)
530+
defer closeConn(t, conn)
531+
532+
_, err := conn.Prepare(context.Background(), "test", "select $1::text")
533+
require.NoError(t, err)
534+
535+
var s string
536+
err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s)
537+
require.NoError(t, err)
538+
require.Equal(t, "hello", s)
539+
})
540+
}
541+
}
542+
499543
func TestListenNotify(t *testing.T) {
500544
t.Parallel()
501545

0 commit comments

Comments
 (0)