Skip to content

Commit 2f74ad0

Browse files
author
Hayim Shaul
committed
improve KeysetPagination readability by using firstId and lastId
Signed-off-by: Hayim Shaul <[email protected]>
1 parent 0abc543 commit 2f74ad0

File tree

3 files changed

+41
-31
lines changed

3 files changed

+41
-31
lines changed

platform/common/driver/vault.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ func (q *SqlQuery) SetTable(t string) {
7878
q.table = t
7979
}
8080

81-
func (q *SqlQuery) SetLimit(l string) {
82-
q.limit = l
81+
func (q *SqlQuery) SetLimit(l int) {
82+
q.limit = fmt.Sprintf("%d", l)
8383
}
8484

85-
func (q *SqlQuery) SetOffset(o string) {
86-
q.offset = o
85+
func (q *SqlQuery) SetOffset(o int) {
86+
q.offset = fmt.Sprintf("%d", o)
8787
}
8888

8989
func (q *SqlQuery) SetOrder(o string) {

platform/view/services/db/driver/sql/common/pagination.go

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ type KeysetPagination struct {
9494
sqlIdName string
9595
// name of the field in the struct that is returned from the database
9696
idFieldName string
97-
// the last id value read and the offset in which it was read
98-
lastId string // TODO: should this be int?
99-
lastOffset int
97+
// the first and last id values in the page
98+
firstId string
99+
lastId string
100100
}
101101

102102
func NewKeysetPagination(offset int, pageSize int, sqlIdName string, idFieldName string) (*KeysetPagination, error) {
@@ -111,23 +111,34 @@ func NewKeysetPagination(offset int, pageSize int, sqlIdName string, idFieldName
111111
pageSize: pageSize,
112112
sqlIdName: sqlIdName,
113113
idFieldName: idFieldName,
114+
firstId: "",
114115
lastId: "",
115-
lastOffset: -1,
116116
}, nil
117117
}
118118

119119
func (p *KeysetPagination) GoToOffset(offset int) (driver.Pagination, error) {
120120
if offset < 0 {
121121
return NewEmptyPagination(), nil
122122
}
123-
return &KeysetPagination{
124-
offset: offset,
125-
pageSize: p.pageSize,
126-
sqlIdName: p.sqlIdName,
127-
idFieldName: p.idFieldName,
128-
lastId: p.lastId,
129-
lastOffset: p.lastOffset,
130-
}, nil
123+
if offset == p.offset+p.pageSize {
124+
return &KeysetPagination{
125+
offset: offset,
126+
pageSize: p.pageSize,
127+
sqlIdName: p.sqlIdName,
128+
idFieldName: p.idFieldName,
129+
firstId: p.lastId,
130+
lastId: "",
131+
}, nil
132+
} else {
133+
return &KeysetPagination{
134+
offset: offset,
135+
pageSize: p.pageSize,
136+
sqlIdName: p.sqlIdName,
137+
idFieldName: p.idFieldName,
138+
firstId: "",
139+
lastId: "",
140+
}, nil
141+
}
131142
}
132143

133144
func (p *KeysetPagination) GoToPage(pageNum int) (driver.Pagination, error) {
@@ -146,7 +157,6 @@ func (p *KeysetPagination) Prev() (driver.Pagination, error) { return p.GoBack(1
146157
func (p *KeysetPagination) Next() (driver.Pagination, error) { return p.GoForward(1) }
147158
func (p *KeysetPagination) UpdateId(id string) {
148159
p.lastId = id
149-
p.lastOffset = p.offset
150160
}
151161

152162
func NewPaginationInterpreter() *paginationInterpreter {
@@ -164,22 +174,22 @@ func (i *paginationInterpreter) Interpret(p driver.Pagination, sql driver.SqlQue
164174
case *NoPagination:
165175
return sql, nil
166176
case *OffsetPagination:
167-
sql.SetLimit(fmt.Sprintf("%d", pagination.pageSize))
168-
sql.SetOffset(fmt.Sprintf("%d", pagination.offset))
177+
sql.SetLimit(pagination.pageSize)
178+
sql.SetOffset(pagination.offset)
169179
return sql, nil
170180
case *KeysetPagination:
171181
sql.SetOrder(fmt.Sprintf("%s ASC", pagination.sqlIdName))
172-
sql.SetLimit(fmt.Sprintf("%d", pagination.pageSize))
173-
if (pagination.lastOffset != -1) && (pagination.offset == pagination.lastOffset+pagination.pageSize) {
174-
lastId := sql.AddParam(pagination.lastId)
182+
sql.SetLimit(pagination.pageSize)
183+
if pagination.firstId != "" {
184+
lastId := sql.AddParam(pagination.firstId)
175185
sql.AddWhere(fmt.Sprintf("%s>$%d", pagination.sqlIdName, lastId))
176186
} else {
177-
sql.SetOffset(fmt.Sprintf("%d", pagination.offset))
187+
sql.SetOffset(pagination.offset)
178188
}
179189
return sql, nil
180190
case *EmptyPagination:
181-
sql.SetLimit("0")
182-
sql.SetOffset("0")
191+
sql.SetLimit(0)
192+
sql.SetOffset(0)
183193
return sql, nil
184194
default:
185195
return sql, errors.Errorf("invalid pagination option %+v", pagination)
@@ -196,21 +206,21 @@ func NewPaginationUpdater[R comparable]() *paginationUpdater[R] {
196206
return &paginationUpdater[R]{}
197207
}
198208

199-
func (i *paginationUpdater[R]) Update(recs *driver.PageIterator[*R]) (*driver.PageIterator[*R], error) {
209+
func (i *paginationUpdater[R]) Update(recs driver.PageIterator[*R]) (driver.PageIterator[*R], error) {
200210
switch page := recs.Pagination.(type) {
201211
case *KeysetPagination:
202212
items := recs.Items
203213
record, newIt, err := collections.ReadLast(items)
204214
if err != nil {
205-
return nil, err
215+
return recs, err
206216
}
207217
if record != nil {
208218
refRec := reflect.ValueOf(*record)
209219
id := refRec.FieldByName(page.idFieldName)
210220
page.UpdateId(id.String())
211221
}
212-
return (&driver.PageIterator[*R]{Items: newIt, Pagination: page}), nil
222+
return (driver.PageIterator[*R]{Items: newIt, Pagination: page}), nil
213223
default:
214-
return (&driver.PageIterator[*R]{Items: recs.Items, Pagination: recs.Pagination}), nil
224+
return (driver.PageIterator[*R]{Items: recs.Items, Pagination: recs.Pagination}), nil
215225
}
216226
}

platform/view/services/db/driver/sql/common/vault.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,9 @@ func (db *vaultReader) GetAllTxStatuses(ctx context.Context, sql driver.SqlQuery
483483
return nil, err
484484
}
485485
pu := NewPaginationUpdater[driver.TxStatus]()
486-
pageIt := &driver.PageIterator[*driver.TxStatus]{Items: txStatusIterator, Pagination: pagination}
486+
pageIt := driver.PageIterator[*driver.TxStatus]{Items: txStatusIterator, Pagination: pagination}
487487
pageIt, err = pu.Update(pageIt)
488-
return pageIt, err
488+
return &pageIt, err
489489
}
490490

491491
func (db *vaultReader) queryStatus(where string, params []any, limit string) (driver.TxStatusIterator, error) {

0 commit comments

Comments
 (0)