Skip to content

Commit 915a824

Browse files
author
Hayim Shaul
committed
bug fixes
Signed-off-by: Hayim Shaul <[email protected]>
1 parent d8de815 commit 915a824

File tree

4 files changed

+153
-131
lines changed

4 files changed

+153
-131
lines changed

platform/common/driver/vault.go

+81-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package driver
88

99
import (
1010
"context"
11+
"fmt"
12+
"strings"
1113

1214
"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections"
1315
)
@@ -62,6 +64,84 @@ type TxStateIterator = collections.Iterator[*VaultRead]
6264

6365
type VersionedResultsIterator = collections.Iterator[*VaultRead]
6466

67+
type SqlQuery struct {
68+
table string
69+
fields []string
70+
where []string
71+
limit string
72+
offset string
73+
order string
74+
params []any
75+
}
76+
77+
func (q *SqlQuery) SetTable(t string) {
78+
q.table = t
79+
}
80+
81+
func (q *SqlQuery) SetLimit(l string) {
82+
q.limit = l
83+
}
84+
85+
func (q *SqlQuery) SetOffset(o string) {
86+
q.offset = o
87+
}
88+
89+
func (q *SqlQuery) SetOrder(o string) {
90+
q.order = o
91+
}
92+
93+
func (q *SqlQuery) AddWhere(p string) {
94+
q.where = append(q.where, p)
95+
}
96+
97+
func (q *SqlQuery) AddFields(f []string) {
98+
q.fields = append(q.fields, f...)
99+
}
100+
101+
func (q *SqlQuery) AddParam(p any) int {
102+
q.params = append(q.params, p)
103+
return len(q.params)
104+
}
105+
106+
func (q *SqlQuery) GetParams() []any {
107+
return q.params
108+
}
109+
110+
func (q *SqlQuery) FormatQuery() string {
111+
if q.table == "" {
112+
return ""
113+
}
114+
115+
// SELECT fields
116+
fields := "*"
117+
if len(q.fields) > 0 {
118+
fields = strings.Join(q.fields, ", ")
119+
}
120+
query := fmt.Sprintf("SELECT %s FROM %s", fields, q.table)
121+
122+
// WHERE clause
123+
if len(q.where) > 0 {
124+
query += " WHERE " + strings.Join(q.where, " AND ")
125+
}
126+
127+
// ORDER BY
128+
if q.order != "" {
129+
query += " ORDER BY " + q.order
130+
}
131+
132+
// LIMIT
133+
if q.limit != "" {
134+
query += " LIMIT " + q.limit
135+
}
136+
137+
// OFFSET
138+
if q.offset != "" {
139+
query += " OFFSET " + q.offset
140+
}
141+
142+
return query
143+
}
144+
65145
type QueryExecutor interface {
66146
GetState(ctx context.Context, namespace Namespace, key PKey) (*VaultRead, error)
67147
GetStateMetadata(ctx context.Context, namespace Namespace, key PKey) (Metadata, RawVersion, error)
@@ -138,7 +218,7 @@ type VaultReader interface {
138218
GetTxStatuses(ctx context.Context, txIDs ...TxID) (TxStatusIterator, error)
139219

140220
// GetAllTxStatuses returns the statuses of the all transactions in the vault
141-
GetAllTxStatuses(ctx context.Context, pagination Pagination) (*PageIterator[*TxStatus], error)
221+
GetAllTxStatuses(ctx context.Context, sql SqlQuery, pagination Pagination) (*PageIterator[*TxStatus], error)
142222
}
143223

144224
// LockedVaultReader is a VaultReader with a lock on some or all entries

platform/view/services/db/driver/sql/common/pagination.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -154,32 +154,32 @@ func NewPaginationInterpreter() *paginationInterpreter {
154154
}
155155

156156
type PaginationInterpreter interface {
157-
Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error)
157+
Interpret(p driver.Pagination, sql driver.SqlQuery) (driver.SqlQuery, error)
158158
}
159159

160160
type paginationInterpreter struct{}
161161

