chore(source/net-filter): improve flow logic and add more tests (#5629)

Signed-off-by: ivan katliarchuk <ivan.katliarchuk@gmail.com>
This commit is contained in:
Ivan Ka 2025-07-11 17:47:51 +01:00 committed by GitHub
parent 8088b57dd1
commit 28d0ff9316
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 199 additions and 35 deletions

View File

@ -949,3 +949,9 @@ func TestDomainFilterNormalizeDomain(t *testing.T) {
assert.Equal(t, r.expect, gotName)
}
}
func TestMatchTargetFilterReturnsProperEmptyVal(t *testing.T) {
var emptyFilters []string
assert.True(t, matchFilter(emptyFilters, "sometarget.com", true))
assert.False(t, matchFilter(emptyFilters, "sometarget.com", false))
}

View File

@ -26,12 +26,13 @@ import (
// TargetFilterInterface defines the interface to select matching targets for a specific provider or runtime
type TargetFilterInterface interface {
Match(target string) bool
IsEnabled() bool
}
// TargetNetFilter holds a lists of valid target names
type TargetNetFilter struct {
// FilterNets define what targets to match
FilterNets []*net.IPNet
// filterNets define what targets to match
filterNets []*net.IPNet
// excludeNets define what targets not to match
excludeNets []*net.IPNet
}
@ -42,11 +43,9 @@ func prepareTargetFilters(filters []string) []*net.IPNet {
for _, filter := range filters {
filter = strings.TrimSpace(filter)
_, filterNet, err := net.ParseCIDR(filter)
if err != nil {
log.Errorf("Invalid target net filter: %s", filter)
continue
}
@ -57,12 +56,17 @@ func prepareTargetFilters(filters []string) []*net.IPNet {
// NewTargetNetFilterWithExclusions returns a new TargetNetFilter, given a list of matches and exclusions
func NewTargetNetFilterWithExclusions(targetFilterNets []string, excludeNets []string) TargetNetFilter {
return TargetNetFilter{FilterNets: prepareTargetFilters(targetFilterNets), excludeNets: prepareTargetFilters(excludeNets)}
return TargetNetFilter{filterNets: prepareTargetFilters(targetFilterNets), excludeNets: prepareTargetFilters(excludeNets)}
}
// Match checks whether a target can be found in the TargetNetFilter.
func (tf TargetNetFilter) Match(target string) bool {
return matchTargetNetFilter(tf.FilterNets, target, true) && !matchTargetNetFilter(tf.excludeNets, target, false)
return matchTargetNetFilter(tf.filterNets, target, true) && !matchTargetNetFilter(tf.excludeNets, target, false)
}
// IsEnabled returns true if any filters or exclusions are set.
func (tf TargetNetFilter) IsEnabled() bool {
return len(tf.filterNets) > 0 || len(tf.excludeNets) > 0
}
// matchTargetNetFilter determines if any `filters` match `target`.
@ -73,9 +77,9 @@ func matchTargetNetFilter(filters []*net.IPNet, target string, emptyval bool) bo
return emptyval
}
for _, filter := range filters {
ip := net.ParseIP(target)
ip := net.ParseIP(target)
for _, filter := range filters {
if filter.Contains(ip) {
return true
}

View File

@ -66,6 +66,18 @@ var targetFilterTests = []targetFilterTest{
[]string{"10.1.2.3"},
false,
},
{
[]string{},
[]string{"10.0.0.0/8"},
[]string{"49.13.41.161"},
true,
},
{
[]string{},
[]string{"10.0.0.0/8"},
[]string{"10.0.1.101"},
false,
},
}
func TestTargetFilterWithExclusions(t *testing.T) {
@ -89,8 +101,21 @@ func TestTargetFilterMatchWithEmptyFilter(t *testing.T) {
}
}
func TestMatchTargetFilterReturnsProperEmptyVal(t *testing.T) {
emptyFilters := []string{}
assert.True(t, matchFilter(emptyFilters, "sometarget.com", true))
assert.False(t, matchFilter(emptyFilters, "sometarget.com", false))
func TestTargetNetFilter_IsEnabled(t *testing.T) {
tests := []struct {
name string
filterNets []string
excludeNets []string
want bool
}{
{"both empty", []string{}, []string{}, false},
{"filterNets non-empty", []string{"10.0.0.0/8"}, []string{}, true},
{"excludeNets non-empty", []string{}, []string{"10.0.0.0/8"}, true},
{"both non-empty", []string{"10.0.0.0/8"}, []string{"192.168.0.0/16"}, true},
}
for _, tt := range tests {
tf := NewTargetNetFilterWithExclusions(tt.filterNets, tt.excludeNets)
assert.Equal(t, tt.want, tf.IsEnabled())
}
}

View File

@ -21,34 +21,38 @@ import (
log "github.com/sirupsen/logrus"
source2 "sigs.k8s.io/external-dns/source"
"sigs.k8s.io/external-dns/source"
"sigs.k8s.io/external-dns/endpoint"
)
// targetFilterSource is a Source that removes endpoints matching the target filter from its wrapped source.
type targetFilterSource struct {
source source2.Source
source source.Source
targetFilter endpoint.TargetFilterInterface
}
// NewTargetFilterSource creates a new targetFilterSource wrapping the provided Source.
func NewTargetFilterSource(source source2.Source, targetFilter endpoint.TargetFilterInterface) source2.Source {
func NewTargetFilterSource(source source.Source, targetFilter endpoint.TargetFilterInterface) source.Source {
return &targetFilterSource{source: source, targetFilter: targetFilter}
}
// Endpoints collects endpoints from its wrapped source and returns
// them without targets matching the target filter.
func (ms *targetFilterSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) {
result := []*endpoint.Endpoint{}
endpoints, err := ms.source.Endpoints(ctx)
if err != nil {
return nil, err
}
if !ms.targetFilter.IsEnabled() {
return endpoints, nil
}
result := make([]*endpoint.Endpoint, 0, len(endpoints))
for _, ep := range endpoints {
filteredTargets := []string{}
filteredTargets := make([]string, 0, len(ep.Targets))
for _, t := range ep.Targets {
if ms.targetFilter.Match(t) {
@ -71,5 +75,7 @@ func (ms *targetFilterSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoi
}
func (ms *targetFilterSource) AddEventHandler(ctx context.Context, handler func()) {
ms.source.AddEventHandler(ctx, handler)
if ms.targetFilter.IsEnabled() {
ms.source.AddEventHandler(ctx, handler)
}
}

View File

@ -19,6 +19,7 @@ package wrappers
import (
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"sigs.k8s.io/external-dns/source"
@ -42,15 +43,21 @@ func (m *mockTargetNetFilter) Match(target string) bool {
return m.targets[target]
}
func (m *mockTargetNetFilter) IsEnabled() bool {
return true
}
// echoSource is a Source that returns the endpoints passed in on creation.
type echoSource struct {
mock.Mock
endpoints []*endpoint.Endpoint
}
func (e *echoSource) AddEventHandler(ctx context.Context, handler func()) {
e.Called(ctx)
}
// Endpoints returns all of the endpoints passed in on creation
// Endpoints returns all the endpoints passed in on creation
func (e *echoSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoint, error) {
return e.endpoints, nil
}
@ -63,7 +70,7 @@ func NewEchoSource(endpoints []*endpoint.Endpoint) source.Source {
func TestEchoSourceReturnGivenSources(t *testing.T) {
startEndpoints := []*endpoint.Endpoint{{
DNSName: "foo.bar.com",
RecordType: "A",
RecordType: endpoint.RecordTypeA,
Targets: endpoint.Targets{"1.2.3.4"},
RecordTTL: endpoint.TTL(300),
Labels: endpoint.Labels{},
@ -75,9 +82,9 @@ func TestEchoSourceReturnGivenSources(t *testing.T) {
t.Errorf("Expected no error but got %s", err.Error())
}
for i, endpoint := range endpoints {
if endpoint != startEndpoints[i] {
t.Errorf("Expected %s but got %s", startEndpoints[i], endpoint)
for i, ep := range endpoints {
if ep != startEndpoints[i] {
t.Errorf("Expected %s but got %s", startEndpoints[i], ep)
}
}
}
@ -107,28 +114,28 @@ func TestTargetFilterSourceEndpoints(t *testing.T) {
title: "filter exclusion all",
filters: NewMockTargetNetFilter([]string{}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", "A", "1.2.3.4"),
endpoint.NewEndpoint("foo", "A", "1.2.3.5"),
endpoint.NewEndpoint("foo", "A", "1.2.3.6"),
endpoint.NewEndpoint("foo", "A", "1.3.4.5"),
endpoint.NewEndpoint("foo", "A", "1.4.4.5")},
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.4"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.5"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.6"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.3.4.5"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.4.4.5")},
expected: []*endpoint.Endpoint{},
},
{
title: "filter exclude internal net",
filters: NewMockTargetNetFilter([]string{"8.8.8.8"}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", "A", "10.0.0.1"),
endpoint.NewEndpoint("foo", "A", "8.8.8.8")},
expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", "A", "8.8.8.8")},
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")},
expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")},
},
{
title: "filter only internal",
filters: NewMockTargetNetFilter([]string{"10.0.0.1"}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", "A", "10.0.0.1"),
endpoint.NewEndpoint("foo", "A", "8.8.8.8")},
expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", "A", "10.0.0.1")},
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")},
expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1")},
},
}
for _, tt := range tests {
@ -145,3 +152,119 @@ func TestTargetFilterSourceEndpoints(t *testing.T) {
})
}
}
func TestTargetFilterConcreteTargetFilter(t *testing.T) {
tests := []struct {
title string
filters endpoint.TargetFilterInterface
endpoints []*endpoint.Endpoint
expected []*endpoint.Endpoint
}{
{
title: "should skip filtering if no filters are set",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.4"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.5"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.6"),
},
expected: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.4"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.5"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "1.2.3.6"),
},
},
{
title: "should include all targets when filters are not correctly set",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{"8.8.8.8"}, []string{}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8")},
expected: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "8.8.8.8"),
},
},
{
title: "should include internal when include filter is set",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{"10.0.0.0/8"}, []string{}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "49.13.41.161")},
expected: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.0.1"),
},
},
{
title: "exclude internal keep public ips",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{"10.0.0.0/8"}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.1.101"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "49.13.41.161")},
expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "49.13.41.161")},
},
{
title: "should not exclude ipv6 when excluding ipv4",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{"10.0.0.0/8"}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "2a01:asdf:asdf:asdf::1"),
},
expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "2a01:asdf:asdf:asdf::1")},
},
{
title: "should not include ipv6 when including ipv4",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{"10.0.0.0/8"}, []string{}),
endpoints: []*endpoint.Endpoint{
endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43"),
endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "2a01:asdf:asdf:asdf::1"),
},
expected: []*endpoint.Endpoint{endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.0.178.43")},
},
}
for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
echo := NewEchoSource(tt.endpoints)
src := NewTargetFilterSource(echo, tt.filters)
endpoints, err := src.Endpoints(context.Background())
require.NoError(t, err, "failed to get Endpoints")
validateEndpoints(t, endpoints, tt.expected)
})
}
}
func TestTargetFilterSource_AddEventHandler(t *testing.T) {
tests := []struct {
title string
filters endpoint.TargetFilterInterface
times int
}{
{
title: "should add event handler if target filter is enabled",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{"10.0.0.0/8"}, []string{}),
times: 1,
},
{
title: "should not add event handler if target filter is disabled",
filters: endpoint.NewTargetNetFilterWithExclusions([]string{}, []string{}),
times: 0,
},
}
for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
echo := NewEchoSource([]*endpoint.Endpoint{})
m := echo.(*echoSource)
m.On("AddEventHandler", t.Context()).Return()
src := NewTargetFilterSource(echo, tt.filters)
src.AddEventHandler(t.Context(), func() {})
m.AssertNumberOfCalls(t, "AddEventHandler", tt.times)
})
}
}