From b625f190ce1bb289b4962e60c42f4e9662a01111 Mon Sep 17 00:00:00 2001 From: Alex Sergeyev Date: Mon, 29 Jun 2015 08:06:49 -0400 Subject: [PATCH] Not allocating 64K buffers for reading --- client.go | 78 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 27 deletions(-) diff --git a/client.go b/client.go index 4203e460..b12e5f09 100644 --- a/client.go +++ b/client.go @@ -215,20 +215,29 @@ func (co *Conn) ReadMsg() (*Msg, error) { // Note that this function would not be able to report TSIG error or // check it got actual DNS payload. func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) { - var p []byte + var ( + p []byte + n int + err error + ) - if _, ok := co.Conn.(*net.TCPConn); ok { - // we got two byte - p = make([]byte, MaxMsgSize) + if t, ok := co.Conn.(*net.TCPConn); ok { + // we got two byte header to know how much to receive... + l, err := tcpMsgLen(t) + if err != nil { + return nil, err + } + p = make([]byte, l) + n, err = tcpRead(t, p) } else { if co.UDPSize > MinMsgSize { p = make([]byte, co.UDPSize) } else { p = make([]byte, MinMsgSize) } + n, err = co.Read(p) } - n, err := co.Read(p) if err != nil { return nil, err } else if n < _HBytes { @@ -244,6 +253,38 @@ func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) { return p, err } +// tcpMsgLen - helper func to read first two bytes of stream as uint16 packet length +func tcpMsgLen(t *net.TCPConn) (int, error) { + p := [2]byte{0, 0} + n, err := t.Read(p[:]) + if err != nil { + return 0, err + } else if n != 2 { + return 0, ErrShortRead + } + l, _ := unpackUint16(p[:], 0) + if l == 0 { + return 0, ErrShortRead + } + return int(l), nil +} + +// tcpRead - calls TCPConn.Read enough times to fill allocated buffer +func tcpRead(t *net.TCPConn, p []byte) (int, error) { + n, err := t.Read(p) + if err != nil { + return n, err + } + for n < len(p) { + j, err := t.Read(p[n:]) + if err != nil { + return n, err + } + n += j + } + return n, err +} + // Read implements the net.Conn read method. func (co *Conn) Read(p []byte) (n int, err error) { if co.Conn == nil { @@ -253,31 +294,14 @@ func (co *Conn) Read(p []byte) (n int, err error) { return 0, io.ErrShortBuffer } if t, ok := co.Conn.(*net.TCPConn); ok { - n, err = t.Read(p[0:2]) - if err != nil || n != 2 { - return n, err + l, err := tcpMsgLen(t) + if err != nil { + return 0, err } - l, _ := unpackUint16(p[0:2], 0) - if l == 0 { - return 0, ErrShortRead - } - if int(l) > len(p) { + if l > len(p) { return int(l), io.ErrShortBuffer } - n, err = t.Read(p[:l]) - if err != nil { - return n, err - } - i := n - for i < int(l) { - j, err := t.Read(p[i:int(l)]) - if err != nil { - return i, err - } - i += j - } - n = i - return n, err + return tcpRead(t, p[:l]) } // UDP connection n, err = co.Conn.Read(p)