Skip to content

Commit 0abc543

Browse files
author
Hayim Shaul
committed
bug fixes
Signed-off-by: Hayim Shaul <[email protected]>
1 parent 97b255f commit 0abc543

File tree

4 files changed

+153
-131
lines changed

4 files changed

+153
-131
lines changed

platform/common/driver/vault.go

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package driver
88

99
import (
1010
"context"
11+
"fmt"
12+
"strings"
1113

1214
"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections"
1315
)
@@ -62,6 +64,84 @@ type TxStateIterator = collections.Iterator[*VaultRead]
6264

6365
type VersionedResultsIterator = collections.Iterator[*VaultRead]
6466

67+
type SqlQuery struct {
68+
table string
69+
fields []string
70+
where []string
71+
limit string
72+
offset string
73+
order string
74+
params []any
75+
}
76+
77+
func (q *SqlQuery) SetTable(t string) {
78+
q.table = t
79+
}
80+
81+
func (q *SqlQuery) SetLimit(l string) {
82+
q.limit = l
83+
}
84+
85+
func (q *SqlQuery) SetOffset(o string) {
86+
q.offset = o
87+
}
88+
89+
func (q *SqlQuery) SetOrder(o string) {
90+
q.order = o
91+
}
92+
93+
func (q *SqlQuery) AddWhere(p string) {
94+
q.where = append(q.where, p)
95+
}
96+
97+
func (q *SqlQuery) AddFields(f []string) {
98+
q.fields = append(q.fields, f...)
99+
}
100+
101+
func (q *SqlQuery) AddParam(p any) int {
102+
q.params = append(q.params, p)
103+
return len(q.params)
104+
}
105+
106+
func (q *SqlQuery) GetParams() []any {
107+
return q.params
108+
}
109+
110+
func (q *SqlQuery) FormatQuery() string {
111+
if q.table == "" {
112+
return ""
113+
}
114+
115+
// SELECT fields
116+
fields := "*"
117+
if len(q.fields) > 0 {
118+
fields = strings.Join(q.fields, ", ")
119+
}
120+
query := fmt.Sprintf("SELECT %s FROM %s", fields, q.table)
121+
122+
// WHERE clause
123+
if len(q.where) > 0 {
124+
query += " WHERE " + strings.Join(q.where, " AND ")
125+
}
126+
127+
// ORDER BY
128+
if q.order != "" {
129+
query += " ORDER BY " + q.order
130+
}
131+
132+
// LIMIT
133+
if q.limit != "" {
134+
query += " LIMIT " + q.limit
135+
}
136+
137+
// OFFSET
138+
if q.offset != "" {
139+
query += " OFFSET " + q.offset
140+
}
141+
142+
return query
143+
}
144+
65145
type QueryExecutor interface {
66146
GetState(ctx context.Context, namespace Namespace, key PKey) (*VaultRead, error)
67147
GetStateMetadata(ctx context.Context, namespace Namespace, key PKey) (Metadata, RawVersion, error)
@@ -138,7 +218,7 @@ type VaultReader interface {
138218
GetTxStatuses(ctx context.Context, txIDs ...TxID) (TxStatusIterator, error)
139219

140220
// GetAllTxStatuses returns the statuses of the all transactions in the vault
141-
GetAllTxStatuses(ctx context.Context, pagination Pagination) (*PageIterator[*TxStatus], error)
221+
GetAllTxStatuses(ctx context.Context, sql SqlQuery, pagination Pagination) (*PageIterator[*TxStatus], error)
142222
}
143223

144224
// LockedVaultReader is a VaultReader with a lock on some or all entries

platform/view/services/db/driver/sql/common/pagination.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,32 +154,32 @@ func NewPaginationInterpreter() *paginationInterpreter {
154154
}
155155

156156
type PaginationInterpreter interface {
157-
Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error)
157+
Interpret(p driver.Pagination, sql driver.SqlQuery) (driver.SqlQuery, error)
158158
}
159159

160160
type paginationInterpreter struct{}
161161

