Add Domain filter interface

This commit is contained in:
Thibault Jamet 2023-09-08 16:21:54 +03:00
parent 82c6983fa3
commit b2ff1619f5
No known key found for this signature in database
GPG Key ID: 9D28A304A3810C17
14 changed files with 30 additions and 24 deletions

View File

@ -186,7 +186,7 @@ type Controller struct {
// The interval between individual synchronizations // The interval between individual synchronizations
Interval time.Duration Interval time.Duration
// The DomainFilter defines which DNS records to keep or exclude // The DomainFilter defines which DNS records to keep or exclude
DomainFilter endpoint.DomainFilter DomainFilter endpoint.DomainFilterInterface
// The nextRunAt used for throttling and batching reconciliation // The nextRunAt used for throttling and batching reconciliation
nextRunAt time.Time nextRunAt time.Time
// The runAtMutex is for atomic updating of nextRunAt and lastRunAt // The runAtMutex is for atomic updating of nextRunAt and lastRunAt
@ -245,7 +245,7 @@ func (c *Controller) RunOnce(ctx context.Context) error {
Policies: []plan.Policy{c.Policy}, Policies: []plan.Policy{c.Policy},
Current: records, Current: records,
Desired: endpoints, Desired: endpoints,
DomainFilter: endpoint.MatchAllDomainFilters{&c.DomainFilter, &registryFilter}, DomainFilter: endpoint.MatchAllDomainFilters{c.DomainFilter, registryFilter},
ManagedRecords: c.ManagedRecordTypes, ManagedRecords: c.ManagedRecordTypes,
ExcludeRecords: c.ExcludeRecordTypes, ExcludeRecords: c.ExcludeRecordTypes,
OwnerID: c.Registry.OwnerID(), OwnerID: c.Registry.OwnerID(),

View File

@ -57,7 +57,7 @@ type errorMockProvider struct {
mockProvider mockProvider
} }
func (p *filteredMockProvider) GetDomainFilter() endpoint.DomainFilter { func (p *filteredMockProvider) GetDomainFilter() endpoint.DomainFilterInterface {
return p.domainFilter return p.domainFilter
} }

View File

@ -25,7 +25,7 @@ import (
"strings" "strings"
) )
type MatchAllDomainFilters []*DomainFilter type MatchAllDomainFilters []DomainFilterInterface
func (f MatchAllDomainFilters) Match(domain string) bool { func (f MatchAllDomainFilters) Match(domain string) bool {
for _, filter := range f { for _, filter := range f {
@ -39,6 +39,10 @@ func (f MatchAllDomainFilters) Match(domain string) bool {
return true return true
} }
type DomainFilterInterface interface {
Match(domain string) bool
}
// DomainFilter holds a lists of valid domain names // DomainFilter holds a lists of valid domain names
type DomainFilter struct { type DomainFilter struct {
// Filters define what domains to match // Filters define what domains to match
@ -51,6 +55,8 @@ type DomainFilter struct {
regexExclusion *regexp.Regexp regexExclusion *regexp.Regexp
} }
var _ DomainFilterInterface = &DomainFilter{}
// domainFilterSerde is a helper type for serializing and deserializing DomainFilter. // domainFilterSerde is a helper type for serializing and deserializing DomainFilter.
type domainFilterSerde struct { type domainFilterSerde struct {
Include []string `json:"include,omitempty"` Include []string `json:"include,omitempty"`

View File

@ -67,7 +67,7 @@ type Config struct {
AlwaysPublishNotReadyAddresses bool AlwaysPublishNotReadyAddresses bool
ConnectorSourceServer string ConnectorSourceServer string
Provider string Provider string
ProviderCacheTime int ProviderCacheTime time.Duration
GoogleProject string GoogleProject string
GoogleBatchChangeSize int GoogleBatchChangeSize int
GoogleBatchChangeInterval time.Duration GoogleBatchChangeInterval time.Duration

View File

@ -567,7 +567,7 @@ func (p *AWSProvider) createUpdateChanges(newEndpoints, oldEndpoints []*endpoint
} }
// GetDomainFilter generates a filter to exclude any domain that is not controlled by the provider // GetDomainFilter generates a filter to exclude any domain that is not controlled by the provider
func (p *AWSProvider) GetDomainFilter() endpoint.DomainFilter { func (p *AWSProvider) GetDomainFilter() endpoint.DomainFilterInterface {
zones, err := p.Zones(context.Background()) zones, err := p.Zones(context.Background())
if err != nil { if err != nil {
log.Errorf("failed to list zones: %v", err) log.Errorf("failed to list zones: %v", err)

View File

@ -319,10 +319,10 @@ func TestAWSZones(t *testing.T) {
func TestAWSRecordsFilter(t *testing.T) { func TestAWSRecordsFilter(t *testing.T) {
provider, _ := newAWSProvider(t, endpoint.DomainFilter{}, provider.ZoneIDFilter{}, provider.ZoneTypeFilter{}, false, false, nil) provider, _ := newAWSProvider(t, endpoint.DomainFilter{}, provider.ZoneIDFilter{}, provider.ZoneTypeFilter{}, false, false, nil)
domainFilter := provider.GetDomainFilter() domainFilter := provider.GetDomainFilter()
assert.NotNil(t, domainFilter) require.NotNil(t, domainFilter)
require.IsType(t, endpoint.DomainFilter{}, domainFilter) require.IsType(t, endpoint.DomainFilter{}, domainFilter)
count := 0 count := 0
filters := domainFilter.Filters filters := domainFilter.(endpoint.DomainFilter).Filters
for _, tld := range []string{ for _, tld := range []string{
"zone-4.ext-dns-test-3.teapot.zalan.do", "zone-4.ext-dns-test-3.teapot.zalan.do",
".zone-4.ext-dns-test-3.teapot.zalan.do", ".zone-4.ext-dns-test-3.teapot.zalan.do",

View File

@ -50,7 +50,6 @@ var (
type CachedProvider struct { type CachedProvider struct {
Provider Provider
RefreshDelay time.Duration RefreshDelay time.Duration
err error
lastRead time.Time lastRead time.Time
cache []*endpoint.Endpoint cache []*endpoint.Endpoint
} }
@ -58,17 +57,19 @@ type CachedProvider struct {
func (c *CachedProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { func (c *CachedProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
if c.needRefresh() { if c.needRefresh() {
log.Info("Records cache provider: refreshing records list cache") log.Info("Records cache provider: refreshing records list cache")
c.cache, c.err = c.Provider.Records(ctx) records, err := c.Provider.Records(ctx)
if c.err != nil { if err != nil {
log.Errorf("Records cache provider: list records failed: %v", c.err) c.cache = nil
return nil, err
} }
c.cache = records
c.lastRead = time.Now() c.lastRead = time.Now()
cachedRecordsCallsTotal.WithLabelValues("false").Inc() cachedRecordsCallsTotal.WithLabelValues("false").Inc()
} else { } else {
log.Info("Records cache provider: using records list from cache") log.Debug("Records cache provider: using records list from cache")
cachedRecordsCallsTotal.WithLabelValues("true").Inc() cachedRecordsCallsTotal.WithLabelValues("true").Inc()
} }
return c.cache, c.err return c.cache, nil
} }
func (c *CachedProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { func (c *CachedProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
if !changes.HasChanges() { if !changes.HasChanges() {
@ -81,13 +82,12 @@ func (c *CachedProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
} }
func (c *CachedProvider) Reset() { func (c *CachedProvider) Reset() {
c.err = nil
c.cache = nil c.cache = nil
c.lastRead = time.Time{} c.lastRead = time.Time{}
} }
func (c *CachedProvider) needRefresh() bool { func (c *CachedProvider) needRefresh() bool {
if c.cache == nil || c.err != nil { if c.cache == nil {
log.Debug("Records cache provider is not initialized") log.Debug("Records cache provider is not initialized")
return true return true
} }

View File

@ -46,7 +46,7 @@ var (
// initialized as dns provider with no records // initialized as dns provider with no records
type InMemoryProvider struct { type InMemoryProvider struct {
provider.BaseProvider provider.BaseProvider
domain endpoint.DomainFilter domain endpoint.DomainFilterInterface
client *inMemoryClient client *inMemoryClient
filter *filter filter *filter
OnApplyChanges func(ctx context.Context, changes *plan.Changes) OnApplyChanges func(ctx context.Context, changes *plan.Changes)

View File

@ -48,7 +48,7 @@ type Provider interface {
// unnecessary (potentially failing) changes. It may also modify other fields, add, or remove // unnecessary (potentially failing) changes. It may also modify other fields, add, or remove
// Endpoints. It is permitted to modify the supplied endpoints. // Endpoints. It is permitted to modify the supplied endpoints.
AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error)
GetDomainFilter() endpoint.DomainFilter GetDomainFilter() endpoint.DomainFilterInterface
} }
type BaseProvider struct{} type BaseProvider struct{}
@ -57,7 +57,7 @@ func (b BaseProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoi
return endpoints, nil return endpoints, nil
} }
func (b BaseProvider) GetDomainFilter() endpoint.DomainFilter { func (b BaseProvider) GetDomainFilter() endpoint.DomainFilterInterface {
return endpoint.DomainFilter{} return endpoint.DomainFilter{}
} }

View File

@ -42,7 +42,7 @@ func NewAWSSDRegistry(provider provider.Provider, ownerID string) (*AWSSDRegistr
}, nil }, nil
} }
func (sdr *AWSSDRegistry) GetDomainFilter() endpoint.DomainFilter { func (sdr *AWSSDRegistry) GetDomainFilter() endpoint.DomainFilterInterface {
return sdr.provider.GetDomainFilter() return sdr.provider.GetDomainFilter()
} }

View File

@ -105,7 +105,7 @@ func NewDynamoDBRegistry(provider provider.Provider, ownerID string, dynamodbAPI
}, nil }, nil
} }
func (im *DynamoDBRegistry) GetDomainFilter() endpoint.DomainFilter { func (im *DynamoDBRegistry) GetDomainFilter() endpoint.DomainFilterInterface {
return im.provider.GetDomainFilter() return im.provider.GetDomainFilter()
} }

View File

@ -36,7 +36,7 @@ func NewNoopRegistry(provider provider.Provider) (*NoopRegistry, error) {
}, nil }, nil
} }
func (im *NoopRegistry) GetDomainFilter() endpoint.DomainFilter { func (im *NoopRegistry) GetDomainFilter() endpoint.DomainFilterInterface {
return im.provider.GetDomainFilter() return im.provider.GetDomainFilter()
} }

View File

@ -31,6 +31,6 @@ type Registry interface {
Records(ctx context.Context) ([]*endpoint.Endpoint, error) Records(ctx context.Context) ([]*endpoint.Endpoint, error)
ApplyChanges(ctx context.Context, changes *plan.Changes) error ApplyChanges(ctx context.Context, changes *plan.Changes) error
AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error)
GetDomainFilter() endpoint.DomainFilter GetDomainFilter() endpoint.DomainFilterInterface
OwnerID() string OwnerID() string
} }

View File

@ -95,7 +95,7 @@ func getSupportedTypes() []string {
return []string{endpoint.RecordTypeA, endpoint.RecordTypeAAAA, endpoint.RecordTypeCNAME, endpoint.RecordTypeNS} return []string{endpoint.RecordTypeA, endpoint.RecordTypeAAAA, endpoint.RecordTypeCNAME, endpoint.RecordTypeNS}
} }
func (im *TXTRegistry) GetDomainFilter() endpoint.DomainFilter { func (im *TXTRegistry) GetDomainFilter() endpoint.DomainFilterInterface {
return im.provider.GetDomainFilter() return im.provider.GetDomainFilter()
} }