Skip to content

Commit d8de815

Browse files
author
Hayim Shaul
committed
use SqlStruct when query includes pagination
Signed-off-by: Hayim Shaul <[email protected]>
1 parent d8a85f1 commit d8de815

File tree

4 files changed

+114
-24
lines changed

4 files changed

+114
-24
lines changed

platform/common/utils/collections/iterators.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,23 @@ func ReadFirst[T any](it Iterator[*T], limit int) ([]T, error) {
5757
}
5858

5959
func ReadLast[T any](it Iterator[*T]) (*T, *sliceIterator[*T], error) {
60-
var items []*T
61-
// create a copy of it so not to ruin it
62-
for item, err := it.Next(); item != nil || err != nil; item, err = it.Next() {
63-
if err != nil {
64-
return nil, nil, err
65-
}
66-
items = append(items, item)
60+
var items []T
61+
var ptrItems []*T
62+
var err error
63+
items, err = ReadAll(it)
64+
65+
if err != nil {
66+
return nil, nil, err
67+
}
68+
// convert from []T to []*T
69+
for i := range items {
70+
ptrItems = append(ptrItems, &items[i])
6771
}
68-
si := NewSliceIterator(items)
72+
si := NewSliceIterator(ptrItems)
6973
if len(items) == 0 {
7074
return nil, si, nil
7175
}
72-
return items[len(items)-1], si, nil
76+
return ptrItems[len(ptrItems)-1], si, nil
7377
}
7478

7579
func ReadAll[T any](it Iterator[*T]) ([]T, error) {

platform/view/services/db/driver/sql/common/pagination.go

+17-9
Original file line numberDiff line numberDiff line change
@@ -154,27 +154,35 @@ func NewPaginationInterpreter() *paginationInterpreter {
154154
}
155155

156156
type PaginationInterpreter interface {
157-
Interpret(p driver.Pagination) (string, error)
157+
Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error)
158158
}
159159

160160
type paginationInterpreter struct{}
161161

162-
func (i *paginationInterpreter) Interpret(p driver.Pagination) (string, error) {
162+
func (i *paginationInterpreter) Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error) {
163163
switch pagination := p.(type) {
164164
case *NoPagination:
165-
return "", nil
165+
return sql, nil
166166
case *OffsetPagination:
167-
return fmt.Sprintf("LIMIT %d OFFSET %d", pagination.pageSize, pagination.offset), nil
167+
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
168+
sql.offset = fmt.Sprintf("%d", pagination.offset)
169+
return sql, nil
168170
case *KeysetPagination:
169-
// TODO: add OrderBy?
171+
sql.order = fmt.Sprintf("%s ASC", pagination.sqlIdName)
172+
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
170173
if (pagination.lastOffset != -1) && (pagination.offset == pagination.lastOffset+pagination.pageSize) {
171-
return fmt.Sprintf("WHERE %s>'%s' ORDER BY %s ASC LIMIT %d", pagination.sqlIdName, pagination.lastId, pagination.sqlIdName, pagination.pageSize), nil
174+
lastId := sql.AddParam(fmt.Sprintf("%d", pagination.lastId))
175+
sql.where = append(sql.where, fmt.Sprintf("%s>'$%d'", pagination.sqlIdName, lastId))
176+
} else {
177+
sql.offset = fmt.Sprintf("%d", pagination.offset)
172178
}
173-
return fmt.Sprintf("ORDER BY %s ASC LIMIT %d OFFSET %d", pagination.sqlIdName, pagination.pageSize, pagination.offset), nil
179+
return sql, nil
174180
case *EmptyPagination:
175-
return "LIMIT 0 OFFSET 0", nil
181+
sql.limit = "0"
182+
sql.offset = "0"
183+
return sql, nil
176184
default:
177-
return "", errors.Errorf("invalid pagination option %+v", pagination)
185+
return sql, errors.Errorf("invalid pagination option %+v", pagination)
178186
}
179187
}
180188

platform/view/services/db/driver/sql/common/vault.go

+75-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"encoding/gob"
1414
errors2 "errors"
1515
"fmt"
16+
"strings"
1617
"sync"
1718

1819
"github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors"
@@ -62,6 +63,60 @@ type VaultPersistence struct {
6263
il IsolationLevelMapper
6364
}
6465

66+
type SqlQuery struct {
67+
table string
68+
fields []string
69+
where []string
70+
limit string
71+
offset string
72+
order string
73+
params []any
74+
}
75+
76+
func (q *SqlQuery) AddWhere(p string) {
77+
q.where = append(q.where, p)
78+
}
79+
80+
func (q *SqlQuery) AddParam(p any) int {
81+
q.params = append(q.params, p)
82+
return len(q.params) - 1
83+
}
84+
85+
func (q *SqlQuery) FormatQuery() string {
86+
if q.table == "" {
87+
return ""
88+
}
89+
90+
// SELECT fields
91+
fields := "*"
92+
if len(q.fields) > 0 {
93+
fields = strings.Join(q.fields, ", ")
94+
}
95+
query := fmt.Sprintf("SELECT %s FROM %s", fields, q.table)
96+
97+
// WHERE clause
98+
if len(q.where) > 0 {
99+
query += " WHERE " + strings.Join(q.where, " AND ")
100+
}
101+
102+
// ORDER BY
103+
if q.order != "" {
104+
query += " ORDER BY " + q.order
105+
}
106+
107+
// LIMIT
108+
if q.limit != "" {
109+
query += " LIMIT " + q.limit
110+
}
111+
112+
// OFFSET
113+
if q.offset != "" {
114+
query += " OFFSET " + q.offset
115+
}
116+
117+
return query
118+
}
119+
65120
func (db *VaultPersistence) NewTxLockVaultReader(ctx context.Context, txID driver.TxID, isolationLevel driver.IsolationLevel) (driver.LockedVaultReader, error) {
66121
logger.Debugf("Acquire tx id lock for [%s]", txID)
67122
span := trace.SpanFromContext(ctx)
@@ -439,11 +494,16 @@ func (db *vaultReader) GetLast(ctx context.Context) (*driver.TxStatus, error) {
439494
span.AddEvent("start_get_last")
440495
defer span.AddEvent("end_get_last")
441496
it, err := db.queryStatus(fmt.Sprintf("WHERE pos=(SELECT max(pos) FROM %s WHERE code!=$1)", db.tables.StatusTable), []any{driver.Busy}, "")
497+
// sql := SqlQuery{}
498+
// driverBusy := sql.AddParam(driver.Busy)
499+
// sql.AddWhere(fmt.Sprintf("WHERE pos=(SELECT max(pos) FROM %s WHERE code!=$%d)", db.tables.StatusTable, driverBusy))
500+
// it, err := db.queryStatus(sql)
442501
if err != nil {
443502
return nil, err
444503
}
445504
return collections.GetUnique(it)
446505
}
506+
447507
func (db *vaultReader) GetTxStatus(ctx context.Context, txID driver.TxID) (*driver.TxStatus, error) {
448508
span := trace.SpanFromContext(ctx)
449509
span.AddEvent("start_get_tx_status")
@@ -469,11 +529,12 @@ func (db *vaultReader) GetAllTxStatuses(ctx context.Context, pagination driver.P
469529
if pagination == nil {
470530
return nil, fmt.Errorf("invalid input pagination: %+v", pagination)
471531
}
472-
limit, err := db.pi.Interpret(pagination)
532+
sql := SqlQuery{}
533+
sql, err := db.pi.Interpret(pagination, sql)
473534
if err != nil {
474535
return nil, err
475536
}
476-
txStatusIterator, err := db.queryStatus("", []any{}, limit)
537+
txStatusIterator, err := db.queryStatusWithPagination(sql)
477538
if err != nil {
478539
return nil, err
479540
}
@@ -494,6 +555,18 @@ func (db *vaultReader) queryStatus(where string, params []any, limit string) (dr
494555
return &TxCodeIterator{rows: rows}, nil
495556
}
496557

558+
func (db *vaultReader) queryStatusWithPagination(sql SqlQuery) (driver.TxStatusIterator, error) {
559+
sql.table = db.tables.StatusTable
560+
query := sql.FormatQuery()
561+
logger.Debug(query, sql.params)
562+
563+
rows, err := db.readDB.Query(query, sql.params...)
564+
if err != nil {
565+
return nil, err
566+
}
567+
return &TxCodeIterator{rows: rows}, nil
568+
}
569+
497570
func (db *VaultPersistence) Close() error {
498571
return errors2.Join(db.writeDB.Close(), db.readDB.Close())
499572
}

platform/view/services/storage/vault/vaultstore_test.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package vault
88

99
import (
1010
"context"
11+
"fmt"
1112
"testing"
1213

1314
"github.com/hyperledger-labs/fabric-smart-client/platform/common/driver"
@@ -167,9 +168,11 @@ func testPagination(t *testing.T, store driver.VaultStore) {
167168
pagination := item.pagination
168169
page := 0
169170
for ; true; page++ {
170-
sql, err := interpreter.Interpret((pagination))
171+
sql := common.SqlQuery{}
172+
sql, err := interpreter.Interpret(pagination, sql)
171173
Expect(err).ToNot(HaveOccurred())
172-
Expect(sql).To(Equal(item.sqlForward[page]))
174+
fmt.Printf("sql (forward) = %s\n", sql.FormatQuery())
175+
Expect(sql.FormatQuery()).To(Equal(item.sqlForward[page]))
173176
statuses, err := getAllTxStatuses(store, pagination)
174177
Expect(err).ToNot(HaveOccurred())
175178
Expect(err).ToNot(HaveOccurred())
@@ -192,9 +195,11 @@ func testPagination(t *testing.T, store driver.VaultStore) {
192195
Expect(err).ToNot(HaveOccurred())
193196
}
194197
for page := len(item.matcher) - 1; page >= 0; page-- {
195-
sql, err := interpreter.Interpret((pagination))
198+
sql := common.SqlQuery{}
199+
sql, err := interpreter.Interpret(pagination, sql)
196200
Expect(err).ToNot(HaveOccurred())
197-
Expect(sql).To(Equal(item.sqlBackward[page]))
201+
fmt.Printf("sql (backward) = %s\n", sql.FormatQuery())
202+
Expect(sql.FormatQuery()).To(Equal(item.sqlBackward[page]))
198203
statuses, err := getAllTxStatuses(store, pagination)
199204
Expect(err).ToNot(HaveOccurred())
200205
Expect(statuses).To(item.matcher[page])

0 commit comments

Comments
 (0)