diff --git a/msg.go b/msg.go index e9f1148a..fa565f86 100644 --- a/msg.go +++ b/msg.go @@ -77,25 +77,25 @@ var Rr_str = map[uint16]string{ TypeTXT: "TXT", TypeSRV: "SRV", TypeNAPTR: "NAPTR", - TypeCERT: "CERT", - TypeDNAME: "DNAME", + TypeCERT: "CERT", + TypeDNAME: "DNAME", TypeA: "A", TypeAAAA: "AAAA", TypeLOC: "LOC", TypeOPT: "OPT", TypeDS: "DS", - TypeSSHFP: "SSHFP", + TypeSSHFP: "SSHFP", TypeRRSIG: "RRSIG", TypeNSEC: "NSEC", TypeDNSKEY: "DNSKEY", TypeNSEC3: "NSEC3", TypeNSEC3PARAM: "NSEC3PARAM", // DNSSEC's bitch - TypeSPF: "SPF", - TypeTKEY: "TKEY", // Meta RR - TypeTSIG: "TSIG", // Meta RR - TypeAXFR: "AXFR", // Meta RR - TypeIXFR: "IXFR", // Meta RR - TypeALL: "ANY", // Meta RR + TypeSPF: "SPF", + TypeTKEY: "TKEY", // Meta RR + TypeTSIG: "TSIG", // Meta RR + TypeAXFR: "AXFR", // Meta RR + TypeIXFR: "IXFR", // Meta RR + TypeALL: "ANY", // Meta RR } // Reverse of Rr_str (needed for parsing) @@ -277,27 +277,27 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o off += len(data) } case "A": - // It must be a slice of 4, even if it is 16, we encode - // only the first 4 + // It must be a slice of 4, even if it is 16, we encode + // only the first 4 if off+net.IPv4len > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow packing A") + fmt.Fprintf(os.Stderr, "dns: overflow packing A") return len(msg), false } - if fv.Len() == net.IPv6len { - msg[off] = byte(fv.Elem(12).(*reflect.UintValue).Get()) - msg[off+1] = byte(fv.Elem(13).(*reflect.UintValue).Get()) - msg[off+2] = byte(fv.Elem(14).(*reflect.UintValue).Get()) - msg[off+3] = byte(fv.Elem(15).(*reflect.UintValue).Get()) - } else { - msg[off] = byte(fv.Elem(0).(*reflect.UintValue).Get()) - msg[off+1] = byte(fv.Elem(1).(*reflect.UintValue).Get()) - msg[off+2] = byte(fv.Elem(2).(*reflect.UintValue).Get()) - msg[off+3] = byte(fv.Elem(3).(*reflect.UintValue).Get()) - } + if fv.Len() == net.IPv6len { + msg[off] = byte(fv.Elem(12).(*reflect.UintValue).Get()) + msg[off+1] = byte(fv.Elem(13).(*reflect.UintValue).Get()) + msg[off+2] = byte(fv.Elem(14).(*reflect.UintValue).Get()) + msg[off+3] = byte(fv.Elem(15).(*reflect.UintValue).Get()) + } else { + msg[off] = byte(fv.Elem(0).(*reflect.UintValue).Get()) + msg[off+1] = byte(fv.Elem(1).(*reflect.UintValue).Get()) + msg[off+2] = byte(fv.Elem(2).(*reflect.UintValue).Get()) + msg[off+3] = byte(fv.Elem(3).(*reflect.UintValue).Get()) + } off += net.IPv4len case "AAAA": if fv.Len() > net.IPv6len || off+fv.Len() > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow packing AAAA") + fmt.Fprintf(os.Stderr, "dns: overflow packing AAAA") return len(msg), false } for j := 0; j < net.IPv6len; j++ { @@ -422,7 +422,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, return len(msg), false case "A": if off+net.IPv4len > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking A") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking A") return len(msg), false } b := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]) @@ -430,7 +430,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, off += net.IPv4len case "AAAA": if off+net.IPv6len > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking AAAA") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking AAAA") return len(msg), false } p := make(net.IP, net.IPv6len) @@ -440,7 +440,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, off += net.IPv6len case "OPT": // EDNS if off+2 > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT") // No room for anything else break } @@ -448,7 +448,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, opt[0].Code, off = unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off) if off1+int(optlen) > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking OPT") return len(msg), false } opt[0].Data = hex.EncodeToString(msg[off1 : off1+int(optlen)]) @@ -465,7 +465,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, window := int(msg[off]) blocks := int(msg[off+1]) if off+blocks > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking NSEC") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking NSEC") return len(msg), false } off += 2 @@ -517,7 +517,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, goto BadType case reflect.Uint8: if off+1 > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8") return len(msg), false } i := uint8(msg[off]) @@ -526,14 +526,14 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, case reflect.Uint16: var i uint16 if off+2 > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16") return len(msg), false } i, off = unpackUint16(msg, off) fv.Set(uint64(i)) case reflect.Uint32: if off+4 > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32") return len(msg), false } i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) @@ -543,7 +543,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, // This is *only* used in TSIG where the last 48 bits are occupied // So for now, assume a uint48 (6 bytes) if off+6 > len(msg) { - fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64") + fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64") return len(msg), false } i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | @@ -564,10 +564,10 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, switch val.Type().Name() { case "RR_DS": consumed = 4 // KeyTag(2) + Algorithm(1) + DigestType(1) - case "RR_SSHFP": - consumed = 2 // Algorithm(1) + Type(1) - case "RR_NSEC3PARAM": - consumed = 5 // Hash(1) + Flags(1) + Iterations(2) + SaltLength(1) + case "RR_SSHFP": + consumed = 2 // Algorithm(1) + Type(1) + case "RR_NSEC3PARAM": + consumed = 5 // Hash(1) + Flags(1) + Iterations(2) + SaltLength(1) default: consumed = 0 // return len(msg), false? } @@ -595,12 +595,31 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, case "domain-name": s, off, ok = unpackDomainName(msg, off) if !ok { - fmt.Fprintf(os.Stderr, "dns: failure unpacking domain-name") + fmt.Fprintf(os.Stderr, "dns: failure unpacking domain-name") return len(msg), false } + case "fixed-size": + // We should already know how many bytes we can expect + // TODO(mg) pack variant + println(val.Type().Name()) + println(f.Name) // MAC Otherdata, then get back the + //f := val.Type().(*reflect.StructType).Field(i) + //FieldByName(MACSize), Othersize to get the stuff we need + var size int + // consumed += len(val.FieldByName("SignerName").(*reflect.StringValue).Get()) + 1 + switch val.Type().Name() { + case "RR_TSIG": + // tsig has MACSize + size = 16 // TODO(mg) other hashes + } + if off+size > len(msg) { + fmt.Fprintf(os.Stderr, "dns: failure unpacking fixed-size string") + return len(msg), false + } + s = string(msg[off : off+size]) case "": if off >= len(msg) || off+1+int(msg[off]) > len(msg) { - fmt.Fprintf(os.Stderr, "dns: failure unpacking string") + fmt.Fprintf(os.Stderr, "dns: failure unpacking string") return len(msg), false } n := int(msg[off]) @@ -702,11 +721,11 @@ func unpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) { // Reverse a map func reverse(m map[uint16]string) map[string]uint16 { - n := make(map[string]uint16) - for u, s := range m { - n[s] = u - } - return n + n := make(map[string]uint16) + for u, s := range m { + n[s] = u + } + return n } diff --git a/tsig.go b/tsig.go index 3211558f..01c67fbb 100644 --- a/tsig.go +++ b/tsig.go @@ -22,11 +22,11 @@ type RR_TSIG struct { TimeSigned uint64 Fudge uint16 MACSize uint16 - MAC string + MAC string "fixed-size" OrigId uint16 Error uint16 OtherLen uint16 - OtherData string + OtherData string "fixed-size" } func (rr *RR_TSIG) Header() *RR_Header { @@ -35,6 +35,7 @@ func (rr *RR_TSIG) Header() *RR_Header { func (rr *RR_TSIG) String() string { // It has no official presentation format + println("mac len: ", rr.MACSize) return rr.Hdr.String() + " " + rr.Algorithm + " " + tsigTimeToDate(rr.TimeSigned) +