diff --git a/dnssec.go b/dnssec.go index 8539aae6..e95f0377 100644 --- a/dnssec.go +++ b/dnssec.go @@ -65,6 +65,9 @@ var AlgorithmToString = map[uint8]string{ } // AlgorithmToHash is a map of algorithm crypto hash IDs to crypto.Hash's. +// For newer algorithm that do their own hashing (i.e. ED25519) the returned value +// is 0, implying no (external) hashing should occur. The non-exported identityHash is then +// used. var AlgorithmToHash = map[uint8]crypto.Hash{ RSAMD5: crypto.MD5, // Deprecated in RFC 6725 DSA: crypto.SHA1, @@ -74,7 +77,7 @@ var AlgorithmToHash = map[uint8]crypto.Hash{ ECDSAP256SHA256: crypto.SHA256, ECDSAP384SHA384: crypto.SHA384, RSASHA512: crypto.SHA512, - ED25519: crypto.Hash(0), + ED25519: 0, } // DNSSEC hashing algorithm codes. @@ -296,35 +299,20 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error { return err } - hash, ok := AlgorithmToHash[rr.Algorithm] - if !ok { - return ErrAlg + h, cryptohash, err := hashFromAlgorithm(rr.Algorithm) + if err != nil { + return err } switch rr.Algorithm { - case ED25519: - // ed25519 signs the raw message and performs hashing internally. - // All other supported signature schemes operate over the pre-hashed - // message, and thus ed25519 must be handled separately here. - // - // The raw message is passed directly into sign and crypto.Hash(0) is - // used to signal to the crypto.Signer that the data has not been hashed. - signature, err := sign(k, append(signdata, wire...), crypto.Hash(0), rr.Algorithm) - if err != nil { - return err - } - - rr.Signature = toBase64(signature) - return nil case RSAMD5, DSA, DSANSEC3SHA1: // See RFC 6944. return ErrAlg default: - h := hash.New() h.Write(signdata) h.Write(wire) - signature, err := sign(k, h.Sum(nil), hash, rr.Algorithm) + signature, err := sign(k, h.Sum(nil), cryptohash, rr.Algorithm) if err != nil { return err } @@ -341,7 +329,7 @@ func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte, } switch alg { - case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512: + case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512, ED25519: return signature, nil case ECDSAP256SHA256, ECDSAP384SHA384: ecdsaSignature := &struct { @@ -362,8 +350,6 @@ func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte, signature := intToBytes(ecdsaSignature.R, intlen) signature = append(signature, intToBytes(ecdsaSignature.S, intlen)...) return signature, nil - case ED25519: - return signature, nil default: return nil, ErrAlg } @@ -437,9 +423,9 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error { // remove the domain name and assume its ours? } - hash, ok := AlgorithmToHash[rr.Algorithm] - if !ok { - return ErrAlg + h, cryptohash, err := hashFromAlgorithm(rr.Algorithm) + if err != nil { + return err } switch rr.Algorithm { @@ -450,10 +436,9 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error { return ErrKey } - h := hash.New() h.Write(signeddata) h.Write(wire) - return rsa.VerifyPKCS1v15(pubkey, hash, h.Sum(nil), sigbuf) + return rsa.VerifyPKCS1v15(pubkey, cryptohash, h.Sum(nil), sigbuf) case ECDSAP256SHA256, ECDSAP384SHA384: pubkey := k.publicKeyECDSA() @@ -465,7 +450,6 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error { r := new(big.Int).SetBytes(sigbuf[:len(sigbuf)/2]) s := new(big.Int).SetBytes(sigbuf[len(sigbuf)/2:]) - h := hash.New() h.Write(signeddata) h.Write(wire) if ecdsa.Verify(pubkey, h.Sum(nil), r, s) { diff --git a/hash.go b/hash.go new file mode 100644 index 00000000..7d4183e0 --- /dev/null +++ b/hash.go @@ -0,0 +1,31 @@ +package dns + +import ( + "bytes" + "crypto" + "hash" +) + +// identityHash will not hash, it only buffers the data written into it and returns it as-is. +type identityHash struct { + b *bytes.Buffer +} + +// Implement the hash.Hash interface. + +func (i identityHash) Write(b []byte) (int, error) { return i.b.Write(b) } +func (i identityHash) Size() int { return i.b.Len() } +func (i identityHash) BlockSize() int { return 1024 } +func (i identityHash) Reset() { i.b.Reset() } +func (i identityHash) Sum(b []byte) []byte { return append(b, i.b.Bytes()...) } + +func hashFromAlgorithm(alg uint8) (hash.Hash, crypto.Hash, error) { + hashnumber, ok := AlgorithmToHash[alg] + if !ok { + return nil, 0, ErrAlg + } + if hashnumber == 0 { + return identityHash{b: &bytes.Buffer{}}, hashnumber, nil + } + return hashnumber.New(), hashnumber, nil +} diff --git a/sig0.go b/sig0.go index e781c9bb..2c4b1035 100644 --- a/sig0.go +++ b/sig0.go @@ -3,6 +3,7 @@ package dns import ( "crypto" "crypto/ecdsa" + "crypto/ed25519" "crypto/rsa" "encoding/binary" "math/big" @@ -38,18 +39,17 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) { } buf = buf[:off:cap(buf)] - hash, ok := AlgorithmToHash[rr.Algorithm] - if !ok { - return nil, ErrAlg + h, cryptohash, err := hashFromAlgorithm(rr.Algorithm) + if err != nil { + return nil, err } - hasher := hash.New() // Write SIG rdata - hasher.Write(buf[len(mbuf)+1+2+2+4+2:]) + h.Write(buf[len(mbuf)+1+2+2+4+2:]) // Write message - hasher.Write(buf[:len(mbuf)]) + h.Write(buf[:len(mbuf)]) - signature, err := sign(k, hasher.Sum(nil), hash, rr.Algorithm) + signature, err := sign(k, h.Sum(nil), cryptohash, rr.Algorithm) if err != nil { return nil, err } @@ -82,20 +82,10 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { return ErrKey } - var hash crypto.Hash - switch rr.Algorithm { - case RSASHA1: - hash = crypto.SHA1 - case RSASHA256, ECDSAP256SHA256: - hash = crypto.SHA256 - case ECDSAP384SHA384: - hash = crypto.SHA384 - case RSASHA512: - hash = crypto.SHA512 - default: - return ErrAlg + h, cryptohash, err := hashFromAlgorithm(rr.Algorithm) + if err != nil { + return err } - hasher := hash.New() buflen := len(buf) qdc := binary.BigEndian.Uint16(buf[4:]) @@ -103,7 +93,6 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { auc := binary.BigEndian.Uint16(buf[8:]) adc := binary.BigEndian.Uint16(buf[10:]) offset := headerSize - var err error for i := uint16(0); i < qdc && offset < buflen; i++ { _, offset, err = UnpackDomainName(buf, offset) if err != nil { @@ -166,21 +155,21 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { return &Error{err: "signer name doesn't match key name"} } sigend := offset - hasher.Write(buf[sigstart:sigend]) - hasher.Write(buf[:10]) - hasher.Write([]byte{ + h.Write(buf[sigstart:sigend]) + h.Write(buf[:10]) + h.Write([]byte{ byte((adc - 1) << 8), byte(adc - 1), }) - hasher.Write(buf[12:bodyend]) + h.Write(buf[12:bodyend]) - hashed := hasher.Sum(nil) + hashed := h.Sum(nil) sig := buf[sigend:] switch k.Algorithm { case RSASHA1, RSASHA256, RSASHA512: pk := k.publicKeyRSA() if pk != nil { - return rsa.VerifyPKCS1v15(pk, hash, hashed, sig) + return rsa.VerifyPKCS1v15(pk, cryptohash, hashed, sig) } case ECDSAP256SHA256, ECDSAP384SHA384: pk := k.publicKeyECDSA() @@ -192,6 +181,14 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { } return ErrSig } + case ED25519: + pk := k.publicKeyED25519() + if pk != nil { + if ed25519.Verify(pk, hashed, sig) { + return nil + } + return ErrSig + } } return ErrKeyAlg } diff --git a/sig0_test.go b/sig0_test.go index 1abdc4af..5b991d22 100644 --- a/sig0_test.go +++ b/sig0_test.go @@ -12,7 +12,7 @@ func TestSIG0(t *testing.T) { } m := new(Msg) m.SetQuestion("example.org.", TypeSOA) - for _, alg := range []uint8{ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256, RSASHA512} { + for _, alg := range []uint8{ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256, RSASHA512, ED25519} { algstr := AlgorithmToString[alg] keyrr := new(KEY) keyrr.Hdr.Name = algstr + "." @@ -21,7 +21,7 @@ func TestSIG0(t *testing.T) { keyrr.Algorithm = alg keysize := 512 switch alg { - case ECDSAP256SHA256: + case ECDSAP256SHA256, ED25519: keysize = 256 case ECDSAP384SHA384: keysize = 384 @@ -30,7 +30,7 @@ func TestSIG0(t *testing.T) { } pk, err := keyrr.Generate(keysize) if err != nil { - t.Errorf("failed to generate key for “%s”: %v", algstr, err) + t.Errorf("failed to generate key for %q: %v", algstr, err) continue } now := uint32(time.Now().Unix()) @@ -45,16 +45,16 @@ func TestSIG0(t *testing.T) { sigrr.SignerName = keyrr.Hdr.Name mb, err := sigrr.Sign(pk.(crypto.Signer), m) if err != nil { - t.Errorf("failed to sign message using “%s”: %v", algstr, err) + t.Errorf("failed to sign message using %q: %v", algstr, err) continue } m := new(Msg) if err := m.Unpack(mb); err != nil { - t.Errorf("failed to unpack message signed using “%s”: %v", algstr, err) + t.Errorf("failed to unpack message signed using %q: %v", algstr, err) continue } if len(m.Extra) != 1 { - t.Errorf("missing SIG for message signed using “%s”", algstr) + t.Errorf("missing SIG for message signed using %q", algstr) continue } var sigrrwire *SIG @@ -71,20 +71,20 @@ func TestSIG0(t *testing.T) { id = "sigrrwire" } if err := rr.Verify(keyrr, mb); err != nil { - t.Errorf("failed to verify “%s” signed SIG(%s): %v", algstr, id, err) + t.Errorf("failed to verify %q signed SIG(%s): %v", algstr, id, err) continue } } mb[13]++ if err := sigrr.Verify(keyrr, mb); err == nil { - t.Errorf("verify succeeded on an altered message using “%s”", algstr) + t.Errorf("verify succeeded on an altered message using %q", algstr) continue } sigrr.Expiration = 2 sigrr.Inception = 1 mb, _ = sigrr.Sign(pk.(crypto.Signer), m) if err := sigrr.Verify(keyrr, mb); err == nil { - t.Errorf("verify succeeded on an expired message using “%s”", algstr) + t.Errorf("verify succeeded on an expired message using %q", algstr) continue } }