Skip to content

Commit 97b255f

Browse files
author
Hayim Shaul
committed
use SqlStruct when query includes pagination
Signed-off-by: Hayim Shaul <[email protected]>
1 parent 6cdc6a2 commit 97b255f

File tree

4 files changed

+114
-24
lines changed

4 files changed

+114
-24
lines changed

platform/common/utils/collections/iterators.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,23 @@ func ReadFirst[T any](it Iterator[*T], limit int) ([]T, error) {
5757
}
5858

5959
func ReadLast[T any](it Iterator[*T]) (*T, *sliceIterator[*T], error) {
60-
var items []*T
61-
// create a copy of it so not to ruin it
62-
for item, err := it.Next(); item != nil || err != nil; item, err = it.Next() {
63-
if err != nil {
64-
return nil, nil, err
65-
}
66-
items = append(items, item)
60+
var items []T
61+
var ptrItems []*T
62+
var err error
63+
items, err = ReadAll(it)
64+
65+
if err != nil {
66+
return nil, nil, err
67+
}
68+
// convert from []T to []*T
69+
for i := range items {
70+
ptrItems = append(ptrItems, &items[i])
6771
}
68-
si := NewSliceIterator(items)
72+
si := NewSliceIterator(ptrItems)
6973
if len(items) == 0 {
7074
return nil, si, nil
7175
}
72-
return items[len(items)-1], si, nil
76+
return ptrItems[len(ptrItems)-1], si, nil
7377
}
7478

7579
func ReadAll[T any](it Iterator[*T]) ([]T, error) {

platform/view/services/db/driver/sql/common/pagination.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,27 +154,35 @@ func NewPaginationInterpreter() *paginationInterpreter {
154154
}
155155

156156
type PaginationInterpreter interface {
157-
Interpret(p driver.Pagination) (string, error)
157+
Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error)
158158
}
159159

160160
type paginationInterpreter struct{}
161161

162-
func (i *paginationInterpreter) Interpret(p driver.Pagination) (string, error) {
162+
func (i *paginationInterpreter) Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error) {
163163
switch pagination := p.(type) {
164164
case *NoPagination:
165-
return "", nil
165+
return sql, nil
166166
case *OffsetPagination:
167-
return fmt.Sprintf("LIMIT %d OFFSET %d", pagination.pageSize, pagination.offset), nil
167+
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
168+
sql.offset = fmt.Sprintf("%d", pagination.offset)
169+
return sql, nil
168170
case *KeysetPagination:
169-
// TODO: add OrderBy?
171+
sql.order = fmt.Sprintf("%s ASC", pagination.sqlIdName)
172+
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
170173
if (pagination.lastOffset != -1) && (pagination.offset == pagination.lastOffset+pagination.pageSize) {
171-
return fmt.Sprintf("WHERE %s>'%s' ORDER BY %s ASC LIMIT %d", pagination.sqlIdName, pagination.lastId, pagination.sqlIdName, pagination.pageSize), nil
174+
lastId := sql.AddParam(fmt.Sprintf("%d", pagination.lastId))
175+
sql.where = append(sql.where, fmt.Sprintf("%s>'$%d'", pagination.sqlIdName, lastId))
176+
} else {
177+
sql.offset = fmt.Sprintf("%d", pagination.offset)
172178
}
173-
return fmt.Sprintf("ORDER BY %s ASC LIMIT %d OFFSET %d", pagination.sqlIdName, pagination.pageSize, pagination.offset), nil
179+
return sql, nil
174180
case *EmptyPagination:
175-
return "LIMIT 0 OFFSET 0", nil
181+
sql.limit = "0"
182+
sql.offset = "0"
183+
return sql, nil
176184
default:
177-
return "", errors.Errorf("invalid pagination option %+v", pagination)
185+
return sql, errors.Errorf("invalid pagination option %+v", pagination)
178186
}
179187
}
180188

platform/view/services/db/driver/sql/common/vault.go

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"encoding/gob"
1414
errors2 "errors"
1515
"fmt"
16+
"strings"
1617
"sync"
1718

1819
"github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors"
@@ -63,6 +64,60 @@ type VaultPersistence struct {
6364
il IsolationLevelMapper
6465
}
6566

67+
type SqlQuery struct {
68+
table string
69+
fields []string
70+
where []string
71+
limit string
72+
offset string
73+
order string
74+
params []any
75+
}
76+
77+
func (q *SqlQuery) AddWhere(p string) {
78+
q.where = append(q.where, p)
79+
}
80+
81+
func (q *SqlQuery) AddParam(p any) int {
82+
q.params = append(q.params, p)
83+
return len(q.params) - 1
84+
}
85+
86+
func (q *SqlQuery) FormatQuery() string {
87+
if q.table == "" {
88+
return ""
89+
}
90+
91+
// SELECT fields
92+
fields := "*"
93+
if len(q.fields) > 0 {
94+
fields = strings.Join(q.fields, ", ")
95+
}
96+
query := fmt.Sprintf("SELECT %s FROM %s", fields, q.table)
97+
98+
// WHERE clause
99+
if len(q.where) > 0 {
100+
query += " WHERE " + strings.Join(q.where, " AND ")
101+
}
102+
103+
// ORDER BY
104+
if q.order != "" {
105+
query += " ORDER BY " + q.order
106+
}
107+
108+
// LIMIT
109+
if q.limit != "" {
110+
query += " LIMIT " + q.limit
111+
}
112+
113+
// OFFSET
114+
if q.offset != "" {
115+
query += " OFFSET " + q.offset
116+
}
117+
118+
return query
119+
}
120+
66121
func (db *VaultPersistence) NewTxLockVaultReader(ctx context.Context, txID driver.TxID, isolationLevel driver.IsolationLevel) (driver.LockedVaultReader, error) {
67122
logger.Debugf("Acquire tx id lock for [%s]", txID)
68123
span := trace.SpanFromContext(ctx)
@@ -440,11 +495,16 @@ func (db *vaultReader) GetLast(ctx context.Context) (*driver.TxStatus, error) {
440495
span.AddEvent("start_get_last")
441496
defer span.AddEvent("end_get_last")
442497
it, err := db.queryStatus(fmt.Sprintf("WHERE pos=(SELECT max(pos) FROM %s WHERE code!=$1)", db.tables.StatusTable), []any{driver.Busy}, "")
498+
// sql := SqlQuery{}
499+
// driverBusy := sql.AddParam(driver.Busy)
500+
// sql.AddWhere(fmt.Sprintf("WHERE pos=(SELECT max(pos) FROM %s WHERE code!=$%d)", db.tables.StatusTable, driverBusy))
501+
// it, err := db.queryStatus(sql)
443502
if err != nil {
444503
return nil, err
445504
}
446505
return collections.GetUnique(it)
447506
}
507+
448508
func (db *vaultReader) GetTxStatus(ctx context.Context, txID driver.TxID) (*driver.TxStatus, error) {
449509
span := trace.SpanFromContext(ctx)
450510
span.AddEvent("start_get_tx_status")
@@ -470,11 +530,12 @@ func (db *vaultReader) GetAllTxStatuses(ctx context.Context, pagination driver.P
470530
if pagination == nil {
471531
return nil, fmt.Errorf("invalid input pagination: %+v", pagination)
472532
}
473-
limit, err := db.pi.Interpret(pagination)
533+
sql := SqlQuery{}
534+
sql, err := db.pi.Interpret(pagination, sql)
474535
if err != nil {
475536
return nil, err
476537
}
477-
txStatusIterator, err := db.queryStatus("", []any{}, limit)
538+
txStatusIterator, err := db.queryStatusWithPagination(sql)
478539
if err != nil {
479540
return nil, err
480541
}
@@ -495,6 +556,18 @@ func (db *vaultReader) queryStatus(where string, params []any, limit string) (dr
495556
return &TxCodeIterator{rows: rows}, nil
496557
}
497558

559+
func (db *vaultReader) queryStatusWithPagination(sql SqlQuery) (driver.TxStatusIterator, error) {
560+
sql.table = db.tables.StatusTable
561+
query := sql.FormatQuery()
562+
logger.Debug(query, sql.params)
563+
564+
rows, err := db.readDB.Query(query, sql.params...)
565+
if err != nil {
566+
return nil, err
567+
}
568+
return &TxCodeIterator{rows: rows}, nil
569+
}
570+
498571
func (db *VaultPersistence) Close() error {
499572
return errors2.Join(db.writeDB.Close(), db.readDB.Close())
500573
}

platform/view/services/storage/vault/vaultstore_test.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package vault
88

99
import (
1010
"context"
11+
"fmt"
1112
"testing"
1213

1314
"github.com/hyperledger-labs/fabric-smart-client/pkg/utils"
@@ -168,9 +169,11 @@ func testPagination(t *testing.T, store driver.VaultStore) {
168169
pagination := item.pagination
169170
page := 0
170171
for ; true; page++ {
171-
sql, err := interpreter.Interpret((pagination))
172+
sql := common.SqlQuery{}
173+
sql, err := interpreter.Interpret(pagination, sql)
172174
Expect(err).ToNot(HaveOccurred())
173-
Expect(sql).To(Equal(item.sqlForward[page]))
175+
fmt.Printf("sql (forward) = %s\n", sql.FormatQuery())
176+
Expect(sql.FormatQuery()).To(Equal(item.sqlForward[page]))
174177
statuses, err := getAllTxStatuses(store, pagination)
175178
Expect(err).ToNot(HaveOccurred())
176179
Expect(err).ToNot(HaveOccurred())
@@ -193,9 +196,11 @@ func testPagination(t *testing.T, store driver.VaultStore) {
193196
Expect(err).ToNot(HaveOccurred())
194197
}
195198
for page := len(item.matcher) - 1; page >= 0; page-- {
196-
sql, err := interpreter.Interpret((pagination))
199+
sql := common.SqlQuery{}
200+
sql, err := interpreter.Interpret(pagination, sql)
197201
Expect(err).ToNot(HaveOccurred())
198-
Expect(sql).To(Equal(item.sqlBackward[page]))
202+
fmt.Printf("sql (backward) = %s\n", sql.FormatQuery())
203+
Expect(sql.FormatQuery()).To(Equal(item.sqlBackward[page]))
199204
statuses, err := getAllTxStatuses(store, pagination)
200205
Expect(err).ToNot(HaveOccurred())
201206
Expect(statuses).To(item.matcher[page])

0 commit comments

Comments
 (0)