162-
func (i *paginationInterpreter) Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error) {
162+
func (i *paginationInterpreter) Interpret(p driver.Pagination, sql driver.SqlQuery) (driver.SqlQuery, error) {
163163
switch pagination := p.(type) {
164164
case *NoPagination:
165165
return sql, nil
166166
case *OffsetPagination:
167-
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
168-
sql.offset = fmt.Sprintf("%d", pagination.offset)
167+
sql.SetLimit(fmt.Sprintf("%d", pagination.pageSize))
168+
sql.SetOffset(fmt.Sprintf("%d", pagination.offset))
169169
return sql, nil
170170
case *KeysetPagination:
171-
sql.order = fmt.Sprintf("%s ASC", pagination.sqlIdName)
172-
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
171+
sql.SetOrder(fmt.Sprintf("%s ASC", pagination.sqlIdName))
172+
sql.SetLimit(fmt.Sprintf("%d", pagination.pageSize))
173173
if (pagination.lastOffset != -1) && (pagination.offset == pagination.lastOffset+pagination.pageSize) {
174-
lastId := sql.AddParam(fmt.Sprintf("%d", pagination.lastId))
175-
sql.where = append(sql.where, fmt.Sprintf("%s>'$%d'", pagination.sqlIdName, lastId))
174+
lastId := sql.AddParam(pagination.lastId)
175+
sql.AddWhere(fmt.Sprintf("%s>$%d", pagination.sqlIdName, lastId))
176176
} else {
177-
sql.offset = fmt.Sprintf("%d", pagination.offset)
177+
sql.SetOffset(fmt.Sprintf("%d", pagination.offset))
178178
}
179179
return sql, nil
180180
case *EmptyPagination:
181-
sql.limit = "0"
182-
sql.offset = "0"
181+
sql.SetLimit("0")
182+
sql.SetOffset("0")
183183
return sql, nil
184184
default:
185185
return sql, errors.Errorf("invalid pagination option %+v", pagination)

platform/view/services/db/driver/sql/common/vault.go

+9-66
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"encoding/gob"
1414
errors2 "errors"
1515
"fmt"
16-
"strings"
1716
"sync"
1817

1918
"github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors"
@@ -63,60 +62,6 @@ type VaultPersistence struct {
6362
il IsolationLevelMapper
6463
}
6564

66-
type SqlQuery struct {
67-
table string
68-
fields []string
69-
where []string
70-
limit string
71-
offset string
72-
order string
73-
params []any
74-
}
75-
76-
func (q *SqlQuery) AddWhere(p string) {
77-
q.where = append(q.where, p)
78-
}
79-
80-
func (q *SqlQuery) AddParam(p any) int {
81-
q.params = append(q.params, p)
82-
return len(q.params) - 1
83-
}
84-
85-
func (q *SqlQuery) FormatQuery() string {
86-
if q.table == "" {
87-
return ""
88-
}
89-
90-
// SELECT fields
91-
fields := "*"
92-
if len(q.fields) > 0 {
93-
fields = strings.Join(q.fields, ", ")
94-
}
95-
query := fmt.Sprintf("SELECT %s FROM %s", fields, q.table)
96-
97-
// WHERE clause
98-
if len(q.where) > 0 {
99-
query += " WHERE " + strings.Join(q.where, " AND ")
100-
}
101-
102-
// ORDER BY
103-
if q.order != "" {
104-
query += " ORDER BY " + q.order
105-
}
106-
107-
// LIMIT
108-
if q.limit != "" {
109-
query += " LIMIT " + q.limit
110-
}
111-
112-
// OFFSET
113-
if q.offset != "" {
114-
query += " OFFSET " + q.offset
115-
}
116-
117-
return query
118-
}
119-
12065
func (db *VaultPersistence) NewTxLockVaultReader(ctx context.Context, txID driver.TxID, isolationLevel driver.IsolationLevel) (driver.LockedVaultReader, error) {
12166
logger.Debugf("Acquire tx id lock for [%s]", txID)
12267
span := trace.SpanFromContext(ctx)
@@ -399,11 +344,11 @@ func (db *txVaultReader) GetTxStatuses(ctx context.Context, txIDs ...driver.TxID
399344
return db.vr.GetTxStatuses(ctx, txIDs...)
400345
}
401346

402-
func (db *txVaultReader) GetAllTxStatuses(ctx context.Context, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
347+
func (db *txVaultReader) GetAllTxStatuses(ctx context.Context, sql driver.SqlQuery, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
403348
if err := db.setVaultReader(); err != nil {
404349
return nil, err
405350
}
406-
return db.vr.GetAllTxStatuses(ctx, pagination)
351+
return db.vr.GetAllTxStatuses(ctx, sql, pagination)
407352
}
408353

409354
func (db *txVaultReader) Done() error {
@@ -525,11 +470,9 @@ func (db *vaultReader) GetTxStatuses(ctx context.Context, txIDs ...driver.TxID)
525470
were, any := Where(db.ci.InStrings("tx_id", txIDs))
526471
return db.queryStatus(were, any, "")
527472
}
528-
func (db *vaultReader) GetAllTxStatuses(ctx context.Context, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
529-
if pagination == nil {
530-
return nil, fmt.Errorf("invalid input pagination: %+v", pagination)
531-
}
532-
sql := SqlQuery{}
473+
474+
func (db *vaultReader) GetAllTxStatuses(ctx context.Context, sql driver.SqlQuery, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
475+
sql.AddFields([]string{"tx_id", "code", "message"})
533476
sql, err := db.pi.Interpret(pagination, sql)
534477
if err != nil {
535478
return nil, err
@@ -555,12 +498,12 @@ func (db *vaultReader) queryStatus(where string, params []any, limit string) (dr
555498
return &TxCodeIterator{rows: rows}, nil
556499
}
557500

558-
func (db *vaultReader) queryStatusWithPagination(sql SqlQuery) (driver.TxStatusIterator, error) {
559-
sql.table = db.tables.StatusTable
501+
func (db *vaultReader) queryStatusWithPagination(sql driver.SqlQuery) (driver.TxStatusIterator, error) {
502+
sql.SetTable(db.tables.StatusTable)
560503
query := sql.FormatQuery()
561-
logger.Debug(query, sql.params)
504+
logger.Debug(query, sql.GetParams())
562505

563-
rows, err := db.readDB.Query(query, sql.params...)
506+
rows, err := db.readDB.Query(query, sql.GetParams()...)
564507
if err != nil {
565508
return nil, err
566509
}

0 commit comments

Comments
 (0)