diff --git a/edns.go b/edns.go index d2bfecbb2..5314c543b 100644 --- a/edns.go +++ b/edns.go @@ -192,6 +192,11 @@ func (e *EDNS0_NSID) String() string { return string(e.Nsid) } // e.Address = net.ParseIP("127.0.0.1").To4() // for IPv4 // // e.Address = net.ParseIP("2001:7b8:32a::2") // for IPV6 // o.Option = append(o.Option, e) +// +// Note: the spec (draft-ietf-dnsop-edns-client-subnet-00) has some insane logic +// for which netmask applies to the address. This code will parse all the +// available bits when unpacking (up to optlen). When packing it will apply +// SourceNetmask. If you need more advanced logic, patches welcome and good luck. type EDNS0_SUBNET struct { Code uint16 // Always EDNS0SUBNET Family uint16 // 1 for IP, 2 for IP6 @@ -218,38 +223,22 @@ func (e *EDNS0_SUBNET) pack() ([]byte, error) { if e.SourceNetmask > net.IPv4len*8 { return nil, errors.New("dns: bad netmask") } - ip := make([]byte, net.IPv4len) - a := e.Address.To4().Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv4len*8)) - for i := 0; i < net.IPv4len; i++ { - if i+1 > len(e.Address) { - break - } - ip[i] = a[i] - } - needLength := e.SourceNetmask / 8 - if e.SourceNetmask%8 > 0 { - needLength++ + if len(e.Address.To4()) != net.IPv4len { + return nil, errors.New("dns: bad address") } - ip = ip[:needLength] - b = append(b, ip...) + ip := e.Address.To4().Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv4len*8)) + needLength := (e.SourceNetmask + 8 - 1) / 8 // division rounding up + b = append(b, ip[:needLength]...) case 2: if e.SourceNetmask > net.IPv6len*8 { return nil, errors.New("dns: bad netmask") } - ip := make([]byte, net.IPv6len) - a := e.Address.Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv6len*8)) - for i := 0; i < net.IPv6len; i++ { - if i+1 > len(e.Address) { - break - } - ip[i] = a[i] - } - needLength := e.SourceNetmask / 8 - if e.SourceNetmask%8 > 0 { - needLength++ + if len(e.Address) != net.IPv6len { + return nil, errors.New("dns: bad address") } - ip = ip[:needLength] - b = append(b, ip...) + ip := e.Address.Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv6len*8)) + needLength := (e.SourceNetmask + 8 - 1) / 8 // division rounding up + b = append(b, ip[:needLength]...) default: return nil, errors.New("dns: bad address family") } @@ -257,8 +246,7 @@ func (e *EDNS0_SUBNET) pack() ([]byte, error) { } func (e *EDNS0_SUBNET) unpack(b []byte) error { - lb := len(b) - if lb < 4 { + if len(b) < 4 { return ErrBuf } e.Family, _ = unpackUint16(b, 0) @@ -266,25 +254,27 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error { e.SourceScope = b[3] switch e.Family { case 1: - addr := make([]byte, 4) - for i := 0; i < int(e.SourceNetmask/8); i++ { - if i >= len(addr) || 4+i >= len(b) { - return ErrBuf - } + if e.SourceNetmask > net.IPv4len*8 || e.SourceScope > net.IPv4len*8 { + return errors.New("dns: bad netmask") + } + addr := make([]byte, net.IPv4len) + for i := 0; i < net.IPv4len && 4+i < len(b); i++ { addr[i] = b[4+i] } e.Address = net.IPv4(addr[0], addr[1], addr[2], addr[3]) case 2: - addr := make([]byte, 16) - for i := 0; i < int(e.SourceNetmask/8); i++ { - if i >= len(addr) || 4+i >= len(b) { - return ErrBuf - } + if e.SourceNetmask > net.IPv6len*8 || e.SourceScope > net.IPv6len*8 { + return errors.New("dns: bad netmask") + } + addr := make([]byte, net.IPv6len) + for i := 0; i < net.IPv6len && 4+i < len(b); i++ { addr[i] = b[4+i] } e.Address = net.IP{addr[0], addr[1], addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9], addr[10], addr[11], addr[12], addr[13], addr[14], addr[15]} + default: + return errors.New("dns: bad address family") } return nil } diff --git a/msg.go b/msg.go index 3f71c96cd..01dafde90 100644 --- a/msg.go +++ b/msg.go @@ -413,12 +413,6 @@ Loop: case 0x00: if c == 0x00 { // end of name - if len(s) == 0 { - if ptr == 0 { - off1 = off - } - return ".", off1, nil - } break Loop } // literal string @@ -479,6 +473,9 @@ Loop: if ptr == 0 { off1 = off } + if len(s) == 0 { + s = []byte(".") + } return string(s), off1, nil } @@ -576,17 +573,16 @@ func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) return offset, nil } -func unpackTxt(msg []byte, offset, rdend int) ([]string, int, error) { - var err error - var ss []string +func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) { + off = off0 var s string - for offset < rdend && err == nil { - s, offset, err = unpackTxtString(msg, offset) + for off < len(msg) && err == nil { + s, off, err = unpackTxtString(msg, off) if err == nil { ss = append(ss, s) } } - return ss, offset, err + return } func unpackTxtString(msg []byte, offset int) (string, int, error) { @@ -751,58 +747,49 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str if val.Field(i).Len() == 0 { break } - var bitmapbyte uint16 + off1 := off for j := 0; j < val.Field(i).Len(); j++ { - serv := uint16((fv.Index(j).Uint())) - bitmapbyte = uint16(serv / 8) - if int(bitmapbyte) > lenmsg { - return lenmsg, &Error{err: "overflow packing wks"} + serv := int(fv.Index(j).Uint()) + if off+serv/8+1 > len(msg) { + return len(msg), &Error{err: "overflow packing wks"} + } + msg[off+serv/8] |= byte(1 << (7 - uint(serv%8))) + if off+serv/8+1 > off1 { + off1 = off + serv/8 + 1 } - bit := uint16(serv) - bitmapbyte*8 - msg[bitmapbyte] = byte(1 << (7 - bit)) } - off += int(bitmapbyte) + off = off1 case `dns:"nsec"`: // NSEC/NSEC3 // This is the uint16 type bitmap if val.Field(i).Len() == 0 { // Do absolutely nothing break } - - lastwindow := uint16(0) - length := uint16(0) - if off+2 > lenmsg { - return lenmsg, &Error{err: "overflow packing nsecx"} - } + var lastwindow, lastlength uint16 for j := 0; j < val.Field(i).Len(); j++ { - t := uint16((fv.Index(j).Uint())) - window := uint16(t / 256) - if lastwindow != window { + t := uint16(fv.Index(j).Uint()) + window := t / 256 + length := (t-window*256)/8 + 1 + if window > lastwindow && lastlength != 0 { // New window, jump to the new offset - off += int(length) + 3 - if off > lenmsg { - return lenmsg, &Error{err: "overflow packing nsecx bitmap"} - } + off += int(lastlength) + 2 + lastlength = 0 } - length = (t - window*256) / 8 - bit := t - (window * 256) - (length * 8) - if off+2+int(length) > lenmsg { - return lenmsg, &Error{err: "overflow packing nsecx bitmap"} + if window < lastwindow || length < lastlength { + return len(msg), &Error{err: "nsec bits out of order"} + } + if off+2+int(length) > len(msg) { + return len(msg), &Error{err: "overflow packing nsec"} } - // Setting the window # msg[off] = byte(window) // Setting the octets length - msg[off+1] = byte(length + 1) + msg[off+1] = byte(length) // Setting the bit value for the type in the right octet - msg[off+2+int(length)] |= byte(1 << (7 - bit)) - lastwindow = window - } - off += 2 + int(length) - off++ - if off > lenmsg { - return lenmsg, &Error{err: "overflow packing nsecx bitmap"} + msg[off+1+int(length)] |= byte(1 << (7 - (t % 8))) + lastwindow, lastlength = window, length } + off += int(lastlength) + 2 } case reflect.Struct: off, err = packStructValue(fv, msg, off, compression, compress) @@ -960,17 +947,11 @@ func packStructCompress(any interface{}, msg []byte, off int, compression map[st return off, err } -// TODO(miek): Fix use of rdlength here - // Unpack a reflect.StructValue from msg. // Same restrictions as packStructValue. func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { - var lenrd int lenmsg := len(msg) for i := 0; i < val.NumField(); i++ { - if lenrd != 0 && lenrd == off { - break - } if off > lenmsg { return lenmsg, &Error{"bad offset unpacking"} } @@ -982,7 +963,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // therefore it's expected that this interface would be PrivateRdata switch data := fv.Interface().(type) { case PrivateRdata: - n, err := data.Unpack(msg[off:lenrd]) + n, err := data.Unpack(msg[off:]) if err != nil { return lenmsg, err } @@ -998,7 +979,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // HIP record slice of name (or none) var servers []string var s string - for off < lenrd { + for off < lenmsg { s, off, err = UnpackDomainName(msg, off) if err != nil { return lenmsg, err @@ -1007,17 +988,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } fv.Set(reflect.ValueOf(servers)) case `dns:"txt"`: - if off == lenmsg || lenrd == off { + if off == lenmsg { break } var txt []string - txt, off, err = unpackTxt(msg, off, lenrd) + txt, off, err = unpackTxt(msg, off) if err != nil { return lenmsg, err } fv.Set(reflect.ValueOf(txt)) case `dns:"opt"`: // edns0 - if off == lenrd { + if off == lenmsg { // This is an EDNS0 (OPT Record) with no rdata // We can safely return here. break @@ -1025,12 +1006,12 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er var edns []EDNS0 Option: code := uint16(0) - if off+2 > lenmsg { + if off+4 > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } code, off = unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off) - if off1+int(optlen) > lenrd { + if off1+int(optlen) > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } switch code { @@ -1095,7 +1076,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er edns = append(edns, e) off = off1 + int(optlen) } - if off < lenrd { + if off < lenmsg { goto Option } fv.Set(reflect.ValueOf(edns)) @@ -1106,10 +1087,10 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er continue } } - if off == lenrd { + if off == lenmsg { break // dyn. update } - if off+net.IPv4len > lenrd || off+net.IPv4len > lenmsg { + if off+net.IPv4len > lenmsg { return lenmsg, &Error{err: "overflow unpacking a"} } fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) @@ -1121,10 +1102,10 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er continue } } - if off == lenrd { + if off == lenmsg { break } - if off+net.IPv6len > lenrd || off+net.IPv6len > lenmsg { + if off+net.IPv6len > lenmsg { return lenmsg, &Error{err: "overflow unpacking aaaa"} } fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], @@ -1135,7 +1116,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // Rest of the record is the bitmap var serv []uint16 j := 0 - for off < lenrd { + for off < lenmsg { if off+1 > lenmsg { return lenmsg, &Error{err: "overflow unpacking wks"} } @@ -1170,40 +1151,39 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } fv.Set(reflect.ValueOf(serv)) case `dns:"nsec"`: // NSEC/NSEC3 - if off == lenrd { + if off == len(msg) { break } // Rest of the record is the type bitmap - if off+2 > lenrd || off+2 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking nsecx"} - } var nsec []uint16 length := 0 window := 0 - for off+2 < lenrd { - if off+2 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking nsecx"} + lastwindow := -1 + for off < len(msg) { + if off+2 > len(msg) { + return len(msg), &Error{err: "overflow unpacking nsecx"} } - window = int(msg[off]) length = int(msg[off+1]) - //println("off, windows, length, end", off, window, length, endrr) + off += 2 + if window <= lastwindow { + // RFC 4034: Blocks are present in the NSEC RR RDATA in + // increasing numerical order. + return len(msg), &Error{err: "out of order NSEC block"} + } if length == 0 { - // A length window of zero is strange. If there - // the window should not have been specified. Bail out - // println("dns: length == 0 when unpacking NSEC") - return lenmsg, &Error{err: "overflow unpacking nsecx"} + // RFC 4034: Blocks with no types present MUST NOT be included. + return len(msg), &Error{err: "empty NSEC block"} } if length > 32 { - return lenmsg, &Error{err: "overflow unpacking nsecx"} + return len(msg), &Error{err: "NSEC block too long"} + } + if off+length > len(msg) { + return len(msg), &Error{err: "overflowing NSEC block"} } - // Walk the bytes in the window - and check the bit settings... - off += 2 + // Walk the bytes in the window and extract the type bits for j := 0; j < length; j++ { - if off+j+1 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking nsecx"} - } b := msg[off+j] // Check the bits one by one, and set the type if b&0x80 == 0x80 { @@ -1232,6 +1212,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } } off += length + lastwindow = window } fv.Set(reflect.ValueOf(nsec)) } @@ -1241,7 +1222,12 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er return lenmsg, err } if val.Type().Field(i).Name == "Hdr" { - lenrd = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) + lenrd := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) + if lenrd > lenmsg { + return lenmsg, &Error{err: "overflowing header size"} + } + msg = msg[:lenrd] + lenmsg = len(msg) } case reflect.Uint8: if off == lenmsg { @@ -1272,6 +1258,9 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))) off += 4 case reflect.Uint64: + if off == lenmsg { + break + } switch val.Type().Field(i).Tag { default: if off+8 > lenmsg { @@ -1298,30 +1287,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er default: return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} case `dns:"octet"`: - strend := lenrd - if strend > lenmsg { - return lenmsg, &Error{err: "overflow unpacking octet"} - } - s = string(msg[off:strend]) - off = strend + s = string(msg[off:]) + off = lenmsg case `dns:"hex"`: - hexend := lenrd + hexend := lenmsg if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { hexend = off + int(val.FieldByName("HitLength").Uint()) } - if hexend > lenrd || hexend > lenmsg { - return lenmsg, &Error{err: "overflow unpacking hex"} + if hexend > lenmsg { + return lenmsg, &Error{err: "overflow unpacking HIP hex"} } s = hex.EncodeToString(msg[off:hexend]) off = hexend case `dns:"base64"`: // Rest of the RR is base64 encoded value - b64end := lenrd + b64end := lenmsg if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { b64end = off + int(val.FieldByName("PublicKeyLength").Uint()) } - if b64end > lenrd || b64end > lenmsg { - return lenmsg, &Error{err: "overflow unpacking base64"} + if b64end > lenmsg { + return lenmsg, &Error{err: "overflow unpacking HIP base64"} } s = toBase64(msg[off:b64end]) off = b64end