dns/msg_generate.go
Tom Thorogood ff7d445081 Avoid setting the Rdlength field when packing records (#859)
* Avoid setting the Rdlength field when packing

The offset of the end of the header is now returned from the RR.pack
method, with the RDLENGTH record field being written in packRR.

To maintain compatability with callers of PackRR who might be relying
on this old behaviour, PackRR will now set rr.Header().Rdlength for
external callers. Care must be taken by callers to ensure this won't
cause a data-race.

* Prevent panic if TestClientLocalAddress fails

This came up during testing of the previous change.

* Change style of overflow check in packRR
2018-12-02 08:23:35 +00:00

346 lines
10 KiB
Go

//+build ignore
// msg_generate.go is meant to run with go generate. It will use
// go/{importer,types} to track down all the RR struct types. Then for each type
// it will generate pack/unpack methods based on the struct tags. The generated source is
// written to zmsg.go, and is meant to be checked into git.
package main
import (
"bytes"
"fmt"
"go/format"
"go/importer"
"go/types"
"log"
"os"
"strings"
)
var packageHdr = `
// Code generated by "go run msg_generate.go"; DO NOT EDIT.
package dns
`
// getTypeStruct will take a type and the package scope, and return the
// (innermost) struct if the type is considered a RR type (currently defined as
// those structs beginning with a RR_Header, could be redefined as implementing
// the RR interface). The bool return value indicates if embedded structs were
// resolved.
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
st, ok := t.Underlying().(*types.Struct)
if !ok {
return nil, false
}
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
return st, false
}
if st.Field(0).Anonymous() {
st, _ := getTypeStruct(st.Field(0).Type(), scope)
return st, true
}
return nil, false
}
func main() {
// Import and type-check the package
pkg, err := importer.Default().Import("github.com/miekg/dns")
fatalIfErr(err)
scope := pkg.Scope()
// Collect actual types (*X)
var namedTypes []string
for _, name := range scope.Names() {
o := scope.Lookup(name)
if o == nil || !o.Exported() {
continue
}
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
continue
}
if name == "PrivateRR" {
continue
}
// Check if corresponding TypeX exists
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
log.Fatalf("Constant Type%s does not exist.", o.Name())
}
namedTypes = append(namedTypes, o.Name())
}
b := &bytes.Buffer{}
b.WriteString(packageHdr)
fmt.Fprint(b, "// pack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {\n", name)
fmt.Fprint(b, `headerEnd, off, err := rr.Hdr.pack(msg, off, compression, compress)
if err != nil {
return headerEnd, off, err
}
`)
for i := 1; i < st.NumFields(); i++ {
o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil {
return headerEnd, off, err
}
`)
}
if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"-"`: // ignored
case `dns:"txt"`:
o("off, err = packStringTxt(rr.%s, msg, off)\n")
case `dns:"opt"`:
o("off, err = packDataOpt(rr.%s, msg, off)\n")
case `dns:"nsec"`:
o("off, err = packDataNsec(rr.%s, msg, off)\n")
case `dns:"domain-name"`:
o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
continue
}
switch {
case st.Tag(i) == `dns:"-"`: // ignored
case st.Tag(i) == `dns:"cdomain-name"`:
o("off, _, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
case st.Tag(i) == `dns:"domain-name"`:
o("off, _, err = packDomainName(rr.%s, msg, off, compression, false)\n")
case st.Tag(i) == `dns:"a"`:
o("off, err = packDataA(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"aaaa"`:
o("off, err = packDataAAAA(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"uint48"`:
o("off, err = packUint48(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"txt"`:
o("off, err = packString(rr.%s, msg, off)\n")
case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
fallthrough
case st.Tag(i) == `dns:"base32"`:
o("off, err = packStringBase32(rr.%s, msg, off)\n")
case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
fallthrough
case st.Tag(i) == `dns:"base64"`:
o("off, err = packStringBase64(rr.%s, msg, off)\n")
case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
// directly write instead of using o() so we get the error check in the correct place
field := st.Field(i).Name()
fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
if rr.%s != "-" {
off, err = packStringHex(rr.%s, msg, off)
if err != nil {
return headerEnd, off, err
}
}
`, field, field)
continue
case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
fallthrough
case st.Tag(i) == `dns:"hex"`:
o("off, err = packStringHex(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"octet"`:
o("off, err = packStringOctet(rr.%s, msg, off)\n")
case st.Tag(i) == "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
o("off, err = packUint8(rr.%s, msg, off)\n")
case types.Uint16:
o("off, err = packUint16(rr.%s, msg, off)\n")
case types.Uint32:
o("off, err = packUint32(rr.%s, msg, off)\n")
case types.Uint64:
o("off, err = packUint64(rr.%s, msg, off)\n")
case types.String:
o("off, err = packString(rr.%s, msg, off)\n")
default:
log.Fatalln(name, st.Field(i).Name())
}
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
}
fmt.Fprintln(b, "return headerEnd, off, nil }\n")
}
fmt.Fprint(b, "// unpack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
fmt.Fprintf(b, "rr := new(%s)\n", name)
fmt.Fprint(b, "rr.Hdr = h\n")
fmt.Fprint(b, `if noRdata(h) {
return rr, off, nil
}
var err error
rdStart := off
_ = rdStart
`)
for i := 1; i < st.NumFields(); i++ {
o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil {
return rr, off, err
}
`)
}
// size-* are special, because they reference a struct member we should use for the length.
if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
structMember := structMember(st.Tag(i))
structTag := structTag(st.Tag(i))
switch structTag {
case "hex":
fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
case "base32":
fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
case "base64":
fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
fmt.Fprint(b, `if err != nil {
return rr, off, err
}
`)
continue
}
if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"-"`: // ignored
case `dns:"txt"`:
o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
case `dns:"opt"`:
o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
case `dns:"nsec"`:
o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
case `dns:"domain-name"`:
o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
continue
}
switch st.Tag(i) {
case `dns:"-"`: // ignored
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
case `dns:"a"`:
o("rr.%s, off, err = unpackDataA(msg, off)\n")
case `dns:"aaaa"`:
o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
case `dns:"uint48"`:
o("rr.%s, off, err = unpackUint48(msg, off)\n")
case `dns:"txt"`:
o("rr.%s, off, err = unpackString(msg, off)\n")
case `dns:"base32"`:
o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"base64"`:
o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"hex"`:
o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"octet"`:
o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
case "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
o("rr.%s, off, err = unpackUint8(msg, off)\n")
case types.Uint16:
o("rr.%s, off, err = unpackUint16(msg, off)\n")
case types.Uint32:
o("rr.%s, off, err = unpackUint32(msg, off)\n")
case types.Uint64:
o("rr.%s, off, err = unpackUint64(msg, off)\n")
case types.String:
o("rr.%s, off, err = unpackString(msg, off)\n")
default:
log.Fatalln(name, st.Field(i).Name())
}
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
// If we've hit len(msg) we return without error.
if i < st.NumFields()-1 {
fmt.Fprintf(b, `if off == len(msg) {
return rr, off, nil
}
`)
}
}
fmt.Fprintf(b, "return rr, off, err }\n\n")
}
// Generate typeToUnpack map
fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
for _, name := range namedTypes {
if name == "RFC3597" {
continue
}
fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
}
fmt.Fprintln(b, "}\n")
// gofmt
res, err := format.Source(b.Bytes())
if err != nil {
b.WriteTo(os.Stderr)
log.Fatal(err)
}
// write result
f, err := os.Create("zmsg.go")
fatalIfErr(err)
defer f.Close()
f.Write(res)
}
// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
func structMember(s string) string {
fields := strings.Split(s, ":")
if len(fields) == 0 {
return ""
}
f := fields[len(fields)-1]
// f should have a closing "
if len(f) > 1 {
return f[:len(f)-1]
}
return f
}
// structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
func structTag(s string) string {
fields := strings.Split(s, ":")
if len(fields) < 2 {
return ""
}
return fields[1][len("\"size-"):]
}
func fatalIfErr(err error) {
if err != nil {
log.Fatal(err)
}
}