diff --git a/client.go b/client.go index c15decdb..b27a5617 100644 --- a/client.go +++ b/client.go @@ -23,12 +23,12 @@ type reply struct { // A Client defines parameter for a DNS client. A nil // Client is usable for sending queries. type Client struct { - Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "", is UDP) + Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) Attempts int // number of attempts, if not set defaults to 1 Retry bool // retry with TCP ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9 WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9 - TsigSecret map[string]string // secret(s) for Tsig map[] + TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified } func (w *reply) RemoteAddr() net.Addr { @@ -58,8 +58,6 @@ func (c *Client) DoRtt(msg *Msg, addr string, data interface{}, callback func(*M }() } -// exchangeBuffer performs a synchronous query. It sends the buffer m to the -// address contained in a. func (c *Client) exchangeBuffer(inbuf []byte, a string, outbuf []byte) (n int, w *reply, err error) { w = new(reply) w.client = c @@ -70,10 +68,10 @@ func (c *Client) exchangeBuffer(inbuf []byte, a string, outbuf []byte) (n int, w } defer w.conn.Close() w.t = time.Now() - if n, err = w.writeClient(inbuf); err != nil { + if n, err = w.write(inbuf); err != nil { return 0, w, err } - if n, err = w.readClient(outbuf); err != nil { + if n, err = w.read(outbuf); err != nil { return n, w, err } w.rtt = time.Since(w.t) @@ -99,34 +97,17 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) { // in, rtt, err := c.ExchangeRtt(message, "127.0.0.1:53") // func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - var n int - var w *reply - out, err := m.Pack() - if err != nil { + w := new(reply) + w.client = c + w.addr = a + if err = w.dial(); err != nil { return nil, 0, err } - var in []byte - switch c.Net { - case "tcp", "tcp4", "tcp6": - in = make([]byte, MaxMsgSize) - case "", "udp", "udp4", "udp6": - size := udpMsgSize - for _, r := range m.Extra { - if r.Header().Rrtype == TypeOPT { - size = int(r.(*RR_OPT).UDPSize()) - } - } - in = make([]byte, size) - } - if n, w, err = c.exchangeBuffer(out, a, in); err != nil { + if err = w.send(m); err != nil { return nil, 0, err } - r = new(Msg) - r.Size = n - if err := r.Unpack(in[:n]); err != nil { - return nil, w.rtt, err - } - return r, w.rtt, nil + r, err = w.receive() + return r, w.rtt, err } // dial connects to the address addr for the network set in c.Net @@ -151,9 +132,10 @@ func (w *reply) receive() (*Msg, error) { case "tcp", "tcp4", "tcp6": p = make([]byte, MaxMsgSize) case "", "udp", "udp4", "udp6": + // OPT! TODO(mg) p = make([]byte, DefaultMsgSize) } - n, err := w.readClient(p) + n, err := w.read(p) if err != nil && n == 0 { return nil, err } @@ -167,15 +149,15 @@ func (w *reply) receive() (*Msg, error) { secret := t.Hdr.Name if _, ok := w.client.TsigSecret[secret]; !ok { w.tsigStatus = ErrSecret - return m, nil + return m, ErrSecret } // Need to work on the original message p, as that was used to calculate the tsig. w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) } - return m, nil + return m, w.tsigStatus } -func (w *reply) readClient(p []byte) (n int, err error) { +func (w *reply) read(p []byte) (n int, err error) { if w.conn == nil { return 0, ErrConnEmpty } @@ -255,24 +237,21 @@ func (w *reply) send(m *Msg) (err error) { return ErrSecret } out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) - if err != nil { - return err - } w.tsigRequestMAC = mac } else { out, err = m.Pack() - if err != nil { - return err - } + } + if err != nil { + return err } w.t = time.Now() - if _, err = w.writeClient(out); err != nil { + if _, err = w.write(out); err != nil { return err } return nil } -func (w *reply) writeClient(p []byte) (n int, err error) { +func (w *reply) write(p []byte) (n int, err error) { attempts := w.client.Attempts if attempts == 0 { attempts = 1 diff --git a/zone.go b/zone.go index 28af7849..c694ba67 100644 --- a/zone.go +++ b/zone.go @@ -318,16 +318,8 @@ func (z *Zone) Remove(r RR) error { func (z *Zone) RemoveName(s string) error { key := toRadixName(s) z.Lock() - zd, exact := z.Radix.Find(key) - if !exact { - defer z.Unlock() - return nil - } - z.Unlock() - zd.Value.(*ZoneData).mutex.Lock() - defer zd.Value.(*ZoneData).mutex.Unlock() - zd.Value = nil // remove the lot - + defer z.Unlock() + z.Radix.Remove(key) if len(s) > 1 && s[0] == '*' && s[1] == '.' { z.Wildcard-- if z.Wildcard < 0 {