Skip to content

Commit bba700f

Browse files
committed
[stripe] Use attributionId in stripe queries
1 parent 7df06a8 commit bba700f

File tree

4 files changed

+51
-54
lines changed

4 files changed

+51
-54
lines changed

components/usage/pkg/apiv1/billing.go

+12-24
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (s *BillingService) UpdateInvoices(ctx context.Context, in *v1.UpdateInvoic
5252
return nil, status.Errorf(codes.Internal, "Failed to download usage report with ID: %s", in.GetReportId())
5353
}
5454

55-
credits, err := s.creditSummaryForTeams(report.UsageRecords, in.GetReportId())
55+
credits, err := s.creditSummary(report.UsageRecords, in.GetReportId())
5656
if err != nil {
5757
log.Log.WithError(err).Errorf("Failed to compute credit summary.")
5858
return nil, status.Errorf(codes.InvalidArgument, "failed to compute credit summary")
@@ -74,16 +74,9 @@ func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.Reconcile
7474
return nil, status.Errorf(codes.Internal, "Failed to reconcile invoices.")
7575
}
7676

77-
creditSummaryForTeams := map[string]stripe.CreditSummary{}
77+
creditSummaryForTeams := map[db.AttributionID]stripe.CreditSummary{}
7878
for _, balance := range balances {
79-
entity, id := balance.AttributionID.Values()
80-
81-
// TODO: Support updating of user attribution IDs
82-
if entity != db.AttributionEntity_Team {
83-
continue
84-
}
85-
86-
creditSummaryForTeams[id] = stripe.CreditSummary{
79+
creditSummaryForTeams[balance.AttributionID] = stripe.CreditSummary{
8780
Credits: int64(math.Ceil(balance.CreditCents.ToCredits())),
8881
ReportID: "no-report",
8982
}
@@ -117,7 +110,7 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
117110
return nil, status.Errorf(codes.Internal, "Failed to retrieve subscription details from invoice.")
118111
}
119112

120-
teamID, found := subscription.Metadata[stripe.TeamIDMetadataKey]
113+
teamID, found := subscription.Metadata[stripe.AttributionIDMetadataKey]
121114
if !found {
122115
logger.Error("Failed to find teamID from subscription metadata.")
123116
return nil, status.Errorf(codes.Internal, "Failed to extra teamID from Stripe subscription.")
@@ -196,29 +189,24 @@ func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcom
196189
}, nil
197190
}
198191

