Fix Tsig parsing

This commit is contained in:
Miek Gieben 2011-01-25 15:09:30 +01:00
parent 7e51041a24
commit 1ece21a05e
2 changed files with 66 additions and 46 deletions

107
msg.go
View File

@ -77,25 +77,25 @@ var Rr_str = map[uint16]string{
TypeTXT: "TXT", TypeTXT: "TXT",
TypeSRV: "SRV", TypeSRV: "SRV",
TypeNAPTR: "NAPTR", TypeNAPTR: "NAPTR",
TypeCERT: "CERT", TypeCERT: "CERT",
TypeDNAME: "DNAME", TypeDNAME: "DNAME",
TypeA: "A", TypeA: "A",
TypeAAAA: "AAAA", TypeAAAA: "AAAA",
TypeLOC: "LOC", TypeLOC: "LOC",
TypeOPT: "OPT", TypeOPT: "OPT",
TypeDS: "DS", TypeDS: "DS",
TypeSSHFP: "SSHFP", TypeSSHFP: "SSHFP",
TypeRRSIG: "RRSIG", TypeRRSIG: "RRSIG",
TypeNSEC: "NSEC", TypeNSEC: "NSEC",
TypeDNSKEY: "DNSKEY", TypeDNSKEY: "DNSKEY",
TypeNSEC3: "NSEC3", TypeNSEC3: "NSEC3",
TypeNSEC3PARAM: "NSEC3PARAM", // DNSSEC's bitch TypeNSEC3PARAM: "NSEC3PARAM", // DNSSEC's bitch
TypeSPF: "SPF", TypeSPF: "SPF",
TypeTKEY: "TKEY", // Meta RR TypeTKEY: "TKEY", // Meta RR
TypeTSIG: "TSIG", // Meta RR TypeTSIG: "TSIG", // Meta RR
TypeAXFR: "AXFR", // Meta RR TypeAXFR: "AXFR", // Meta RR
TypeIXFR: "IXFR", // Meta RR TypeIXFR: "IXFR", // Meta RR
TypeALL: "ANY", // Meta RR TypeALL: "ANY", // Meta RR
} }
// Reverse of Rr_str (needed for parsing) // Reverse of Rr_str (needed for parsing)
@ -277,27 +277,27 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
off += len(data) off += len(data)
} }
case "A": case "A":
// It must be a slice of 4, even if it is 16, we encode // It must be a slice of 4, even if it is 16, we encode
// only the first 4 // only the first 4
if off+net.IPv4len > len(msg) { if off+net.IPv4len > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow packing A") fmt.Fprintf(os.Stderr, "dns: overflow packing A")
return len(msg), false return len(msg), false
} }
if fv.Len() == net.IPv6len { if fv.Len() == net.IPv6len {
msg[off] = byte(fv.Elem(12).(*reflect.UintValue).Get()) msg[off] = byte(fv.Elem(12).(*reflect.UintValue).Get())
msg[off+1] = byte(fv.Elem(13).(*reflect.UintValue).Get()) msg[off+1] = byte(fv.Elem(13).(*reflect.UintValue).Get())
msg[off+2] = byte(fv.Elem(14).(*reflect.UintValue).Get()) msg[off+2] = byte(fv.Elem(14).(*reflect.UintValue).Get())
msg[off+3] = byte(fv.Elem(15).(*reflect.UintValue).Get()) msg[off+3] = byte(fv.Elem(15).(*reflect.UintValue).Get())
} else { } else {
msg[off] = byte(fv.Elem(0).(*reflect.UintValue).Get()) msg[off] = byte(fv.Elem(0).(*reflect.UintValue).Get())
msg[off+1] = byte(fv.Elem(1).(*reflect.UintValue).Get()) msg[off+1] = byte(fv.Elem(1).(*reflect.UintValue).Get())
msg[off+2] = byte(fv.Elem(2).(*reflect.UintValue).Get()) msg[off+2] = byte(fv.Elem(2).(*reflect.UintValue).Get())
msg[off+3] = byte(fv.Elem(3).(*reflect.UintValue).Get()) msg[off+3] = byte(fv.Elem(3).(*reflect.UintValue).Get())
} }
off += net.IPv4len off += net.IPv4len
case "AAAA": case "AAAA":
if fv.Len() > net.IPv6len || off+fv.Len() > len(msg) { if fv.Len() > net.IPv6len || off+fv.Len() > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow packing AAAA") fmt.Fprintf(os.Stderr, "dns: overflow packing AAAA")
return len(msg), false return len(msg), false
} }
for j := 0; j < net.IPv6len; j++ { for j := 0; j < net.IPv6len; j++ {
@ -422,7 +422,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
return len(msg), false return len(msg), false
case "A": case "A":
if off+net.IPv4len > len(msg) { if off+net.IPv4len > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking A") fmt.Fprintf(os.Stderr, "dns: overflow unpacking A")
return len(msg), false return len(msg), false
} }
b := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]) b := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])
@ -430,7 +430,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
off += net.IPv4len off += net.IPv4len
case "AAAA": case "AAAA":
if off+net.IPv6len > len(msg) { if off+net.IPv6len > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking AAAA") fmt.Fprintf(os.Stderr, "dns: overflow unpacking AAAA")
return len(msg), false return len(msg), false
} }
p := make(net.IP, net.IPv6len) p := make(net.IP, net.IPv6len)
@ -440,7 +440,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
off += net.IPv6len off += net.IPv6len
case "OPT": // EDNS case "OPT": // EDNS
if off+2 > len(msg) { if off+2 > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT") fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT")
// No room for anything else // No room for anything else
break break
} }
@ -448,7 +448,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
opt[0].Code, off = unpackUint16(msg, off) opt[0].Code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > len(msg) { if off1+int(optlen) > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT") fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT")
return len(msg), false return len(msg), false
} }
opt[0].Data = hex.EncodeToString(msg[off1 : off1+int(optlen)]) opt[0].Data = hex.EncodeToString(msg[off1 : off1+int(optlen)])
@ -465,7 +465,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
window := int(msg[off]) window := int(msg[off])
blocks := int(msg[off+1]) blocks := int(msg[off+1])
if off+blocks > len(msg) { if off+blocks > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking NSEC") fmt.Fprintf(os.Stderr, "dns: overflow unpacking NSEC")
return len(msg), false return len(msg), false
} }
off += 2 off += 2
@ -517,7 +517,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
goto BadType goto BadType
case reflect.Uint8: case reflect.Uint8:
if off+1 > len(msg) { if off+1 > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8") fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8")
return len(msg), false return len(msg), false
} }
i := uint8(msg[off]) i := uint8(msg[off])
@ -526,14 +526,14 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
case reflect.Uint16: case reflect.Uint16:
var i uint16 var i uint16
if off+2 > len(msg) { if off+2 > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16") fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16")
return len(msg), false return len(msg), false
} }
i, off = unpackUint16(msg, off) i, off = unpackUint16(msg, off)
fv.Set(uint64(i)) fv.Set(uint64(i))
case reflect.Uint32: case reflect.Uint32:
if off+4 > len(msg) { if off+4 > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32") fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32")
return len(msg), false return len(msg), false
} }
i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
@ -543,7 +543,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
// This is *only* used in TSIG where the last 48 bits are occupied // This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes) // So for now, assume a uint48 (6 bytes)
if off+6 > len(msg) { if off+6 > len(msg) {
fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64") fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64")
return len(msg), false return len(msg), false
} }
i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
@ -564,10 +564,10 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
switch val.Type().Name() { switch val.Type().Name() {
case "RR_DS": case "RR_DS":
consumed = 4 // KeyTag(2) + Algorithm(1) + DigestType(1) consumed = 4 // KeyTag(2) + Algorithm(1) + DigestType(1)
case "RR_SSHFP": case "RR_SSHFP":
consumed = 2 // Algorithm(1) + Type(1) consumed = 2 // Algorithm(1) + Type(1)
case "RR_NSEC3PARAM": case "RR_NSEC3PARAM":
consumed = 5 // Hash(1) + Flags(1) + Iterations(2) + SaltLength(1) consumed = 5 // Hash(1) + Flags(1) + Iterations(2) + SaltLength(1)
default: default:
consumed = 0 // return len(msg), false? consumed = 0 // return len(msg), false?
} }
@ -595,12 +595,31 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
case "domain-name": case "domain-name":
s, off, ok = unpackDomainName(msg, off) s, off, ok = unpackDomainName(msg, off)
if !ok { if !ok {
fmt.Fprintf(os.Stderr, "dns: failure unpacking domain-name") fmt.Fprintf(os.Stderr, "dns: failure unpacking domain-name")
return len(msg), false return len(msg), false
} }
case "fixed-size":
// We should already know how many bytes we can expect
// TODO(mg) pack variant
println(val.Type().Name())
println(f.Name) // MAC Otherdata, then get back the
//f := val.Type().(*reflect.StructType).Field(i)
//FieldByName(MACSize), Othersize to get the stuff we need
var size int
// consumed += len(val.FieldByName("SignerName").(*reflect.StringValue).Get()) + 1
switch val.Type().Name() {
case "RR_TSIG":
// tsig has MACSize
size = 16 // TODO(mg) other hashes
}
if off+size > len(msg) {
fmt.Fprintf(os.Stderr, "dns: failure unpacking fixed-size string")
return len(msg), false
}
s = string(msg[off : off+size])
case "": case "":
if off >= len(msg) || off+1+int(msg[off]) > len(msg) { if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
fmt.Fprintf(os.Stderr, "dns: failure unpacking string") fmt.Fprintf(os.Stderr, "dns: failure unpacking string")
return len(msg), false return len(msg), false
} }
n := int(msg[off]) n := int(msg[off])
@ -702,11 +721,11 @@ func unpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
// Reverse a map // Reverse a map
func reverse(m map[uint16]string) map[string]uint16 { func reverse(m map[uint16]string) map[string]uint16 {
n := make(map[string]uint16) n := make(map[string]uint16)
for u, s := range m { for u, s := range m {
n[s] = u n[s] = u
} }
return n return n
} }

View File

@ -22,11 +22,11 @@ type RR_TSIG struct {
TimeSigned uint64 TimeSigned uint64
Fudge uint16 Fudge uint16
MACSize uint16 MACSize uint16
MAC string MAC string "fixed-size"
OrigId uint16 OrigId uint16
Error uint16 Error uint16
OtherLen uint16 OtherLen uint16
OtherData string OtherData string "fixed-size"
} }
func (rr *RR_TSIG) Header() *RR_Header { func (rr *RR_TSIG) Header() *RR_Header {
@ -35,6 +35,7 @@ func (rr *RR_TSIG) Header() *RR_Header {
func (rr *RR_TSIG) String() string { func (rr *RR_TSIG) String() string {
// It has no official presentation format // It has no official presentation format
println("mac len: ", rr.MACSize)
return rr.Hdr.String() + return rr.Hdr.String() +
" " + rr.Algorithm + " " + rr.Algorithm +
" " + tsigTimeToDate(rr.TimeSigned) + " " + tsigTimeToDate(rr.TimeSigned) +