Skip to content

Commit d8bbaa8

Browse files
authored
Merge pull request #24 from jingxu97/updateSize
Fix updating restore size issue
2 parents 390aa39 + a674806 commit d8bbaa8

File tree

5 files changed

+35
-27
lines changed

5 files changed

+35
-27
lines changed

pkg/connection/connection.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ type CSIConnection interface {
5151
// DeleteSnapshot deletes a snapshot from a volume
5252
DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) (err error)
5353

54-
// GetSnapshotStatus lists snapshot from a volume
55-
GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error)
54+
// GetSnapshotStatus returns a snapshot's status, creation time, and restore size.
55+
GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error)
5656

5757
// Probe checks that the CSI driver is ready to process requests
5858
Probe(ctx context.Context) error
@@ -232,7 +232,7 @@ func (c *csiConnection) DeleteSnapshot(ctx context.Context, snapshotID string, s
232232
return nil
233233
}
234234

235-
func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error) {
235+
func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error) {
236236
client := csi.NewControllerClient(c.conn)
237237

238238
req := csi.ListSnapshotsRequest{
@@ -241,14 +241,14 @@ func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string
241241

242242
rsp, err := client.ListSnapshots(ctx, &req)
243243
if err != nil {
244-
return nil, 0, err
244+
return nil, 0, 0, err
245245
}
246246

247247
if rsp.Entries == nil || len(rsp.Entries) == 0 {
248-
return nil, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID)
248+
return nil, 0, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID)
249249
}
250250

251-
return rsp.Entries[0].Snapshot.Status, rsp.Entries[0].Snapshot.CreatedAt, nil
251+
return rsp.Entries[0].Snapshot.Status, rsp.Entries[0].Snapshot.CreatedAt, rsp.Entries[0].Snapshot.SizeBytes, nil
252252
}
253253

