Allow use of fs.FS for $INCLUDE and wrap errors (#1526)

* Allow use of fs.FS for $INCLUDE and wrap errors

This adds ZoneParser.SetIncludeAllowedFS, to specify an fs.FS when
enabling support for $INCLUDE, for reading included files from
somewhere other than the local filesystem.

I've also modified ParseError to support wrapping another error, such
as errors encountered while opening the $INCLUDE target.  This allows
for much more robust handling, using errors.Is() instead of testing
for particular strings (which may not be identical between fs.FS
implementations).

ParseError was being constructed in a lot of places using positional
instead of named members.  Updating ParseError initialization after
the new member field was added makes this change seem a lot larger
than it actually is.

The changes here should be completely backwards compatible.  The
ParseError change should be invisible to anyone not trying to unwrap
it, and ZoneParser will continue to use os.Open if the existing
SetIncludeAllowed method is called instead of the new
SetIncludeAllowedFS method.

* Don't duplicate SetIncludeAllowed; clarify edge cases

Rather than duplicate functionality between SetIncludeAllowed and
SetIncludeAllowedFS, have a method SetIncludeFS, which only sets the
fs.FS.

I've improved the documentation to point out some considerations for
users hoping to use fs.FS as a security boundary.

Per the fs.ValidPath documentation, fs.FS implementations must use
path (not filepath) semantics, with slash as a separator (even on
Windows).  Some, like os.DirFS, also require all paths to be relative.
I've clarified this in the documentation, made the includePath
manipulation more robust to edge cases, and added some additional
tests for relative and absolute paths.
This commit is contained in:
Dave Pifke 2024-01-15 07:40:43 -07:00 committed by GitHub
parent f206faa01f
commit 50fbccd204
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 344 additions and 219 deletions

View File

@ -160,7 +160,7 @@ func parseKey(r io.Reader, file string) (map[string]string, error) {
k = l.token
case zValue:
if k == "" {
return nil, &ParseError{file, "no private key seen", l}
return nil, &ParseError{file: file, err: "no private key seen", lex: l}
}
m[strings.ToLower(k)] = l.token

View File

@ -116,7 +116,7 @@ func (r *generateReader) parseError(msg string, end int) *ParseError {
l.token = r.s[r.si-1 : end]
l.column += r.si // l.column starts one zBLANK before r.s
return &ParseError{r.file, msg, l}
return &ParseError{file: r.file, err: msg, lex: l}
}
func (r *generateReader) Read(p []byte) (int, error) {

View File

@ -84,7 +84,7 @@ Fetch:
err := r.Data.Parse(text)
if err != nil {
return &ParseError{"", err.Error(), l}
return &ParseError{wrappedErr: err, lex: l}
}
return nil

103
scan.go
View File

@ -4,7 +4,9 @@ import (
"bufio"
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strconv"
"strings"
@ -64,20 +66,26 @@ const (
// ParseError is a parsing error. It contains the parse error and the location in the io.Reader
// where the error occurred.
type ParseError struct {
file string
err string
lex lex
file string
err string
wrappedErr error
lex lex
}
func (e *ParseError) Error() (s string) {
if e.file != "" {
s = e.file + ": "
}
if e.err == "" && e.wrappedErr != nil {
e.err = e.wrappedErr.Error()
}
s += "dns: " + e.err + ": " + strconv.QuoteToASCII(e.lex.token) + " at line: " +
strconv.Itoa(e.lex.line) + ":" + strconv.Itoa(e.lex.column)
return
}
func (e *ParseError) Unwrap() error { return e.wrappedErr }
type lex struct {
token string // text of the token
err bool // when true, token text has lexer error
@ -168,8 +176,9 @@ type ZoneParser struct {
// sub is used to parse $INCLUDE files and $GENERATE directives.
// Next, by calling subNext, forwards the resulting RRs from this
// sub parser to the calling code.
sub *ZoneParser
osFile *os.File
sub *ZoneParser
r io.Reader
fsys fs.FS
includeDepth uint8
@ -188,7 +197,7 @@ func NewZoneParser(r io.Reader, origin, file string) *ZoneParser {
if origin != "" {
origin = Fqdn(origin)
if _, ok := IsDomainName(origin); !ok {
pe = &ParseError{file, "bad initial origin name", lex{}}
pe = &ParseError{file: file, err: "bad initial origin name"}
}
}
@ -220,6 +229,24 @@ func (zp *ZoneParser) SetIncludeAllowed(v bool) {
zp.includeAllowed = v
}
// SetIncludeFS provides an [fs.FS] to use when looking for the target of
// $INCLUDE directives. ($INCLUDE must still be enabled separately by calling
// [ZoneParser.SetIncludeAllowed].) If fsys is nil, [os.Open] will be used.
//
// When fsys is an on-disk FS, the ability of $INCLUDE to reach files from
// outside its root directory depends upon the FS implementation. For
// instance, [os.DirFS] will refuse to open paths like "../../etc/passwd",
// however it will still follow links which may point anywhere on the system.
//
// FS paths are slash-separated on all systems, even Windows. $INCLUDE paths
// containing other characters such as backslash and colon may be accepted as
// valid, but those characters will never be interpreted by an FS
// implementation as path element separators. See [fs.ValidPath] for more
// details.
func (zp *ZoneParser) SetIncludeFS(fsys fs.FS) {
zp.fsys = fsys
}
// Err returns the first non-EOF error that was encountered by the
// ZoneParser.
func (zp *ZoneParser) Err() error {
@ -237,7 +264,7 @@ func (zp *ZoneParser) Err() error {
}
func (zp *ZoneParser) setParseError(err string, l lex) (RR, bool) {
zp.parseErr = &ParseError{zp.file, err, l}
zp.parseErr = &ParseError{file: zp.file, err: err, lex: l}
return nil, false
}
@ -260,9 +287,11 @@ func (zp *ZoneParser) subNext() (RR, bool) {
return rr, true
}
if zp.sub.osFile != nil {
zp.sub.osFile.Close()
zp.sub.osFile = nil
if zp.sub.r != nil {
if c, ok := zp.sub.r.(io.Closer); ok {
c.Close()
}
zp.sub.r = nil
}
if zp.sub.Err() != nil {
@ -402,24 +431,44 @@ func (zp *ZoneParser) Next() (RR, bool) {
// Start with the new file
includePath := l.token
if !filepath.IsAbs(includePath) {
includePath = filepath.Join(filepath.Dir(zp.file), includePath)
}
r1, e1 := os.Open(includePath)
if e1 != nil {
var as string
if !filepath.IsAbs(l.token) {
as = fmt.Sprintf(" as `%s'", includePath)
var r1 io.Reader
var e1 error
if zp.fsys != nil {
// fs.FS always uses / as separator, even on Windows, so use
// path instead of filepath here:
if !path.IsAbs(includePath) {
includePath = path.Join(path.Dir(zp.file), includePath)
}
msg := fmt.Sprintf("failed to open `%s'%s: %v", l.token, as, e1)
return zp.setParseError(msg, l)
// os.DirFS, and probably others, expect all paths to be
// relative, so clean the path and remove leading / if
// present:
includePath = strings.TrimLeft(path.Clean(includePath), "/")
r1, e1 = zp.fsys.Open(includePath)
} else {
if !filepath.IsAbs(includePath) {
includePath = filepath.Join(filepath.Dir(zp.file), includePath)
}
r1, e1 = os.Open(includePath)
}
if e1 != nil {
var as string
if includePath != l.token {
as = fmt.Sprintf(" as `%s'", includePath)
}
zp.parseErr = &ParseError{
file: zp.file,
wrappedErr: fmt.Errorf("failed to open `%s'%s: %w", l.token, as, e1),
lex: l,
}
return nil, false
}
zp.sub = NewZoneParser(r1, neworigin, includePath)
zp.sub.defttl, zp.sub.includeDepth, zp.sub.osFile = zp.defttl, zp.includeDepth+1, r1
zp.sub.defttl, zp.sub.includeDepth, zp.sub.r = zp.defttl, zp.includeDepth+1, r1
zp.sub.SetIncludeAllowed(true)
zp.sub.SetIncludeFS(zp.fsys)
return zp.subNext()
case zExpectDirTTLBl:
if l.value != zBlank {
@ -1326,12 +1375,12 @@ func slurpRemainder(c *zlexer) *ParseError {
case zBlank:
l, _ = c.Next()
if l.value != zNewline && l.value != zEOF {
return &ParseError{"", "garbage after rdata", l}
return &ParseError{err: "garbage after rdata", lex: l}
}
case zNewline:
case zEOF:
default:
return &ParseError{"", "garbage after rdata", l}
return &ParseError{err: "garbage after rdata", lex: l}
}
return nil
}
@ -1340,16 +1389,16 @@ func slurpRemainder(c *zlexer) *ParseError {
// Used for NID and L64 record.
func stringToNodeID(l lex) (uint64, *ParseError) {
if len(l.token) < 19 {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l}
}
// There must be three colons at fixes positions, if not its a parse error
if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l}
}
s := l.token[0:4] + l.token[5:9] + l.token[10:14] + l.token[15:19]
u, err := strconv.ParseUint(s, 16, 64)
if err != nil {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l}
}
return u, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,14 @@
package dns
import (
"errors"
"io"
"io/fs"
"net"
"os"
"strings"
"testing"
"testing/fstest"
)
func TestZoneParserGenerate(t *testing.T) {
@ -96,6 +99,78 @@ func TestZoneParserInclude(t *testing.T) {
}
}
func TestZoneParserIncludeFS(t *testing.T) {
fsys := fstest.MapFS{
"db.foo": &fstest.MapFile{
Data: []byte("foo\tIN\tA\t127.0.0.1"),
},
}
zone := "$ORIGIN example.org.\n$INCLUDE db.foo\nbar\tIN\tA\t127.0.0.2"
var got int
z := NewZoneParser(strings.NewReader(zone), "", "")
z.SetIncludeAllowed(true)
z.SetIncludeFS(fsys)
for rr, ok := z.Next(); ok; _, ok = z.Next() {
switch rr.Header().Name {
case "foo.example.org.", "bar.example.org.":
default:
t.Fatalf("expected foo.example.org. or bar.example.org., but got %s", rr.Header().Name)
}
got++
}
if err := z.Err(); err != nil {
t.Fatalf("expected no error, but got %s", err)
}
if expected := 2; got != expected {
t.Errorf("failed to parse zone after include, expected %d records, got %d", expected, got)
}
fsys = fstest.MapFS{}
z = NewZoneParser(strings.NewReader(zone), "", "")
z.SetIncludeAllowed(true)
z.SetIncludeFS(fsys)
z.Next()
if err := z.Err(); !errors.Is(err, fs.ErrNotExist) {
t.Fatalf(`expected fs.ErrNotExist but got: %T %v`, err, err)
}
}
func TestZoneParserIncludeFSPaths(t *testing.T) {
fsys := fstest.MapFS{
"baz/bat/db.foo": &fstest.MapFile{
Data: []byte("foo\tIN\tA\t127.0.0.1"),
},
}
for _, p := range []string{
"../bat/db.foo",
"/baz/bat/db.foo",
} {
zone := "$ORIGIN example.org.\n$INCLUDE " + p + "\nbar\tIN\tA\t127.0.0.2"
var got int
z := NewZoneParser(strings.NewReader(zone), "", "baz/quux/db.bar")
z.SetIncludeAllowed(true)
z.SetIncludeFS(fsys)
for rr, ok := z.Next(); ok; _, ok = z.Next() {
switch rr.Header().Name {
case "foo.example.org.", "bar.example.org.":
default:
t.Fatalf("$INCLUDE %q: expected foo.example.org. or bar.example.org., but got %s", p, rr.Header().Name)
}
got++
}
if err := z.Err(); err != nil {
t.Fatalf("$INCLUDE %q: expected no error, but got %s", p, err)
}
if expected := 2; got != expected {
t.Errorf("$INCLUDE %q: failed to parse zone after include, expected %d records, got %d", p, expected, got)
}
}
}
func TestZoneParserIncludeDisallowed(t *testing.T) {
tmpfile, err := os.CreateTemp("", "dns")
if err != nil {

20
svcb.go
View File

@ -85,7 +85,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return &ParseError{l.token, "bad SVCB priority", l}
return &ParseError{file: l.token, err: "bad SVCB priority", lex: l}
}
rr.Priority = uint16(i)
@ -95,7 +95,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return &ParseError{l.token, "bad SVCB Target", l}
return &ParseError{file: l.token, err: "bad SVCB Target", lex: l}
}
rr.Target = name
@ -111,7 +111,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
if !canHaveNextKey {
// The key we can now read was probably meant to be
// a part of the last value.
return &ParseError{l.token, "bad SVCB value quotation", l}
return &ParseError{file: l.token, err: "bad SVCB value quotation", lex: l}
}
// In key=value pairs, value does not have to be quoted unless value
@ -124,7 +124,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
// Key with no value and no equality sign
key = l.token
} else if idx == 0 {
return &ParseError{l.token, "bad SVCB key", l}
return &ParseError{file: l.token, err: "bad SVCB key", lex: l}
} else {
key, value = l.token[:idx], l.token[idx+1:]
@ -144,30 +144,30 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
value = l.token
l, _ = c.Next()
if l.value != zQuote {
return &ParseError{l.token, "SVCB unterminated value", l}
return &ParseError{file: l.token, err: "SVCB unterminated value", lex: l}
}
case zQuote:
// There's nothing in double quotes.
default:
return &ParseError{l.token, "bad SVCB value", l}
return &ParseError{file: l.token, err: "bad SVCB value", lex: l}
}
}
}
}
kv := makeSVCBKeyValue(svcbStringToKey(key))
if kv == nil {
return &ParseError{l.token, "bad SVCB key", l}
return &ParseError{file: l.token, err: "bad SVCB key", lex: l}
}
if err := kv.parse(value); err != nil {
return &ParseError{l.token, err.Error(), l}
return &ParseError{file: l.token, wrappedErr: err, lex: l}
}
xs = append(xs, kv)
case zQuote:
return &ParseError{l.token, "SVCB key can't contain double quotes", l}
return &ParseError{file: l.token, err: "SVCB key can't contain double quotes", lex: l}
case zBlank:
canHaveNextKey = true
default:
return &ParseError{l.token, "bad SVCB values", l}
return &ParseError{file: l.token, err: "bad SVCB values", lex: l}
}
l, _ = c.Next()
}