diff --git a/ex/fks-shield/cache.go b/ex/fks-shield/cache.go index 03019e7e..a81c72d7 100644 --- a/ex/fks-shield/cache.go +++ b/ex/fks-shield/cache.go @@ -3,6 +3,7 @@ package main import ( "dns" "fmt" + "log" "radix" "strings" "time" @@ -11,7 +12,7 @@ import ( // Cache elements, we using to key (toRadixKey) to distinguish between dns and dnssec type Packet struct { ttl time.Time // insertion time - d *dns.Msg // packet + d []byte // raw packet } func toRadixKey(d *dns.Msg) string { @@ -30,20 +31,36 @@ type Cache struct { *radix.Radix } +func quickCopy(p []byte) []byte { + q := make([]byte, 2) + q = append(q, p[2:]...) + return q +} + func NewCache() *Cache { return &Cache{Radix: radix.New()} } -func (c *Cache) Find(d *dns.Msg) *dns.Msg { +func (c *Cache) Find(d *dns.Msg) []byte { p := c.Radix.Find(toRadixKey(d)) if p == nil { + if *verbose { + log.Printf("Cache miss for " + toRadixKey(d)) + } return nil } - return p.Value.(*Packet).d + if *verbose { + log.Printf("Cache hit for " + toRadixKey(d)) + } + return quickCopy(p.Value.(*Packet).d) } func (c *Cache) Insert(d *dns.Msg) { - c.Radix.Insert(toRadixKey(d), &Packet{d: d, ttl: time.Now().UTC()}) + if *verbose { + log.Printf("Inserting " + toRadixKey(d)) + } + buf, _ := d.Pack() // Should always work + c.Radix.Insert(toRadixKey(d), &Packet{d: buf, ttl: time.Now().UTC()}) } func (c *Cache) Remove(d *dns.Msg) { diff --git a/ex/fks-shield/shield.go b/ex/fks-shield/shield.go index aa88d126..74323128 100644 --- a/ex/fks-shield/shield.go +++ b/ex/fks-shield/shield.go @@ -16,17 +16,22 @@ import ( var ( listen = flag.String("listen", "127.0.0.1:8053", "set the listener address") server = flag.String("server", "127.0.0.1:53", "remote server address(es), seperate with commas") - verbose = flag.Bool("verbose", false, "Print packet as it flows through") + verbose = flag.Bool("verbose", false, "be more verbose") ) func serve(w dns.ResponseWriter, r *dns.Msg, c *Cache) { if p := c.Find(r); p != nil { - w.Write(p) + dns.RawSetId(p, r.MsgHdr.Id) + w.WriteBuf(p) return } // Cache miss client := new(dns.Client) if p, e := client.Exchange(r, *server); e == nil { + if *verbose { + log.Printf("fks-shield: cache miss") + } + // TODO(mg): If r has edns0 and p has not we create a mismatch here w.Write(p) c.Insert(p) return @@ -38,7 +43,7 @@ func serve(w dns.ResponseWriter, r *dns.Msg, c *Cache) { func listenAndServe(add, net string) { if err := dns.ListenAndServe(add, net, nil); err != nil { - log.Printf("fks-shield: failed to setup:", net, add) + log.Fatal("fks-shield: failed to setup %s %s", net, add) } } diff --git a/ex/fksd/serve.go b/ex/fksd/serve.go index f4c0d0fd..de312840 100644 --- a/ex/fksd/serve.go +++ b/ex/fksd/serve.go @@ -5,12 +5,24 @@ import ( "log" ) +// Create skeleton edns opt RR from the query and +// add it to the message m +func ednsFromRequest(req, m *dns.Msg) { + for _, r := range req.Extra { + if r.Header().Rrtype == dns.TypeOPT { + m.SetEdns0(4096, r.(*dns.RR_OPT).Do()) + return + } + } + return +} + func serve(w dns.ResponseWriter, req *dns.Msg, z *dns.Zone) { if z == nil { panic("fks: no zone") } if *l { - log.Printf("fks: [zone %s] incoming %s %s %d\n", z.Origin, req.Question[0].Name, dns.Rr_str[req.Question[0].Qtype], req.MsgHdr.Id) + log.Printf("fks: [zone %s] incoming %s %s %d from %s\n", z.Origin, req.Question[0].Name, dns.Rr_str[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr()) } // Ds Handling // Referral @@ -36,6 +48,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *dns.Zone) { } } } + ednsFromRequest(req, m) w.Write(m) return } @@ -45,6 +58,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *dns.Zone) { node := z.Find(req.Question[0].Name) if node == nil { m.SetRcode(req, dns.RcodeNameError) + ednsFromRequest(req, m) w.Write(m) return } @@ -70,6 +84,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *dns.Zone) { } } } + ednsFromRequest(req, m) w.Write(m) return } @@ -81,14 +96,17 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *dns.Zone) { m.MsgHdr.Authoritative = true m.Answer = rrs m.Ns = apex.RR[dns.TypeNS] + ednsFromRequest(req, m) w.Write(m) return } else { // NoData reply or CNAME m.SetReply(req) m.Ns = apex.RR[dns.TypeSOA] + ednsFromRequest(req, m) w.Write(m) return } m.SetRcode(req, dns.RcodeNameError) + ednsFromRequest(req, m) w.Write(m) }