Skip to content

Commit 2f623f0

Browse files
committed
feat: provide search async function and drop search with channels go-ldap#319 go-ldap#341
1 parent a9daeeb commit 2f623f0

10 files changed

+267
-308
lines changed

client.go

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"time"
67
)
@@ -32,6 +33,7 @@ type Client interface {
3233
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)
3334

3435
Search(*SearchRequest) (*SearchResult, error)
36+
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
3537
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
3638
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
3739
}

examples_test.go

+9-8
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ func ExampleConn_Search() {
5151
}
5252
}
5353

54-
// This example demonstrates how to search with channel
55-
func ExampleConn_SearchWithChannel() {
54+
// This example demonstrates how to search asynchronously
55+
func ExampleConn_SearchAsync() {
5656
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
5757
if err != nil {
5858
log.Fatal(err)
@@ -70,12 +70,13 @@ func ExampleConn_SearchWithChannel() {
7070
ctx, cancel := context.WithCancel(context.Background())
7171
defer cancel()
7272

73-
ch := l.SearchWithChannel(ctx, searchRequest, 64)
74-
for res := range ch {
75-
if res.Error != nil {
76-
log.Fatalf("Error searching: %s", res.Error)
77-
}
78-
fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN)
73+
r := l.SearchAsync(ctx, searchRequest, 64)
74+
for r.Next() {
75+
entry := r.Entry()
76+
fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN)
77+
}
78+
if err := r.Err(); err != nil {
79+
log.Fatal(err)
7980
}
8081
}
8182

ldap_test.go

+21-20
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ldap
33
import (
44
"context"
55
"crypto/tls"
6+
"log"
67
"testing"
78

89
ber "github.com/go-asn1-ber/asn1-ber"
@@ -346,7 +347,7 @@ func TestEscapeDN(t *testing.T) {
346347
}
347348
}
348349

349-
func TestSearchWithChannel(t *testing.T) {
350+
func TestSearchAsync(t *testing.T) {
350351
l, err := DialURL(ldapServer)
351352
if err != nil {
352353
t.Fatal(err)
@@ -362,17 +363,18 @@ func TestSearchWithChannel(t *testing.T) {
362363

363364
srs := make([]*Entry, 0)
364365
ctx := context.Background()
365-
for sr := range l.SearchWithChannel(ctx, searchRequest, 64) {
366-
if sr.Error != nil {
367-
t.Fatal(err)
368-
}
369-
srs = append(srs, sr.Entry)
366+
r := l.SearchAsync(ctx, searchRequest, 64)
367+
for r.Next() {
368+
srs = append(srs, r.Entry())
369+
}
370+
if err := r.Err(); err != nil {
371+
log.Fatal(err)
370372
}
371373

372-
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
374+
t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs))
373375
}
374376

375-
func TestSearchWithChannelAndCancel(t *testing.T) {
377+
func TestSearchAsyncAndCancel(t *testing.T) {
376378
l, err := DialURL(ldapServer)
377379
if err != nil {
378380
t.Fatal(err)
@@ -390,22 +392,21 @@ func TestSearchWithChannelAndCancel(t *testing.T) {
390392
srs := make([]*Entry, 0)
391393
ctx, cancel := context.WithCancel(context.Background())
392394
defer cancel()
393-
ch := l.SearchWithChannel(ctx, searchRequest, 0)
394-
for i := 0; i < 10; i++ {
395-
sr := <-ch
396-
if sr.Error != nil {
397-
t.Fatal(err)
398-
}
399-
srs = append(srs, sr.Entry)
395+
r := l.SearchAsync(ctx, searchRequest, 0)
396+
for r.Next() {
397+
srs = append(srs, r.Entry())
400398
if len(srs) == cancelNum {
401399
cancel()
402400
}
403401
}
404-
for range ch {
405-
t.Log("Consume all entries from the channel to prevent blocking by the connection")
402+
if err := r.Err(); err != nil {
403+
log.Fatal(err)
406404
}
407-
if len(srs) != cancelNum {
408-
t.Errorf("Got entries %d, expected %d", len(srs), cancelNum)
405+
406+
if len(srs) > cancelNum+3 {
407+
// the cancellation process is asynchronous,
408+
// so it might get some entries after calling cancel()
409+
t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3)
409410
}
410-
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
411+
t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
411412
}

response.go

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package ldap
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
8+
ber "github.com/go-asn1-ber/asn1-ber"
9+
)
10+
11+
// Response defines an interface to get data from an LDAP server
12+
type Response interface {
13+
Entry() *Entry
14+
Referral() string
15+
Controls() []Control
16+
Err() error
17+
Next() bool
18+
}
19+
20+
type searchResponse struct {
21+
conn *Conn
22+
ch chan *SearchSingleResult
23+
24+
entry *Entry
25+
referral string
26+
controls []Control
27+
err error
28+
}
29+
30+
// Entry returns an entry from the given search request
31+
func (r *searchResponse) Entry() *Entry {
32+
return r.entry
33+
}
34+
35+
// Referral returns a referral from the given search request
36+
func (r *searchResponse) Referral() string {
37+
return r.referral
38+
}
39+
40+
// Controls returns controls from the given search request
41+
func (r *searchResponse) Controls() []Control {
42+
return r.controls
43+
}
44+
45+
// Err returns an error when the given search request was failed
46+
func (r *searchResponse) Err() error {
47+
return r.err
48+
}
49+
50+
// Next returns whether next data exist or not
51+
func (r *searchResponse) Next() bool {
52+
res, ok := <-r.ch
53+
if !ok {
54+
return false
55+
}
56+
if res == nil {
57+
return false
58+
}
59+
r.err = res.Error
60+
if r.err != nil {
61+
return false
62+
}
63+
r.err = r.conn.GetLastError()
64+
if r.err != nil {
65+
return false
66+
}
67+
r.entry = res.Entry
68+
r.referral = res.Referral
69+
r.controls = res.Controls
70+
return true
71+
}
72+
73+
func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) {
74+
go func() {
75+
defer func() {
76+
close(r.ch)
77+
if err := recover(); err != nil {
78+
r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err)
79+
}
80+
}()
81+
82+
if r.conn.IsClosing() {
83+
return
84+
}
85+
86+
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
87+
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
88+
// encode search request
89+
err := searchRequest.appendTo(packet)
90+
if err != nil {
91+
r.ch <- &SearchSingleResult{Error: err}
92+
return
93+
}
94+
r.conn.Debug.PrintPacket(packet)
95+
96+
msgCtx, err := r.conn.sendMessage(packet)
97+
if err != nil {
98+
r.ch <- &SearchSingleResult{Error: err}
99+
return
100+
}
101+
defer r.conn.finishMessage(msgCtx)
102+
103+
foundSearchSingleResultDone := false
104+
for !foundSearchSingleResultDone {
105+
select {
106+
case <-ctx.Done():
107+
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
108+
return
109+
default:
110+
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
111+
packetResponse, ok := <-msgCtx.responses
112+
if !ok {
113+
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
114+
r.ch <- &SearchSingleResult{Error: err}
115+
return
116+
}
117+
packet, err = packetResponse.ReadPacket()
118+
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
119+
if err != nil {
120+
r.ch <- &SearchSingleResult{Error: err}
121+
return
122+
}
123+
124+
if r.conn.Debug {
125+
if err := addLDAPDescriptions(packet); err != nil {
126+
r.ch <- &SearchSingleResult{Error: err}
127+
return
128+
}
129+
ber.PrintPacket(packet)
130+
}
131+
132+
switch packet.Children[1].Tag {
133+
case ApplicationSearchResultEntry:
134+
r.ch <- &SearchSingleResult{
135+
Entry: &Entry{
136+
DN: packet.Children[1].Children[0].Value.(string),
137+
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
138+
},
139+
}
140+
141+
case ApplicationSearchResultDone:
142+
if err := GetLDAPError(packet); err != nil {
143+
r.ch <- &SearchSingleResult{Error: err}
144+
return
145+
}
146+
if len(packet.Children) == 3 {
147+
result := &SearchSingleResult{}
148+
for _, child := range packet.Children[2].Children {
149+
decodedChild, err := DecodeControl(child)
150+
if err != nil {
151+
werr := fmt.Errorf("failed to decode child control: %w", err)
152+
r.ch <- &SearchSingleResult{Error: werr}
153+
return
154+
}
155+
result.Controls = append(result.Controls, decodedChild)
156+
}
157+
r.ch <- result
158+
}
159+
foundSearchSingleResultDone = true
160+
161+
case ApplicationSearchResultReference:
162+
ref := packet.Children[1].Children[0].Value.(string)
163+
r.ch <- &SearchSingleResult{Referral: ref}
164+
}
165+
}
166+
}
167+
r.conn.Debug.Printf("%d: returning", msgCtx.id)
168+
}()
169+
}
170+
171+
func newSearchResponse(conn *Conn, bufferSize int) *searchResponse {
172+
var ch chan *SearchSingleResult
173+
if bufferSize > 0 {
174+
ch = make(chan *SearchSingleResult, bufferSize)
175+
} else {
176+
ch = make(chan *SearchSingleResult)
177+
}
178+
return &searchResponse{
179+
conn: conn,
180+
ch: ch,
181+
}
182+
}

0 commit comments

Comments
 (0)