dns/resolver/resolver.go
Miek Gieben 57b1e2f714 First stab at server support
Cherry pick GRONG for the server stuff (only server.go) seems
interesting.

Still working on the responder idea of GRONG and how to implement
it in godns
2011-01-11 23:33:07 +01:00

402 lines
9.2 KiB
Go

// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DNS resolver client: see RFC 1035.
// A dns resolver is to be run as a seperate goroutine.
// For every reply the resolver answers by sending the
// received packet (with a possible error) back on the channel.
//
// Basic usage pattern for setting up a resolver:
//
// res := new(Resolver)
// ch := res.NewQuerier() // start new resolver
// res.Servers = []string{"127.0.0.1"} // set the nameserver
//
// m := new(Msg) // prepare a new message
// m.MsgHdr.Recursion_desired = true // header bits
// m.Question = make([]Question, 1) // 1 RR in question sec.
// m.Question[0] = Question{"miek.nl", TypeSOA, ClassINET}
// ch <- DnsMsg{m, nil} // send the query
// in := <-ch // wait for reply
//
// Note that message id checking is left to the caller
//
package resolver
import (
"os"
"net"
"dns"
)
const packErr = "Failed to pack message"
const servErr = "No servers could be reached"
// When communicating with a resolver, we use this structure
// to send packets to it, for sending Error must be nil.
// A resolver responds with a reply packet and a possible error.
// Sending a nil message instructs to resolver to stop.
type DnsMsg struct {
Dns *dns.Msg
Error os.Error
}
// ^ prolly extend this so that servers can use this too. We need
// to also encode some client side stuff in here, like client addr
// I think that is the *only* thing we need.
type Resolver struct {
Servers []string // servers to use
Search []string // suffixes to append to local name
Port string // what port to use
Ndots int // number of dots in name to trigger absolute lookup
Timeout int // seconds before giving up on packet
Attempts int // lost packets before giving up on server
Rotate bool // round robin among servers
Tcp bool // use TCP
Mangle func([]byte) []byte // mangle the packet
}
// Start a new resolver as a goroutine, return the communication channel
func (res *Resolver) NewQuerier() (ch chan DnsMsg) {
ch = make(chan DnsMsg)
go query(res, ch)
return
}
// The query function.
func query(res *Resolver, msg chan DnsMsg) {
// error checking, robustness
var (
c net.Conn
err os.Error
in *dns.Msg
port string
)
// len(res.Server) == 0 can be perfectly valid, when setting up the resolver
if res.Port == "" {
port = "53"
} else {
port = res.Port
}
for {
select {
case out := <-msg: //msg received
if out.Dns == nil {
// nil message, quit the goroutine
msg <- DnsMsg{nil, nil}
close(msg)
return
}
var cerr os.Error
//if len(name) >= 256 {
if out.Dns.Id == 0 {
// No Id sed, set it
out.Dns.SetId()
}
sending, ok := out.Dns.Pack()
if !ok {
msg <- DnsMsg{nil, &dns.Error{Error: packErr}}
continue
}
for i := 0; i < len(res.Servers); i++ {
server := res.Servers[i] + ":" + port
if res.Tcp {
c, cerr = net.Dial("tcp", "", server)
} else {
c, cerr = net.Dial("udp", "", server)
}
if cerr != nil {
err = cerr
continue
}
if res.Tcp {
in, err = exchangeTCP(c, sending, res, true)
} else {
in, err = exchangeUDP(c, sending, res, true)
}
// Check id in.id != out.id, should be checked in the client!
c.Close()
if err != nil {
continue
}
}
if err != nil {
msg <- DnsMsg{nil, err}
} else {
msg <- DnsMsg{in, nil}
}
}
}
return
}
// Start a new xfr as a goroutine, return a channel.
// Channel will be closed when the axfr is finished, until
// that time new messages will appear on the channel
func (res *Resolver) NewXfer() (ch chan DnsMsg) {
ch = make(chan DnsMsg)
go axfr(res, ch)
return
}
func axfr(res *Resolver, msg chan DnsMsg) {
var port string
var err os.Error
var in *dns.Msg
if res.Port == "" {
port = "53"
} else {
port = res.Port
}
for {
select {
case out := <-msg: // msg received
if out.Dns == nil {
// stop
msg <- DnsMsg{nil, nil}
close(msg)
return
}
out.Dns.SetId()
sending, ok := out.Dns.Pack()
if !ok {
msg <- DnsMsg{nil, &dns.Error{Error: packErr}}
}
SERVER:
for i := 0; i < len(res.Servers); i++ {
server := res.Servers[i] + ":" + port
c, cerr := net.Dial("tcp", "", server)
if cerr != nil {
err = cerr
continue SERVER
}
first := true
// Start the AXFR
for {
if first {
in, cerr = exchangeTCP(c, sending, res, true)
} else {
in, cerr = exchangeTCP(c, sending, res, false)
}
if cerr != nil {
// Failed to send, try the next
err = cerr
c.Close()
continue SERVER
}
// if in.Dns.Id != out.Id // error
if first {
if !checkSOA(in, true) {
// SOA record not there...
c.Close()
continue SERVER
}
first = !first
}
if !first {
if !checkSOA(in, false) {
// Soa record not the last one
msg <- DnsMsg{in, nil}
continue
// next
} else {
c.Close()
msg <- DnsMsg{in, nil}
close(msg)
return
}
}
}
println("Should never be reached")
return
}
msg <- DnsMsg{nil, err}
close(msg)
return
}
}
return
}
// Send a request on the connection and hope for a reply.
// Up to res.Attempts attempts.
func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) (*dns.Msg, os.Error) {
var timeout int64
var attempts int
if r.Mangle != nil {
m = r.Mangle(m)
}
if r.Timeout == 0 {
timeout = 1
} else {
timeout = int64(r.Timeout)
}
if r.Attempts == 0 {
attempts = 1
} else {
attempts = r.Attempts
}
for a := 0; a < attempts; a++ {
if send {
err := sendUDP(c, m)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return nil, err
}
}
c.SetReadTimeout(timeout * 1e9) // nanoseconds
buf, err := recvUDP(c)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return nil, err
}
in := new(dns.Msg)
if !in.Unpack(buf) {
continue
}
return in, nil
}
return nil, &dns.Error{Error: servErr}
}
// Up to res.Attempts attempts.
func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) (*dns.Msg, os.Error) {
var timeout int64
var attempts int
if r.Mangle != nil {
m = r.Mangle(m)
}
if r.Timeout == 0 {
timeout = 1
} else {
timeout = int64(r.Timeout)
}
if r.Attempts == 0 {
attempts = 1
} else {
attempts = r.Attempts
}
for a := 0; a < attempts; a++ {
// only send something when told so
if send {
err := sendTCP(c, m)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return nil, err
}
}
c.SetReadTimeout(timeout * 1e9) // nanoseconds
// The server replies with two bytes length
buf, err := recvTCP(c)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return nil, err
}
in := new(dns.Msg)
if !in.Unpack(buf) {
continue
}
return in, nil
}
return nil, &dns.Error{Error: servErr}
}
func sendUDP(c net.Conn, m []byte) os.Error {
_, err := c.Write(m)
if err != nil {
return err
}
return nil
}
func recvUDP(c net.Conn) ([]byte, os.Error) {
m := make([]byte, dns.DefaultMsgSize) // More than enough???
n, err := c.Read(m)
if err != nil {
return nil, err
}
m = m[:n]
return m, nil
}
func sendTCP(c net.Conn, m []byte) os.Error {
l := make([]byte, 2)
l[0] = byte(len(m) >> 8)
l[1] = byte(len(m))
// First we send the length
_, err := c.Write(l)
if err != nil {
return err
}
// And the the message
_, err = c.Write(m)
if err != nil {
return err
}
return nil
}
func recvTCP(c net.Conn) ([]byte, os.Error) {
l := make([]byte, 2) // receiver length
// The server replies with two bytes length
_, err := c.Read(l)
if err != nil {
return nil,err
}
length := uint16(l[0])<<8 | uint16(l[1])
if length == 0 {
return nil, &dns.Error{Error: "received nil msg length", Server: c.RemoteAddr().String()}
}
m := make([]byte, length)
n, cerr := c.Read(m)
if cerr != nil {
return nil, cerr
}
i := n
if i < int(length) {
n, err = c.Read(m[i:])
if err != nil {
return nil, err
}
i += n
}
return m, nil
}
// Check if he SOA record exists in the Answer section of
// the packet. If first is true the first RR must be a soa
// if false, the last one should be a SOA
func checkSOA(in *dns.Msg, first bool) bool {
if len(in.Answer) > 0 {
if first {
return in.Answer[0].Header().Rrtype == dns.TypeSOA
} else {
return in.Answer[len(in.Answer)-1].Header().Rrtype == dns.TypeSOA
}
}
return false
}