From 39b9f93167de674e01eeecd124aca9e5500a464f Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Mon, 15 Oct 2012 20:00:49 +0200 Subject: [PATCH] Fix tsig in the normal sending of queries --- client.go | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index c15decdb..bdb33520 100644 --- a/client.go +++ b/client.go @@ -23,12 +23,12 @@ type reply struct { // A Client defines parameter for a DNS client. A nil // Client is usable for sending queries. type Client struct { - Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "", is UDP) + Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) Attempts int // number of attempts, if not set defaults to 1 Retry bool // retry with TCP ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9 WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9 - TsigSecret map[string]string // secret(s) for Tsig map[] + TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified } func (w *reply) RemoteAddr() net.Addr { @@ -99,9 +99,19 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) { // in, rtt, err := c.ExchangeRtt(message, "127.0.0.1:53") // func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - var n int - var w *reply - out, err := m.Pack() + var ( + n int + out []byte + ) + w := new(reply) + if t := m.IsTsig(); t != nil { + if _, ok := w.client.TsigSecret[t.Hdr.Name]; !ok { + return nil, 0, ErrSecret + } + out, _, err = TsigGenerate(m, c.TsigSecret[t.Hdr.Name], "", false) + } else { + out, err = m.Pack() + } if err != nil { return nil, 0, err } @@ -126,6 +136,17 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e if err := r.Unpack(in[:n]); err != nil { return nil, w.rtt, err } + if t := r.IsTsig(); t != nil { + secret := t.Hdr.Name + if _, ok := client.TsigSecret[secret]; !ok { + w.tsigStatus = ErrSecret + return m, nil + } + // Need to work on the original message p, as that was used to calculate the tsig. + w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) + } + + return r, w.rtt, nil } @@ -255,15 +276,12 @@ func (w *reply) send(m *Msg) (err error) { return ErrSecret } out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) - if err != nil { - return err - } w.tsigRequestMAC = mac } else { out, err = m.Pack() - if err != nil { - return err - } + } + if err != nil { + return err } w.t = time.Now() if _, err = w.writeClient(out); err != nil {