Remove Tag.Get() and just look at the raw strings

This commit is contained in:
Miek Gieben 2012-10-11 12:57:08 +02:00
parent dce8b2e71a
commit 6e43b3b666

76
msg.go
View File

@ -375,10 +375,10 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
default: default:
return lenmsg, &Error{Err: "bad kind packing"} return lenmsg, &Error{Err: "bad kind packing"}
case reflect.Slice: case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag {
default: default:
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing slice"} return lenmsg, &Error{Name: val.Type().Field(i).Tag, Err: "bad tag packing slice"}
case "domain-name": case `dns:"domain-name"`:
for j := 0; j < val.Field(i).Len(); j++ { for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).String() element := val.Field(i).Index(j).String()
off, err = PackDomainName(element, msg, off, compression, false && compress) off, err = PackDomainName(element, msg, off, compression, false && compress)
@ -386,7 +386,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
return lenmsg, err return lenmsg, err
} }
} }
case "txt": case `dns:"txt"`:
for j := 0; j < val.Field(i).Len(); j++ { for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).String() element := val.Field(i).Index(j).String()
// Counted string: 1 byte length. // Counted string: 1 byte length.
@ -400,7 +400,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
} }
off += len(element) off += len(element)
} }
case "opt": // edns case `dns:"opt"`: // edns
for j := 0; j < val.Field(i).Len(); j++ { for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).Interface() element := val.Field(i).Index(j).Interface()
b, e := element.(EDNS0).pack() b, e := element.(EDNS0).pack()
@ -416,7 +416,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
copy(msg[off:off+len(b)], b) copy(msg[off:off+len(b)], b)
off += len(b) off += len(b)
} }
case "a": case `dns:"a"`:
// It must be a slice of 4, even if it is 16, we encode // It must be a slice of 4, even if it is 16, we encode
// only the first 4 // only the first 4
if off+net.IPv4len > lenmsg { if off+net.IPv4len > lenmsg {
@ -440,7 +440,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
default: default:
return lenmsg, &Error{Err: "overflow packing a"} return lenmsg, &Error{Err: "overflow packing a"}
} }
case "aaaa": case `dns:"aaaa"`:
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg { if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
return lenmsg, &Error{Err: "overflow packing aaaa"} return lenmsg, &Error{Err: "overflow packing aaaa"}
} }
@ -448,7 +448,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
msg[off] = byte(fv.Index(j).Uint()) msg[off] = byte(fv.Index(j).Uint())
off++ off++
} }
case "wks": case `dns:"wks"`:
if val.Field(i).Len() == 0 { if val.Field(i).Len() == 0 {
break break
} }
@ -463,7 +463,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
msg[bitmapbyte] = byte(1 << (7 - bit)) msg[bitmapbyte] = byte(1 << (7 - bit))
} }
off += int(bitmapbyte) off += int(bitmapbyte)
case "nsec": // NSEC/NSEC3 case `dns:"nsec"`: // NSEC/NSEC3
// This is the uint16 type bitmap // This is the uint16 type bitmap
if val.Field(i).Len() == 0 { if val.Field(i).Len() == 0 {
// Do absolutely nothing // Do absolutely nothing
@ -548,40 +548,40 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// There are multiple string encodings. // There are multiple string encodings.
// The tag distinguishes ordinary strings from domain names. // The tag distinguishes ordinary strings from domain names.
s := fv.String() s := fv.String()
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag {
default: default:
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing string"} return lenmsg, &Error{Name: val.Type().Field(i).Tag, Err: "bad tag packing string"}
case "base64": case `dns:"base64"`:
b64, err := packBase64([]byte(s)) b64, err := packBase64([]byte(s))
if err != nil { if err != nil {
return lenmsg, &Error{Err: "overflow packing base64"} return lenmsg, &Error{Err: "overflow packing base64"}
} }
copy(msg[off:off+len(b64)], b64) copy(msg[off:off+len(b64)], b64)
off += len(b64) off += len(b64)
case "domain-name": case `dns:"domain-name"`:
if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil { if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
return lenmsg, err return lenmsg, err
} }
case "cdomain-name": case `dns:"cdomain-name"`:
if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil { if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil {
return lenmsg, err return lenmsg, err
} }
case "size-base32": case `dns:"size-base32"`:
// This is purely for NSEC3 atm, the previous byte must // This is purely for NSEC3 atm, the previous byte must
// holds the length of the encoded string. As NSEC3 // holds the length of the encoded string. As NSEC3
// is only defined to SHA1, the hashlength is 20 (160 bits) // is only defined to SHA1, the hashlength is 20 (160 bits)
msg[off-1] = 20 msg[off-1] = 20
fallthrough fallthrough
case "base32": case `dns:"base32"`:
b32, err := packBase32([]byte(s)) b32, err := packBase32([]byte(s))
if err != nil { if err != nil {
return lenmsg, &Error{Err: "overflow packing base32"} return lenmsg, &Error{Err: "overflow packing base32"}
} }
copy(msg[off:off+len(b32)], b32) copy(msg[off:off+len(b32)], b32)
off += len(b32) off += len(b32)
case "size-hex": case `dns:"size-hex"`:
fallthrough fallthrough
case "hex": case `dns:"hex"`:
// There is no length encoded here // There is no length encoded here
h, e := hex.DecodeString(s) h, e := hex.DecodeString(s)
if e != nil { if e != nil {
@ -592,12 +592,12 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
} }
copy(msg[off:off+hex.DecodedLen(len(s))], h) copy(msg[off:off+hex.DecodedLen(len(s))], h)
off += hex.DecodedLen(len(s)) off += hex.DecodedLen(len(s))
case "size": case `dns:"size"`:
// the size is already encoded in the RR, we can safely use the // the size is already encoded in the RR, we can safely use the
// length of string. String is RAW (not encoded in hex, nor base64) // length of string. String is RAW (not encoded in hex, nor base64)
copy(msg[off:off+len(s)], s) copy(msg[off:off+len(s)], s)
off += len(s) off += len(s)
case "txt": case `dns:"txt"`:
fallthrough fallthrough
case "": case "":
// Counted string: 1 byte length. // Counted string: 1 byte length.
@ -640,10 +640,10 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
default: default:
return lenmsg, &Error{Err: "bad kind unpacking"} return lenmsg, &Error{Err: "bad kind unpacking"}
case reflect.Slice: case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag {
default: default:
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking slice"} return lenmsg, &Error{Name: val.Type().Field(i).Tag("dns"), Err: "bad tag unpacking slice"}
case "domain-name": case `dns:"domain-name"`:
// HIP record slice of name (or none) // HIP record slice of name (or none)
servers := make([]string, 0) servers := make([]string, 0)
var s string var s string
@ -655,7 +655,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
servers = append(servers, s) servers = append(servers, s)
} }
fv.Set(reflect.ValueOf(servers)) fv.Set(reflect.ValueOf(servers))
case "txt": case `dns:"txt"`:
txt := make([]string, 0) txt := make([]string, 0)
rdlength := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txts: Txts:
@ -670,7 +670,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
goto Txts goto Txts
} }
fv.Set(reflect.ValueOf(txt)) fv.Set(reflect.ValueOf(txt))
case "opt": // edns0 case `dns:"opt"`: // edns0
// TODO: multiple EDNS0 options // TODO: multiple EDNS0 options
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
if rdlength == 0 { if rdlength == 0 {
@ -702,13 +702,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
fv.Set(reflect.ValueOf(edns)) fv.Set(reflect.ValueOf(edns))
// multiple EDNS codes? // multiple EDNS codes?
case "a": case `dns:"a"`:
if off+net.IPv4len > len(msg) { if off+net.IPv4len > len(msg) {
return lenmsg, &Error{Err: "overflow unpacking a"} return lenmsg, &Error{Err: "overflow unpacking a"}
} }
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len off += net.IPv4len
case "aaaa": case `dns:"aaaa"`:
if off+net.IPv6len > lenmsg { if off+net.IPv6len > lenmsg {
return lenmsg, &Error{Err: "overflow unpacking aaaa"} return lenmsg, &Error{Err: "overflow unpacking aaaa"}
} }
@ -716,7 +716,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10], msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10],
msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]})) msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]}))
off += net.IPv6len off += net.IPv6len
case "wks": case `dns:"wks"`:
// Rest of the record is the bitmap // Rest of the record is the bitmap
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength endrr := rdstart + rdlength
@ -753,7 +753,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
off++ off++
} }
fv.Set(reflect.ValueOf(serv)) fv.Set(reflect.ValueOf(serv))
case "nsec": // NSEC/NSEC3 case `dns:"nsec"`: // NSEC/NSEC3
// Rest of the record is the type bitmap // Rest of the record is the type bitmap
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength endrr := rdstart + rdlength
@ -847,10 +847,10 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
off += 6 off += 6
case reflect.String: case reflect.String:
var s string var s string
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag {
default: default:
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking string"} return lenmsg, &Error{Name: val.Type().Field(i).Tag, Err: "bad tag unpacking string"}
case "hex": case `dns:"hex"`:
// Rest of the RR is hex encoded, network order an issue here? // Rest of the RR is hex encoded, network order an issue here?
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength endrr := rdstart + rdlength
@ -859,7 +859,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
s = hex.EncodeToString(msg[off:endrr]) s = hex.EncodeToString(msg[off:endrr])
off = endrr off = endrr
case "base64": case `dns:"base64"`:
// Rest of the RR is base64 encoded value // Rest of the RR is base64 encoded value
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength endrr := rdstart + rdlength
@ -868,14 +868,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
s = unpackBase64(msg[off:endrr]) s = unpackBase64(msg[off:endrr])
off = endrr off = endrr
case "cdomain-name": case `dns:"cdomain-name"`:
fallthrough fallthrough
case "domain-name": case `dns:"domain-name"`:
s, off, err = UnpackDomainName(msg, off) s, off, err = UnpackDomainName(msg, off)
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
} }
case "size-base32": case `dns:"size-base32"`:
var size int var size int
switch val.Type().Name() { switch val.Type().Name() {
case "RR_NSEC3": case "RR_NSEC3":
@ -890,7 +890,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
s = unpackBase32(msg[off : off+size]) s = unpackBase32(msg[off : off+size])
off += size off += size
case "size-hex": case `dns:"size-hex"`:
// a "size" string, but it must be encoded in hex in the string // a "size" string, but it must be encoded in hex in the string
var size int var size int
switch val.Type().Name() { switch val.Type().Name() {
@ -918,7 +918,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
s = hex.EncodeToString(msg[off : off+size]) s = hex.EncodeToString(msg[off : off+size])
off += size off += size
case "txt": case `dns:"txt"`:
// 1 txt piece // 1 txt piece
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txt: Txt: