diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 8c8f342e1..413b87915 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -292,7 +292,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo osshare.SetFileSharingEnabled(false, logf) ctx, cancel := context.WithCancel(context.Background()) - portpoll, err := portlist.NewPoller() + portpoll, err := portlist.NewPoller(portlist.PollerOptions{}) if err != nil { logf("skipping portlist: %s", err) } diff --git a/portlist/poller.go b/portlist/poller.go index 16dd8e74c..831198256 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -24,11 +24,6 @@ var debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") // Poller scans the systems for listening ports periodically and sends // the results to C. type Poller struct { - // IncludeLocalhost controls whether services bound to localhost are included. - // - // This field should only be changed before calling Run. - IncludeLocalhost bool - c chan List // unbuffered // os, if non-nil, is an OS-specific implementation of the portlist getting @@ -50,6 +45,10 @@ type Poller struct { scratch []Port prev List // most recent data, not aliasing scratch + + // caller options fields + includeLocalhost bool + pollInterval time.Duration } // osImpl is the OS-specific implementation of getting the open listening ports. @@ -71,15 +70,34 @@ var newOSImpl func(includeLocalhost bool) osImpl var errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS) +// PollerOptions for customizing the behavior +// of the Poller. The zero value uses each +// of the options' defaults. +type PollerOptions struct { + // IncludeLocalhost controls whether services bound to localhost are included. + // + // This field should only be changed before calling Run. + IncludeLocalhost bool + + // PollInterval sets the interval for checking the underlying OS + // for port updates. + PollInterval time.Duration +} + // NewPoller returns a new portlist Poller. It returns an error // if the portlist couldn't be obtained. -func NewPoller() (*Poller, error) { +func NewPoller(opts PollerOptions) (*Poller, error) { if debugDisablePortlist() { return nil, errors.New("portlist disabled by envknob") } + if opts.PollInterval == 0 { + opts.PollInterval = pollInterval + } p := &Poller{ - c: make(chan List), - runDone: make(chan struct{}), + c: make(chan List), + runDone: make(chan struct{}), + includeLocalhost: opts.IncludeLocalhost, + pollInterval: opts.PollInterval, } p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) p.osOnce.Do(p.initOSField) @@ -105,7 +123,7 @@ func (p *Poller) setPrev(pl List) { func (p *Poller) initOSField() { if newOSImpl != nil { - p.os = newOSImpl(p.IncludeLocalhost) + p.os = newOSImpl(p.includeLocalhost) } } @@ -142,7 +160,7 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) { // // Run may only be called once. func (p *Poller) Run(ctx context.Context) error { - tick := time.NewTicker(pollInterval) + tick := time.NewTicker(p.pollInterval) defer tick.Stop() return p.runWithTickChan(ctx, tick.C) } diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 6055a8426..6be589c54 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -51,7 +51,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) { func TestChangesOverTime(t *testing.T) { var p Poller - p.IncludeLocalhost = true + p.includeLocalhost = true get := func(t *testing.T) []Port { t.Helper() s, err := p.getList() @@ -176,7 +176,7 @@ func TestEqualLessThan(t *testing.T) { } func TestPoller(t *testing.T) { - p, err := NewPoller() + p, err := NewPoller(PollerOptions{}) if err != nil { t.Skipf("not running test: %v", err) }