diff --git a/core/base/stat.go b/core/base/stat.go index 952e53d9e..5a2e505eb 100644 --- a/core/base/stat.go +++ b/core/base/stat.go @@ -98,8 +98,8 @@ func (ws *nopWriteStat) AddCount(_ MetricEvent, _ int64) { // ConcurrencyStat provides read/update operation for concurrency statistics. type ConcurrencyStat interface { CurrentConcurrency() int32 - IncreaseConcurrency() - DecreaseConcurrency() + IncreaseConcurrency(int32) + DecreaseConcurrency(int32) } // StatNode holds real-time statistics for resources. diff --git a/core/base/stat_test.go b/core/base/stat_test.go index d4882a243..c1296ee60 100644 --- a/core/base/stat_test.go +++ b/core/base/stat_test.go @@ -69,12 +69,12 @@ func (m *StatNodeMock) CurrentConcurrency() int32 { return int32(args.Int(0)) } -func (m *StatNodeMock) IncreaseConcurrency() { +func (m *StatNodeMock) IncreaseConcurrency(count int32) { m.Called() return } -func (m *StatNodeMock) DecreaseConcurrency() { +func (m *StatNodeMock) DecreaseConcurrency(count int32) { m.Called() return } diff --git a/core/stat/base_node.go b/core/stat/base_node.go index 9ef59b779..23172d856 100644 --- a/core/stat/base_node.go +++ b/core/stat/base_node.go @@ -92,12 +92,12 @@ func (n *BaseStatNode) CurrentConcurrency() int32 { return atomic.LoadInt32(&(n.concurrency)) } -func (n *BaseStatNode) IncreaseConcurrency() { - n.UpdateConcurrency(atomic.AddInt32(&(n.concurrency), 1)) +func (n *BaseStatNode) IncreaseConcurrency(count int32) { + n.UpdateConcurrency(atomic.AddInt32(&(n.concurrency), count)) } -func (n *BaseStatNode) DecreaseConcurrency() { - atomic.AddInt32(&(n.concurrency), -1) +func (n *BaseStatNode) DecreaseConcurrency(count int32) { + atomic.AddInt32(&(n.concurrency), -count) } func (n *BaseStatNode) GenerateReadStat(sampleCount uint32, intervalInMs uint32) (base.ReadStat, error) { diff --git a/core/stat/stat_slot.go b/core/stat/stat_slot.go index 2845dead3..c0f499e09 100644 --- a/core/stat/stat_slot.go +++ b/core/stat/stat_slot.go @@ -77,7 +77,7 @@ func (s *Slot) recordPassFor(sn base.StatNode, count uint32) { if sn == nil { return } - sn.IncreaseConcurrency() + sn.IncreaseConcurrency(int32(count)) sn.AddCount(base.MetricEventPass, int64(count)) } @@ -97,5 +97,5 @@ func (s *Slot) recordCompleteFor(sn base.StatNode, count uint32, rt uint64, err } sn.AddCount(base.MetricEventRt, int64(rt)) sn.AddCount(base.MetricEventComplete, int64(count)) - sn.DecreaseConcurrency() + sn.DecreaseConcurrency(int32(count)) } diff --git a/core/system/slot_test.go b/core/system/slot_test.go index aa918aa20..44e947e1d 100644 --- a/core/system/slot_test.go +++ b/core/system/slot_test.go @@ -66,11 +66,11 @@ func TestDoCheckRuleConcurrency(t *testing.T) { }) t.Run("FalseConcurrency", func(t *testing.T) { - stat.InboundNode().IncreaseConcurrency() + stat.InboundNode().IncreaseConcurrency(1) isOK, _, v := sas.doCheckRule(rule) assert.True(t, util.Float64Equals(float64(1.0), v)) assert.Equal(t, false, isOK) - stat.InboundNode().DecreaseConcurrency() + stat.InboundNode().DecreaseConcurrency(1) }) } diff --git a/tests/api/api_entry_integration_test.go b/tests/api/api_entry_integration_test.go index 6ede27826..a56a073b1 100644 --- a/tests/api/api_entry_integration_test.go +++ b/tests/api/api_entry_integration_test.go @@ -1,7 +1,9 @@ package api import ( + "github.com/alibaba/sentinel-golang/core/isolation" "log" + "os" "runtime/debug" "testing" "time" @@ -127,3 +129,51 @@ func TestAdaptiveFlowControl2(t *testing.T) { _, blockError := api.Entry(rs, api.WithTrafficType(base.Inbound)) assert.Nil(t, blockError) } + +func assertIsPass(t *testing.T, b *base.BlockError) { + assert.True(t, b == nil) +} +func assertIsBlock(t *testing.T, b *base.BlockError) { + assert.True(t, b != nil) +} + +func Test_Isolation(t *testing.T) { + initSentinel() + + r1 := &isolation.Rule{ + Resource: "abc", + MetricType: isolation.Concurrency, + Threshold: 12, + } + _, err := isolation.LoadRules([]*isolation.Rule{r1}) + if err != nil { + logging.Error(err, "fail") + os.Exit(1) + } + + entries := make([]*base.SentinelEntry, 0) + + // Threshold = 12, BatchCount = 1, Should Pass 12 Entry + for i := 0; i < 12; i++ { + e, b := api.Entry("abc", api.WithBatchCount(1)) + assertIsPass(t, b) + entries = append(entries, e) + } + _, b := api.Entry("abc", api.WithBatchCount(1)) + assertIsBlock(t, b) + for _, e := range entries { + e.Exit() + } + + // Threshold = 12, BatchCount = 2, Should Pass 6 Entry + for i := 0; i < 6; i++ { + e, b := api.Entry("abc", api.WithBatchCount(2)) + assertIsPass(t, b) + entries = append(entries, e) + } + _, b = api.Entry("abc", api.WithBatchCount(2)) + assertIsBlock(t, b) + for _, e := range entries { + e.Exit() + } +}