Andrey Smirnov 5855b8d532 fix: refresh control plane endpoints on worker apids on schedule
This moves endpoint refresh from the context of the service `apid` in
`machined` into `apid` service itself for the workers. `apid` does
initial poll for the endpoints when it boots, but also periodically
polls for new endpoints to make sure it has accurate list of `trustd`
endpoints to talk to, this handles cases when control plane endpoints
change (e.g. rolling replace of control plane nodes with new IPs).

Related to #3069

Fixes #3068

Signed-off-by: Andrey Smirnov <smirnov.andrey@gmail.com>
2021-02-03 14:27:03 -08:00

147 lines
3.6 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Package provider provides TLS config for client & server.
package provider
import (
stdlibtls "crypto/tls"
"fmt"
"log"
stdlibnet "net"
"reflect"
"sort"
"time"
"github.com/talos-systems/crypto/tls"
"github.com/talos-systems/crypto/x509"
"github.com/talos-systems/net"
"github.com/talos-systems/talos/pkg/grpc/gen"
"github.com/talos-systems/talos/pkg/machinery/config"
)
// TLSConfig provides client & server TLS configs for apid.
type TLSConfig struct {
endpoints Endpoints
lastEndpointList []string
generator *gen.RemoteGenerator
certificateProvider tls.CertificateProvider
}
// NewTLSConfig builds provider from configuration and endpoints.
func NewTLSConfig(config config.Provider, endpoints Endpoints) (*TLSConfig, error) {
ips, err := net.IPAddrs()
if err != nil {
return nil, fmt.Errorf("failed to discover IP addresses: %w", err)
}
dnsNames, err := net.DNSNames()
if err != nil {
return nil, err
}
for _, san := range config.Machine().Security().CertSANs() {
if ip := stdlibnet.ParseIP(san); ip != nil {
ips = append(ips, ip)
} else {
dnsNames = append(dnsNames, san)
}
}
endpointList, err := endpoints.GetEndpoints()
if err != nil {
return nil, fmt.Errorf("failed to fetch initial endpoint list: %w", err)
}
sort.Strings(endpointList)
tlsConfig := &TLSConfig{
endpoints: endpoints,
lastEndpointList: endpointList,
}
tlsConfig.generator, err = gen.NewRemoteGenerator(
config.Machine().Security().Token(),
endpointList,
)
if err != nil {
return nil, fmt.Errorf("failed to create remote certificate genertor: %w", err)
}
tlsConfig.certificateProvider, err = tls.NewRenewingCertificateProvider(
tlsConfig.generator,
dnsNames,
ips,
)
if err != nil {
return nil, err
}
go tlsConfig.refreshEndpoints()
return tlsConfig, nil
}
// ServerConfig generates server-side tls.Config.
func (tlsConfig *TLSConfig) ServerConfig() (*stdlibtls.Config, error) {
ca, err := tlsConfig.certificateProvider.GetCA()
if err != nil {
return nil, fmt.Errorf("failed to get root CA: %w", err)
}
return tls.New(
tls.WithClientAuthType(tls.Mutual),
tls.WithCACertPEM(ca),
tls.WithServerCertificateProvider(tlsConfig.certificateProvider),
)
}
// ClientConfig generates client-side tls.Config.
func (tlsConfig *TLSConfig) ClientConfig() (*stdlibtls.Config, error) {
ca, err := tlsConfig.certificateProvider.GetCA()
if err != nil {
return nil, fmt.Errorf("failed to get root CA: %w", err)
}
return tls.New(
tls.WithClientAuthType(tls.Mutual),
tls.WithCACertPEM(ca),
tls.WithClientCertificateProvider(tlsConfig.certificateProvider),
)
}
func (tlsConfig *TLSConfig) refreshEndpoints() {
// refresh endpoints 1/20 of the default certificate validity time
ticker := time.NewTicker(x509.DefaultCertificateValidityDuration / 20)
defer ticker.Stop()
for {
<-ticker.C
endpointList, err := tlsConfig.endpoints.GetEndpoints()
if err != nil {
log.Printf("error refreshing endpoints: %s", err)
continue
}
sort.Strings(endpointList)
if reflect.DeepEqual(tlsConfig.lastEndpointList, endpointList) {
continue
}
if err = tlsConfig.generator.SetEndpoints(endpointList); err != nil {
log.Printf("error setting new endpoints %v: %s", endpointList, err)
continue
}
tlsConfig.lastEndpointList = endpointList
log.Printf("updated control plane endpoints to %v", endpointList)
}
}