diff --git a/net/dns/direct_linux.go b/net/dns/direct_linux.go index bdeefb352..20d96e2f1 100644 --- a/net/dns/direct_linux.go +++ b/net/dns/direct_linux.go @@ -6,20 +6,28 @@ import ( "bytes" "context" + "fmt" "github.com/illarion/gonotify/v2" "tailscale.com/health" ) func (m *directManager) runFileWatcher() { - ctx, cancel := context.WithCancel(m.ctx) + if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil { + // This is all best effort for now, so surface warnings to users. + m.logf("dns: inotify: %s", err) + } +} + +// watchFile sets up an inotify watch for a given directory and +// calls the callback function every time a particular file is changed. +// The filename should be located in the provided directory. +func watchFile(ctx context.Context, dir, filename string, cb func()) error { + ctx, cancel := context.WithCancel(ctx) defer cancel() in, err := gonotify.NewInotify(ctx) if err != nil { - // Oh well, we tried. This is all best effort for now, to - // surface warnings to users. - m.logf("dns: inotify new: %v", err) - return + return fmt.Errorf("NewInotify: %w", err) } const events = gonotify.IN_ATTRIB | @@ -29,22 +37,20 @@ func (m *directManager) runFileWatcher() { gonotify.IN_MODIFY | gonotify.IN_MOVE - if err := in.AddWatch("/etc/", events); err != nil { - m.logf("dns: inotify addwatch: %v", err) - return + if err := in.AddWatch(dir, events); err != nil { + return fmt.Errorf("AddWatch: %w", err) } for { events, err := in.Read() if ctx.Err() != nil { - return + return ctx.Err() } if err != nil { - m.logf("dns: inotify read: %v", err) - return + return fmt.Errorf("Read: %w", err) } var match bool for _, ev := range events { - if ev.Name == resolvConf { + if ev.Name == filename { match = true break } @@ -52,7 +58,7 @@ func (m *directManager) runFileWatcher() { if !match { continue } - m.checkForFileTrample() + cb() } } diff --git a/net/dns/direct_linux_test.go b/net/dns/direct_linux_test.go new file mode 100644 index 000000000..079d060ed --- /dev/null +++ b/net/dns/direct_linux_test.go @@ -0,0 +1,56 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "context" + "errors" + "fmt" + "os" + "sync/atomic" + "testing" + "time" + + "golang.org/x/sync/errgroup" +) + +func TestWatchFile(t *testing.T) { + dir := t.TempDir() + filepath := dir + "/test.txt" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var callbackCalled atomic.Bool + callbackDone := make(chan bool) + callback := func() { + callbackDone <- true + callbackCalled.Store(true) + } + + var eg errgroup.Group + eg.Go(func() error { return watchFile(ctx, dir, filepath, callback) }) + + // Keep writing until we get a callback. + func() { + for i := range 10000 { + if err := os.WriteFile(filepath, []byte(fmt.Sprintf("write%d", i)), 0644); err != nil { + t.Fatal(err) + } + select { + case <-callbackDone: + return + case <-time.After(10 * time.Millisecond): + } + } + }() + + cancel() + if err := eg.Wait(); err != nil && !errors.Is(err, context.Canceled) { + t.Error(err) + } + if !callbackCalled.Load() { + t.Error("callback was not called") + } +}