From 28d0ff93168d9297f4ec2523ec5f4934b9b16637 Mon Sep 17 00:00:00 2001 From: Ivan Ka <5395690+ivankatliarchuk@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:47:51 +0100 Subject: [PATCH] chore(source/net-filter): improve flow logic and add more tests (#5629) Signed-off-by: ivan katliarchuk --- endpoint/domain_filter_test.go | 6 + endpoint/target_filter.go | 20 +-- endpoint/target_filter_test.go | 33 ++++- source/wrappers/targetfiltersource.go | 20 ++- source/wrappers/targetfiltersource_test.go | 155 ++++++++++++++++++--- 5 files changed, 199 insertions(+), 35 deletions(-) diff --git a/endpoint/domain_filter_test.go b/endpoint/domain_filter_test.go index 29dfdfea6..328743464 100644 --- a/endpoint/domain_filter_test.go +++ b/endpoint/domain_filter_test.go @@ -949,3 +949,9 @@ func TestDomainFilterNormalizeDomain(t *testing.T) { assert.Equal(t, r.expect, gotName) } } + +func TestMatchTargetFilterReturnsProperEmptyVal(t *testing.T) { + var emptyFilters []string + assert.True(t, matchFilter(emptyFilters, "sometarget.com", true)) + assert.False(t, matchFilter(emptyFilters, "sometarget.com", false)) +} diff --git a/endpoint/target_filter.go b/endpoint/target_filter.go index e4e69957f..2706155e9 100644 --- a/endpoint/target_filter.go +++ b/endpoint/target_filter.go @@ -26,12 +26,13 @@ import ( // TargetFilterInterface defines the interface to select matching targets for a specific provider or runtime type TargetFilterInterface interface { Match(target string) bool + IsEnabled() bool } // TargetNetFilter holds a lists of valid target names type TargetNetFilter struct { - // FilterNets define what targets to match - FilterNets []*net.IPNet + // filterNets define what targets to match + filterNets []*net.IPNet // excludeNets define what targets not to match excludeNets []*net.IPNet } @@ -42,11 +43,9 @@ func prepareTargetFilters(filters []string) []*net.IPNet { for _, filter := range filters { filter = strings.TrimSpace(filter) - _, filterNet, err := net.ParseCIDR(filter) if err != nil { log.Errorf("Invalid target net filter: %s", filter) - continue } @@ -57,12 +56,17 @@ func prepareTargetFilters(filters []string) []*net.IPNet { // NewTargetNetFilterWithExclusions returns a new TargetNetFilter, given a list of matches and exclusions func NewTargetNetFilterWithExclusions(targetFilterNets []string, excludeNets []string) TargetNetFilter { - return TargetNetFilter{FilterNets: prepareTargetFilters(targetFilterNets), excludeNets: prepareTargetFilters(excludeNets)} + return TargetNetFilter{filterNets: prepareTargetFilters(targetFilterNets), excludeNets: prepareTargetFilters(excludeNets)} } // Match checks whether a target can be found in the TargetNetFilter. func (tf TargetNetFilter) Match(target string) bool { - return matchTargetNetFilter(tf.FilterNets, target, true) && !matchTargetNetFilter(tf.excludeNets, target, false) + return matchTargetNetFilter(tf.filterNets, target, true) && !matchTargetNetFilter(tf.excludeNets, target, false) +} + +// IsEnabled returns true if any filters or exclusions are set. +func (tf TargetNetFilter) IsEnabled() bool { + return len(tf.filterNets) > 0 || len(tf.excludeNets) > 0 } // matchTargetNetFilter determines if any `filters` match `target`. @@ -73,9 +77,9 @@ func matchTargetNetFilter(filters []*net.IPNet, target string, emptyval bool) bo return emptyval } - for _, filter := range filters { - ip := net.ParseIP(target) + ip := net.ParseIP(target) + for _, filter := range filters { if filter.Contains(ip) { return true } diff --git a/endpoint/target_filter_test.go b/endpoint/target_filter_test.go index 01ffbf5cf..d803093c1 100644 --- a/endpoint/target_filter_test.go +++ b/endpoint/target_filter_test.go @@ -66,6 +66,18 @@ var targetFilterTests = []targetFilterTest{ []string{"10.1.2.3"}, false, }, + { + []string{}, + []string{"10.0.0.0/8"}, + []string{"49.13.41.161"}, + true, + }, + { + []string{}, + []string{"10.0.0.0/8"}, + []string{"10.0.1.101"}, + false, + }, } func TestTargetFilterWithExclusions(t *testing.T) { @@ -89,8 +101,21 @@ func TestTargetFilterMatchWithEmptyFilter(t *testing.T) { } } -func TestMatchTargetFilterReturnsProperEmptyVal(t *testing.T) { - emptyFilters := []string{} - assert.True(t, matchFilter(emptyFilters, "sometarget.com", true)) - assert.False(t, matchFilter(emptyFilters, "sometarget.com", false)) +func TestTargetNetFilter_IsEnabled(t *testing.T) { + tests := []struct { + name string + filterNets []string + excludeNets []string + want bool + }{ + {"both empty", []string{}, []string{}, false}, + {"filterNets non-empty", []string{"10.0.0.0/8"}, []string{}, true}, + {"excludeNets non-empty", []string{}, []string{"10.0.0.0/8"}, true}, + {"both non-empty", []string{"10.0.0.0/8"}, []string{"192.168.0.0/16"}, true}, + } + + for _, tt := range tests { + tf := NewTargetNetFilterWithExclusions(tt.filterNets, tt.excludeNets) + assert.Equal(t, tt.want, tf.IsEnabled()) + } } diff --git a/source/wrappers/targetfiltersource.go b/source/wrappers/targetfiltersource.go index fe6d1d283..afc654d90 100644 --- a/source/wrappers/targetfiltersource.go +++ b/source/wrappers/targetfiltersource.go @@ -21,34 +21,38 @@ import ( log "github.com/sirupsen/logrus" - source2 "sigs.k8s.io/external-dns/source" + "sigs.k8s.io/external-dns/source" "sigs.k8s.io/external-dns/endpoint" ) // targetFilterSource is a Source that removes endpoints matching the target filter from its wrapped source. type targetFilterSource struct { - source source2.Source + source source.Source targetFilter endpoint.TargetFilterInterface } // NewTargetFilterSource creates a new targetFilterSource wrapping the provided Source. -func NewTargetFilterSource(source source2.Source, targetFilter endpoint.TargetFilterInterface) source2.Source { +func NewTargetFilterSource(source source.Source, targetFilter endpoint.TargetFilterInterface) source.Source { return &targetFilterSource{source: source, targetFilter: targetFilter} } // Endpoints collects endpoints from its wrapped source and returns // them without targets matching the target filter. func (ms *targetFilterSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { - result := []*endpoint.Endpoint{} - endpoints, err := ms.source.Endpoints(ctx) if err != nil { return nil, err } + if !ms.targetFilter.IsEnabled() { + return endpoints, nil + } + + result := make([]*endpoint.Endpoint, 0, len(endpoints)) + for _, ep := range endpoints { - filteredTargets := []string{} + filteredTargets := make([]string, 0, len(ep.Targets)) for _, t := range ep.Targets { if ms.targetFilter.Match(t) { @@ -71,5 +75,7 @@ func (ms *targetFilterSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoi } func (ms *targetFilterSource) AddEventHandler(ctx context.Context, handler func()) { - ms.source.AddEventHandler(ctx, handler) + if ms.targetFilter.IsEnabled() { + ms.source.AddEventHandler(ctx, handler) + } } diff --git a/source/wrappers/targetfiltersource_test.go b/source/wrappers/targetfiltersource_test.go index 733c1e1a3..e0e01d745 100644 --- a/source/wrappers/targetfiltersource_test.go +++ b/source/wrappers/targetfiltersource_test.go @@ -19,6 +19,7 @@ package wrappers import ( "testing" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "golang.org/x/net/context" "sigs.k8s.io/external-dns/source" @@ -42,15 +43,21 @@ func (m *mockTargetNetFilter) Match(target string) bool { return m.targets[target] } +func (m *mockTargetNetFilter) IsEnabled() bool { + return true +} + // echoSource is a Source that returns the endpoints passed in on creation. type echoSource struct { + mock.Mock endpoints []*endpoint.Endpoint } func (e *echoSource) AddEventHandler(ctx context.Context, handler func()) { + e.Called(ctx) } -// Endpoints returns all of the endpoints passed in on creation +// Endpoints returns all the endpoints passed in on creation func (e *echoSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) { return e.endpoints, nil } @@ -63,7 +70,7 @@ func NewEchoSource(endpoints []*endpoint.Endpoint) source.Source { func TestEchoSourceReturnGivenSources(t *testing.T) { startEndpoints := []*endpoint.Endpoint{{ DNSName: "foo.bar.com", - RecordType: "A", + RecordType: endpoint.RecordTypeA, Targets: endpoint.Targets{"1.2.3.4"}, RecordTTL: endpoint.TTL(300), Labels: endpoint.Labels{}, @@ -75,9 +82,9 @@ func TestEchoSourceReturnGivenSources(t *testing.T) { t.Errorf("Expected no error but got %s", err.Error()) } - for i, endpoint := range endpoints { - if endpoint != startEndpoints[i] { - t.Errorf("Expected %s but got %s", startEndpoints[i], endpoint) + for i, ep := range endpoints { + if ep != startEndpoints[i] { + t.Errorf("Expected %s but got %s", startEndpoints[i], ep) } } } @@ -107,28 +114,28 @@ func TestTargetFilterSourceEndpoints(t *testing.T) { title: "filter exclusion all", filters: NewMockTargetNetFilter([]string{}), endpoints: []*endpoint.Endpoint{ - endpoint.NewEndpoint("foo", "A", "1.2.3.4"), - endpoint.NewEndpoint("foo", "A", "1.2.3.5"), - endpoint.NewEndpoint("foo", "A", "1.2.3.6"), - endpoint.NewEndpoint("foo", "A", "1.3.4.5"), - endpoint.NewEndpoint("foo", "A", "1.4.4.5")}, + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.4"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.5"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.6"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.3.4.5"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.4.4.5")}, expected: []*endpoint.Endpoint{}, }, { title: "filter exclude internal net", filters: NewMockTargetNetFilter([]string{"8.8.8.8"}), endpoints: []*endpoint.Endpoint{ - endpoint.NewEndpoint("foo", "A", "10.0.0.1"), - endpoint.NewEndpoint("foo", "A", "8.8.8.8")}, - expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", "A", "8.8.8.8")}, + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")}, + expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")}, }, { title: "filter only internal", filters: NewMockTargetNetFilter([]string{"10.0.0.1"}), endpoints: []*endpoint.Endpoint{ - endpoint.NewEndpoint("foo", "A", "10.0.0.1"), - endpoint.NewEndpoint("foo", "A", "8.8.8.8")}, - expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", "A", "10.0.0.1")}, + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")}, + expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1")}, }, } for _, tt := range tests { @@ -145,3 +152,119 @@ func TestTargetFilterSourceEndpoints(t *testing.T) { }) } } + +func TestTargetFilterConcreteTargetFilter(t *testing.T) { + tests := []struct { + title string + filters endpoint.TargetFilterInterface + endpoints []*endpoint.Endpoint + expected []*endpoint.Endpoint + }{ + { + title: "should skip filtering if no filters are set", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{}), + endpoints: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.4"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.5"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.6"), + }, + expected: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.4"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.5"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.6"), + }, + }, + { + title: "should include all targets when filters are not correctly set", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{"8.8.8.8"}, []string{}), + endpoints: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")}, + expected: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8"), + }, + }, + { + title: "should include internal when include filter is set", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{"10.0.0.0/8"}, []string{}), + endpoints: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "49.13.41.161")}, + expected: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"), + }, + }, + { + title: "exclude internal keep public ips", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{"10.0.0.0/8"}), + endpoints: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.1.101"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "49.13.41.161")}, + expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "49.13.41.161")}, + }, + { + title: "should not exclude ipv6 when excluding ipv4", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{"10.0.0.0/8"}), + endpoints: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "2a01:asdf:asdf:asdf::1"), + }, + expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "2a01:asdf:asdf:asdf::1")}, + }, + { + title: "should not include ipv6 when including ipv4", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{"10.0.0.0/8"}, []string{}), + endpoints: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "2a01:asdf:asdf:asdf::1"), + }, + expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43")}, + }, + } + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + echo := NewEchoSource(tt.endpoints) + src := NewTargetFilterSource(echo, tt.filters) + + endpoints, err := src.Endpoints(context.Background()) + require.NoError(t, err, "failed to get Endpoints") + + validateEndpoints(t, endpoints, tt.expected) + }) + } +} + +func TestTargetFilterSource_AddEventHandler(t *testing.T) { + tests := []struct { + title string + filters endpoint.TargetFilterInterface + times int + }{ + { + title: "should add event handler if target filter is enabled", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{"10.0.0.0/8"}, []string{}), + times: 1, + }, + { + title: "should not add event handler if target filter is disabled", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{}), + times: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + echo := NewEchoSource([]*endpoint.Endpoint{}) + + m := echo.(*echoSource) + m.On("AddEventHandler", t.Context()).Return() + + src := NewTargetFilterSource(echo, tt.filters) + src.AddEventHandler(t.Context(), func() {}) + + m.AssertNumberOfCalls(t, "AddEventHandler", tt.times) + }) + } +}