162-
func (i *paginationInterpreter) Interpret(p driver.Pagination, sql SqlQuery) (SqlQuery, error) {
162+
func (i *paginationInterpreter) Interpret(p driver.Pagination, sql driver.SqlQuery) (driver.SqlQuery, error) {
163163
switch pagination := p.(type) {
164164
case *NoPagination:
165165
return sql, nil
166166
case *OffsetPagination:
167-
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
168-
sql.offset = fmt.Sprintf("%d", pagination.offset)
167+
sql.SetLimit(fmt.Sprintf("%d", pagination.pageSize))
168+
sql.SetOffset(fmt.Sprintf("%d", pagination.offset))
169169
return sql, nil
170170
case *KeysetPagination:
171-
sql.order = fmt.Sprintf("%s ASC", pagination.sqlIdName)
172-
sql.limit = fmt.Sprintf("%d", pagination.pageSize)
171+
sql.SetOrder(fmt.Sprintf("%s ASC", pagination.sqlIdName))
172+
sql.SetLimit(fmt.Sprintf("%d", pagination.pageSize))
173173
if (pagination.lastOffset != -1) && (pagination.offset == pagination.lastOffset+pagination.pageSize) {
174-
lastId := sql.AddParam(fmt.Sprintf("%d", pagination.lastId))
175-
sql.where = append(sql.where, fmt.Sprintf("%s>'$%d'", pagination.sqlIdName, lastId))
174+
lastId := sql.AddParam(pagination.lastId)
175+
sql.AddWhere(fmt.Sprintf("%s>$%d", pagination.sqlIdName, lastId))
176176
} else {
177-
sql.offset = fmt.Sprintf("%d", pagination.offset)
177+
sql.SetOffset(fmt.Sprintf("%d", pagination.offset))
178178
}
179179
return sql, nil
180180
case *EmptyPagination:
181-
sql.limit = "0"
182-
sql.offset = "0"
181+
sql.SetLimit("0")
182+
sql.SetOffset("0")
183183
return sql, nil
184184
default:
185185
return sql, errors.Errorf("invalid pagination option %+v", pagination)

platform/view/services/db/driver/sql/common/vault.go

Lines changed: 9 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"encoding/gob"
1414
errors2 "errors"
1515
"fmt"
16-
"strings"
1716
"sync"
1817

1918
"github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors"
@@ -64,60 +63,6 @@ type VaultPersistence struct {
6463
il IsolationLevelMapper
6564
}
6665

67-
type SqlQuery struct {
68-
table string
69-
fields []string
70-
where []string
71-
limit string
72-
offset string
73-
order string
74-
params []any
75-
}
76-
77-
func (q *SqlQuery) AddWhere(p string) {
78-
q.where = append(q.where, p)
79-
}
80-
81-
func (q *SqlQuery) AddParam(p any) int {
82-
q.params = append(q.params, p)
83-
return len(q.params) - 1
84-
}
85-
86-
func (q *SqlQuery) FormatQuery() string {
87-
if q.table == "" {
88-
return ""
89-
}
90-
91-
// SELECT fields
92-
fields := "*"
93-
if len(q.fields) > 0 {
94-
fields = strings.Join(q.fields, ", ")
95-
}
96-
query := fmt.Sprintf("SELECT %s FROM %s", fields, q.table)
97-
98-
// WHERE clause
99-
if len(q.where) > 0 {
100-
query += " WHERE " + strings.Join(q.where, " AND ")
101-
}
102-
103-
// ORDER BY
104-
if q.order != "" {
105-
query += " ORDER BY " + q.order
106-
}
107-
108-
// LIMIT
109-
if q.limit != "" {
110-
query += " LIMIT " + q.limit
111-
}
112-
113-
// OFFSET
114-
if q.offset != "" {
115-
query += " OFFSET " + q.offset
116-
}
117-
118-
return query
119-
}
120-
12166
func (db *VaultPersistence) NewTxLockVaultReader(ctx context.Context, txID driver.TxID, isolationLevel driver.IsolationLevel) (driver.LockedVaultReader, error) {
12267
logger.Debugf("Acquire tx id lock for [%s]", txID)
12368
span := trace.SpanFromContext(ctx)
@@ -400,11 +345,11 @@ func (db *txVaultReader) GetTxStatuses(ctx context.Context, txIDs ...driver.TxID
400345
return db.vr.GetTxStatuses(ctx, txIDs...)
401346
}
402347

403-
func (db *txVaultReader) GetAllTxStatuses(ctx context.Context, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
348+
func (db *txVaultReader) GetAllTxStatuses(ctx context.Context, sql driver.SqlQuery, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
404349
if err := db.setVaultReader(); err != nil {
405350
return nil, err
406351
}
407-
return db.vr.GetAllTxStatuses(ctx, pagination)
352+
return db.vr.GetAllTxStatuses(ctx, sql, pagination)
408353
}
409354

410355
func (db *txVaultReader) Done() error {
@@ -526,11 +471,9 @@ func (db *vaultReader) GetTxStatuses(ctx context.Context, txIDs ...driver.TxID)
526471
were, any := Where(db.ci.InStrings("tx_id", txIDs))
527472
return db.queryStatus(were, any, "")
528473
}
529-
func (db *vaultReader) GetAllTxStatuses(ctx context.Context, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
530-
if pagination == nil {
531-
return nil, fmt.Errorf("invalid input pagination: %+v", pagination)
532-
}
533-
sql := SqlQuery{}
474+
475+
func (db *vaultReader) GetAllTxStatuses(ctx context.Context, sql driver.SqlQuery, pagination driver.Pagination) (*driver.PageIterator[*driver.TxStatus], error) {
476+
sql.AddFields([]string{"tx_id", "code", "message"})
534477
sql, err := db.pi.Interpret(pagination, sql)
535478
if err != nil {
536479
return nil, err
@@ -556,12 +499,12 @@ func (db *vaultReader) queryStatus(where string, params []any, limit string) (dr
556499
return &TxCodeIterator{rows: rows}, nil
557500
}
558501

559-
func (db *vaultReader) queryStatusWithPagination(sql SqlQuery) (driver.TxStatusIterator, error) {
560-
sql.table = db.tables.StatusTable
502+
func (db *vaultReader) queryStatusWithPagination(sql driver.SqlQuery) (driver.TxStatusIterator, error) {
503+
sql.SetTable(db.tables.StatusTable)
561504
query := sql.FormatQuery()
562-
logger.Debug(query, sql.params)
505+
logger.Debug(query, sql.GetParams())
563506

564-
rows, err := db.readDB.Query(query, sql.params...)
507+
rows, err := db.readDB.Query(query, sql.GetParams()...)
565508
if err != nil {
566509
return nil, err
567510
}

0 commit comments

Comments
 (0)