From aa528bb7bf74f01303bf8c9425038387348fad47 Mon Sep 17 00:00:00 2001 From: Marwan Sulaiman Date: Wed, 24 May 2023 16:35:00 -0400 Subject: [PATCH] portlist: Accept Options for NewPoller This is a follow up on PR #8172 and a breaking change that allows NewPoller to take an options struct. The issue with the previous PR was that NewPoller immediately initializes the underlying os implementation and therefore setting IncludeLocalhost as an exported field happened too late and cannot happen early enough. Using the zero value of Poller was also not an option from outside of the package because we need to set initial private fields Fixes #8171 Signed-off-by: Marwan Sulaiman --- ipn/ipnlocal/local.go | 2 +- portlist/poller.go | 38 ++++++++++++++++++++++++++++---------- portlist/portlist_test.go | 4 ++-- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 8c8f342e1..413b87915 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -292,7 +292,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo osshare.SetFileSharingEnabled(false, logf) ctx, cancel := context.WithCancel(context.Background()) - portpoll, err := portlist.NewPoller() + portpoll, err := portlist.NewPoller(portlist.PollerOptions{}) if err != nil { logf("skipping portlist: %s", err) } diff --git a/portlist/poller.go b/portlist/poller.go index 16dd8e74c..831198256 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -24,11 +24,6 @@ var debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") // Poller scans the systems for listening ports periodically and sends // the results to C. type Poller struct { - // IncludeLocalhost controls whether services bound to localhost are included. - // - // This field should only be changed before calling Run. - IncludeLocalhost bool - c chan List // unbuffered // os, if non-nil, is an OS-specific implementation of the portlist getting @@ -50,6 +45,10 @@ type Poller struct { scratch []Port prev List // most recent data, not aliasing scratch + + // caller options fields + includeLocalhost bool + pollInterval time.Duration } // osImpl is the OS-specific implementation of getting the open listening ports. @@ -71,15 +70,34 @@ var newOSImpl func(includeLocalhost bool) osImpl var errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS) +// PollerOptions for customizing the behavior +// of the Poller. The zero value uses each +// of the options' defaults. +type PollerOptions struct { + // IncludeLocalhost controls whether services bound to localhost are included. + // + // This field should only be changed before calling Run. + IncludeLocalhost bool + + // PollInterval sets the interval for checking the underlying OS + // for port updates. + PollInterval time.Duration +} + // NewPoller returns a new portlist Poller. It returns an error // if the portlist couldn't be obtained. -func NewPoller() (*Poller, error) { +func NewPoller(opts PollerOptions) (*Poller, error) { if debugDisablePortlist() { return nil, errors.New("portlist disabled by envknob") } + if opts.PollInterval == 0 { + opts.PollInterval = pollInterval + } p := &Poller{ - c: make(chan List), - runDone: make(chan struct{}), + c: make(chan List), + runDone: make(chan struct{}), + includeLocalhost: opts.IncludeLocalhost, + pollInterval: opts.PollInterval, } p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) p.osOnce.Do(p.initOSField) @@ -105,7 +123,7 @@ func (p *Poller) setPrev(pl List) { func (p *Poller) initOSField() { if newOSImpl != nil { - p.os = newOSImpl(p.IncludeLocalhost) + p.os = newOSImpl(p.includeLocalhost) } } @@ -142,7 +160,7 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) { // // Run may only be called once. func (p *Poller) Run(ctx context.Context) error { - tick := time.NewTicker(pollInterval) + tick := time.NewTicker(p.pollInterval) defer tick.Stop() return p.runWithTickChan(ctx, tick.C) } diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 6055a8426..6be589c54 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -51,7 +51,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) { func TestChangesOverTime(t *testing.T) { var p Poller - p.IncludeLocalhost = true + p.includeLocalhost = true get := func(t *testing.T) []Port { t.Helper() s, err := p.getList() @@ -176,7 +176,7 @@ func TestEqualLessThan(t *testing.T) { } func TestPoller(t *testing.T) { - p, err := NewPoller() + p, err := NewPoller(PollerOptions{}) if err != nil { t.Skipf("not running test: %v", err) }