commit af8e62003b475a5667c46e1a544e738f6f17e662 Author: David Anderson Date: Thu Feb 18 00:41:31 2016 -0800 Implement a read-only TFTP server. Testing is done against the atftp client. diff --git a/tftp/handlers.go b/tftp/handlers.go new file mode 100644 index 0000000..218e761 --- /dev/null +++ b/tftp/handlers.go @@ -0,0 +1,44 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "path/filepath" +) + +// FilesystemHandler returns a Handler that serves files in root. +func FilesystemHandler(root string) (Handler, error) { + root, err := filepath.Abs(root) + if err != nil { + return nil, err + } + root = filepath.ToSlash(root) + return func(path string, addr net.Addr) (io.ReadCloser, error) { + // Join with a root, which gets rid of directory traversal + // attempts. Then we join that canonicalized path with the + // actual root, which resolves to the actual on-disk file to + // serve. + path = filepath.Join("/", path) + path = filepath.FromSlash(filepath.Join(root, path)) + + st, err := os.Stat(path) + if err != nil { + return nil, err + } + if !st.Mode().IsRegular() { + return nil, fmt.Errorf("requested path %q is not a file", path) + } + return os.Open(path) + }, nil +} + +// ConstantHandler returns a Handler that serves bs for all requested paths. +func ConstantHandler(bs []byte) Handler { + return func(path string, addr net.Addr) (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBuffer(bs)), nil + } +} diff --git a/tftp/interop_test.go b/tftp/interop_test.go new file mode 100644 index 0000000..bf1fcec --- /dev/null +++ b/tftp/interop_test.go @@ -0,0 +1,166 @@ +package tftp + +import ( + "fmt" + "io/ioutil" + "math/rand" + "net" + "os" + "os/exec" + "strconv" + "strings" + "testing" + "time" +) + +var testFile = strings.Repeat(`This is a test file. + +My, what a pretty test file. + +I wonder if TFTP clients will be able to retrieve it! +`, 100) + +func TestInterop(t *testing.T) { + fmt.Println(len(testFile)) + prog, err := exec.LookPath("atftp") + if err != nil { + if e, ok := err.(*exec.Error); ok && e.Err == exec.ErrNotFound { + t.Skip("atftp is not installed") + } + t.Fatalf("Error while looking for atftp: %s", err) + } + + f, err := ioutil.TempFile("", "interop_test") + if err != nil { + t.Fatalf("creating temporary file: %s", err) + } + os.Remove(f.Name()) + defer f.Close() + + servers := []*Server{ + &Server{ + Handler: ConstantHandler([]byte(testFile)), + InfoLog: infoLog, + TransferLog: transferLog, + }, + &Server{ + Handler: ConstantHandler([]byte(testFile)), + InfoLog: infoLog, + TransferLog: transferLog, + // This Server clamps to a smaller block size. + MaxBlockSize: 500, + }, + &Server{ + Handler: ConstantHandler([]byte(testFile)), + InfoLog: infoLog, + TransferLog: transferLog, + // Lower block size to send more packets + MaxBlockSize: 500, + WriteTimeout: 100 * time.Millisecond, + // 10% loss rate until we've dropped 5 packets + Dial: lossyDialer(10, 5), + }, + } + + for _, s := range servers { + fmt.Fprintf(os.Stderr, "\nUsing server: %#v\n", s) + l, port := mkListener(t) + defer l.Close() + go s.Serve(l) + + options := [][]string{ + {"blksize 8"}, + {"blksize 4000"}, + {"tsize enable"}, + {"tsize enable", "blksize 1000"}, + } + + for _, opts := range options { + c := exec.Command(prog, "--get", "--trace", "--verbose", "-r", "foo", "-l", f.Name()) + for _, o := range opts { + c.Args = append(c.Args, "--option", o) + } + c.Args = append(c.Args, "127.0.0.1", strconv.Itoa(port)) + fmt.Fprintf(os.Stderr, "Fetching with: %#v\n", c.Args) + + out, err := c.CombinedOutput() + if err != nil { + t.Fatalf("TFTP fetch failed, command output:\n%s\n", string(out)) + } + bs, err := ioutil.ReadFile(f.Name()) + if err != nil { + t.Fatalf("Reading back fetched file: %s", err) + } + if string(bs) != testFile { + t.Fatal("File fetched over TFTP doesn't match file served") + } + if err := os.Remove(f.Name()); err != nil { + t.Fatalf("Failed to remove temp file: %s", err) + } + } + } +} + +func lossyDialer(lossPercent int, maxDrops int) func(string, string) (net.Conn, error) { + return func(network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + return &lossyConn{conn, lossPercent, maxDrops, 0, 0}, nil + } +} + +type lossyConn struct { + net.Conn + lossPercent int + dropsLeft int + droppedWrites int + droppedReads int +} + +func (c *lossyConn) Write(b []byte) (int, error) { + if c.dropsLeft > 0 && rand.Intn(100) < c.lossPercent { + // Pretend to send, to simulate a network failure. + c.dropsLeft-- + c.droppedWrites++ + return len(b), nil + } + return c.Conn.Write(b) +} + +func (c *lossyConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + if c.dropsLeft > 0 && rand.Intn(100) < c.lossPercent { + // nope, didn't receive anything, read next packet. + c.dropsLeft-- + c.droppedReads++ + return c.Conn.Read(b) + } + return n, err +} + +func (c *lossyConn) Close() error { + fmt.Fprintf(os.Stderr, "Dropped %d reads and %d write\n", c.droppedReads, c.droppedWrites) + return c.Conn.Close() +} + +func infoLog(m string) { + fmt.Fprintf(os.Stderr, "TFTP server log: %s\n", m) +} + +func transferLog(a net.Addr, p string, e error) { + extra := "" + if e != nil { + extra = "(" + e.Error() + ")" + } + fmt.Fprintf(os.Stderr, "TFTP server transferred %q to %s %s\n", p, a, extra) +} + +func mkListener(t *testing.T) (net.PacketConn, int) { + l, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("creating listener for test: %s", err) + } + return l, l.LocalAddr().(*net.UDPAddr).Port +} diff --git a/tftp/tftp.go b/tftp/tftp.go new file mode 100644 index 0000000..d958be8 --- /dev/null +++ b/tftp/tftp.go @@ -0,0 +1,356 @@ +// Package tftp implements a read-only TFTP server. +package tftp + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strconv" + "time" +) + +const ( + // DefaultWriteTimeout is the duration a client has to acknowledge + // a data packet from the server. This can be overridden by + // setting Server.WriteTimeout. + DefaultWriteTimeout = 2 * time.Second + // DefaultWriteAttempts is the maximum number of times a single + // packet will be (re)sent before timing out a client. This can be + // overridden by setting Server.WriteAttempts. + DefaultWriteAttempts = 5 + // DefaultBlockSize is the maximum block size used to send data to + // clients. The server will respect a request for a smaller block + // size, but requests for larger block sizes will be clamped to + // DefaultBlockSize. This can be overridden by setting + // Server.MaxBlockSize. + DefaultBlockSize = 1450 + + // maxErrorSize is the largest error message string that will be + // sent to the client without truncation. + maxErrorSize = 500 +) + +// A Handler provides bytes for a file. +type Handler func(path string, clientAddr net.Addr) (io.ReadCloser, error) + +// A Server defines parameters for running a TFTP server. +type Server struct { + Handler Handler // handler to invoke for requests + + // WriteTimeout sets the duration to wait for the client to + // acknowledge a data packet. Defaults to DefaultWriteTimeout. + WriteTimeout time.Duration + // WriteAttempts sets how many times a packet will be (re)sent + // before timing out the client and aborting the transfer. If 0, + // uses DefaultWriteAttempts. + WriteAttempts int + // MaxBlockSize sets the maximum block size used for file + // transfers. If 0, uses DefaultBlockSize. + MaxBlockSize int64 + + // InfoLog specifies an optional logger for informational + // messages. If nil, informational messages are suppressed. + InfoLog func(msg string) + // TransferLog specifies an optional logger for completed + // transfers. A successful transfer is logged with err == nil. If + // nil, transfer logs are suppressed. + TransferLog func(clientAddr net.Addr, path string, err error) + + // Dial specifies a function to use when setting up a "connected" + // UDP socket to a TFTP client. While this is mostly here for + // testing, it can also be used to implement advanced relay + // functionality (e.g. serving TFTP through SOCKS). If nil, + // net.Dial is used. + Dial func(network, addr string) (net.Conn, error) +} + +// ListenAndServe listens on the UDP network address addr and then +// calls Serve to handle TFTP requests. If addr is blank, ":69" is +// used. +func (s *Server) ListenAndServe(addr string) error { + l, err := net.ListenPacket("udp", addr) + if err != nil { + return err + } + defer l.Close() + return s.Serve(l) +} + +// Serve accepts requests on listener l, creating a new transfer +// goroutine for each. The transfer goroutines use s.Handler to get +// bytes, and transfers them to the client. +func (s *Server) Serve(l net.PacketConn) error { + if s.Handler == nil { + return errors.New("can't serve, Handler is nil") + } + if err := l.SetDeadline(time.Time{}); err != nil { + return err + } + s.infoLog("TFTP listening on %s", l.LocalAddr()) + buf := make([]byte, 512) + for { + n, addr, err := l.ReadFrom(buf) + if err != nil { + return err + } + + req, err := parseRRQ(buf[:n]) + if err != nil { + s.infoLog("bad request from %q: %s", addr, err) + continue + } + + go s.transferAndLog(addr, req) + } + +} + +func (s *Server) infoLog(msg string, args ...interface{}) { + if s.InfoLog != nil { + s.InfoLog(fmt.Sprintf(msg, args...)) + } +} + +func (s *Server) transferLog(addr net.Addr, path string, err error) { + if s.TransferLog != nil { + s.TransferLog(addr, path, err) + } +} + +func (s *Server) transferAndLog(addr net.Addr, req *rrq) { + err := s.transfer(addr, req) + if err != nil { + err = fmt.Errorf("%q: %s", addr, err) + } + s.transferLog(addr, req.Filename, err) +} + +func (s *Server) transfer(addr net.Addr, req *rrq) error { + d := s.Dial + if d == nil { + d = net.Dial + } + conn, err := d("udp", addr.String()) + if err != nil { + return fmt.Errorf("creating socket: %s", err) + } + defer conn.Close() + + file, err := s.Handler(req.Filename, addr) + if err != nil { + conn.Write(tftpError("failed to get file")) + return fmt.Errorf("getting file bytes: %s", err) + } + defer file.Close() + + var b bytes.Buffer + if req.BlockSize == 0 { + // Client didn't negotiate, use classic blocksize from RFC. + req.BlockSize = 512 + } else { + // Client requested a specific blocksize, need to OACK before + // sending data. + maxBlockSize := s.MaxBlockSize + if maxBlockSize <= 0 { + maxBlockSize = DefaultBlockSize + } + if req.BlockSize > maxBlockSize { + s.infoLog("clamping blocksize to %q: %d -> %d", addr, req.BlockSize, maxBlockSize) + req.BlockSize = maxBlockSize + } + + b.WriteByte(0) + b.WriteByte(6) + b.WriteString("blksize") + b.WriteByte(0) + b.WriteString(strconv.FormatInt(req.BlockSize, 10)) + b.WriteByte(0) + if err := s.send(conn, b.Bytes(), 0); err != nil { + // TODO: some PXE roms request a transfer with the tsize + // option to try and find out the file length, and abort + // if the server doesn't echo tsize. + return fmt.Errorf("sending OACK: %s", err) + } + b.Reset() + } + + seq := uint16(1) + b.Grow(int(req.BlockSize + 4)) + b.WriteByte(0) + b.WriteByte(3) + for { + b.Truncate(2) + if err = binary.Write(&b, binary.BigEndian, seq); err != nil { + conn.Write(tftpError("internal server error")) + return fmt.Errorf("writing seqnum: %s", err) + } + n, err := io.CopyN(&b, file, int64(req.BlockSize)) + if err != nil && err != io.EOF { + conn.Write(tftpError("internal server error")) + return fmt.Errorf("reading bytes for block %d: %s", seq, err) + } + if err = s.send(conn, b.Bytes(), seq); err != nil { + conn.Write(tftpError("timeout")) + return fmt.Errorf("sending data packet %d: %s", seq, err) + } + seq++ + if n < req.BlockSize { + // Transfer complete + return nil + } + } +} + +func (s *Server) send(conn net.Conn, b []byte, seq uint16) error { + timeout := s.WriteTimeout + if timeout <= 0 { + timeout = DefaultWriteTimeout + } + attempts := s.WriteAttempts + if attempts <= 0 { + attempts = DefaultWriteAttempts + } + +Attempt: + for attempt := 0; attempt < attempts; attempt++ { + if _, err := conn.Write(b); err != nil { + return err + } + + conn.SetReadDeadline(time.Now().Add(timeout)) + + var recv [256]byte + for { + n, err := conn.Read(recv[:]) + if err != nil { + if t, ok := err.(net.Error); ok && t.Timeout() { + continue Attempt + } + return err + } + + if n < 4 { // packet too small + continue + } + switch binary.BigEndian.Uint16(recv[:2]) { + case 4: + if binary.BigEndian.Uint16(recv[2:4]) == seq { + return nil + } + case 5: + msg, _, _ := tftpStr(recv[4:]) + return fmt.Errorf("client aborted transfer: %s", msg) + } + } + } + + return errors.New("timeout waiting for ACK") +} + +type rrq struct { + Filename string + BlockSize int64 +} + +func parseRRQ(bs []byte) (*rrq, error) { + // Smallest a useful TFTP packet can be is 6 bytes: 2b opcode, 1b + // filename, 1b null, 1b mode, 1b null. + if len(bs) < 6 || binary.BigEndian.Uint16(bs[:2]) != 1 { + return nil, errors.New("not an RRQ packet") + } + + fname, bs, err := tftpStr(bs[2:]) + if err != nil { + return nil, fmt.Errorf("reading filename: %s", err) + } + + mode, bs, err := tftpStr(bs) + if err != nil { + return nil, fmt.Errorf("reading mode: %s", err) + } + if mode != "octet" { + // Only support octet mode, because in practice that's the + // only remaining sensible use of TFTP (i.e. PXE booting) + return nil, fmt.Errorf("unsupported transfer mode %q", mode) + } + + req := &rrq{ + Filename: fname, + } + + for len(bs) > 0 { + opt, rest, err := tftpStr(bs) + if err != nil { + return nil, fmt.Errorf("reading option name: %s", err) + } + bs = rest + val, rest, err := tftpStr(bs) + if err != nil { + return nil, fmt.Errorf("reading option %q value: %s", opt, err) + } + bs = rest + if opt != "blksize" { + continue + } + size, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return nil, fmt.Errorf("non-integer block size value %q", val) + } + if size < 8 || size > 65464 { + return nil, fmt.Errorf("unsupported block size %q", size) + } + req.BlockSize = size + } + + return req, nil +} + +// tftpError constructs an ERROR packet. +// +// The error is coerced to the sensible subset of "netascii", namely +// the printable ASCII characters plus newline. +func tftpError(msg string) []byte { + if len(msg) > maxErrorSize { + msg = msg[:maxErrorSize] + } + var ret bytes.Buffer + ret.Grow(len(msg) + 5) + ret.Write([]byte{0, 5, 0, 0}) // generic "see message" error packet + for _, b := range msg { + switch { + case b >= 0x20 && b <= 0x7E: + ret.WriteRune(b) + case b == '\r': + // Assume this is the start of a CRLF sequence and just + // swallow the CR. The LF will output CRLF, see + // below. Also, please stop using CRLF line termination in + // Go. + case b == '\n': + ret.WriteString("\r\n") + default: + ret.WriteByte('?') + } + } + ret.WriteByte(0) + return ret.Bytes() +} + +// tftpStr extracts a null-terminated string from the given bytes, and +// returns any remaining bytes. +// +// String content is checked to be a "read-useful" subset of +// "netascii", itself a subset of ASCII. Specifically, all byte values +// must fall in the range 0x20 to 0x7E inclusive. +func tftpStr(bs []byte) (str string, remaining []byte, err error) { + for i, b := range bs { + if b == 0 { + return string(bs[:i]), bs[i+1:], nil + } else if b < 0x20 || b > 0x7E { + return "", nil, fmt.Errorf("invalid netascii byte %q at offset %d", b, i) + } + } + return "", nil, errors.New("no null terminated string found") +}