From 3831baccb07a785fd7e78283e44cdc43ed00b5d8 Mon Sep 17 00:00:00 2001 From: Sven Efftinge Date: Thu, 8 Sep 2022 10:14:32 +0000 Subject: [PATCH] [usage] make tests robust ... against parallel DB changes --- components/usage/pkg/db/usage_test.go | 30 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/components/usage/pkg/db/usage_test.go b/components/usage/pkg/db/usage_test.go index 0bfd430a84afc7..e4977ee7239050 100644 --- a/components/usage/pkg/db/usage_test.go +++ b/components/usage/pkg/db/usage_test.go @@ -144,8 +144,19 @@ func TestInsertUsageRecords(t *testing.T) { drafts, err := db.FindAllDraftUsage(context.Background(), conn) require.NoError(t, err) - require.Equal(t, 1, len(drafts)) - require.NotEqual(t, updatedDesc, drafts[0].Description) + cleaned := filter(drafts, attributionID) + require.Equal(t, 1, len(cleaned)) + require.NotEqual(t, updatedDesc, cleaned[0].Description) +} + +func filter(drafts []db.Usage, attributionID db.AttributionID) []db.Usage { + var cleaned []db.Usage + for _, draft := range drafts { + if draft.AttributionID == attributionID { + cleaned = append(cleaned, draft) + } + } + return cleaned } func TestUpdateUsageRecords(t *testing.T) { @@ -168,8 +179,9 @@ func TestUpdateUsageRecords(t *testing.T) { drafts, err := db.FindAllDraftUsage(context.Background(), conn) require.NoError(t, err) - require.Equal(t, 1, len(drafts)) - require.Equal(t, updatedDesc, drafts[0].Description) + cleaned := filter(drafts, attributionID) + require.Equal(t, 1, len(cleaned)) + require.Equal(t, updatedDesc, cleaned[0].Description) } func TestFindAllDraftUsage(t *testing.T) { @@ -197,8 +209,9 @@ func TestFindAllDraftUsage(t *testing.T) { dbtest.CreateUsageRecords(t, conn, usage1, usage2, usage3) drafts, err := db.FindAllDraftUsage(context.Background(), conn) require.NoError(t, err) - require.Equal(t, 2, len(drafts)) - for _, usage := range drafts { + cleaned := filter(drafts, attributionID) + require.Equal(t, 2, len(cleaned)) + for _, usage := range cleaned { require.True(t, usage.Draft) } @@ -208,8 +221,9 @@ func TestFindAllDraftUsage(t *testing.T) { drafts, err = db.FindAllDraftUsage(context.Background(), conn) require.NoError(t, err) - require.Equal(t, 1, len(drafts)) - for _, usage := range drafts { + cleaned = filter(drafts, attributionID) + require.Equal(t, 1, len(cleaned)) + for _, usage := range cleaned { require.True(t, usage.Draft) } }