diff --git a/add.go b/add.go index c3101b76..9b03e0b7 100644 --- a/add.go +++ b/add.go @@ -1,6 +1,8 @@ package ldap import ( + "context" + ber "github.com/go-asn1-ber/asn1-ber" ) @@ -66,7 +68,12 @@ func NewAddRequest(dn string, controls []Control) *AddRequest { // Add performs the given AddRequest func (l *Conn) Add(addRequest *AddRequest) error { - msgCtx, err := l.doRequest(addRequest) + return l.AddContext(l.ctx, addRequest) +} + +// AddContext performs the given AddRequest +func (l *Conn) AddContext(ctx context.Context, addRequest *AddRequest) error { + msgCtx, err := l.doRequest(ctx, addRequest) if err != nil { return err } diff --git a/bind.go b/bind.go index 9c6cc282..04c9b446 100644 --- a/bind.go +++ b/bind.go @@ -62,7 +62,7 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) } - msgCtx, err := l.doRequest(simpleBindRequest) + msgCtx, err := l.doRequest(l.ctx, simpleBindRequest) if err != nil { return nil, err } @@ -170,7 +170,7 @@ func (l *Conn) DigestMD5Bind(digestMD5BindRequest *DigestMD5BindRequest) (*Diges return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) } - msgCtx, err := l.doRequest(digestMD5BindRequest) + msgCtx, err := l.doRequest(l.ctx, digestMD5BindRequest) if err != nil { return nil, err } @@ -235,7 +235,7 @@ func (l *Conn) DigestMD5Bind(digestMD5BindRequest *DigestMD5BindRequest) (*Diges auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, resp, "Credentials")) request.AppendChild(auth) packet.AppendChild(request) - msgCtx, err = l.sendMessage(packet) + msgCtx, err = l.sendMessage(l.ctx, packet) if err != nil { return nil, fmt.Errorf("send message: %s", err) } @@ -375,7 +375,7 @@ var externalBindRequest = requestFunc(func(envelope *ber.Packet) error { // // See https://tools.ietf.org/html/rfc4422#appendix-A func (l *Conn) ExternalBind() error { - msgCtx, err := l.doRequest(externalBindRequest) + msgCtx, err := l.doRequest(l.ctx, externalBindRequest) if err != nil { return err } @@ -478,7 +478,7 @@ func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindRes return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) } - msgCtx, err := l.doRequest(ntlmBindRequest) + msgCtx, err := l.doRequest(l.ctx, ntlmBindRequest) if err != nil { return nil, err } @@ -538,7 +538,7 @@ func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindRes request.AppendChild(auth) packet.AppendChild(request) - msgCtx, err = l.sendMessage(packet) + msgCtx, err = l.sendMessage(l.ctx, packet) if err != nil { return nil, fmt.Errorf("send message: %s", err) } @@ -671,7 +671,7 @@ func (l *Conn) saslBindTokenExchange(reqControls []Control, reqToken []byte) ([] envelope.AppendChild(encodeControls(reqControls)) } - msgCtx, err := l.sendMessage(envelope) + msgCtx, err := l.sendMessage(l.ctx, envelope) if err != nil { return nil, err } diff --git a/client.go b/client.go index f0312aff..a24c6718 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "time" ) @@ -22,14 +23,23 @@ type Client interface { Unbind() error Add(*AddRequest) error + AddContext(context.Context, *AddRequest) error Del(*DelRequest) error + DelContext(context.Context, *DelRequest) error Modify(*ModifyRequest) error + ModifyContext(context.Context, *ModifyRequest) error ModifyDN(*ModifyDNRequest) error + ModifyDNContext(context.Context, *ModifyDNRequest) error ModifyWithResult(*ModifyRequest) (*ModifyResult, error) + ModifyWithResultContext(context.Context, *ModifyRequest) (*ModifyResult, error) Compare(dn, attribute, value string) (bool, error) + CompareContext(ctx context.Context, dn, attribute, value string) (bool, error) PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) + PasswordModifyContext(context.Context, *PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchContext(context.Context, *SearchRequest) (*SearchResult, error) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) + SearchWithPagingContext(ctx context.Context, searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) } diff --git a/compare.go b/compare.go index cd43e4c5..017a4c59 100644 --- a/compare.go +++ b/compare.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "fmt" ber "github.com/go-asn1-ber/asn1-ber" @@ -31,7 +32,13 @@ func (req *CompareRequest) appendTo(envelope *ber.Packet) error { // Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise // false with any error that occurs if any. func (l *Conn) Compare(dn, attribute, value string) (bool, error) { - msgCtx, err := l.doRequest(&CompareRequest{ + return l.CompareContext(l.ctx, dn, attribute, value) +} + +// CompareContext checks to see if the attribute of the dn matches value. Returns true if it does otherwise +// false with any error that occurs if any. +func (l *Conn) CompareContext(ctx context.Context, dn, attribute, value string) (bool, error) { + msgCtx, err := l.doRequest(ctx, &CompareRequest{ DN: dn, Attribute: attribute, Value: value}) diff --git a/conn.go b/conn.go index 858224af..75f74ff3 100644 --- a/conn.go +++ b/conn.go @@ -2,6 +2,7 @@ package ldap import ( "bufio" + "context" "crypto/tls" "errors" "fmt" @@ -51,7 +52,8 @@ func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { } type messageContext struct { - id int64 + id int64 + ctx context.Context // close(done) should only be called from finishMessage() done chan struct{} // close(responses) should only be called from processMessages(), and only sent to from sendResponse() @@ -93,6 +95,8 @@ type Conn struct { isTLS bool closing uint32 closeErr atomic.Value + ctx context.Context + cancel context.CancelFunc isStartingTLS bool Debug debugging chanConfirm chan struct{} @@ -140,18 +144,25 @@ func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt { } } +func DialWithContext(ctx context.Context) DialOpt { + return func(dc *DialContext) { + dc.ctx = ctx + } +} + // DialContext contains necessary parameters to dial the given ldap URL. type DialContext struct { + ctx context.Context dialer *net.Dialer tlsConfig *tls.Config } -func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { +func (dc *DialContext) dial(ctx context.Context, u *url.URL) (net.Conn, error) { if u.Scheme == "ldapi" { if u.Path == "" || u.Path == "/" { u.Path = "/var/run/slapd/ldapi" } - return dc.dialer.Dial("unix", u.Path) + return dc.dialer.DialContext(ctx, "unix", u.Path) } host, port, err := net.SplitHostPort(u.Host) @@ -166,20 +177,24 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { if port == "" { port = DefaultLdapPort } - return dc.dialer.Dial("udp", net.JoinHostPort(host, port)) + return dc.dialer.DialContext(ctx, "udp", net.JoinHostPort(host, port)) case "ldap": if port == "" { port = DefaultLdapPort } - return dc.dialer.Dial("tcp", net.JoinHostPort(host, port)) + return dc.dialer.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) case "ldaps": if port == "" { port = DefaultLdapsPort } - return tls.DialWithDialer(dc.dialer, "tcp", net.JoinHostPort(host, port), dc.tlsConfig) + tlsDialer := &tls.Dialer{ + NetDialer: dc.dialer, + Config: dc.tlsConfig, + } + return tlsDialer.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) } - return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) + return nil, fmt.Errorf("unknown scheme '%s'", u.Scheme) } // Dial connects to the given address on the given network using net.Dial @@ -219,6 +234,7 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) { } var dc DialContext + dc.ctx = context.Background() for _, opt := range opts { opt(&dc) } @@ -226,13 +242,13 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) { dc.dialer = &net.Dialer{Timeout: DefaultTimeout} } - c, err := dc.dial(u) + c, err := dc.dial(dc.ctx, u) if err != nil { return nil, NewError(ErrorNetwork, err) } conn := NewConn(c, u.Scheme == "ldaps") - conn.Start() + conn.StartContext(dc.ctx) return conn, nil } @@ -249,11 +265,22 @@ func NewConn(conn net.Conn, isTLS bool) *Conn { } } -// Start initializes goroutines to read responses and process messages -func (l *Conn) Start() { +// StartContext initializes goroutines to read responses and process messages. They will be terminated if the context is cancelled. +func (l *Conn) StartContext(ctx context.Context) { + l.ctx, l.cancel = context.WithCancel(ctx) l.wgClose.Add(1) go l.reader() go l.processMessages() + go func() { + // No matter what happens, connection must be closed on cancel to end any blocking calls + <-ctx.Done() + l.conn.Close() + }() +} + +// StartContext initializes goroutines to read responses and process messages. +func (l *Conn) Start() { + l.StartContext(context.Background()) } // IsClosing returns whether or not we're currently closing. @@ -285,9 +312,11 @@ func (l *Conn) Close() { l.wgClose.Done() } l.wgClose.Wait() + l.cancel() } // SetTimeout sets the time after a request is sent that a MessageTimeout triggers +// @deprecated Use context.Context parameters instead. func (l *Conn) SetTimeout(timeout time.Duration) { atomic.StoreInt64(&l.requestTimeout, int64(timeout)) } @@ -313,7 +342,7 @@ func (l *Conn) StartTLS(config *tls.Config) error { packet.AppendChild(request) l.Debug.PrintPacket(packet) - msgCtx, err := l.sendMessageWithFlags(packet, startTLS) + msgCtx, err := l.sendMessageWithFlags(l.ctx, packet, startTLS) if err != nil { return err } @@ -368,11 +397,11 @@ func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) { return tc.ConnectionState(), true } -func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { - return l.sendMessageWithFlags(packet, 0) +func (l *Conn) sendMessage(ctx context.Context, packet *ber.Packet) (*messageContext, error) { + return l.sendMessageWithFlags(ctx, packet, 0) } -func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { +func (l *Conn) sendMessageWithFlags(ctx context.Context, packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { if l.IsClosing() { return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) } @@ -401,14 +430,18 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) Packet: packet, Context: &messageContext{ id: messageID, + ctx: ctx, done: make(chan struct{}), responses: responses, }, } - if !l.sendProcessMessage(message) { + if !l.sendProcessMessage(ctx, message) { if l.IsClosing() { return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) } + if ctx.Err() != nil { + return nil, NewError(ErrorNetwork, ctx.Err()) + } return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason")) } return message.Context, nil @@ -432,17 +465,21 @@ func (l *Conn) finishMessage(msgCtx *messageContext) { Op: MessageFinish, MessageID: msgCtx.id, } - l.sendProcessMessage(message) + l.sendProcessMessage(l.ctx, message) } -func (l *Conn) sendProcessMessage(message *messagePacket) bool { +func (l *Conn) sendProcessMessage(ctx context.Context, message *messagePacket) bool { l.messageMutex.Lock() defer l.messageMutex.Unlock() if l.IsClosing() { return false } - l.chanMessage <- message - return true + select { + case <-ctx.Done(): // abort blocking on l.chanMessage if ctx is cancelled + return false + case l.chanMessage <- message: + return true + } } func (l *Conn) processMessages() { @@ -467,6 +504,8 @@ func (l *Conn) processMessages() { var messageID int64 = 1 for { select { + case <-l.ctx.Done(): + return case l.chanMessageID <- messageID: messageID++ case message := <-l.chanMessage: @@ -493,27 +532,20 @@ func (l *Conn) processMessages() { // Add timeout if defined if l.requestTimeout > 0 { - go func() { - timer := time.NewTimer(time.Duration(l.requestTimeout)) - defer func() { - if err := recover(); err != nil { - logger.Printf("ldap: recovered panic in RequestTimeout: %v", err) - } - - timer.Stop() - }() - - select { - case <-timer.C: - timeoutMessage := &messagePacket{ - Op: MessageTimeout, - MessageID: message.MessageID, - } - l.sendProcessMessage(timeoutMessage) - case <-message.Context.done: - } - }() + message.Context.ctx, _ = context.WithTimeout(message.Context.ctx, time.Duration(l.requestTimeout)) } + + go func() { + select { + case <-message.Context.ctx.Done(): + timeoutMessage := &messagePacket{ + Op: MessageTimeout, + MessageID: message.MessageID, + } + l.sendProcessMessage(l.ctx, timeoutMessage) + case <-message.Context.done: + } + }() case MessageResponse: l.Debug.Printf("Receiving message %d", message.MessageID) if msgCtx, ok := l.messageContexts[message.MessageID]; ok { @@ -527,7 +559,7 @@ func (l *Conn) processMessages() { // All reads will return immediately if msgCtx, ok := l.messageContexts[message.MessageID]; ok { l.Debug.Printf("Receiving message timeout for %d", message.MessageID) - msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}) + msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, fmt.Errorf("ldap: %w", msgCtx.ctx.Err()))}) delete(l.messageContexts, message.MessageID) close(msgCtx.responses) } @@ -585,7 +617,7 @@ func (l *Conn) reader() { MessageID: packet.Children[0].Value.(int64), Packet: packet, } - if !l.sendProcessMessage(message) { + if !l.sendProcessMessage(l.ctx, message) { return } } diff --git a/conn_test.go b/conn_test.go index c2885686..46ffdb05 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,6 +2,7 @@ package ldap import ( "bytes" + "context" "errors" "io" "net" @@ -16,6 +17,7 @@ import ( ) func TestUnresponsiveConnection(t *testing.T) { + ctx := context.Background() // The do-nothing server that accepts requests and does nothing ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) @@ -39,7 +41,7 @@ func TestUnresponsiveConnection(t *testing.T) { packet.AppendChild(bindRequest) // Send packet and test response - msgCtx, err := conn.sendMessage(packet) + msgCtx, err := conn.sendMessage(ctx, packet) if err != nil { t.Fatalf("error sending message: %v", err) } @@ -53,7 +55,7 @@ func TestUnresponsiveConnection(t *testing.T) { if err == nil { t.Fatalf("expected timeout error") } - if !IsErrorWithCode(err, ErrorNetwork) || err.(*Error).Err.Error() != "ldap: connection timed out" { + if !IsErrorWithCode(err, ErrorNetwork) || !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("unexpected error: %v", err) } } @@ -107,13 +109,14 @@ func TestFinishMessage(t *testing.T) { // See: https://github.com/go-ldap/ldap/issues/332 func TestNilConnection(t *testing.T) { var conn *Conn - _, err := conn.Search(&SearchRequest{}) + _, err := conn.SearchContext(context.Background(), &SearchRequest{}) if err != ErrNilConnection { t.Fatalf("expected error to be ErrNilConnection, got %v", err) } } func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) { + ctx := context.Background() var msgID int64 runWithTimeout(t, time.Second, func() { msgID = conn.nextMessageID() @@ -125,7 +128,7 @@ func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCt var err error runWithTimeout(t, time.Second, func() { - msgCtx, err = conn.sendMessage(requestPacket) + msgCtx, err = conn.sendMessage(ctx, requestPacket) if err != nil { t.Fatalf("unable to send request message: %s", err) } diff --git a/del.go b/del.go index bac0dfb7..bae90fa9 100644 --- a/del.go +++ b/del.go @@ -1,6 +1,8 @@ package ldap import ( + "context" + ber "github.com/go-asn1-ber/asn1-ber" ) @@ -34,7 +36,12 @@ func NewDelRequest(DN string, Controls []Control) *DelRequest { // Del executes the given delete request func (l *Conn) Del(delRequest *DelRequest) error { - msgCtx, err := l.doRequest(delRequest) + return l.DelContext(l.ctx, delRequest) +} + +// DelContext executes the given delete request +func (l *Conn) DelContext(ctx context.Context, delRequest *DelRequest) error { + msgCtx, err := l.doRequest(ctx, delRequest) if err != nil { return err } diff --git a/error.go b/error.go index 3cdb7b31..d654af92 100644 --- a/error.go +++ b/error.go @@ -192,18 +192,22 @@ func (e *Error) Error() string { return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error()) } +func (e *Error) Unwrap() error { + return e.Err +} + // GetLDAPError creates an Error out of a BER packet representing a LDAPResult // The return is an error object. It can be casted to a Error structure. // This function returns nil if resultCode in the LDAPResult sequence is success(0). func GetLDAPError(packet *ber.Packet) error { if packet == nil { - return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")} + return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("empty packet")} } if len(packet.Children) >= 2 { response := packet.Children[1] if response == nil { - return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet} + return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("empty response in packet"), Packet: packet} } if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 { resultCode := uint16(response.Children[0].Value.(int64)) @@ -219,7 +223,7 @@ func GetLDAPError(packet *ber.Packet) error { } } - return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format"), Packet: packet} + return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("invalid packet format"), Packet: packet} } // NewError creates an LDAP error with the given code and underlying error diff --git a/moddn.go b/moddn.go index ec246d1f..51ce702d 100644 --- a/moddn.go +++ b/moddn.go @@ -1,6 +1,8 @@ package ldap import ( + "context" + ber "github.com/go-asn1-ber/asn1-ber" ) @@ -40,7 +42,7 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi // // Refer NewModifyDNRequest for other parameters func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool, - newSup string, controls []Control) *ModifyDNRequest { + newSup string, controls []Control) *ModifyDNRequest { return &ModifyDNRequest{ DN: dn, NewRDN: rdn, @@ -75,7 +77,13 @@ func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error { // ModifyDN renames the given DN and optionally move to another base (when the "newSup" argument // to NewModifyDNRequest() is not ""). func (l *Conn) ModifyDN(m *ModifyDNRequest) error { - msgCtx, err := l.doRequest(m) + return l.ModifyDNContext(l.ctx, m) +} + +// ModifyDNContext renames the given DN and optionally move to another base (when the "newSup" argument +// to NewModifyDNRequest() is not ""). +func (l *Conn) ModifyDNContext(ctx context.Context, m *ModifyDNRequest) error { + msgCtx, err := l.doRequest(ctx, m) if err != nil { return err } diff --git a/modify.go b/modify.go index 8b379558..98e373c5 100644 --- a/modify.go +++ b/modify.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "errors" ber "github.com/go-asn1-ber/asn1-ber" @@ -109,7 +110,12 @@ func NewModifyRequest(dn string, controls []Control) *ModifyRequest { // Modify performs the ModifyRequest func (l *Conn) Modify(modifyRequest *ModifyRequest) error { - msgCtx, err := l.doRequest(modifyRequest) + return l.ModifyContext(l.ctx, modifyRequest) +} + +// Modify performs the ModifyRequest +func (l *Conn) ModifyContext(ctx context.Context, modifyRequest *ModifyRequest) error { + msgCtx, err := l.doRequest(ctx, modifyRequest) if err != nil { return err } @@ -141,7 +147,12 @@ type ModifyResult struct { // ModifyWithResult performs the ModifyRequest and returns the result func (l *Conn) ModifyWithResult(modifyRequest *ModifyRequest) (*ModifyResult, error) { - msgCtx, err := l.doRequest(modifyRequest) + return l.ModifyWithResultContext(l.ctx, modifyRequest) +} + +// ModifyWithResultContext performs the ModifyRequest and returns the result +func (l *Conn) ModifyWithResultContext(ctx context.Context, modifyRequest *ModifyRequest) (*ModifyResult, error) { + msgCtx, err := l.doRequest(ctx, modifyRequest) if err != nil { return nil, err } diff --git a/passwdmodify.go b/passwdmodify.go index e776e3b3..ac2de325 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "fmt" ber "github.com/go-asn1-ber/asn1-ber" @@ -81,7 +82,12 @@ func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPasswo // PasswordModify performs the modification request func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { - msgCtx, err := l.doRequest(passwordModifyRequest) + return l.PasswordModifyContext(l.ctx, passwordModifyRequest) +} + +// PasswordModify performs the modification request +func (l *Conn) PasswordModifyContext(ctx context.Context, passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { + msgCtx, err := l.doRequest(ctx, passwordModifyRequest) if err != nil { return nil, err } diff --git a/request.go b/request.go index adc3b1c2..159cef81 100644 --- a/request.go +++ b/request.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "errors" "fmt" @@ -23,7 +24,7 @@ func (f requestFunc) appendTo(p *ber.Packet) error { return f(p) } -func (l *Conn) doRequest(req request) (*messageContext, error) { +func (l *Conn) doRequest(ctx context.Context, req request) (*messageContext, error) { if l == nil || l.conn == nil { return nil, ErrNilConnection } @@ -38,7 +39,7 @@ func (l *Conn) doRequest(req request) (*messageContext, error) { l.Debug.PrintPacket(packet) } - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(ctx, packet) if err != nil { return nil, err } diff --git a/search.go b/search.go index c174f197..432de0f3 100644 --- a/search.go +++ b/search.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "errors" "fmt" "reflect" @@ -411,6 +412,18 @@ func NewSearchRequest( // - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries // A requested pagingSize of 0 is interpreted as no limit by LDAP servers. func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) { + return l.SearchWithPagingContext(l.ctx, searchRequest, pagingSize) +} + +// SearchWithPagingContext accepts a search request and desired page size in order to execute LDAP queries to fulfill the +// search request. All paged LDAP query responses will be buffered and the final result will be returned atomically. +// The following four cases are possible given the arguments: +// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size +// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries +// A requested pagingSize of 0 is interpreted as no limit by LDAP servers. +func (l *Conn) SearchWithPagingContext(ctx context.Context, searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) { var pagingControl *ControlPaging control := FindControl(searchRequest.Controls, ControlTypePaging) @@ -430,7 +443,7 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) searchResult := new(SearchResult) for { - result, err := l.Search(searchRequest) + result, err := l.SearchContext(ctx, searchRequest) l.Debug.Printf("Looking for Paging Control...") if err != nil { return searchResult, err @@ -463,7 +476,7 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) if pagingControl != nil { l.Debug.Printf("Abandoning Paging...") pagingControl.PagingSize = 0 - if _, err := l.Search(searchRequest); err != nil { + if _, err := l.SearchContext(ctx, searchRequest); err != nil { return searchResult, err } } @@ -473,7 +486,12 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) // Search performs the given search request func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { - msgCtx, err := l.doRequest(searchRequest) + return l.SearchContext(l.ctx, searchRequest) +} + +// SearchContext performs the given search request +func (l *Conn) SearchContext(ctx context.Context, searchRequest *SearchRequest) (*SearchResult, error) { + msgCtx, err := l.doRequest(ctx, searchRequest) if err != nil { return nil, err } diff --git a/unbind.go b/unbind.go index 6c411cd1..dafeff36 100644 --- a/unbind.go +++ b/unbind.go @@ -23,7 +23,7 @@ func (l *Conn) Unbind() error { return ErrConnUnbound } - _, err := l.doRequest(unbindRequest{}) + _, err := l.doRequest(l.ctx, unbindRequest{}) if err != nil { return err } diff --git a/whoami.go b/whoami.go index 10c523d0..b15716bf 100644 --- a/whoami.go +++ b/whoami.go @@ -42,7 +42,7 @@ func (l *Conn) WhoAmI(controls []Control) (*WhoAmIResult, error) { l.Debug.PrintPacket(packet) - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(l.ctx, packet) if err != nil { return nil, err } @@ -77,7 +77,7 @@ func (l *Conn) WhoAmI(controls []Control) (*WhoAmIResult, error) { return nil, err } } else { - return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag)) + return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag)) } extendedResponse := packet.Children[1]