talos/internal/pkg/dns/manager.go
Dmitriy Matrenichev dab30a8b9f
fix: ensure no goroutines escape in dns controller
- Remove all reliance on finalizers.
- Add `Close` method to CoreDNS `Proxy` struct.
- Wait for `Runner.Serve` to complete.

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
2025-03-07 22:08:46 +03:00

304 lines
7.7 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dns
import (
"context"
"errors"
"fmt"
"iter"
"net/netip"
"slices"
"strings"
"time"
"github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/hashicorp/go-multierror"
dnssrv "github.com/miekg/dns"
"github.com/siderolabs/gen/xiter"
"github.com/thejerf/suture/v4"
"go.uber.org/zap"
"github.com/siderolabs/talos/pkg/machinery/resources/cluster"
)
// ErrCreatingRunner is an error that occurs when creating a runner.
var ErrCreatingRunner = errors.New("error creating runner")
// Manager manages DNS runners.
type Manager struct {
originalCtx context.Context //nolint:containedctx
handler *Handler
nodeHandler *NodeHandler
rootHandler *Cache
s *suture.Supervisor
supervisorCh <-chan error
logger *zap.Logger
runners map[AddressPair]suture.ServiceToken
}
// NewManager creates a new manager.
func NewManager(mr MemberReader, hook suture.EventHook, logger *zap.Logger) *Manager {
handler := NewHandler(logger)
nodeHandler := NewNodeHandler(handler, &addrResolver{mr: mr}, logger)
rootHandler := NewCache(nodeHandler, logger)
m := &Manager{
handler: handler,
nodeHandler: nodeHandler,
rootHandler: rootHandler,
s: suture.New("dns-resolve-cache-runners", suture.Spec{EventHook: hook}),
logger: logger,
runners: map[AddressPair]suture.ServiceToken{},
}
return m
}
// ServeBackground starts the manager in the background. It panics if the manager is not initialized or if it's called
// more than once.
func (m *Manager) ServeBackground(ctx context.Context) {
switch {
case m.originalCtx == nil:
m.originalCtx = ctx
case m.originalCtx != ctx:
panic("Manager.ServeBackground is called with a different context")
case m.originalCtx == ctx:
return
}
m.supervisorCh = m.s.ServeBackground(ctx)
}
// AddressPair represents a network and address with port.
type AddressPair struct {
Network string
Addr netip.AddrPort
}
// String returns a string representation of the address pair.
func (a AddressPair) String() string { return "Network: " + a.Network + ", Addr: " + a.Addr.String() }
// RunAll updates and run the runners managed by the manager. It returns an iterator which yields the address pairs for
// all running and attempted ro run configurations. It's mandatory to range over the iterator to ensure all runners are updated.
func (m *Manager) RunAll(pairs iter.Seq[AddressPair], forwardEnabled bool) iter.Seq2[RunResult, error] {
return func(yield func(RunResult, error) bool) {
preserve := make(map[AddressPair]struct{}, len(m.runners))
for cfg := range pairs {
preserve[cfg] = struct{}{}
if _, ok := m.runners[cfg]; ok {
if !yield(makeResult(cfg, StatusRunning), nil) {
return
}
continue
}
opts, err := newDNSRunnerOpts(cfg, m.rootHandler, forwardEnabled)
if err != nil {
err = fmt.Errorf("%w: %w", ErrCreatingRunner, err)
} else {
m.runners[cfg] = m.s.Add(NewRunner(opts, m.logger))
}
if !yield(makeResult(cfg, StatusNew), err) {
return
}
}
for cfg, token := range m.runners {
if _, ok := preserve[cfg]; ok {
continue
}
err := m.s.RemoveAndWait(token, 0)
if err != nil {
err = fmt.Errorf("error removing runner: %w", err)
}
if !yield(makeResult(cfg, StatusRemoved), err) {
return
}
delete(m.runners, cfg)
}
}
}
func makeResult(cfg AddressPair, s Status) RunResult { return RunResult{AddressPair: cfg, Status: s} }
// AllowNodeResolving enables or disables the node resolving feature.
func (m *Manager) AllowNodeResolving(enabled bool) { m.nodeHandler.SetEnabled(enabled) }
// SetUpstreams sets the upstreams for the DNS handler. It returns true if the upstreams were updated, false otherwise.
func (m *Manager) SetUpstreams(prxs iter.Seq[*proxy.Proxy]) bool {
if !m.handler.SetProxy(prxs) {
return false
}
// Upstreams updated, clear cache to prevent DNS poisoning.
m.rootHandler.Clear()
return true
}
// ClearAll stops and removes all runners. Returns all errors if any runner failed to properly stop.
func (m *Manager) ClearAll(dry bool) error {
if dry {
return nil
}
var multiErr *multierror.Error
for _, err := range m.clearAll() {
if err != nil {
multiErr = multierror.Append(multiErr, err)
}
}
return multiErr.ErrorOrNil()
}
func (m *Manager) clearAll() iter.Seq2[AddressPair, error] {
return func(yield func(AddressPair, error) bool) {
if len(m.runners) == 0 {
return
}
defer m.handler.Stop()
for runData, token := range m.runners {
err := m.s.RemoveAndWait(token, 0)
if err != nil && !errors.Is(err, suture.ErrSupervisorNotRunning) && !errors.Is(err, suture.ErrTimeout) {
err = fmt.Errorf("error removing runner: %w", err)
} else {
err = nil
}
if !yield(runData, err) {
return
}
delete(m.runners, runData)
}
}
}
// Done reports if superwisor finished execution.
func (m *Manager) Done() <-chan error {
return m.supervisorCh
}
type addrResolver struct {
mr MemberReader
}
func (s *addrResolver) ResolveAddr(ctx context.Context, qType uint16, name string) (iter.Seq[netip.Addr], bool) {
name = strings.TrimRight(name, ".")
items, err := s.mr.ReadMembers(ctx)
if err != nil {
return nil, false
}
found, ok := xiter.Find(func(res *cluster.Member) bool {
return fqdnMatch(name, res.TypedSpec().Hostname) || fqdnMatch(name, res.Metadata().ID())
}, items)
if !ok {
return nil, false
}
return xiter.Filter(
func(addr netip.Addr) bool {
return (qType == dnssrv.TypeA && addr.Is4()) || (qType == dnssrv.TypeAAAA && addr.Is6())
},
slices.Values(found.TypedSpec().Addresses),
), true
}
func fqdnMatch(what, where string) bool {
what = strings.TrimRight(what, ".")
where = strings.TrimRight(where, ".")
if what == where {
return true
}
first, _, found := strings.Cut(where, ".")
if !found {
return false
}
return what == first
}
// MemberReader is an interface to read members.
type MemberReader interface {
ReadMembers(ctx context.Context) (iter.Seq[*cluster.Member], error)
}
func newDNSRunnerOpts(cfg AddressPair, rootHandler dnssrv.Handler, forwardEnabled bool) (RunnerOptions, error) {
if cfg.Addr.Addr().Is6() && !strings.HasSuffix(cfg.Network, "6") {
cfg.Network += "6"
}
var serverOpts RunnerOptions
controlFn, ctrlErr := MakeControl(cfg.Network, forwardEnabled)
if ctrlErr != nil {
return serverOpts, fmt.Errorf("error creating %q control function: %w", cfg.Network, ctrlErr)
}
switch cfg.Network {
case "udp", "udp6":
packetConn, err := NewUDPPacketConn(cfg.Network, cfg.Addr.String(), controlFn)
if err != nil {
return serverOpts, fmt.Errorf("error creating %q packet conn: %w", cfg.Network, err)
}
serverOpts = RunnerOptions{
PacketConn: packetConn,
Handler: rootHandler,
}
case "tcp", "tcp6":
listener, err := NewTCPListener(cfg.Network, cfg.Addr.String(), controlFn)
if err != nil {
return serverOpts, fmt.Errorf("error creating %q listener: %w", cfg.Network, err)
}
serverOpts = RunnerOptions{
Listener: listener,
Handler: rootHandler,
ReadTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: func() time.Duration { return 10 * time.Second },
MaxTCPQueries: -1,
}
}
return serverOpts, nil
}
// RunResult represents the result of a RunAll iteration.
type RunResult struct {
AddressPair
Status Status
}
// Status represents the status of a runner.
type Status int
const (
// StatusNew represents a new runner.
StatusNew Status = iota
// StatusRunning represents a already running runner.
StatusRunning
// StatusRemoved represents a removed runner.
StatusRemoved
)