Skip to content

Commit 950a4c0

Browse files
committed
singleflight: add OnceGroup
OnceGroup.Do has the same semantics as Group.Do, but caches and returns the first computed result if called sequentially.
1 parent 66deaeb commit 950a4c0

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

singleflight/singleflight.go

+36
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ type call struct {
2525
wg sync.WaitGroup
2626
val interface{}
2727
err error
28+
// true if call has completed; guarded by (Once)Group.mu
29+
complete bool
2830
}
2931

3032
// Group represents a class of work and forms a namespace in which
@@ -62,3 +64,37 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, err
6264

6365
return c.val, c.err
6466
}
67+
68+
// OnceGroup is like Group, but caches the results of calls.
69+
type OnceGroup struct {
70+
mu sync.Mutex // protects m
71+
m map[string]*call // lazily initialized
72+
}
73+
74+
func (g *OnceGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
75+
g.mu.Lock()
76+
if g.m == nil {
77+
g.m = make(map[string]*call)
78+
}
79+
if c, ok := g.m[key]; ok {
80+
if c.complete {
81+
g.mu.Unlock()
82+
return c.val, c.err
83+
}
84+
g.mu.Unlock()
85+
c.wg.Wait()
86+
return c.val, c.err
87+
}
88+
c := new(call)
89+
c.wg.Add(1)
90+
g.m[key] = c
91+
g.mu.Unlock()
92+
93+
c.val, c.err = fn()
94+
g.mu.Lock()
95+
c.complete = true
96+
g.mu.Unlock()
97+
c.wg.Done()
98+
99+
return c.val, c.err
100+
}

singleflight/singleflight_test.go

+111
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,114 @@ func TestDoDupSuppress(t *testing.T) {
8383
t.Errorf("number of calls = %d; want 1", got)
8484
}
8585
}
86+
87+
func TestDoCalledTwice(t *testing.T) {
88+
var g Group
89+
c := make(chan string)
90+
var calls int32
91+
fn := func() (interface{}, error) {
92+
atomic.AddInt32(&calls, 1)
93+
return <-c, nil
94+
}
95+
96+
const n = 10
97+
var wg sync.WaitGroup
98+
for i := 0; i < n; i++ {
99+
wg.Add(1)
100+
go func() {
101+
v, err := g.Do("key", fn)
102+
if err != nil {
103+
t.Errorf("Do error: %v", err)
104+
}
105+
if v.(string) != "bar" {
106+
t.Errorf("got %q; want %q", v, "bar")
107+
}
108+
wg.Done()
109+
}()
110+
}
111+
time.Sleep(100 * time.Millisecond) // let goroutines above block
112+
c <- "bar"
113+
wg.Wait()
114+
go func() {
115+
// call one more time; fn() should get called a second time
116+
v, err := g.Do("key", fn)
117+
if err != nil {
118+
t.Errorf("Do error: %v", err)
119+
}
120+
if v.(string) != "bar" {
121+
t.Errorf("got %q; want %q", v, "bar")
122+
}
123+
}()
124+
c <- "bar"
125+
if got := atomic.LoadInt32(&calls); got != 2 {
126+
t.Errorf("number of calls = %d; want 2", got)
127+
}
128+
}
129+
130+
func TestOnceDo(t *testing.T) {
131+
var g OnceGroup
132+
v, err := g.Do("key", func() (interface{}, error) {
133+
return "bar", nil
134+
})
135+
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
136+
t.Errorf("Do = %v; want %v", got, want)
137+
}
138+
if err != nil {
139+
t.Errorf("Do error = %v", err)
140+
}
141+
}
142+
143+
func TestOnceDoErr(t *testing.T) {
144+
var g OnceGroup
145+
someErr := errors.New("Some error")
146+
v, err := g.Do("key", func() (interface{}, error) {
147+
return nil, someErr
148+
})
149+
if err != someErr {
150+
t.Errorf("Do error = %v; want someErr", err)
151+
}
152+
if v != nil {
153+
t.Errorf("unexpected non-nil value %#v", v)
154+
}
155+
}
156+
157+
func TestOnceDoDupSuppress(t *testing.T) {
158+
var g OnceGroup
159+
c := make(chan string)
160+
var calls int32
161+
fn := func() (interface{}, error) {
162+
atomic.AddInt32(&calls, 1)
163+
return <-c, nil
164+
}
165+
166+
const n = 10
167+
var wg sync.WaitGroup
168+
for i := 0; i < n; i++ {
169+
wg.Add(1)
170+
go func() {
171+
v, err := g.Do("key", fn)
172+
if err != nil {
173+
t.Errorf("Do error: %v", err)
174+
}
175+
if v.(string) != "bar" {
176+
t.Errorf("got %q; want %q", v, "bar")
177+
}
178+
wg.Done()
179+
}()
180+
}
181+
time.Sleep(100 * time.Millisecond) // let goroutines above block
182+
c <- "bar"
183+
wg.Wait()
184+
// one more time after every goroutine has completed - should return the
185+
// same result instantly.
186+
v, err := g.Do("key", fn)
187+
if err != nil {
188+
t.Errorf("Do error: %v", err)
189+
}
190+
if v.(string) != "bar" {
191+
t.Errorf("got %q; want %q", v, "bar")
192+
}
193+
if got := atomic.LoadInt32(&calls); got != 1 {
194+
t.Errorf("number of calls = %d; want 1", got)
195+
}
196+
}

0 commit comments

Comments
 (0)