diff --git a/error.go b/error.go index 3cdb7b3..b9d6452 100644 --- a/error.go +++ b/error.go @@ -206,15 +206,21 @@ func GetLDAPError(packet *ber.Packet) error { return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet} } if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 { - resultCode := uint16(response.Children[0].Value.(int64)) - if resultCode == 0 { // No error - return nil - } - return &Error{ - ResultCode: resultCode, - MatchedDN: response.Children[1].Value.(string), - Err: fmt.Errorf("%s", response.Children[2].Value.(string)), - Packet: packet, + if ber.Type(response.Children[0].Tag) == ber.Type(ber.TagInteger) || ber.Type(response.Children[0].Tag) == ber.Type(ber.TagEnumerated) { + resultCode := uint16(response.Children[0].Value.(int64)) + if resultCode == 0 { // No error + return nil + } + + if ber.Type(response.Children[1].Tag) == ber.Type(ber.TagOctetString) && + ber.Type(response.Children[2].Tag) == ber.Type(ber.TagOctetString) { + return &Error{ + ResultCode: resultCode, + MatchedDN: response.Children[1].Value.(string), + Err: fmt.Errorf("%s", response.Children[2].Value.(string)), + Packet: packet, + } + } } } } diff --git a/error_test.go b/error_test.go index fe7a155..63018ee 100644 --- a/error_test.go +++ b/error_test.go @@ -82,6 +82,27 @@ func TestGetLDAPError(t *testing.T) { } } +// TestGetLDAPErrorInvalidResponse tests that responses with an unexpected ordering or combination of children +// don't cause a panic. +func TestGetLDAPErrorInvalidResponse(t *testing.T) { + bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") + bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "dc=example,dc=org", "matchedDN")) + bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode")) + bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode")) + packet := ber.NewSequence("LDAPMessage") + packet.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "messageID")) + packet.AppendChild(bindResponse) + err := GetLDAPError(packet) + if err == nil { + t.Errorf("Did not get error response") + } + + ldapError := err.(*Error) + if ldapError.ResultCode != ErrorNetwork { + t.Errorf("Got incorrect error code in LDAP error; got %v, expected %v", ldapError.ResultCode, ErrorNetwork) + } +} + // TestGetLDAPErrorSuccess tests parsing of a result with no error (resultCode == 0). func TestGetLDAPErrorSuccess(t *testing.T) { bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") diff --git a/v3/error.go b/v3/error.go index 3cdb7b3..b9d6452 100644 --- a/v3/error.go +++ b/v3/error.go @@ -206,15 +206,21 @@ func GetLDAPError(packet *ber.Packet) error { return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet} } if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 { - resultCode := uint16(response.Children[0].Value.(int64)) - if resultCode == 0 { // No error - return nil - } - return &Error{ - ResultCode: resultCode, - MatchedDN: response.Children[1].Value.(string), - Err: fmt.Errorf("%s", response.Children[2].Value.(string)), - Packet: packet, + if ber.Type(response.Children[0].Tag) == ber.Type(ber.TagInteger) || ber.Type(response.Children[0].Tag) == ber.Type(ber.TagEnumerated) { + resultCode := uint16(response.Children[0].Value.(int64)) + if resultCode == 0 { // No error + return nil + } + + if ber.Type(response.Children[1].Tag) == ber.Type(ber.TagOctetString) && + ber.Type(response.Children[2].Tag) == ber.Type(ber.TagOctetString) { + return &Error{ + ResultCode: resultCode, + MatchedDN: response.Children[1].Value.(string), + Err: fmt.Errorf("%s", response.Children[2].Value.(string)), + Packet: packet, + } + } } } } diff --git a/v3/error_test.go b/v3/error_test.go index fe7a155..63018ee 100644 --- a/v3/error_test.go +++ b/v3/error_test.go @@ -82,6 +82,27 @@ func TestGetLDAPError(t *testing.T) { } } +// TestGetLDAPErrorInvalidResponse tests that responses with an unexpected ordering or combination of children +// don't cause a panic. +func TestGetLDAPErrorInvalidResponse(t *testing.T) { + bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") + bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "dc=example,dc=org", "matchedDN")) + bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode")) + bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode")) + packet := ber.NewSequence("LDAPMessage") + packet.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "messageID")) + packet.AppendChild(bindResponse) + err := GetLDAPError(packet) + if err == nil { + t.Errorf("Did not get error response") + } + + ldapError := err.(*Error) + if ldapError.ResultCode != ErrorNetwork { + t.Errorf("Got incorrect error code in LDAP error; got %v, expected %v", ldapError.ResultCode, ErrorNetwork) + } +} + // TestGetLDAPErrorSuccess tests parsing of a result with no error (resultCode == 0). func TestGetLDAPErrorSuccess(t *testing.T) { bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")