Allow use of fs.FS for $INCLUDE and wrap errors (#1526)

* Allow use of fs.FS for $INCLUDE and wrap errors

This adds ZoneParser.SetIncludeAllowedFS, to specify an fs.FS when
enabling support for $INCLUDE, for reading included files from
somewhere other than the local filesystem.

I've also modified ParseError to support wrapping another error, such
as errors encountered while opening the $INCLUDE target.  This allows
for much more robust handling, using errors.Is() instead of testing
for particular strings (which may not be identical between fs.FS
implementations).

ParseError was being constructed in a lot of places using positional
instead of named members.  Updating ParseError initialization after
the new member field was added makes this change seem a lot larger
than it actually is.

The changes here should be completely backwards compatible.  The
ParseError change should be invisible to anyone not trying to unwrap
it, and ZoneParser will continue to use os.Open if the existing
SetIncludeAllowed method is called instead of the new
SetIncludeAllowedFS method.

* Don't duplicate SetIncludeAllowed; clarify edge cases

Rather than duplicate functionality between SetIncludeAllowed and
SetIncludeAllowedFS, have a method SetIncludeFS, which only sets the
fs.FS.

I've improved the documentation to point out some considerations for
users hoping to use fs.FS as a security boundary.

Per the fs.ValidPath documentation, fs.FS implementations must use
path (not filepath) semantics, with slash as a separator (even on
Windows).  Some, like os.DirFS, also require all paths to be relative.
I've clarified this in the documentation, made the includePath
manipulation more robust to edge cases, and added some additional
tests for relative and absolute paths.
This commit is contained in:
Dave Pifke 2024-01-15 07:40:43 -07:00 committed by GitHub
parent f206faa01f
commit 50fbccd204
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 344 additions and 219 deletions

View File

@ -160,7 +160,7 @@ func parseKey(r io.Reader, file string) (map[string]string, error) {
k = l.token k = l.token
case zValue: case zValue:
if k == "" { if k == "" {
return nil, &ParseError{file, "no private key seen", l} return nil, &ParseError{file: file, err: "no private key seen", lex: l}
} }
m[strings.ToLower(k)] = l.token m[strings.ToLower(k)] = l.token

View File

@ -116,7 +116,7 @@ func (r *generateReader) parseError(msg string, end int) *ParseError {
l.token = r.s[r.si-1 : end] l.token = r.s[r.si-1 : end]
l.column += r.si // l.column starts one zBLANK before r.s l.column += r.si // l.column starts one zBLANK before r.s
return &ParseError{r.file, msg, l} return &ParseError{file: r.file, err: msg, lex: l}
} }
func (r *generateReader) Read(p []byte) (int, error) { func (r *generateReader) Read(p []byte) (int, error) {

View File

@ -84,7 +84,7 @@ Fetch:
err := r.Data.Parse(text) err := r.Data.Parse(text)
if err != nil { if err != nil {
return &ParseError{"", err.Error(), l} return &ParseError{wrappedErr: err, lex: l}
} }
return nil return nil

103
scan.go
View File

@ -4,7 +4,9 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"io/fs"
"os" "os"
"path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
@ -64,20 +66,26 @@ const (
// ParseError is a parsing error. It contains the parse error and the location in the io.Reader // ParseError is a parsing error. It contains the parse error and the location in the io.Reader
// where the error occurred. // where the error occurred.
type ParseError struct { type ParseError struct {
file string file string
err string err string
lex lex wrappedErr error
lex lex
} }
func (e *ParseError) Error() (s string) { func (e *ParseError) Error() (s string) {
if e.file != "" { if e.file != "" {
s = e.file + ": " s = e.file + ": "
} }
if e.err == "" && e.wrappedErr != nil {
e.err = e.wrappedErr.Error()
}
s += "dns: " + e.err + ": " + strconv.QuoteToASCII(e.lex.token) + " at line: " + s += "dns: " + e.err + ": " + strconv.QuoteToASCII(e.lex.token) + " at line: " +
strconv.Itoa(e.lex.line) + ":" + strconv.Itoa(e.lex.column) strconv.Itoa(e.lex.line) + ":" + strconv.Itoa(e.lex.column)
return return
} }
func (e *ParseError) Unwrap() error { return e.wrappedErr }
type lex struct { type lex struct {
token string // text of the token token string // text of the token
err bool // when true, token text has lexer error err bool // when true, token text has lexer error
@ -168,8 +176,9 @@ type ZoneParser struct {
// sub is used to parse $INCLUDE files and $GENERATE directives. // sub is used to parse $INCLUDE files and $GENERATE directives.
// Next, by calling subNext, forwards the resulting RRs from this // Next, by calling subNext, forwards the resulting RRs from this
// sub parser to the calling code. // sub parser to the calling code.
sub *ZoneParser sub *ZoneParser
osFile *os.File r io.Reader
fsys fs.FS
includeDepth uint8 includeDepth uint8
@ -188,7 +197,7 @@ func NewZoneParser(r io.Reader, origin, file string) *ZoneParser {
if origin != "" { if origin != "" {
origin = Fqdn(origin) origin = Fqdn(origin)
if _, ok := IsDomainName(origin); !ok { if _, ok := IsDomainName(origin); !ok {
pe = &ParseError{file, "bad initial origin name", lex{}} pe = &ParseError{file: file, err: "bad initial origin name"}
} }
} }
@ -220,6 +229,24 @@ func (zp *ZoneParser) SetIncludeAllowed(v bool) {
zp.includeAllowed = v zp.includeAllowed = v
} }
// SetIncludeFS provides an [fs.FS] to use when looking for the target of
// $INCLUDE directives. ($INCLUDE must still be enabled separately by calling
// [ZoneParser.SetIncludeAllowed].) If fsys is nil, [os.Open] will be used.
//
// When fsys is an on-disk FS, the ability of $INCLUDE to reach files from
// outside its root directory depends upon the FS implementation. For
// instance, [os.DirFS] will refuse to open paths like "../../etc/passwd",
// however it will still follow links which may point anywhere on the system.
//
// FS paths are slash-separated on all systems, even Windows. $INCLUDE paths
// containing other characters such as backslash and colon may be accepted as
// valid, but those characters will never be interpreted by an FS
// implementation as path element separators. See [fs.ValidPath] for more
// details.
func (zp *ZoneParser) SetIncludeFS(fsys fs.FS) {
zp.fsys = fsys
}
// Err returns the first non-EOF error that was encountered by the // Err returns the first non-EOF error that was encountered by the
// ZoneParser. // ZoneParser.
func (zp *ZoneParser) Err() error { func (zp *ZoneParser) Err() error {
@ -237,7 +264,7 @@ func (zp *ZoneParser) Err() error {
} }
func (zp *ZoneParser) setParseError(err string, l lex) (RR, bool) { func (zp *ZoneParser) setParseError(err string, l lex) (RR, bool) {
zp.parseErr = &ParseError{zp.file, err, l} zp.parseErr = &ParseError{file: zp.file, err: err, lex: l}
return nil, false return nil, false
} }
@ -260,9 +287,11 @@ func (zp *ZoneParser) subNext() (RR, bool) {
return rr, true return rr, true
} }
if zp.sub.osFile != nil { if zp.sub.r != nil {
zp.sub.osFile.Close() if c, ok := zp.sub.r.(io.Closer); ok {
zp.sub.osFile = nil c.Close()
}
zp.sub.r = nil
} }
if zp.sub.Err() != nil { if zp.sub.Err() != nil {
@ -402,24 +431,44 @@ func (zp *ZoneParser) Next() (RR, bool) {
// Start with the new file // Start with the new file
includePath := l.token includePath := l.token
if !filepath.IsAbs(includePath) { var r1 io.Reader
includePath = filepath.Join(filepath.Dir(zp.file), includePath) var e1 error
} if zp.fsys != nil {
// fs.FS always uses / as separator, even on Windows, so use
r1, e1 := os.Open(includePath) // path instead of filepath here:
if e1 != nil { if !path.IsAbs(includePath) {
var as string includePath = path.Join(path.Dir(zp.file), includePath)
if !filepath.IsAbs(l.token) {
as = fmt.Sprintf(" as `%s'", includePath)
} }
msg := fmt.Sprintf("failed to open `%s'%s: %v", l.token, as, e1) // os.DirFS, and probably others, expect all paths to be
return zp.setParseError(msg, l) // relative, so clean the path and remove leading / if
// present:
includePath = strings.TrimLeft(path.Clean(includePath), "/")
r1, e1 = zp.fsys.Open(includePath)
} else {
if !filepath.IsAbs(includePath) {
includePath = filepath.Join(filepath.Dir(zp.file), includePath)
}
r1, e1 = os.Open(includePath)
}
if e1 != nil {
var as string
if includePath != l.token {
as = fmt.Sprintf(" as `%s'", includePath)
}
zp.parseErr = &ParseError{
file: zp.file,
wrappedErr: fmt.Errorf("failed to open `%s'%s: %w", l.token, as, e1),
lex: l,
}
return nil, false
} }
zp.sub = NewZoneParser(r1, neworigin, includePath) zp.sub = NewZoneParser(r1, neworigin, includePath)
zp.sub.defttl, zp.sub.includeDepth, zp.sub.osFile = zp.defttl, zp.includeDepth+1, r1 zp.sub.defttl, zp.sub.includeDepth, zp.sub.r = zp.defttl, zp.includeDepth+1, r1
zp.sub.SetIncludeAllowed(true) zp.sub.SetIncludeAllowed(true)
zp.sub.SetIncludeFS(zp.fsys)
return zp.subNext() return zp.subNext()
case zExpectDirTTLBl: case zExpectDirTTLBl:
if l.value != zBlank { if l.value != zBlank {
@ -1326,12 +1375,12 @@ func slurpRemainder(c *zlexer) *ParseError {
case zBlank: case zBlank:
l, _ = c.Next() l, _ = c.Next()
if l.value != zNewline && l.value != zEOF { if l.value != zNewline && l.value != zEOF {
return &ParseError{"", "garbage after rdata", l} return &ParseError{err: "garbage after rdata", lex: l}
} }
case zNewline: case zNewline:
case zEOF: case zEOF:
default: default:
return &ParseError{"", "garbage after rdata", l} return &ParseError{err: "garbage after rdata", lex: l}
} }
return nil return nil
} }
@ -1340,16 +1389,16 @@ func slurpRemainder(c *zlexer) *ParseError {
// Used for NID and L64 record. // Used for NID and L64 record.
func stringToNodeID(l lex) (uint64, *ParseError) { func stringToNodeID(l lex) (uint64, *ParseError) {
if len(l.token) < 19 { if len(l.token) < 19 {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l}
} }
// There must be three colons at fixes positions, if not its a parse error // There must be three colons at fixes positions, if not its a parse error
if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' { if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l}
} }
s := l.token[0:4] + l.token[5:9] + l.token[10:14] + l.token[15:19] s := l.token[0:4] + l.token[5:9] + l.token[10:14] + l.token[15:19]
u, err := strconv.ParseUint(s, 16, 64) u, err := strconv.ParseUint(s, 16, 64)
if err != nil { if err != nil {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l}
} }
return u, nil return u, nil
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,14 @@
package dns package dns
import ( import (
"errors"
"io" "io"
"io/fs"
"net" "net"
"os" "os"
"strings" "strings"
"testing" "testing"
"testing/fstest"
) )
func TestZoneParserGenerate(t *testing.T) { func TestZoneParserGenerate(t *testing.T) {
@ -96,6 +99,78 @@ func TestZoneParserInclude(t *testing.T) {
} }
} }
func TestZoneParserIncludeFS(t *testing.T) {
fsys := fstest.MapFS{
"db.foo": &fstest.MapFile{
Data: []byte("foo\tIN\tA\t127.0.0.1"),
},
}
zone := "$ORIGIN example.org.\n$INCLUDE db.foo\nbar\tIN\tA\t127.0.0.2"
var got int
z := NewZoneParser(strings.NewReader(zone), "", "")
z.SetIncludeAllowed(true)
z.SetIncludeFS(fsys)
for rr, ok := z.Next(); ok; _, ok = z.Next() {
switch rr.Header().Name {
case "foo.example.org.", "bar.example.org.":
default:
t.Fatalf("expected foo.example.org. or bar.example.org., but got %s", rr.Header().Name)
}
got++
}
if err := z.Err(); err != nil {
t.Fatalf("expected no error, but got %s", err)
}
if expected := 2; got != expected {
t.Errorf("failed to parse zone after include, expected %d records, got %d", expected, got)
}
fsys = fstest.MapFS{}
z = NewZoneParser(strings.NewReader(zone), "", "")
z.SetIncludeAllowed(true)
z.SetIncludeFS(fsys)
z.Next()
if err := z.Err(); !errors.Is(err, fs.ErrNotExist) {
t.Fatalf(`expected fs.ErrNotExist but got: %T %v`, err, err)
}
}
func TestZoneParserIncludeFSPaths(t *testing.T) {
fsys := fstest.MapFS{
"baz/bat/db.foo": &fstest.MapFile{
Data: []byte("foo\tIN\tA\t127.0.0.1"),
},
}
for _, p := range []string{
"../bat/db.foo",
"/baz/bat/db.foo",
} {
zone := "$ORIGIN example.org.\n$INCLUDE " + p + "\nbar\tIN\tA\t127.0.0.2"
var got int
z := NewZoneParser(strings.NewReader(zone), "", "baz/quux/db.bar")
z.SetIncludeAllowed(true)
z.SetIncludeFS(fsys)
for rr, ok := z.Next(); ok; _, ok = z.Next() {
switch rr.Header().Name {
case "foo.example.org.", "bar.example.org.":
default:
t.Fatalf("$INCLUDE %q: expected foo.example.org. or bar.example.org., but got %s", p, rr.Header().Name)
}
got++
}
if err := z.Err(); err != nil {
t.Fatalf("$INCLUDE %q: expected no error, but got %s", p, err)
}
if expected := 2; got != expected {
t.Errorf("$INCLUDE %q: failed to parse zone after include, expected %d records, got %d", p, expected, got)
}
}
}
func TestZoneParserIncludeDisallowed(t *testing.T) { func TestZoneParserIncludeDisallowed(t *testing.T) {
tmpfile, err := os.CreateTemp("", "dns") tmpfile, err := os.CreateTemp("", "dns")
if err != nil { if err != nil {

20
svcb.go
View File

@ -85,7 +85,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next() l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16) i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err { if e != nil || l.err {
return &ParseError{l.token, "bad SVCB priority", l} return &ParseError{file: l.token, err: "bad SVCB priority", lex: l}
} }
rr.Priority = uint16(i) rr.Priority = uint16(i)
@ -95,7 +95,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
name, nameOk := toAbsoluteName(l.token, o) name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk { if l.err || !nameOk {
return &ParseError{l.token, "bad SVCB Target", l} return &ParseError{file: l.token, err: "bad SVCB Target", lex: l}
} }
rr.Target = name rr.Target = name
@ -111,7 +111,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
if !canHaveNextKey { if !canHaveNextKey {
// The key we can now read was probably meant to be // The key we can now read was probably meant to be
// a part of the last value. // a part of the last value.
return &ParseError{l.token, "bad SVCB value quotation", l} return &ParseError{file: l.token, err: "bad SVCB value quotation", lex: l}
} }
// In key=value pairs, value does not have to be quoted unless value // In key=value pairs, value does not have to be quoted unless value
@ -124,7 +124,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
// Key with no value and no equality sign // Key with no value and no equality sign
key = l.token key = l.token
} else if idx == 0 { } else if idx == 0 {
return &ParseError{l.token, "bad SVCB key", l} return &ParseError{file: l.token, err: "bad SVCB key", lex: l}
} else { } else {
key, value = l.token[:idx], l.token[idx+1:] key, value = l.token[:idx], l.token[idx+1:]
@ -144,30 +144,30 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
value = l.token value = l.token
l, _ = c.Next() l, _ = c.Next()
if l.value != zQuote { if l.value != zQuote {
return &ParseError{l.token, "SVCB unterminated value", l} return &ParseError{file: l.token, err: "SVCB unterminated value", lex: l}
} }
case zQuote: case zQuote:
// There's nothing in double quotes. // There's nothing in double quotes.
default: default:
return &ParseError{l.token, "bad SVCB value", l} return &ParseError{file: l.token, err: "bad SVCB value", lex: l}
} }
} }
} }
} }
kv := makeSVCBKeyValue(svcbStringToKey(key)) kv := makeSVCBKeyValue(svcbStringToKey(key))
if kv == nil { if kv == nil {
return &ParseError{l.token, "bad SVCB key", l} return &ParseError{file: l.token, err: "bad SVCB key", lex: l}
} }
if err := kv.parse(value); err != nil { if err := kv.parse(value); err != nil {
return &ParseError{l.token, err.Error(), l} return &ParseError{file: l.token, wrappedErr: err, lex: l}
} }
xs = append(xs, kv) xs = append(xs, kv)
case zQuote: case zQuote:
return &ParseError{l.token, "SVCB key can't contain double quotes", l} return &ParseError{file: l.token, err: "SVCB key can't contain double quotes", lex: l}
case zBlank: case zBlank:
canHaveNextKey = true canHaveNextKey = true
default: default:
return &ParseError{l.token, "bad SVCB values", l} return &ParseError{file: l.token, err: "bad SVCB values", lex: l}
} }
l, _ = c.Next() l, _ = c.Next()
} }