Remove hijacked

This commit is contained in:
Miek Gieben 2012-08-06 20:34:09 +02:00
parent 9ac8d2d7de
commit eec679d102
2 changed files with 15 additions and 22 deletions

View File

@ -29,7 +29,7 @@ type Client struct {
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9 ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9 WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret> TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
Hijacked net.Conn // if set the calling code takes care of the connection // Hijacked net.Conn // if set the calling code takes care of the connection
// LocalAddr string // Local address to use // LocalAddr string // Local address to use
} }
@ -60,15 +60,10 @@ func (c *Client) exchangeBuffer(inbuf []byte, a string, outbuf []byte) (n int, w
w = new(reply) w = new(reply)
w.client = c w.client = c
w.addr = a w.addr = a
if c.Hijacked == nil { if err = w.dial(); err != nil {
if err = w.Dial(); err != nil {
return 0, w, err return 0, w, err
} }
defer w.Close() defer w.Close()
}
if c.Hijacked != nil {
w.conn = c.Hijacked
}
w.t = time.Now() w.t = time.Now()
if n, err = w.writeClient(inbuf); err != nil { if n, err = w.writeClient(inbuf); err != nil {
return 0, w, err return 0, w, err
@ -132,8 +127,8 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
return r, w.rtt, nil return r, w.rtt, nil
} }
// Dial connects to the address addr for the network set in c.Net // dial connects to the address addr for the network set in c.Net
func (w *reply) Dial() (err error) { func (w *reply) dial() (err error) {
var conn net.Conn var conn net.Conn
if w.Client().Net == "" { if w.Client().Net == "" {
conn, err = net.Dial("udp", w.addr) conn, err = net.Dial("udp", w.addr)
@ -147,7 +142,7 @@ func (w *reply) Dial() (err error) {
return nil return nil
} }
func (w *reply) Receive() (*Msg, error) { func (w *reply) receive() (*Msg, error) {
var p []byte var p []byte
m := new(Msg) m := new(Msg)
switch w.Client().Net { switch w.Client().Net {
@ -246,10 +241,10 @@ func (w *reply) readClient(p []byte) (n int, err error) {
return return
} }
// Send sends a dns msg to the address specified in w. // send sends a dns msg to the address specified in w.
// If the message m contains a TSIG record the transaction // If the message m contains a TSIG record the transaction
// signature is calculated. // signature is calculated.
func (w *reply) Send(m *Msg) (err error) { func (w *reply) send(m *Msg) (err error) {
var out []byte var out []byte
if m.IsTsig() { if m.IsTsig() {
mac := "" mac := ""
@ -281,11 +276,9 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
if attempts == 0 { if attempts == 0 {
attempts = 1 attempts = 1
} }
if w.Client().Hijacked == nil { if err = w.dial(); err != nil {
if err = w.Dial(); err != nil {
return 0, err return 0, err
} }
}
switch w.Client().Net { switch w.Client().Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
if len(p) < 2 { if len(p) < 2 {

8
xfr.go
View File

@ -18,10 +18,10 @@ func (c *Client) XfrReceive(q *Msg, a string) (chan *Exchange, error) {
w.client = c w.client = c
w.addr = a w.addr = a
w.req = q w.req = q
if err := w.Dial(); err != nil { if err := w.dial(); err != nil {
return nil, err return nil, err
} }
if err := w.Send(q); err != nil { if err := w.send(q); err != nil {
return nil, err return nil, err
} }
e := make(chan *Exchange) e := make(chan *Exchange)
@ -43,7 +43,7 @@ func (w *reply) axfrReceive(c chan *Exchange) {
defer w.Close() defer w.Close()
defer close(c) defer close(c)
for { for {
in, err := w.Receive() in, err := w.receive()
if err != nil { if err != nil {
c <- &Exchange{Request: w.req, Reply: in, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr(), Error: err} c <- &Exchange{Request: w.req, Reply: in, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr(), Error: err}
return return
@ -78,7 +78,7 @@ func (w *reply) ixfrReceive(c chan *Exchange) {
defer w.Close() defer w.Close()
defer close(c) defer close(c)
for { for {
in, err := w.Receive() in, err := w.receive()
if err != nil { if err != nil {
c <- &Exchange{Request: w.req, Reply: in, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr(), Error: err} c <- &Exchange{Request: w.req, Reply: in, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr(), Error: err}
return return