unpackStructValue: drop rdlen, reslice msg instead

This commit is contained in:
Filippo Valsorda 2015-08-05 00:18:02 +01:00
parent 6313235fed
commit b5133fead4

81
msg.go
View File

@ -576,17 +576,16 @@ func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error)
return offset, nil return offset, nil
} }
func unpackTxt(msg []byte, offset, rdend int) ([]string, int, error) { func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
var err error off = off0
var ss []string
var s string var s string
for offset < rdend && err == nil { for off < len(msg) && err == nil {
s, offset, err = unpackTxtString(msg, offset) s, off, err = unpackTxtString(msg, off)
if err == nil { if err == nil {
ss = append(ss, s) ss = append(ss, s)
} }
} }
return ss, offset, err return
} }
func unpackTxtString(msg []byte, offset int) (string, int, error) { func unpackTxtString(msg []byte, offset int) (string, int, error) {
@ -960,17 +959,11 @@ func packStructCompress(any interface{}, msg []byte, off int, compression map[st
return off, err return off, err
} }
// TODO(miek): Fix use of rdlength here
// Unpack a reflect.StructValue from msg. // Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue. // Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
var lenrd int
lenmsg := len(msg) lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ { for i := 0; i < val.NumField(); i++ {
if lenrd != 0 && lenrd == off {
break
}
if off > lenmsg { if off > lenmsg {
return lenmsg, &Error{"bad offset unpacking"} return lenmsg, &Error{"bad offset unpacking"}
} }
@ -982,7 +975,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// therefore it's expected that this interface would be PrivateRdata // therefore it's expected that this interface would be PrivateRdata
switch data := fv.Interface().(type) { switch data := fv.Interface().(type) {
case PrivateRdata: case PrivateRdata:
n, err := data.Unpack(msg[off:lenrd]) n, err := data.Unpack(msg[off:])
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
} }
@ -998,7 +991,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// HIP record slice of name (or none) // HIP record slice of name (or none)
var servers []string var servers []string
var s string var s string
for off < lenrd { for off < lenmsg {
s, off, err = UnpackDomainName(msg, off) s, off, err = UnpackDomainName(msg, off)
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
@ -1007,17 +1000,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
fv.Set(reflect.ValueOf(servers)) fv.Set(reflect.ValueOf(servers))
case `dns:"txt"`: case `dns:"txt"`:
if off == lenmsg || lenrd == off { if off == lenmsg {
break break
} }
var txt []string var txt []string
txt, off, err = unpackTxt(msg, off, lenrd) txt, off, err = unpackTxt(msg, off)
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
} }
fv.Set(reflect.ValueOf(txt)) fv.Set(reflect.ValueOf(txt))
case `dns:"opt"`: // edns0 case `dns:"opt"`: // edns0
if off == lenrd { if off == lenmsg {
// This is an EDNS0 (OPT Record) with no rdata // This is an EDNS0 (OPT Record) with no rdata
// We can safely return here. // We can safely return here.
break break
@ -1030,7 +1023,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
code, off = unpackUint16(msg, off) code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > lenrd { if off1+int(optlen) > lenmsg {
return lenmsg, &Error{err: "overflow unpacking opt"} return lenmsg, &Error{err: "overflow unpacking opt"}
} }
switch code { switch code {
@ -1095,7 +1088,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
edns = append(edns, e) edns = append(edns, e)
off = off1 + int(optlen) off = off1 + int(optlen)
} }
if off < lenrd { if off < lenmsg {
goto Option goto Option
} }
fv.Set(reflect.ValueOf(edns)) fv.Set(reflect.ValueOf(edns))
@ -1106,10 +1099,10 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
continue continue
} }
} }
if off == lenrd { if off == lenmsg {
break // dyn. update break // dyn. update
} }
if off+net.IPv4len > lenrd || off+net.IPv4len > lenmsg { if off+net.IPv4len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking a"} return lenmsg, &Error{err: "overflow unpacking a"}
} }
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
@ -1121,10 +1114,10 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
continue continue
} }
} }
if off == lenrd { if off == lenmsg {
break break
} }
if off+net.IPv6len > lenrd || off+net.IPv6len > lenmsg { if off+net.IPv6len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking aaaa"} return lenmsg, &Error{err: "overflow unpacking aaaa"}
} }
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
@ -1135,7 +1128,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// Rest of the record is the bitmap // Rest of the record is the bitmap
var serv []uint16 var serv []uint16
j := 0 j := 0
for off < lenrd { for off < lenmsg {
if off+1 > lenmsg { if off+1 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking wks"} return lenmsg, &Error{err: "overflow unpacking wks"}
} }
@ -1170,21 +1163,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
fv.Set(reflect.ValueOf(serv)) fv.Set(reflect.ValueOf(serv))
case `dns:"nsec"`: // NSEC/NSEC3 case `dns:"nsec"`: // NSEC/NSEC3
if off == lenrd { if off == lenmsg {
break break
} }
// Rest of the record is the type bitmap // Rest of the record is the type bitmap
if off+2 > lenrd || off+2 > lenmsg { if off+2 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking nsecx"} return lenmsg, &Error{err: "overflow unpacking nsecx"}
} }
var nsec []uint16 var nsec []uint16
length := 0 length := 0
window := 0 window := 0
for off+2 < lenrd { for off+2 < lenmsg {
if off+2 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking nsecx"}
}
window = int(msg[off]) window = int(msg[off])
length = int(msg[off+1]) length = int(msg[off+1])
//println("off, windows, length, end", off, window, length, endrr) //println("off, windows, length, end", off, window, length, endrr)
@ -1241,7 +1230,12 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
return lenmsg, err return lenmsg, err
} }
if val.Type().Field(i).Name == "Hdr" { if val.Type().Field(i).Name == "Hdr" {
lenrd = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) lenrd := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
if lenrd > lenmsg {
return lenmsg, &Error{err: "overflowing header size"}
}
msg = msg[:lenrd]
lenmsg = len(msg)
} }
case reflect.Uint8: case reflect.Uint8:
if off == lenmsg { if off == lenmsg {
@ -1272,6 +1266,9 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))) fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])))
off += 4 off += 4
case reflect.Uint64: case reflect.Uint64:
if off == lenmsg {
break
}
switch val.Type().Field(i).Tag { switch val.Type().Field(i).Tag {
default: default:
if off+8 > lenmsg { if off+8 > lenmsg {
@ -1298,30 +1295,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
default: default:
return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")}
case `dns:"octet"`: case `dns:"octet"`:
strend := lenrd s = string(msg[off:])
if strend > lenmsg { off = lenmsg
return lenmsg, &Error{err: "overflow unpacking octet"}
}
s = string(msg[off:strend])
off = strend
case `dns:"hex"`: case `dns:"hex"`:
hexend := lenrd hexend := lenmsg
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
hexend = off + int(val.FieldByName("HitLength").Uint()) hexend = off + int(val.FieldByName("HitLength").Uint())
} }
if hexend > lenrd || hexend > lenmsg { if hexend > lenmsg {
return lenmsg, &Error{err: "overflow unpacking hex"} return lenmsg, &Error{err: "overflow unpacking HIP hex"}
} }
s = hex.EncodeToString(msg[off:hexend]) s = hex.EncodeToString(msg[off:hexend])
off = hexend off = hexend
case `dns:"base64"`: case `dns:"base64"`:
// Rest of the RR is base64 encoded value // Rest of the RR is base64 encoded value
b64end := lenrd b64end := lenmsg
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
b64end = off + int(val.FieldByName("PublicKeyLength").Uint()) b64end = off + int(val.FieldByName("PublicKeyLength").Uint())
} }
if b64end > lenrd || b64end > lenmsg { if b64end > lenmsg {
return lenmsg, &Error{err: "overflow unpacking base64"} return lenmsg, &Error{err: "overflow unpacking HIP base64"}
} }
s = toBase64(msg[off:b64end]) s = toBase64(msg[off:b64end])
off = b64end off = b64end