Skip to content

Commit d4853d2

Browse files
committed
[usage] update cost center on stripe cancellation
1 parent 51b0fd5 commit d4853d2

File tree

4 files changed

+47
-20
lines changed

4 files changed

+47
-20
lines changed

components/usage/pkg/apiv1/billing.go

+19-5
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,18 @@ import (
2222
"gorm.io/gorm"
2323
)
2424

25-
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB) *BillingService {
25+
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager) *BillingService {
2626
return &BillingService{
2727
stripeClient: stripeClient,
2828
conn: conn,
29+
ccManager: ccManager,
2930
}
3031
}
3132

3233
type BillingService struct {
3334
conn *gorm.DB
3435
stripeClient *stripe.Client
36+
ccManager *db.CostCenterManager
3537

3638
v1.UnimplementedBillingServiceServer
3739
}
@@ -76,19 +78,18 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
7678
return nil, status.Errorf(codes.Internal, "Failed to retrieve subscription details from invoice.")
7779
}
7880

79-
teamID, found := subscription.Metadata[stripe.AttributionIDMetadataKey]
81+
attrID, found := subscription.Metadata[stripe.AttributionIDMetadataKey]
8082
if !found {
8183
logger.Error("Failed to find teamID from subscription metadata.")
8284
return nil, status.Errorf(codes.Internal, "Failed to extra teamID from Stripe subscription.")
8385
}
84-
logger = logger.WithField("team_id", teamID)
86+
logger = logger.WithField("attribution_id", attrID)
8587

8688
// To support individual `user`s, we'll need to also extract the `userId` from metadata here and handle separately.
87-
attributionID := db.NewTeamAttributionID(teamID)
89+
attributionID := db.NewTeamAttributionID(attrID)
8890
finalizedAt := time.Unix(invoice.StatusTransitions.FinalizedAt, 0)
8991

9092
logger = logger.
91-
WithField("attribution_id", attributionID).
9293
WithField("invoice_finalized_at", finalizedAt)
9394

9495
if invoice.Lines == nil || len(invoice.Lines.Data) == 0 {
@@ -123,6 +124,19 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
123124

124125
logger.WithField("usage_id", usage.ID).Infof("Inserted usage record into database for %d credits against %s attribution", creditsOnInvoice, attributionID)
125126

127+
if subscription.EndedAt < time.Now().Unix() {
128+
logger.Infof("Subscription ended. Setting back to default.")
129+
costCenter, err := s.ccManager.GetOrCreateCostCenter(ctx, attributionID)
130+
if err != nil {
131+
return nil, err
132+
}
133+
costCenter.BillingStrategy = db.CostCenter_Other
134+
_, err = s.ccManager.UpdateCostCenter(ctx, costCenter)
135+
if err != nil {
136+
return nil, err
137+
}
138+
}
139+
126140
return &v1.FinalizeInvoiceResponse{}, nil
127141
}
128142

components/usage/pkg/db/cost_center.go

+14-9
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,15 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos
112112

113113
now := time.Now()
114114

115+
// we update the creationTime
116+
costCenter.CreationTime = NewVarcharTime(now)
115117
// we don't allow setting the creationTime or the nextBillingTime from outside
116118
costCenter.CreationTime = existingCostCenter.CreationTime
117119
costCenter.NextBillingTime = existingCostCenter.NextBillingTime
118120

