diff --git a/client.go b/client.go index fc4a613d..1302e4e0 100644 --- a/client.go +++ b/client.go @@ -284,9 +284,11 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { p = p[:n] if hdr != nil { - if _, err = UnpackStruct(hdr, p, 0); err != nil { + dh, _, err := unpackMsgHdr(p, 0) + if err != nil { return nil, err } + *hdr = dh } return p, err } diff --git a/dns.go b/dns.go index c7d02850..a8131714 100644 --- a/dns.go +++ b/dns.go @@ -84,19 +84,22 @@ func (h *RR_Header) len() int { return l } -// ToRFC3597 converts a known RR to the unknown RR representation -// from RFC 3597. +// ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597. func (rr *RFC3597) ToRFC3597(r RR) error { buf := make([]byte, r.len()*2) - off, err := PackStruct(r, buf, 0) + off, err := PackRR(r, buf, 0, nil, false) if err != nil { return err } buf = buf[:off] - rawSetRdlength(buf, 0, off) - _, err = UnpackStruct(rr, buf, 0) + if int(r.Header().Rdlength) > off { + return ErrBuf + } + + rfc3597, _, err := unpackRFC3597(*r.Header(), buf, off-int(r.Header().Rdlength)) if err != nil { return err } + *rr = *rfc3597.(*RFC3597) return nil } diff --git a/dns_test.go b/dns_test.go index 438e71ca..ad68533f 100644 --- a/dns_test.go +++ b/dns_test.go @@ -315,7 +315,13 @@ func TestToRFC3597(t *testing.T) { x := new(RFC3597) x.ToRFC3597(a) if x.String() != `miek.nl. 3600 CLASS1 TYPE1 \# 4 0a000101` { - t.Error("string mismatch") + t.Errorf("string mismatch, got: %s", x) + } + + b, _ := NewRR("miek.nl. IN MX 10 mx.miek.nl.") + x.ToRFC3597(b) + if x.String() != `miek.nl. 3600 CLASS1 TYPE15 \# 14 000a026d78046d69656b026e6c00` { + t.Errorf("string mismatch, got: %s", x) } } diff --git a/dnssec.go b/dnssec.go index 09c16f09..f5f3fbdd 100644 --- a/dnssec.go +++ b/dnssec.go @@ -104,9 +104,7 @@ const ( ZONE = 1 << 8 ) -// The RRSIG needs to be converted to wireformat with some of -// the rdata (the signature) missing. Use this struct to ease -// the conversion (and re-use the pack/unpack functions). +// The RRSIG needs to be converted to wireformat with some of the rdata (the signature) missing. type rrsigWireFmt struct { TypeCovered uint16 Algorithm uint8 @@ -155,7 +153,7 @@ func (k *DNSKEY) KeyTag() uint16 { keywire.Algorithm = k.Algorithm keywire.PublicKey = k.PublicKey wire := make([]byte, DefaultMsgSize) - n, err := PackStruct(keywire, wire, 0) + n, err := packKeyWire(keywire, wire) if err != nil { return 0 } @@ -193,7 +191,7 @@ func (k *DNSKEY) ToDS(h uint8) *DS { keywire.Algorithm = k.Algorithm keywire.PublicKey = k.PublicKey wire := make([]byte, DefaultMsgSize) - n, err := PackStruct(keywire, wire, 0) + n, err := packKeyWire(keywire, wire) if err != nil { return nil } @@ -290,7 +288,7 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error { // Create the desired binary blob signdata := make([]byte, DefaultMsgSize) - n, err := PackStruct(sigwire, signdata, 0) + n, err := packSigWire(sigwire, signdata) if err != nil { return err } @@ -408,7 +406,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error { sigwire.SignerName = strings.ToLower(rr.SignerName) // Create the desired binary blob signeddata := make([]byte, DefaultMsgSize) - n, err := PackStruct(sigwire, signeddata, 0) + n, err := packSigWire(sigwire, signeddata) if err != nil { return err } @@ -663,3 +661,61 @@ func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) { } return buf, nil } + +func packSigWire(sw *rrsigWireFmt, msg []byte) (int, error) { + // copied from zmsg.go RRSIG packing + off, err := packUint16(sw.TypeCovered, msg, 0) + if err != nil { + return off, err + } + off, err = packUint8(sw.Algorithm, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(sw.Labels, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(sw.OrigTtl, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(sw.Expiration, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(sw.Inception, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(sw.KeyTag, msg, off) + if err != nil { + return off, err + } + off, err = PackDomainName(sw.SignerName, msg, off, nil, false) + if err != nil { + return off, err + } + return off, nil +} + +func packKeyWire(dw *dnskeyWireFmt, msg []byte) (int, error) { + // copied from zmsg.go DNSKEY packing + off, err := packUint16(dw.Flags, msg, 0) + if err != nil { + return off, err + } + off, err = packUint8(dw.Protocol, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(dw.Algorithm, msg, off) + if err != nil { + return off, err + } + off, err = packStringBase64(dw.PublicKey, msg, off) + if err != nil { + return off, err + } + return off, nil +} diff --git a/msg.go b/msg.go index e4d390ec..5db1000b 100644 --- a/msg.go +++ b/msg.go @@ -13,11 +13,8 @@ package dns import ( crand "crypto/rand" "encoding/binary" - "encoding/hex" "math/big" "math/rand" - "net" - "reflect" "strconv" ) @@ -549,591 +546,6 @@ func unpackTxtString(msg []byte, offset int) (string, int, error) { return string(s), offset, nil } -// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string, -// slices and other (often anonymous) structs. -func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { - var txtTmp []byte - lenmsg := len(msg) - numfield := val.NumField() - for i := 0; i < numfield; i++ { - typefield := val.Type().Field(i) - if typefield.Tag == `dns:"-"` { - continue - } - switch fv := val.Field(i); fv.Kind() { - default: - return lenmsg, &Error{err: "bad kind packing"} - case reflect.Interface: - // PrivateRR is the only RR implementation that has interface field. - // therefore it's expected that this interface would be PrivateRdata - switch data := fv.Interface().(type) { - case PrivateRdata: - n, err := data.Pack(msg[off:]) - if err != nil { - return lenmsg, err - } - off += n - default: - return lenmsg, &Error{err: "bad kind interface packing"} - } - case reflect.Slice: - switch typefield.Tag { - default: - return lenmsg, &Error{"bad tag packing slice: " + typefield.Tag.Get("dns")} - case `dns:"domain-name"`: - for j := 0; j < val.Field(i).Len(); j++ { - element := val.Field(i).Index(j).String() - off, err = PackDomainName(element, msg, off, compression, false && compress) - if err != nil { - return lenmsg, err - } - } - case `dns:"a"`: - if val.Type().String() == "dns.IPSECKEY" { - // Field(2) is GatewayType, must be 1 - if val.Field(2).Uint() != 1 { - continue - } - } - // It must be a slice of 4, even if it is 16, we encode - // only the first 4 - if off+net.IPv4len > lenmsg { - return lenmsg, &Error{err: "overflow packing a"} - } - switch fv.Len() { - case net.IPv6len: - msg[off] = byte(fv.Index(12).Uint()) - msg[off+1] = byte(fv.Index(13).Uint()) - msg[off+2] = byte(fv.Index(14).Uint()) - msg[off+3] = byte(fv.Index(15).Uint()) - off += net.IPv4len - case net.IPv4len: - msg[off] = byte(fv.Index(0).Uint()) - msg[off+1] = byte(fv.Index(1).Uint()) - msg[off+2] = byte(fv.Index(2).Uint()) - msg[off+3] = byte(fv.Index(3).Uint()) - off += net.IPv4len - case 0: - // Allowed, for dynamic updates - default: - return lenmsg, &Error{err: "overflow packing a"} - } - case `dns:"aaaa"`: - if val.Type().String() == "dns.IPSECKEY" { - // Field(2) is GatewayType, must be 2 - if val.Field(2).Uint() != 2 { - continue - } - } - if fv.Len() == 0 { - break - } - if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg { - return lenmsg, &Error{err: "overflow packing aaaa"} - } - for j := 0; j < net.IPv6len; j++ { - msg[off] = byte(fv.Index(j).Uint()) - off++ - } - case `dns:"nsec"`: // NSEC/NSEC3 - // This is the uint16 type bitmap - if val.Field(i).Len() == 0 { - // Do absolutely nothing - break - } - var lastwindow, lastlength uint16 - for j := 0; j < val.Field(i).Len(); j++ { - t := uint16(fv.Index(j).Uint()) - window := t / 256 - length := (t-window*256)/8 + 1 - if window > lastwindow && lastlength != 0 { - // New window, jump to the new offset - off += int(lastlength) + 2 - lastlength = 0 - } - if window < lastwindow || length < lastlength { - return len(msg), &Error{err: "nsec bits out of order"} - } - if off+2+int(length) > len(msg) { - return len(msg), &Error{err: "overflow packing nsec"} - } - // Setting the window # - msg[off] = byte(window) - // Setting the octets length - msg[off+1] = byte(length) - // Setting the bit value for the type in the right octet - msg[off+1+int(length)] |= byte(1 << (7 - (t % 8))) - lastwindow, lastlength = window, length - } - off += int(lastlength) + 2 - } - case reflect.Struct: - off, err = packStructValue(fv, msg, off, compression, compress) - if err != nil { - return lenmsg, err - } - case reflect.Uint8: - if off+1 > lenmsg { - return lenmsg, &Error{err: "overflow packing uint8"} - } - msg[off] = byte(fv.Uint()) - off++ - case reflect.Uint16: - if off+2 > lenmsg { - return lenmsg, &Error{err: "overflow packing uint16"} - } - i := fv.Uint() - msg[off] = byte(i >> 8) - msg[off+1] = byte(i) - off += 2 - case reflect.Uint32: - if off+4 > lenmsg { - return lenmsg, &Error{err: "overflow packing uint32"} - } - i := fv.Uint() - msg[off] = byte(i >> 24) - msg[off+1] = byte(i >> 16) - msg[off+2] = byte(i >> 8) - msg[off+3] = byte(i) - off += 4 - case reflect.Uint64: - switch typefield.Tag { - default: - if off+8 > lenmsg { - return lenmsg, &Error{err: "overflow packing uint64"} - } - i := fv.Uint() - msg[off] = byte(i >> 56) - msg[off+1] = byte(i >> 48) - msg[off+2] = byte(i >> 40) - msg[off+3] = byte(i >> 32) - msg[off+4] = byte(i >> 24) - msg[off+5] = byte(i >> 16) - msg[off+6] = byte(i >> 8) - msg[off+7] = byte(i) - off += 8 - case `dns:"uint48"`: - // Used in TSIG, where it stops at 48 bits, so we discard the upper 16 - if off+6 > lenmsg { - return lenmsg, &Error{err: "overflow packing uint64 as uint48"} - } - i := fv.Uint() - msg[off] = byte(i >> 40) - msg[off+1] = byte(i >> 32) - msg[off+2] = byte(i >> 24) - msg[off+3] = byte(i >> 16) - msg[off+4] = byte(i >> 8) - msg[off+5] = byte(i) - off += 6 - } - case reflect.String: - // There are multiple string encodings. - // The tag distinguishes ordinary strings from domain names. - s := fv.String() - switch typefield.Tag { - default: - return lenmsg, &Error{"bad tag packing string: " + typefield.Tag.Get("dns")} - case `dns:"base64"`: - b64, err := fromBase64([]byte(s)) - if err != nil { - return lenmsg, err - } - if off+len(b64) > lenmsg { - return lenmsg, &Error{err: "overflow packing base64"} - } - copy(msg[off:off+len(b64)], b64) - off += len(b64) - case `dns:"domain-name"`: - if val.Type().String() == "dns.IPSECKEY" { - // Field(2) is GatewayType, 1 and 2 or used for addresses - x := val.Field(2).Uint() - if x == 1 || x == 2 { - continue - } - } - if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil { - return lenmsg, err - } - case `dns:"cdomain-name"`: - if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil { - return lenmsg, err - } - case `dns:"size-base32"`: - // This is purely for NSEC3 atm, the previous byte must - // holds the length of the encoded string. As NSEC3 - // is only defined to SHA1, the hashlength is 20 (160 bits) - msg[off-1] = 20 - fallthrough - case `dns:"base32"`: - b32, err := fromBase32([]byte(s)) - if err != nil { - return lenmsg, err - } - if off+len(b32) > lenmsg { - return lenmsg, &Error{err: "overflow packing base32"} - } - copy(msg[off:off+len(b32)], b32) - off += len(b32) - case `dns:"size-hex"`: - fallthrough - case `dns:"hex"`: - // There is no length encoded here - h, err := hex.DecodeString(s) - if err != nil { - return lenmsg, err - } - if off+hex.DecodedLen(len(s)) > lenmsg { - return lenmsg, &Error{err: "overflow packing hex"} - } - copy(msg[off:off+hex.DecodedLen(len(s))], h) - off += hex.DecodedLen(len(s)) - case `dns:"size"`: - // TODO(miek): WTF? size? - // the size is already encoded in the RR, we can safely use the - // length of string. String is RAW (not encoded in hex, nor base64) - copy(msg[off:off+len(s)], s) - off += len(s) - case `dns:"octet"`: - bytesTmp := make([]byte, 256) - off, err = packOctetString(fv.String(), msg, off, bytesTmp) - if err != nil { - return lenmsg, err - } - case `dns:"txt"`: - fallthrough - case "": - if txtTmp == nil { - txtTmp = make([]byte, 256*4+1) - } - off, err = packTxtString(fv.String(), msg, off, txtTmp) - if err != nil { - return lenmsg, err - } - } - } - } - return off, nil -} - -func structValue(any interface{}) reflect.Value { - return reflect.ValueOf(any).Elem() -} - -// PackStruct packs any structure to wire format. -func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) { - off, err = packStructValue(structValue(any), msg, off, nil, false) - return off, err -} - -func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { - off, err = packStructValue(structValue(any), msg, off, compression, compress) - return off, err -} - -// Unpack a reflect.StructValue from msg. -// Same restrictions as packStructValue. -func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { - lenmsg := len(msg) - for i := 0; i < val.NumField(); i++ { - if off > lenmsg { - return lenmsg, &Error{"bad offset unpacking"} - } - switch fv := val.Field(i); fv.Kind() { - default: - return lenmsg, &Error{err: "bad kind unpacking"} - case reflect.Interface: - // PrivateRR is the only RR implementation that has interface field. - // therefore it's expected that this interface would be PrivateRdata - switch data := fv.Interface().(type) { - case PrivateRdata: - n, err := data.Unpack(msg[off:]) - if err != nil { - return lenmsg, err - } - off += n - default: - return lenmsg, &Error{err: "bad kind interface unpacking"} - } - case reflect.Slice: - switch val.Type().Field(i).Tag { - default: - return lenmsg, &Error{"bad tag unpacking slice: " + val.Type().Field(i).Tag.Get("dns")} - case `dns:"domain-name"`: - // HIP record slice of name (or none) - var servers []string - var s string - for off < lenmsg { - s, off, err = UnpackDomainName(msg, off) - if err != nil { - return lenmsg, err - } - servers = append(servers, s) - } - fv.Set(reflect.ValueOf(servers)) - case `dns:"a"`: - if val.Type().String() == "dns.IPSECKEY" { - // Field(2) is GatewayType, must be 1 - if val.Field(2).Uint() != 1 { - continue - } - } - if off == lenmsg { - break // dyn. update - } - if off+net.IPv4len > lenmsg { - return lenmsg, &Error{err: "overflow unpacking a"} - } - fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) - off += net.IPv4len - case `dns:"aaaa"`: - if val.Type().String() == "dns.IPSECKEY" { - // Field(2) is GatewayType, must be 2 - if val.Field(2).Uint() != 2 { - continue - } - } - if off == lenmsg { - break - } - if off+net.IPv6len > lenmsg { - return lenmsg, &Error{err: "overflow unpacking aaaa"} - } - fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], - msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10], - msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]})) - off += net.IPv6len - case `dns:"nsec"`: // NSEC/NSEC3 - if off == len(msg) { - break - } - // Rest of the record is the type bitmap - var nsec []uint16 - length := 0 - window := 0 - lastwindow := -1 - for off < len(msg) { - if off+2 > len(msg) { - return len(msg), &Error{err: "overflow unpacking nsecx"} - } - window = int(msg[off]) - length = int(msg[off+1]) - off += 2 - if window <= lastwindow { - // RFC 4034: Blocks are present in the NSEC RR RDATA in - // increasing numerical order. - return len(msg), &Error{err: "out of order NSEC block"} - } - if length == 0 { - // RFC 4034: Blocks with no types present MUST NOT be included. - return len(msg), &Error{err: "empty NSEC block"} - } - if length > 32 { - return len(msg), &Error{err: "NSEC block too long"} - } - if off+length > len(msg) { - return len(msg), &Error{err: "overflowing NSEC block"} - } - - // Walk the bytes in the window and extract the type bits - for j := 0; j < length; j++ { - b := msg[off+j] - // Check the bits one by one, and set the type - if b&0x80 == 0x80 { - nsec = append(nsec, uint16(window*256+j*8+0)) - } - if b&0x40 == 0x40 { - nsec = append(nsec, uint16(window*256+j*8+1)) - } - if b&0x20 == 0x20 { - nsec = append(nsec, uint16(window*256+j*8+2)) - } - if b&0x10 == 0x10 { - nsec = append(nsec, uint16(window*256+j*8+3)) - } - if b&0x8 == 0x8 { - nsec = append(nsec, uint16(window*256+j*8+4)) - } - if b&0x4 == 0x4 { - nsec = append(nsec, uint16(window*256+j*8+5)) - } - if b&0x2 == 0x2 { - nsec = append(nsec, uint16(window*256+j*8+6)) - } - if b&0x1 == 0x1 { - nsec = append(nsec, uint16(window*256+j*8+7)) - } - } - off += length - lastwindow = window - } - fv.Set(reflect.ValueOf(nsec)) - } - case reflect.Struct: - off, err = unpackStructValue(fv, msg, off) - if err != nil { - return lenmsg, err - } - if val.Type().Field(i).Name == "Hdr" { - lenrd := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) - if lenrd > lenmsg { - return lenmsg, &Error{err: "overflowing header size"} - } - msg = msg[:lenrd] - lenmsg = len(msg) - } - case reflect.Uint8: - if off == lenmsg { - break - } - if off+1 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking uint8"} - } - fv.SetUint(uint64(uint8(msg[off]))) - off++ - case reflect.Uint16: - if off == lenmsg { - break - } - var i uint16 - if off+2 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking uint16"} - } - i = binary.BigEndian.Uint16(msg[off:]) - off += 2 - fv.SetUint(uint64(i)) - case reflect.Uint32: - if off == lenmsg { - break - } - if off+4 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking uint32"} - } - fv.SetUint(uint64(binary.BigEndian.Uint32(msg[off:]))) - off += 4 - case reflect.Uint64: - if off == lenmsg { - break - } - switch val.Type().Field(i).Tag { - default: - if off+8 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking uint64"} - } - fv.SetUint(binary.BigEndian.Uint64(msg[off:])) - off += 8 - case `dns:"uint48"`: - // Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes) - if off+6 > lenmsg { - return lenmsg, &Error{err: "overflow unpacking uint64 as uint48"} - } - fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | - uint64(msg[off+4])<<8 | uint64(msg[off+5]))) - off += 6 - } - case reflect.String: - var s string - if off == lenmsg { - break - } - switch val.Type().Field(i).Tag { - default: - return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} - case `dns:"octet"`: - s = string(msg[off:]) - off = lenmsg - case `dns:"hex"`: - hexend := lenmsg - if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { - hexend = off + int(val.FieldByName("HitLength").Uint()) - } - if hexend > lenmsg { - return lenmsg, &Error{err: "overflow unpacking HIP hex"} - } - s = hex.EncodeToString(msg[off:hexend]) - off = hexend - case `dns:"base64"`: - // Rest of the RR is base64 encoded value - b64end := lenmsg - if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { - b64end = off + int(val.FieldByName("PublicKeyLength").Uint()) - } - if b64end > lenmsg { - return lenmsg, &Error{err: "overflow unpacking HIP base64"} - } - s = toBase64(msg[off:b64end]) - off = b64end - case `dns:"cdomain-name"`: - fallthrough - case `dns:"domain-name"`: - if val.Type().String() == "dns.IPSECKEY" { - // Field(2) is GatewayType, 1 and 2 or used for addresses - x := val.Field(2).Uint() - if x == 1 || x == 2 { - continue - } - } - if off == lenmsg && int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) == 0 { - // zero rdata is ok for dyn updates, but only if rdlength is 0 - break - } - s, off, err = UnpackDomainName(msg, off) - if err != nil { - return lenmsg, err - } - case `dns:"size-base32"`: - var size int - switch val.Type().Name() { - case "NSEC3": - switch val.Type().Field(i).Name { - case "NextDomain": - name := val.FieldByName("HashLength") - size = int(name.Uint()) - } - } - if off+size > lenmsg { - return lenmsg, &Error{err: "overflow unpacking base32"} - } - s = toBase32(msg[off : off+size]) - off += size - case `dns:"size-hex"`: - // a "size" string, but it must be encoded in hex in the string - var size int - switch val.Type().Name() { - case "NSEC3": - switch val.Type().Field(i).Name { - case "Salt": - name := val.FieldByName("SaltLength") - size = int(name.Uint()) - case "NextDomain": - name := val.FieldByName("HashLength") - size = int(name.Uint()) - } - case "TSIG": - switch val.Type().Field(i).Name { - case "MAC": - name := val.FieldByName("MACSize") - size = int(name.Uint()) - case "OtherData": - name := val.FieldByName("OtherLen") - size = int(name.Uint()) - } - } - if off+size > lenmsg { - return lenmsg, &Error{err: "overflow unpacking hex"} - } - s = hex.EncodeToString(msg[off : off+size]) - off += size - case `dns:"txt"`: - fallthrough - case "": - s, off, err = unpackTxtString(msg, off) - } - fv.SetString(s) - } - } - return off, nil -} - // Helpers for dealing with escaped bytes func isDigit(b byte) bool { return b >= '0' && b <= '9' } @@ -1141,12 +553,6 @@ func dddToByte(s []byte) byte { return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) } -// UnpackStruct unpacks a binary message from offset off to the interface -// value given. -func UnpackStruct(any interface{}, msg []byte, off int) (int, error) { - return unpackStructValue(structValue(any), msg, off) -} - // Helper function for packing and unpacking func intToBytes(i *big.Int, length int) []byte { buf := i.Bytes() diff --git a/msg_generate.go b/msg_generate.go index 3179ad94..17eb60c5 100644 --- a/msg_generate.go +++ b/msg_generate.go @@ -61,7 +61,7 @@ func main() { if st, _ := getTypeStruct(o.Type(), scope); st == nil { continue } - if name == "PrivateRR" || name == "WKS" { + if name == "PrivateRR" { continue } diff --git a/nsecx.go b/nsecx.go index d2392c6e..a0388a21 100644 --- a/nsecx.go +++ b/nsecx.go @@ -17,7 +17,7 @@ func HashName(label string, ha uint8, iter uint16, salt string) string { saltwire := new(saltWireFmt) saltwire.Salt = salt wire := make([]byte, DefaultMsgSize) - n, err := PackStruct(saltwire, wire, 0) + n, err := packSaltWire(saltwire, wire) if err != nil { return "" } @@ -110,3 +110,11 @@ func (rr *NSEC3) Match(name string) bool { } return false } + +func packSaltWire(sw *saltWireFmt, msg []byte) (int, error) { + off, err := packStringHex(sw.Salt, msg, 0) + if err != nil { + return off, err + } + return off, nil +} diff --git a/tsig.go b/tsig.go index 4698e449..8a49bcf5 100644 --- a/tsig.go +++ b/tsig.go @@ -72,8 +72,7 @@ type tsigWireFmt struct { OtherData string `dns:"size-hex:OtherLen"` } -// If we have the MAC use this type to convert it to wiredata. -// Section 3.4.3. Request MAC +// If we have the MAC use this type to convert it to wiredata. Section 3.4.3. Request MAC type macWireFmt struct { MACSize uint16 MAC string `dns:"size-hex:MACSize"` @@ -213,7 +212,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b m.MACSize = uint16(len(requestMAC) / 2) m.MAC = requestMAC buf = make([]byte, len(requestMAC)) // long enough - n, _ := PackStruct(m, buf, 0) + n, _ := packMacWire(m, buf) buf = buf[:n] } @@ -222,7 +221,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b tsig := new(timerWireFmt) tsig.TimeSigned = rr.TimeSigned tsig.Fudge = rr.Fudge - n, _ := PackStruct(tsig, tsigvar, 0) + n, _ := packTimerWire(tsig, tsigvar) tsigvar = tsigvar[:n] } else { tsig := new(tsigWireFmt) @@ -235,7 +234,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b tsig.Error = rr.Error tsig.OtherLen = rr.OtherLen tsig.OtherData = rr.OtherData - n, _ := PackStruct(tsig, tsigvar, 0) + n, _ := packTsigWire(tsig, tsigvar) tsigvar = tsigvar[:n] } @@ -250,57 +249,51 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b // Strip the TSIG from the raw message. func stripTsig(msg []byte) ([]byte, *TSIG, error) { - // Copied from msg.go's Unpack() - // Header. - var dh Header - var err error - dns := new(Msg) - rr := new(TSIG) - off := 0 - tsigoff := 0 - if off, err = UnpackStruct(&dh, msg, off); err != nil { + // Copied from msg.go's Unpack() Header, but modified. + var ( + dh Header + err error + ) + off, tsigoff := 0, 0 + + if dh, off, err = unpackMsgHdr(msg, off); err != nil { return nil, nil, err } if dh.Arcount == 0 { return nil, nil, ErrNoSig } + // Rcode, see msg.go Unpack() if int(dh.Bits&0xF) == RcodeNotAuth { return nil, nil, ErrAuth } - // Arrays. - dns.Question = make([]Question, dh.Qdcount) - dns.Answer = make([]RR, dh.Ancount) - dns.Ns = make([]RR, dh.Nscount) - dns.Extra = make([]RR, dh.Arcount) + for i := 0; i < int(dh.Qdcount); i++ { + _, off, err = unpackQuestion(msg, off) + if err != nil { + return nil, nil, err + } + } - for i := 0; i < len(dns.Question); i++ { - off, err = UnpackStruct(&dns.Question[i], msg, off) - if err != nil { - return nil, nil, err - } + _, off, err = unpackRRslice(int(dh.Ancount), msg, off) + if err != nil { + return nil, nil, err } - for i := 0; i < len(dns.Answer); i++ { - dns.Answer[i], off, err = UnpackRR(msg, off) - if err != nil { - return nil, nil, err - } + _, off, err = unpackRRslice(int(dh.Nscount), msg, off) + if err != nil { + return nil, nil, err } - for i := 0; i < len(dns.Ns); i++ { - dns.Ns[i], off, err = UnpackRR(msg, off) - if err != nil { - return nil, nil, err - } - } - for i := 0; i < len(dns.Extra); i++ { + + rr := new(TSIG) + var extra RR + for i := 0; i < int(dh.Arcount); i++ { tsigoff = off - dns.Extra[i], off, err = UnpackRR(msg, off) + extra, off, err = UnpackRR(msg, off) if err != nil { return nil, nil, err } - if dns.Extra[i].Header().Rrtype == TypeTSIG { - rr = dns.Extra[i].(*TSIG) + if extra.Header().Rrtype == TypeTSIG { + rr = extra.(*TSIG) // Adjust Arcount. arcount := binary.BigEndian.Uint16(msg[10:]) binary.BigEndian.PutUint16(msg[10:], arcount-1) @@ -319,3 +312,71 @@ func tsigTimeToString(t uint64) string { ti := time.Unix(int64(t), 0).UTC() return ti.Format("20060102150405") } + +func packTsigWire(tw *tsigWireFmt, msg []byte) (int, error) { + // copied from zmsg.go TSIG packing + // RR_Header + off, err := PackDomainName(tw.Name, msg, 0, nil, false) + if err != nil { + return off, err + } + off, err = packUint16(tw.Class, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(tw.Ttl, msg, off) + if err != nil { + return off, err + } + + off, err = PackDomainName(tw.Algorithm, msg, off, nil, false) + if err != nil { + return off, err + } + off, err = packUint48(tw.TimeSigned, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(tw.Fudge, msg, off) + if err != nil { + return off, err + } + + off, err = packUint16(tw.Error, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(tw.OtherLen, msg, off) + if err != nil { + return off, err + } + off, err = packStringHex(tw.OtherData, msg, off) + if err != nil { + return off, err + } + return off, nil +} + +func packMacWire(mw *macWireFmt, msg []byte) (int, error) { + off, err := packUint16(mw.MACSize, msg, 0) + if err != nil { + return off, err + } + off, err = packStringHex(mw.MAC, msg, off) + if err != nil { + return off, err + } + return off, nil +} + +func packTimerWire(tw *timerWireFmt, msg []byte) (int, error) { + off, err := packUint48(tw.TimeSigned, msg, 0) + if err != nil { + return off, err + } + off, err = packUint16(tw.Fudge, msg, off) + if err != nil { + return off, err + } + return off, nil +}