199-
func (s *BillingService) creditSummaryForTeams(sessions []db.WorkspaceInstanceUsage, reportID string) (map[string]stripe.CreditSummary, error) {
200-
creditsPerTeamID := map[string]float64{}
192+
func (s *BillingService) creditSummary(sessions []db.WorkspaceInstanceUsage, reportID string) (map[db.AttributionID]stripe.CreditSummary, error) {
193+
creditsPerAttributionID := map[db.AttributionID]float64{}
201194

202195
for _, session := range sessions {
203196
if session.StartedAt.Before(s.billInstancesAfter) {
204197
continue
205198
}
206199

207-
entity, id := session.AttributionID.Values()
208-
if entity != db.AttributionEntity_Team {
209-
continue
210-
}
211-
212-
if _, ok := creditsPerTeamID[id]; !ok {
213-
creditsPerTeamID[id] = 0
200+
if _, ok := creditsPerAttributionID[session.AttributionID]; !ok {
201+
creditsPerAttributionID[session.AttributionID] = 0
214202
}
215203

216-
creditsPerTeamID[id] += session.CreditsUsed
204+
creditsPerAttributionID[session.AttributionID] += session.CreditsUsed
217205
}
218206

219-
rounded := map[string]stripe.CreditSummary{}
220-
for teamID, credits := range creditsPerTeamID {
221-
rounded[teamID] = stripe.CreditSummary{
207+
rounded := map[db.AttributionID]stripe.CreditSummary{}
208+
for attributionID, credits := range creditsPerAttributionID {
209+
rounded[attributionID] = stripe.CreditSummary{
222210
Credits: int64(math.Ceil(credits)),
223211
ReportID: reportID,
224212
}

components/usage/pkg/apiv1/billing_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func TestCreditSummaryForTeams(t *testing.T) {
121121
for _, s := range scenarios {
122122
t.Run(s.Name, func(t *testing.T) {
123123
svc := NewBillingService(&stripe.Client{}, s.BillSessionsAfter, &gorm.DB{}, nil)
124-
actual, err := svc.creditSummaryForTeams(s.Sessions, reportID)
124+
actual, err := svc.creditSummary(s.Sessions, reportID)
125125
require.NoError(t, err)
126126
require.Equal(t, s.Expected, actual)
127127
})

components/usage/pkg/stripe/stripe.go

+24-16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"context"
99
"encoding/json"
1010
"fmt"
11+
"github.com/gitpod-io/gitpod/usage/pkg/db"
1112
"os"
1213
"strings"
1314

@@ -17,8 +18,8 @@ import (
1718
)
1819

1920
const (
20-
ReportIDMetadataKey = "reportId"
21-
TeamIDMetadataKey = "teamId"
21+
ReportIDMetadataKey = "reportId"
22+
AttributionIDMetadataKey = "attributionId"
2223
)
2324

2425
type Client struct {
@@ -70,12 +71,12 @@ type CreditSummary struct {
7071

7172
// UpdateUsage updates teams' Stripe subscriptions with usage data
7273
// `usageForTeam` is a map from team name to total workspace seconds used within a billing period.
73-
func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]CreditSummary) error {
74-
teamIds := make([]string, 0, len(creditsPerTeam))
75-
for k := range creditsPerTeam {
76-
teamIds = append(teamIds, k)
74+
func (c *Client) UpdateUsage(ctx context.Context, creditsPerAttributionID map[db.AttributionID]CreditSummary) error {
75+
attributionIDs := make([]db.AttributionID, 0, len(creditsPerAttributionID))
76+
for k := range creditsPerAttributionID {
77+
attributionIDs = append(attributionIDs, k)
7778
}
78-
queries := queriesForCustomersWithTeamIds(teamIds)
79+
queries := queriesForCustomersWithAttributionIDs(attributionIDs)
7980

8081
for _, query := range queries {
8182
log.Infof("Searching customers in Stripe with query: %q", query)
@@ -86,14 +87,21 @@ func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]Cred
8687
}
8788

8889
for _, customer := range customers {
89-
teamID := customer.Metadata["teamId"]
90-
log.Infof("Found customer %q for teamId %q", customer.Name, teamID)
90+
attributionIDRaw := customer.Metadata[AttributionIDMetadataKey]
91+
log.Infof("Found customer %q for attribution ID %q", customer.Name, attributionIDRaw)
9192

92-
_, err := c.updateUsageForCustomer(ctx, customer, creditsPerTeam[teamID])
93+
attributionID, err := db.ParseAttributionID(attributionIDRaw)
94+
if err != nil {
95+
log.WithError(err).Error("Failed to parse attribution ID from Stripe metadata.")
96+
continue
97+
}
98+
99+
_, err = c.updateUsageForCustomer(ctx, customer, creditsPerAttributionID[attributionID])
93100
if err != nil {
94101
log.WithField("customer_id", customer.ID).
95102
WithField("customer_name", customer.Name).
96103
WithField("subscriptions", customer.Subscriptions).
104+
WithField("attribution_id", attributionID).
97105
WithError(err).
98106
Errorf("Failed to update usage.")
99107

@@ -246,19 +254,19 @@ func (c *Client) GetInvoice(ctx context.Context, invoiceID string) (*stripe.Invo
246254
return invoice, nil
247255
}
248256

249-
// queriesForCustomersWithTeamIds constructs Stripe query strings to find the Stripe Customer for each teamId
257+
// queriesForCustomersWithAttributionIDs constructs Stripe query strings to find the Stripe Customer for each teamId
250258
// It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query.
251259
// `clausesPerQuery` is a limit enforced by the Stripe API.
252-
func queriesForCustomersWithTeamIds(teamIds []string) []string {
260+
func queriesForCustomersWithAttributionIDs(attributionIDs []db.AttributionID) []string {
253261
const clausesPerQuery = 10
254262
var queries []string
255263
sb := strings.Builder{}
256264

257-
for i := 0; i < len(teamIds); i += clausesPerQuery {
265+
for i := 0; i < len(attributionIDs); i += clausesPerQuery {
258266
sb.Reset()
259-
for j := 0; j < clausesPerQuery && i+j < len(teamIds); j++ {
260-
sb.WriteString(fmt.Sprintf("metadata['%s']:'%s'", TeamIDMetadataKey, teamIds[i+j]))
261-
if j < clausesPerQuery-1 && i+j < len(teamIds)-1 {
267+
for j := 0; j < clausesPerQuery && i+j < len(attributionIDs); j++ {
268+
sb.WriteString(fmt.Sprintf("metadata['%s']:'%s'", AttributionIDMetadataKey, attributionIDs[i+j]))
269+
if j < clausesPerQuery-1 && i+j < len(attributionIDs)-1 {
262270
sb.WriteString(" OR ")
263271
}
264272
}

components/usage/pkg/stripe/stripe_test.go

+14-13
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,33 @@ package stripe
66

77
import (
88
"fmt"
9+
"github.com/gitpod-io/gitpod/usage/pkg/db"
910
"testing"
1011

1112
"github.com/stretchr/testify/require"
1213
)
1314

14-
func TestCustomerQueriesForTeamIds_SingleQuery(t *testing.T) {
15+
func TestQueriesForCustomersWithAttributionID_Single(t *testing.T) {
1516
testCases := []struct {
1617
Name string
17-
TeamIds []string
18+
AttributionIDs []db.AttributionID
1819
ExpectedQueries []string
1920
}{
2021
{
2122
Name: "1 team id",
22-
TeamIds: []string{"abcd-123"},
23-
ExpectedQueries: []string{"metadata['teamId']:'abcd-123'"},
23+
AttributionIDs: []db.AttributionID{db.NewTeamAttributionID("abcd-123")},
24+
ExpectedQueries: []string{"metadata['attributionId']:'team:abcd-123'"},
2425
},
2526
{
26-
Name: "2 team ids",
27-
TeamIds: []string{"abcd-123", "abcd-456"},
28-
ExpectedQueries: []string{"metadata['teamId']:'abcd-123' OR metadata['teamId']:'abcd-456'"},
27+
Name: "1 team id, 1 user id",
28+
AttributionIDs: []db.AttributionID{db.NewTeamAttributionID("abcd-123"), db.NewUserAttributionID("abcd-456")},
29+
ExpectedQueries: []string{"metadata['attributionId']:'team:abcd-123' OR metadata['attributionId']:'user:abcd-456'"},
2930
},
3031
}
3132

3233
for _, tc := range testCases {
3334
t.Run(tc.Name, func(t *testing.T) {
34-
actualQueries := queriesForCustomersWithTeamIds(tc.TeamIds)
35+
actualQueries := queriesForCustomersWithAttributionIDs(tc.AttributionIDs)
3536

3637
require.Equal(t, tc.ExpectedQueries, actualQueries)
3738
})
@@ -66,18 +67,18 @@ func TestCustomerQueriesForTeamIds_MultipleQueries(t *testing.T) {
6667
},
6768
}
6869

69-
buildTeamIds := func(numberOfTeamIds int) []string {
70-
var teamIds []string
70+
buildTeamIds := func(numberOfTeamIds int) []db.AttributionID {
71+
var attributionIDs []db.AttributionID
7172
for i := 0; i < numberOfTeamIds; i++ {
72-
teamIds = append(teamIds, fmt.Sprintf("abcd-%d", i))
73+
attributionIDs = append(attributionIDs, db.NewTeamAttributionID(fmt.Sprintf("abcd-%d", i)))
7374
}
74-
return teamIds
75+
return attributionIDs
7576
}
7677

7778
for _, tc := range testCases {
7879
t.Run(tc.Name, func(t *testing.T) {
7980
teamIds := buildTeamIds(tc.NumberOfTeamIds)
80-
actualQueries := queriesForCustomersWithTeamIds(teamIds)
81+
actualQueries := queriesForCustomersWithAttributionIDs(teamIds)
8182

8283
require.Equal(t, tc.ExpectedNumberOfQueries, len(actualQueries))
8384
})

0 commit comments

Comments
 (0)