119121
// Do we have a billing strategy update?
120122
if costCenter.BillingStrategy != existingCostCenter.BillingStrategy {
121-
if existingCostCenter.BillingStrategy == CostCenter_Other {
123+
if costCenter.BillingStrategy == CostCenter_Stripe {
122124
// moving to stripe -> let's run a finalization
123125
finalizationUsage, err := c.ComputeInvoiceUsageRecord(ctx, costCenter.ID)
124126
if err != nil {
@@ -130,12 +132,20 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos
130132
return CostCenter{}, err
131133
}
132134
}
135+
// we don't manage stripe billing cycle
136+
costCenter.NextBillingTime = VarcharTime{}
137+
} else if costCenter.BillingStrategy == CostCenter_Other {
138+
// cancelled from stripe reset the spending limit
139+
if costCenter.ID.IsEntity(AttributionEntity_Team) {
140+
costCenter.SpendingLimit = c.cfg.ForTeams
141+
} else {
142+
costCenter.SpendingLimit = c.cfg.ForUsers
143+
}
144+
// see you next month
145+
costCenter.NextBillingTime = NewVarcharTime(now.AddDate(0, 1, 0))
133146
}
134-
c.updateNextBillingTime(&costCenter, now)
135147
}
136148

137-
// we update the creationTime
138-
costCenter.CreationTime = NewVarcharTime(now)
139149
db := c.conn.Save(&costCenter)
140150
if db.Error != nil {
141151
return CostCenter{}, fmt.Errorf("failed to save cost center for attributionID %s: %w", costCenter.ID, db.Error)
@@ -163,8 +173,3 @@ func (c *CostCenterManager) ComputeInvoiceUsageRecord(ctx context.Context, attri
163173
Draft: false,
164174
}, nil
165175
}
166-
167-
func (c *CostCenterManager) updateNextBillingTime(costCenter *CostCenter, now time.Time) {
168-
nextMonth := NewVarcharTime(time.Now().AddDate(0, 1, 0))
169-
costCenter.NextBillingTime = nextMonth
170-
}

components/usage/pkg/db/cost_center_test.go

+13-5
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,28 @@ func TestCostCenterManager_UpdateCostCenter(t *testing.T) {
8383
func TestSaveCostCenterMovedToStripe(t *testing.T) {
8484
conn := dbtest.ConnectForTests(t)
8585
mnr := db.NewCostCenterManager(conn, db.DefaultSpendingLimit{
86-
ForTeams: 0,
86+
ForTeams: 20,
8787
ForUsers: 500,
8888
})
8989
team := db.NewTeamAttributionID(uuid.New().String())
9090
cleanUp(t, conn, team)
9191
teamCC, err := mnr.GetOrCreateCostCenter(context.Background(), team)
9292
require.NoError(t, err)
93-
require.Equal(t, int32(0), teamCC.SpendingLimit)
93+
require.Equal(t, int32(20), teamCC.SpendingLimit)
9494

9595
teamCC.BillingStrategy = db.CostCenter_Stripe
96-
newTeamCC, err := mnr.UpdateCostCenter(context.Background(), teamCC)
96+
teamCC.SpendingLimit = 400050
97+
teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC)
98+
require.NoError(t, err)
99+
require.Equal(t, db.CostCenter_Stripe, teamCC.BillingStrategy)
100+
require.Equal(t, db.VarcharTime{}, teamCC.NextBillingTime)
101+
require.Equal(t, int32(400050), teamCC.SpendingLimit)
102+
103+
teamCC.BillingStrategy = db.CostCenter_Other
104+
teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC)
97105
require.NoError(t, err)
98-
require.Equal(t, db.CostCenter_Stripe, newTeamCC.BillingStrategy)
99-
require.Equal(t, newTeamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), newTeamCC.NextBillingTime.Time().Truncate(time.Second))
106+
require.Equal(t, teamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), teamCC.NextBillingTime.Time().Truncate(time.Second))
107+
require.Equal(t, int32(20), teamCC.SpendingLimit)
100108
}
101109

102110
func cleanUp(t *testing.T, conn *gorm.DB, attributionIds ...db.AttributionID) {

components/usage/pkg/server/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func registerGRPCServices(srv *baseserver.Server, conn *gorm.DB, stripeClient *s
156156
if stripeClient == nil {
157157
v1.RegisterBillingServiceServer(srv.GRPC(), &apiv1.BillingServiceNoop{})
158158
} else {
159-
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn))
159+
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager))
160160
}
161161
return nil
162162
}

0 commit comments

Comments
 (0)