diff --git a/controller/execute.go b/controller/execute.go index a69a1a7ae..aa47c764b 100644 --- a/controller/execute.go +++ b/controller/execute.go @@ -426,10 +426,15 @@ func buildSource(ctx context.Context, cfg *externaldns.Config) (source.Source, e } // Combine multiple sources into a single, deduplicated source. combinedSource := wrappers.NewDedupSource(wrappers.NewMultiSource(sources, sourceCfg.DefaultTargets, sourceCfg.ForceDefaultTargets)) + cfg.AddSourceWrapper("dedup") + combinedSource = wrappers.NewNAT64Source(combinedSource, cfg.NAT64Networks) + cfg.AddSourceWrapper("nat64") // Filter targets targetFilter := endpoint.NewTargetNetFilterWithExclusions(cfg.TargetNetFilter, cfg.ExcludeTargetNets) - combinedSource = wrappers.NewNAT64Source(combinedSource, cfg.NAT64Networks) - combinedSource = wrappers.NewTargetFilterSource(combinedSource, targetFilter) + if targetFilter.IsEnabled() { + combinedSource = wrappers.NewTargetFilterSource(combinedSource, targetFilter) + cfg.AddSourceWrapper("target-filter") + } return combinedSource, nil } diff --git a/controller/execute_test.go b/controller/execute_test.go index 924c4c809..3b5e02e50 100644 --- a/controller/execute_test.go +++ b/controller/execute_test.go @@ -18,7 +18,6 @@ package controller import ( "bytes" - "context" "fmt" "net" "net/http" @@ -36,8 +35,8 @@ import ( "github.com/stretchr/testify/require" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/pkg/apis/externaldns" - "sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/provider" + fakeprovider "sigs.k8s.io/external-dns/provider/fakes" ) func TestSelectRegistry(t *testing.T) { @@ -60,7 +59,7 @@ func TestSelectRegistry(t *testing.T) { ExcludeDNSRecordTypes: []string{"TXT"}, TXTCacheInterval: 60, }, - provider: &MockProvider{}, + provider: &fakeprovider.MockProvider{}, wantErr: false, wantType: "DynamoDBRegistry", }, @@ -69,7 +68,7 @@ func TestSelectRegistry(t *testing.T) { cfg: &externaldns.Config{ Registry: "noop", }, - provider: &MockProvider{}, + provider: &fakeprovider.MockProvider{}, wantErr: false, wantType: "NoopRegistry", }, @@ -84,7 +83,7 @@ func TestSelectRegistry(t *testing.T) { ManagedDNSRecordTypes: []string{"A", "CNAME"}, ExcludeDNSRecordTypes: []string{"TXT"}, }, - provider: &MockProvider{}, + provider: &fakeprovider.MockProvider{}, wantErr: false, wantType: "TXTRegistry", }, @@ -94,7 +93,7 @@ func TestSelectRegistry(t *testing.T) { Registry: "aws-sd", TXTOwnerID: "owner-id", }, - provider: &MockProvider{}, + provider: &fakeprovider.MockProvider{}, wantErr: false, wantType: "AWSSDRegistry", }, @@ -103,7 +102,7 @@ func TestSelectRegistry(t *testing.T) { cfg: &externaldns.Config{ Registry: "unknown", }, - provider: &MockProvider{}, + provider: &fakeprovider.MockProvider{}, wantErr: true, wantType: "", }, @@ -477,21 +476,47 @@ func TestBuildSource(t *testing.T) { } } -// mocks -type MockProvider struct{} +func TestBuildSourceWithWrappers(t *testing.T) { + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + })) + defer svr.Close() -func (m *MockProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { - return nil, nil -} + tests := []struct { + name string + cfg *externaldns.Config + asserts func(*externaldns.Config) + }{ + { + name: "configuration with target filter wrapper", + cfg: &externaldns.Config{ + APIServerURL: svr.URL, + Sources: []string{"fake"}, + TargetNetFilter: []string{"10.0.0.0/8"}, + }, + asserts: func(cfg *externaldns.Config) { + assert.True(t, cfg.IsSourceWrapperInstrumented("target-filter")) + }, + }, + { + name: "configuration without target filter wrapper", + cfg: &externaldns.Config{ + APIServerURL: svr.URL, + Sources: []string{"fake"}, + }, + asserts: func(cfg *externaldns.Config) { + assert.True(t, cfg.IsSourceWrapperInstrumented("dedup")) + assert.True(t, cfg.IsSourceWrapperInstrumented("nat64")) + assert.False(t, cfg.IsSourceWrapperInstrumented("target-filter")) + }, + }, + } -func (p *MockProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { - return nil -} - -func (m *MockProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) { - return nil, nil -} - -func (m *MockProvider) GetDomainFilter() endpoint.DomainFilterInterface { - return nil + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := buildSource(t.Context(), tt.cfg) + require.NoError(t, err) + tt.asserts(tt.cfg) + }) + } } diff --git a/docs/contributing/source-wrappers.md b/docs/contributing/source-wrappers.md new file mode 100644 index 000000000..dd9583ea6 --- /dev/null +++ b/docs/contributing/source-wrappers.md @@ -0,0 +1,149 @@ +# ๐Ÿงฉ Source Wrappers/Middleware + +## Overview + +In ExternalDNS, a **Source** is a component responsible for discovering DNS records from Kubernetes resources (e.g., `Ingress`, `Service`, `Gateway`, etc.). + +**Source Wrappers** are middleware-like components that sit between the source and the plan generation. They extend or modify the behavior of the original sources by transforming, filtering, or enriching the DNS records before they're processed by the planner and provider. + +--- + +## Why Wrappers? + +Wrappers solve these key challenges: + +- โœ‚๏ธ **Filtering**: Remove unwanted targets or records from sources based on labels, annotations, targets and etc. +- ๐Ÿ”— **Aggregation**: Combine Endpoints from multiple underlying sources. For example, from both Kubernetes Services and Ingresses. +- ๐Ÿงน **Deduplication**: Prevent duplicate DNS records across sources. +- ๐ŸŒ **Target transformation**: Rewrite targets for IPv6 networks or alter endpoint attributes like FQDNS or targets. +- ๐Ÿงช **Testing and simulation**: Use the `FakeSource` or wrappers for dry-runs or simulations. +- ๐Ÿ” **Composability**: Chain multiple behaviors without modifying core sources. +- ๐Ÿ” **Access Control**: Limits endpoint exposure based on policies or user access. +- ๐Ÿ“Š **Observability**: Adds logging, debugging, or metrics around source behavior. + +--- + +## Built In Wrappers + +| Wrapper | Purpose | Use Case | +|:--------------------:|:----------------------------------------|:--------------------------------------| +| `MultiSource` | Combine multiple sources. | Aggregate `Ingress`, `Service`, etc. | +| `DedupSource` | Remove duplicate DNS records. | Avoid duplicate records from sources. | +| `TargetFilterSource` | Include/exclude targets based on CIDRs. | Exclude internal IPs. | +| `NAT64Source` | Add NAT64-prefixed AAAA records. | Support IPv6 with NAT64. | + +### Use Cases + +### 1.1 `TargetFilterSource` + +Filters targets (e.g. IPs or hostnames) based on inclusion or exclusion rules. + +๐Ÿ“Œ **Use case**: Only publish public IPs, exclude test environments. + +```yaml +--target-net-filter=192.168.0.0/16 +--exclude-target-nets=10.0.0.0/8 +``` + +### 2.1 `NAT64Source` + +Converts IPv4 targets to IPv6 using NAT64 prefixes. + +๐Ÿ“Œ **Use case**: Publish AAAA records for IPv6-only clients in NAT64 environments. + +```yaml +--nat64-prefix=64:ff9b::/96 +``` + +--- + +## How Wrappers Work + +Wrappers wrap a `Source` and implement the same `Source` interface (e.g., `Endpoints(ctx)`). + +They typically follow this pattern: + +```go +package wrappers + +type myWrapper struct { + next source.Source +} + +func (m *myWrapper) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { + eps, err := m.next.Endpoints(ctx) + if err != nil { + return nil, err + } + + // Modify, filter, or enrich endpoints as needed + return eps, nil +} + +// AddEventHandler must be implemented to satisfy the source.Source interface. +func (m *myWrapper) AddEventHandler(ctx context.Context, handler func()) { + log.Debugf("myWrapper: adding event handler") + m.next.AddEventHandler(ctx, handler) +} +``` + +This allows wrappers to be stacked or composed together. + +--- + +### Composition of Wrappers + +Wrappers are often composed like this: + +```go +source := NewMultiSource(actualSources, defaultTargets) +source = NewDedupSource(source) +source = NewNAT64Source(source, cfg.NAT64Networks) +source = NewTargetFilterSource(source, targetFilter) +``` + +Each wrapper processes the output of the previous one. + +--- + +## High Level Design + +- Source: Implements the base logic for extracting DNS endpoints (e.g. IngressSource, ServiceSource, etc.) +- Wrappers: Decorate the source (e.g. DedupSource, TargetFilterSource) to enhance or filter endpoint data +- Plan: Compares the endpoints from Source with DNS state from Provider and produces create/update/delete changes +- Provider: Applies changes to actual DNS services (e.g. Route53, Cloudflare, Azure DNS) + +```mermaid +sequenceDiagram + participant ExternalDNS + participant Source + participant Wrapper + participant DedupWrapper as DedupSource + participant Provider + participant Plan + + ExternalDNS->>Source: Initialize source (e.g. Ingress, Service) + Source-->>ExternalDNS: Implements Source interface + + ExternalDNS->>Wrapper: Wrap with decorators (e.g. dedup, filters) + Wrapper->>DedupWrapper: Compose with DedupSource + DedupWrapper-->>Wrapper: Return enriched Source + + Wrapper-->>ExternalDNS: Return final wrapped Source + + ExternalDNS->>Plan: Generate plan from Source + Plan->>Wrapper: Call Endpoints(ctx) + Wrapper->>DedupWrapper: Call Endpoints(ctx) + DedupWrapper->>Source: Call Endpoints(ctx) + Source-->>DedupWrapper: Return []*Endpoint + DedupWrapper-->>Wrapper: Return de-duplicated []*Endpoint + Wrapper-->>Plan: Return transformed []*Endpoint + + ExternalDNS->>Provider: ApplyChanges(plan) + Provider-->>ExternalDNS: Sync DNS records +``` + +## Learn More + +- [Source Interface](https://github.com/kubernetes-sigs/external-dns/blob/master/source/source.go) +- [Wrappers Source Code](https://github.com/kubernetes-sigs/external-dns/tree/master/source/wrappers) diff --git a/internal/testutils/mock_source.go b/internal/testutils/mock_source.go index 4819e2113..99644cb08 100644 --- a/internal/testutils/mock_source.go +++ b/internal/testutils/mock_source.go @@ -28,10 +28,20 @@ import ( // MockSource returns mock endpoints. type MockSource struct { mock.Mock + endpoints []*endpoint.Endpoint +} + +func NewMockSource(endpoints ...*endpoint.Endpoint) *MockSource { + m := &MockSource{ + endpoints: endpoints, + } + m.On("Endpoints").Return(endpoints, nil) + m.On("AddEventHandler", mock.AnythingOfType("*context.cancelCtx")).Return() + return m } // Endpoints returns the desired mock endpoints. -func (m *MockSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { +func (m *MockSource) Endpoints(_ context.Context) ([]*endpoint.Endpoint, error) { args := m.Called() endpoints := args.Get(0) @@ -44,6 +54,10 @@ func (m *MockSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error // AddEventHandler adds an event handler that should be triggered if something in source changes func (m *MockSource) AddEventHandler(ctx context.Context, handler func()) { + m.Called(ctx) + if handler == nil { + return + } go func() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() diff --git a/pkg/apis/externaldns/types.go b/pkg/apis/externaldns/types.go index c4df4e35f..83c81a558 100644 --- a/pkg/apis/externaldns/types.go +++ b/pkg/apis/externaldns/types.go @@ -213,6 +213,7 @@ type Config struct { NAT64Networks []string ExcludeUnschedulable bool ForceDefaultTargets bool + sourceWrappers map[string]bool // map of source wrappers, e.g. "targetfilter", "nat64" } var defaultConfig = &Config{ @@ -376,6 +377,7 @@ var defaultConfig = &Config{ WebhookServer: false, ZoneIDFilter: []string{}, ForceDefaultTargets: false, + sourceWrappers: map[string]bool{}, } // NewConfig returns new Config object @@ -427,6 +429,22 @@ func (cfg *Config) ParseFlags(args []string) error { return nil } +func (cfg *Config) AddSourceWrapper(name string) { + if cfg.sourceWrappers == nil { + cfg.sourceWrappers = make(map[string]bool) + } + cfg.sourceWrappers[name] = true +} + +// IsSourceWrapperInstrumented returns whether a source wrapper is enabled or not. +func (cfg *Config) IsSourceWrapperInstrumented(name string) bool { + if cfg.sourceWrappers == nil { + return false + } + _, ok := cfg.sourceWrappers[name] + return ok +} + func App(cfg *Config) *kingpin.Application { app := kingpin.New("external-dns", "ExternalDNS synchronizes exposed Kubernetes Services and Ingresses with DNS providers.\n\nNote that all flags may be replaced with env vars - `--flag` -> `EXTERNAL_DNS_FLAG=1` or `--flag value` -> `EXTERNAL_DNS_FLAG=value`") app.Version(Version) diff --git a/provider/fakes/provider.go b/provider/fakes/provider.go new file mode 100644 index 000000000..6f6d18304 --- /dev/null +++ b/provider/fakes/provider.go @@ -0,0 +1,42 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fakes + +import ( + "context" + + "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/plan" +) + +type MockProvider struct{} + +func (m *MockProvider) Records(_ context.Context) ([]*endpoint.Endpoint, error) { + return nil, nil +} + +func (m *MockProvider) ApplyChanges(_ context.Context, _ *plan.Changes) error { + return nil +} + +func (m *MockProvider) AdjustEndpoints(_ []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) { + return nil, nil +} + +func (m *MockProvider) GetDomainFilter() endpoint.DomainFilterInterface { + return nil +} diff --git a/source/wrappers/dedupsource.go b/source/wrappers/dedupsource.go index 665799b44..961ef45cf 100644 --- a/source/wrappers/dedupsource.go +++ b/source/wrappers/dedupsource.go @@ -22,9 +22,8 @@ import ( log "github.com/sirupsen/logrus" - "sigs.k8s.io/external-dns/source" - "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/source" ) // dedupSource is a Source that removes duplicate endpoints from its wrapped source. @@ -39,6 +38,7 @@ func NewDedupSource(source source.Source) source.Source { // Endpoints collects endpoints from its wrapped source and returns them without duplicates. func (ms *dedupSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { + log.Debug("dedupSource: collecting endpoints and removing duplicates") result := []*endpoint.Endpoint{} collected := map[string]bool{} @@ -67,5 +67,6 @@ func (ms *dedupSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, err } func (ms *dedupSource) AddEventHandler(ctx context.Context, handler func()) { + log.Debug("dedupSource: adding event handler") ms.source.AddEventHandler(ctx, handler) } diff --git a/source/wrappers/dedupsource_test.go b/source/wrappers/dedupsource_test.go index 07a13f40b..3f9ac58ba 100644 --- a/source/wrappers/dedupsource_test.go +++ b/source/wrappers/dedupsource_test.go @@ -144,3 +144,27 @@ func testDedupEndpoints(t *testing.T) { }) } } + +func TestDedupSource_AddEventHandler(t *testing.T) { + tests := []struct { + title string + input []string + times int + }{ + { + title: "should add event handler", + times: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + mockSource := testutils.NewMockSource() + + src := NewDedupSource(mockSource) + src.AddEventHandler(t.Context(), func() {}) + + mockSource.AssertNumberOfCalls(t, "AddEventHandler", tt.times) + }) + } +} diff --git a/source/wrappers/multisource.go b/source/wrappers/multisource.go index 60c287716..60e197667 100644 --- a/source/wrappers/multisource.go +++ b/source/wrappers/multisource.go @@ -18,12 +18,13 @@ package wrappers import ( "context" + "reflect" "strings" + log "github.com/sirupsen/logrus" + "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/source" - - log "github.com/sirupsen/logrus" ) // multiSource is a Source that merges the endpoints of its nested Sources. @@ -35,6 +36,7 @@ type multiSource struct { // Endpoints collects endpoints of all nested Sources and returns them in a single slice. func (ms *multiSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { + log.Debugf("multiSource: collecting endpoints from %d child sources and removing duplicates", len(ms.children)) result := []*endpoint.Endpoint{} hasDefaultTargets := len(ms.defaultTargets) > 0 @@ -70,7 +72,9 @@ func (ms *multiSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, err } func (ms *multiSource) AddEventHandler(ctx context.Context, handler func()) { + log.Debugf("multiSource: adding event handler for %d child sources", len(ms.children)) for _, s := range ms.children { + log.Debugf("multiSource: adding event handler for child %q", reflect.TypeOf(s).String()) s.AddEventHandler(ctx, handler) } } diff --git a/source/wrappers/multisource_test.go b/source/wrappers/multisource_test.go index 06626a08c..ea45093f8 100644 --- a/source/wrappers/multisource_test.go +++ b/source/wrappers/multisource_test.go @@ -269,3 +269,43 @@ func testMultiSourceEndpointsDefaultTargets(t *testing.T) { src.AssertExpectations(t) }) } + +func TestMultiSource_AddEventHandler(t *testing.T) { + tests := []struct { + title string + sources []source.Source + times int + }{ + { + title: "should not add event handler when sources are empty", + sources: []source.Source{}, + times: 0, + }, + { + title: "should add event handler when sources not empty", + sources: []source.Source{ + testutils.NewMockSource(), + testutils.NewMockSource(), + testutils.NewMockSource(), + }, + times: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + src := NewMultiSource(tt.sources, []string{}, true) + src.AddEventHandler(t.Context(), func() {}) + + count := 0 + + for _, mockSource := range tt.sources { + mSource := mockSource.(*testutils.MockSource) + mSource.AssertNumberOfCalls(t, "AddEventHandler", 1) + count += 1 + } + + assert.Equal(t, tt.times, count) + }) + } +} diff --git a/source/wrappers/nat64source.go b/source/wrappers/nat64source.go index 383bdde5c..d1b13f951 100644 --- a/source/wrappers/nat64source.go +++ b/source/wrappers/nat64source.go @@ -21,6 +21,8 @@ import ( "fmt" "net/netip" + log "github.com/sirupsen/logrus" + "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/source" ) @@ -38,6 +40,7 @@ func NewNAT64Source(source source.Source, nat64Prefixes []string) source.Source // Endpoints collects endpoints from its wrapped source and returns them without duplicates. func (s *nat64Source) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { + log.Debug("nat64Source: collecting endpoints and processing NAT64 translation") parsedNAT64Prefixes := make([]netip.Prefix, 0) for _, prefix := range s.nat64Prefixes { pPrefix, err := netip.ParsePrefix(prefix) @@ -109,5 +112,6 @@ func (s *nat64Source) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, erro } func (s *nat64Source) AddEventHandler(ctx context.Context, handler func()) { + log.Debug("nat64Source: adding event handler") s.source.AddEventHandler(ctx, handler) } diff --git a/source/wrappers/nat64source_test.go b/source/wrappers/nat64source_test.go index 3a0d15f12..2401d3f36 100644 --- a/source/wrappers/nat64source_test.go +++ b/source/wrappers/nat64source_test.go @@ -89,3 +89,33 @@ func testNat64Source(t *testing.T) { }) } } + +func TestNat64Source_AddEventHandler(t *testing.T) { + tests := []struct { + title string + input []string + times int + }{ + { + title: "should add event handler when prefixes are provided", + input: []string{"2001:DB8::/96"}, + times: 1, + }, + { + title: "should add event handler when prefixes not provided", + input: []string{}, + times: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + mockSource := testutils.NewMockSource() + + src := NewNAT64Source(mockSource, tt.input) + src.AddEventHandler(t.Context(), func() {}) + + mockSource.AssertNumberOfCalls(t, "AddEventHandler", tt.times) + }) + } +} diff --git a/source/wrappers/targetfiltersource.go b/source/wrappers/targetfiltersource.go index afc654d90..7cbbaa8ce 100644 --- a/source/wrappers/targetfiltersource.go +++ b/source/wrappers/targetfiltersource.go @@ -21,9 +21,8 @@ import ( log "github.com/sirupsen/logrus" - "sigs.k8s.io/external-dns/source" - "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/source" ) // targetFilterSource is a Source that removes endpoints matching the target filter from its wrapped source. @@ -40,6 +39,7 @@ func NewTargetFilterSource(source source.Source, targetFilter endpoint.TargetFil // Endpoints collects endpoints from its wrapped source and returns // them without targets matching the target filter. func (ms *targetFilterSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { + log.Debug("targetFilterSource: collecting endpoints from wrapped source and applying target filter") endpoints, err := ms.source.Endpoints(ctx) if err != nil { return nil, err @@ -75,7 +75,6 @@ func (ms *targetFilterSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoi } func (ms *targetFilterSource) AddEventHandler(ctx context.Context, handler func()) { - if ms.targetFilter.IsEnabled() { - ms.source.AddEventHandler(ctx, handler) - } + log.Debug("targetFilterSource: adding event handler") + ms.source.AddEventHandler(ctx, handler) } diff --git a/source/wrappers/targetfiltersource_test.go b/source/wrappers/targetfiltersource_test.go index e0e01d745..7298c1e0a 100644 --- a/source/wrappers/targetfiltersource_test.go +++ b/source/wrappers/targetfiltersource_test.go @@ -19,9 +19,10 @@ package wrappers import ( "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "golang.org/x/net/context" + + "sigs.k8s.io/external-dns/internal/testutils" "sigs.k8s.io/external-dns/source" "sigs.k8s.io/external-dns/endpoint" @@ -47,26 +48,6 @@ func (m *mockTargetNetFilter) IsEnabled() bool { return true } -// echoSource is a Source that returns the endpoints passed in on creation. -type echoSource struct { - mock.Mock - endpoints []*endpoint.Endpoint -} - -func (e *echoSource) AddEventHandler(ctx context.Context, handler func()) { - e.Called(ctx) -} - -// Endpoints returns all the endpoints passed in on creation -func (e *echoSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { - return e.endpoints, nil -} - -// NewEchoSource creates a new echoSource. -func NewEchoSource(endpoints []*endpoint.Endpoint) source.Source { - return &echoSource{endpoints: endpoints} -} - func TestEchoSourceReturnGivenSources(t *testing.T) { startEndpoints := []*endpoint.Endpoint{{ DNSName: "foo.bar.com", @@ -75,7 +56,7 @@ func TestEchoSourceReturnGivenSources(t *testing.T) { RecordTTL: endpoint.TTL(300), Labels: endpoint.Labels{}, }} - e := NewEchoSource(startEndpoints) + e := testutils.NewMockSource(startEndpoints...) endpoints, err := e.Endpoints(context.Background()) if err != nil { @@ -143,7 +124,7 @@ func TestTargetFilterSourceEndpoints(t *testing.T) { t.Run(tt.title, func(t *testing.T) { t.Parallel() - echo := NewEchoSource(tt.endpoints) + echo := testutils.NewMockSource(tt.endpoints...) src := NewTargetFilterSource(echo, tt.filters) endpoints, err := src.Endpoints(context.Background()) @@ -225,7 +206,7 @@ func TestTargetFilterConcreteTargetFilter(t *testing.T) { } for _, tt := range tests { t.Run(tt.title, func(t *testing.T) { - echo := NewEchoSource(tt.endpoints) + echo := testutils.NewMockSource(tt.endpoints...) src := NewTargetFilterSource(echo, tt.filters) endpoints, err := src.Endpoints(context.Background()) @@ -248,20 +229,16 @@ func TestTargetFilterSource_AddEventHandler(t *testing.T) { times: 1, }, { - title: "should not add event handler if target filter is disabled", + title: "should add event handler if target filter is disabled", filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{}), - times: 0, + times: 1, }, } for _, tt := range tests { t.Run(tt.title, func(t *testing.T) { - echo := NewEchoSource([]*endpoint.Endpoint{}) - - m := echo.(*echoSource) - m.On("AddEventHandler", t.Context()).Return() - - src := NewTargetFilterSource(echo, tt.filters) + m := testutils.NewMockSource() + src := NewTargetFilterSource(m, tt.filters) src.AddEventHandler(t.Context(), func() {}) m.AssertNumberOfCalls(t, "AddEventHandler", tt.times)