diff --git a/tftp/handlers.go b/tftp/handlers.go index c99da40..a0da475 100644 --- a/tftp/handlers.go +++ b/tftp/handlers.go @@ -31,7 +31,7 @@ func FilesystemHandler(root string) (Handler, error) { return nil, err } root = filepath.ToSlash(root) - return func(path string, addr net.Addr) (io.ReadCloser, error) { + return func(path string, addr net.Addr) (io.ReadCloser, int64, error) { // Join with a root, which gets rid of directory traversal // attempts. Then we join that canonicalized path with the // actual root, which resolves to the actual on-disk file to @@ -41,18 +41,19 @@ func FilesystemHandler(root string) (Handler, error) { st, err := os.Stat(path) if err != nil { - return nil, err + return nil, 0, err } if !st.Mode().IsRegular() { - return nil, fmt.Errorf("requested path %q is not a file", path) + return nil, 0, fmt.Errorf("requested path %q is not a file", path) } - return os.Open(path) + f, err := os.Open(path) + return f, st.Size(), err }, nil } // ConstantHandler returns a Handler that serves bs for all requested paths. func ConstantHandler(bs []byte) Handler { - return func(path string, addr net.Addr) (io.ReadCloser, error) { - return ioutil.NopCloser(bytes.NewBuffer(bs)), nil + return func(path string, addr net.Addr) (io.ReadCloser, int64, error) { + return ioutil.NopCloser(bytes.NewBuffer(bs)), int64(len(bs)), nil } } diff --git a/tftp/tftp.go b/tftp/tftp.go index 7088e44..86730ef 100644 --- a/tftp/tftp.go +++ b/tftp/tftp.go @@ -48,7 +48,16 @@ const ( ) // A Handler provides bytes for a file. -type Handler func(path string, clientAddr net.Addr) (io.ReadCloser, error) +// +// If size is non-zero, it must be equal to the number of bytes in +// file. The server will offer the "tsize" extension to clients that +// request it. +// +// Note that some clients (particularly firmware TFTP clients) can be +// very capricious about servers not supporting all the options that +// they request, so passing a size of 0 may cause TFTP transfers to +// fail for some clients. +type Handler func(path string, clientAddr net.Addr) (file io.ReadCloser, size int64, err error) // A Server defines parameters for running a TFTP server. type Server struct { @@ -156,7 +165,7 @@ func (s *Server) transfer(addr net.Addr, req *rrq) error { } defer conn.Close() - file, err := s.Handler(req.Filename, addr) + file, size, err := s.Handler(req.Filename, addr) if err != nil { conn.Write(tftpError("failed to get file")) return fmt.Errorf("getting file bytes: %s", err) @@ -164,35 +173,44 @@ func (s *Server) transfer(addr net.Addr, req *rrq) error { defer file.Close() var b bytes.Buffer - if req.BlockSize == 0 { - // Client didn't negotiate, use classic blocksize from RFC. - req.BlockSize = 512 - } else { - // Client requested a specific blocksize, need to OACK before - // sending data. - maxBlockSize := s.MaxBlockSize - if maxBlockSize <= 0 { - maxBlockSize = DefaultBlockSize - } - if req.BlockSize > maxBlockSize { - s.infoLog("clamping blocksize to %q: %d -> %d", addr, req.BlockSize, maxBlockSize) - req.BlockSize = maxBlockSize - } - + if req.BlockSize != 0 || (req.WantSize && size != 0) { + // Client requested options, need to OACK them before sending + // data. b.WriteByte(0) b.WriteByte(6) - b.WriteString("blksize") - b.WriteByte(0) - b.WriteString(strconv.FormatInt(req.BlockSize, 10)) - b.WriteByte(0) + + if req.BlockSize != 0 { + maxBlockSize := s.MaxBlockSize + if maxBlockSize <= 0 { + maxBlockSize = DefaultBlockSize + } + if req.BlockSize > maxBlockSize { + s.infoLog("clamping blocksize to %q: %d -> %d", addr, req.BlockSize, maxBlockSize) + req.BlockSize = maxBlockSize + } + + b.WriteString("blksize") + b.WriteByte(0) + b.WriteString(strconv.FormatInt(req.BlockSize, 10)) + b.WriteByte(0) + } + + if req.WantSize && size != 0 { + b.WriteString("tsize") + b.WriteByte(0) + b.WriteString(strconv.FormatInt(size, 10)) + b.WriteByte(0) + } + if err := s.send(conn, b.Bytes(), 0); err != nil { - // TODO: some PXE roms request a transfer with the tsize - // option to try and find out the file length, and abort - // if the server doesn't echo tsize. return fmt.Errorf("sending OACK: %s", err) } b.Reset() } + if req.BlockSize == 0 { + // Client didn't negotiate, use classic blocksize from RFC. + req.BlockSize = 512 + } seq := uint16(1) b.Grow(int(req.BlockSize + 4)) @@ -270,6 +288,7 @@ Attempt: type rrq struct { Filename string BlockSize int64 + WantSize bool } func parseRRQ(bs []byte) (*rrq, error) { @@ -310,6 +329,9 @@ func parseRRQ(bs []byte) (*rrq, error) { } bs = rest if opt != "blksize" { + if opt == "tsize" { + req.WantSize = true + } continue } size, err := strconv.ParseInt(val, 10, 64)