254254
func (c *csiConnection) Close() error {

pkg/connection/connection_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ func TestDeleteSnapshot(t *testing.T) {
658658
func TestGetSnapshotStatus(t *testing.T) {
659659
defaultID := "testid"
660660
createdAt := time.Now().UnixNano()
661+
size := int64(1000)
661662

662663
defaultRequest := &csi.ListSnapshotsRequest{
663664
SnapshotId: defaultID,
@@ -668,7 +669,7 @@ func TestGetSnapshotStatus(t *testing.T) {
668669
{
669670
Snapshot: &csi.Snapshot{
670671
Id: defaultID,
671-
SizeBytes: 1000,
672+
SizeBytes: size,
672673
SourceVolumeId: "volumeid",
673674
CreatedAt: createdAt,
674675
Status: &csi.SnapshotStatus{
@@ -689,6 +690,7 @@ func TestGetSnapshotStatus(t *testing.T) {
689690
expectError bool
690691
expectStatus *csi.SnapshotStatus
691692
expectCreateAt int64
693+
expectSize int64
692694
}{
693695
{
694696
name: "success",
@@ -701,6 +703,7 @@ func TestGetSnapshotStatus(t *testing.T) {
701703
Details: "success",
702704
},
703705
expectCreateAt: createdAt,
706+
expectSize: size,
704707
},
705708
{
706709
name: "gRPC transient error",
@@ -741,7 +744,7 @@ func TestGetSnapshotStatus(t *testing.T) {
741744
controllerServer.EXPECT().ListSnapshots(gomock.Any(), in).Return(out, injectedErr).Times(1)
742745
}
743746

744-
status, createTime, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID)
747+
status, createTime, size, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID)
745748
if test.expectError && err == nil {
746749
t.Errorf("test %q: Expected error, got none", test.name)
747750
}
@@ -754,6 +757,9 @@ func TestGetSnapshotStatus(t *testing.T) {
754757
if test.expectCreateAt != createTime {
755758
t.Errorf("test %q: expected createTime: %v, got: %v", test.name, test.expectCreateAt, createTime)
756759
}
760+
if test.expectSize != size {
761+
t.Errorf("test %q: expected size: %v, got: %v", test.name, test.expectSize, size)
762+
}
757763
}
758764
}
759765

pkg/controller/csi_handler.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import (
3232
type Handler interface {
3333
CreateSnapshot(snapshot *crdv1.VolumeSnapshot, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (string, string, int64, int64, *csi.SnapshotStatus, error)
3434
DeleteSnapshot(content *crdv1.VolumeSnapshotContent, snapshotterCredentials map[string]string) error
35-
GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, error)
35+
GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, int64, error)
3636
}
3737

3838
// csiHandler is a handler that calls CSI to create/delete volume snapshot.
@@ -84,18 +84,19 @@ func (handler *csiHandler) DeleteSnapshot(content *crdv1.VolumeSnapshotContent,
8484
return nil
8585
}
8686

87-
func (handler *csiHandler) GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, error) {
87+
func (handler *csiHandler) GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, int64, error) {
8888
if content.Spec.CSI == nil {
89-
return nil, 0, fmt.Errorf("CSISnapshot not defined in spec")
89+
return nil, 0, 0, fmt.Errorf("CSISnapshot not defined in spec")
9090
}
9191
ctx, cancel := context.WithTimeout(context.Background(), handler.timeout)
9292
defer cancel()
9393

94-
csiSnapshotStatus, timestamp, err := handler.csiConnection.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle)
94+
csiSnapshotStatus, timestamp, size, err := handler.csiConnection.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle)
9595
if err != nil {
96-
return nil, 0, fmt.Errorf("failed to list snapshot data %s: %q", content.Name, err)
96+
return nil, 0, 0, fmt.Errorf("failed to list snapshot data %s: %q", content.Name, err)
9797
}
98-
return csiSnapshotStatus, timestamp, nil
98+
return csiSnapshotStatus, timestamp, size, nil
99+
99100
}
100101

101102
func makeSnapshotName(prefix, snapshotUID string, snapshotNameUUIDLength int) (string, error) {

pkg/controller/framework_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,7 @@ type listCall struct {
10921092
// information to return
10931093
status *csi.SnapshotStatus
10941094
createTime int64
1095+
size int64
10951096
err error
10961097
}
10971098

@@ -1202,10 +1203,10 @@ func (f *fakeCSIConnection) DeleteSnapshot(ctx context.Context, snapshotID strin
12021203
return call.err
12031204
}
12041205

1205-
func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error) {
1206+
func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error) {
12061207
if f.listCallCounter >= len(f.listCalls) {
12071208
f.t.Errorf("Unexpected CSI list Snapshot call: snapshotID=%s, index: %d, calls: %+v", snapshotID, f.createCallCounter, f.createCalls)
1208-
return nil, 0, fmt.Errorf("unexpected call")
1209+
return nil, 0, 0, fmt.Errorf("unexpected call")
12091210
}
12101211
call := f.listCalls[f.listCallCounter]
12111212
f.listCallCounter++
@@ -1217,10 +1218,10 @@ func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID st
12171218
}
12181219

12191220
if err != nil {
1220-
return nil, 0, fmt.Errorf("unexpected call")
1221+
return nil, 0, 0, fmt.Errorf("unexpected call")
12211222
}
12221223

1223-
return call.status, call.createTime, call.err
1224+
return call.status, call.createTime, call.size, call.err
12241225
}
12251226

12261227
func (f *fakeCSIConnection) Close() error {

pkg/controller/snapshot_controller.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -421,12 +421,12 @@ func (ctrl *csiSnapshotController) checkandBindSnapshotContent(snapshot *crdv1.V
421421
}
422422

423423
func (ctrl *csiSnapshotController) checkandUpdateSnapshotStatusOperation(snapshot *crdv1.VolumeSnapshot, content *crdv1.VolumeSnapshotContent) (*crdv1.VolumeSnapshot, error) {
424-
status, _, err := ctrl.handler.GetSnapshotStatus(content)
424+
status, _, size, err := ctrl.handler.GetSnapshotStatus(content)
425425
if err != nil {
426426
return nil, fmt.Errorf("failed to check snapshot status %s with error %v", snapshot.Name, err)
427427
}
428-
429-
newSnapshot, err := ctrl.updateSnapshotStatus(snapshot, status, time.Now(), nil, IsSnapshotBound(snapshot, content))
428+
timestamp := time.Now().UnixNano()
429+
newSnapshot, err := ctrl.updateSnapshotStatus(snapshot, status, timestamp, size, IsSnapshotBound(snapshot, content))
430430
if err != nil {
431431
return nil, err
432432
}
@@ -490,7 +490,7 @@ func (ctrl *csiSnapshotController) createSnapshotOperation(snapshot *crdv1.Volum
490490
// Update snapshot status with timestamp
491491
for i := 0; i < ctrl.createSnapshotContentRetryCount; i++ {
492492
glog.V(5).Infof("createSnapshot [%s]: trying to update snapshot creation timestamp", snapshotKey(snapshot))
493-
newSnapshot, err = ctrl.updateSnapshotStatus(snapshot, csiSnapshotStatus, time.Unix(0, timestamp), resource.NewQuantity(size, resource.BinarySI), false)
493+
newSnapshot, err = ctrl.updateSnapshotStatus(snapshot, csiSnapshotStatus, timestamp, size, false)
494494
if err == nil {
495495
break
496496
}
@@ -638,12 +638,12 @@ func (ctrl *csiSnapshotController) bindandUpdateVolumeSnapshot(snapshotContent *
638638
}
639639

640640
// UpdateSnapshotStatus converts snapshot status to crdv1.VolumeSnapshotCondition
641-
func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSnapshot, csistatus *csi.SnapshotStatus, timestamp time.Time, size *resource.Quantity, bound bool) (*crdv1.VolumeSnapshot, error) {
642-
glog.V(5).Infof("updating VolumeSnapshot[]%s, set status %v, timestamp %v", snapshotKey(snapshot), csistatus, timestamp)
641+
func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSnapshot, csistatus *csi.SnapshotStatus, createdAt, size int64, bound bool) (*crdv1.VolumeSnapshot, error) {
642+
glog.V(5).Infof("updating VolumeSnapshot[]%s, set status %v, timestamp %v", snapshotKey(snapshot), csistatus, createdAt)
643643
status := snapshot.Status
644644
change := false
645645
timeAt := &metav1.Time{
646-
Time: timestamp,
646+
Time: time.Unix(0, createdAt),
647647
}
648648

649649
snapshotClone := snapshot.DeepCopy()
@@ -676,8 +676,8 @@ func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSn
676676
}
677677
}
678678
if change {
679-
if size != nil {
680-
status.RestoreSize = size
679+
if size > 0 {
680+
status.RestoreSize = resource.NewQuantity(size, resource.BinarySI)
681681
}
682682
snapshotClone.Status = status
683683
newSnapshotObj, err := ctrl.clientset.VolumesnapshotV1alpha1().VolumeSnapshots(snapshotClone.Namespace).Update(snapshotClone)

0 commit comments

Comments
 (0)