diff --git a/.golangci.yml b/.golangci.yml index e96def4ef..4facabf2e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -12,6 +12,7 @@ linters: - misspell - revive - rowserrcheck # Checks whether Rows.Err of rows is checked successfully. + - errchkjson # Checks types passed to the json encoding functions. ref: https://golangci-lint.run/usage/linters/#errchkjson - errorlint # Checking for unchecked errors in Go code https://golangci-lint.run/usage/linters/#errcheck - staticcheck - unconvert diff --git a/provider/akamai/akamai_test.go b/provider/akamai/akamai_test.go index 7d6ecdb37..f517a723c 100644 --- a/provider/akamai/akamai_test.go +++ b/provider/akamai/akamai_test.go @@ -160,7 +160,8 @@ func TestFetchZonesZoneIDFilter(t *testing.T) { stub.setOutput("zone", []interface{}{"test1.testzone.com", "test2.testzone.com"}) x, _ := c.fetchZones() - y, _ := json.Marshal(x) + y, err := json.Marshal(x) + require.NoError(t, err) if assert.NotNil(t, y) { assert.JSONEq(t, "{\"zones\":[{\"contractId\":\"contract\",\"zone\":\"test1.testzone.com\"},{\"contractId\":\"contract\",\"zone\":\"test2.testzone.com\"}]}", string(y)) } @@ -175,7 +176,8 @@ func TestFetchZonesEmpty(t *testing.T) { stub.setOutput("zone", []interface{}{}) x, _ := c.fetchZones() - y, _ := json.Marshal(x) + y, err := json.Marshal(x) + require.NoError(t, err) if assert.NotNil(t, y) { assert.JSONEq(t, "{\"zones\":[]}", string(y)) } diff --git a/provider/godaddy/godaddy_test.go b/provider/godaddy/godaddy_test.go index 57e04e21f..344391f88 100644 --- a/provider/godaddy/godaddy_test.go +++ b/provider/godaddy/godaddy_test.go @@ -26,6 +26,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" ) @@ -49,40 +50,50 @@ var ( func (c *mockGoDaddyClient) Post(endpoint string, input interface{}, output interface{}) error { log.Infof("POST: %s - %v", endpoint, input) stub := c.Called(endpoint, input) - data, _ := json.Marshal(stub.Get(0)) - json.Unmarshal(data, output) + data, err := json.Marshal(stub.Get(0)) + require.NoError(c.currentTest, err) + err = json.Unmarshal(data, output) + require.NoError(c.currentTest, err) return stub.Error(1) } func (c *mockGoDaddyClient) Patch(endpoint string, input interface{}, output interface{}) error { log.Infof("PATCH: %s - %v", endpoint, input) stub := c.Called(endpoint, input) - data, _ := json.Marshal(stub.Get(0)) - json.Unmarshal(data, output) + data, err := json.Marshal(stub.Get(0)) + require.NoError(c.currentTest, err) + err = json.Unmarshal(data, output) + require.NoError(c.currentTest, err) return stub.Error(1) } func (c *mockGoDaddyClient) Put(endpoint string, input interface{}, output interface{}) error { log.Infof("PUT: %s - %v", endpoint, input) stub := c.Called(endpoint, input) - data, _ := json.Marshal(stub.Get(0)) - json.Unmarshal(data, output) + data, err := json.Marshal(stub.Get(0)) + require.NoError(c.currentTest, err) + err = json.Unmarshal(data, output) + require.NoError(c.currentTest, err) return stub.Error(1) } func (c *mockGoDaddyClient) Get(endpoint string, output interface{}) error { log.Infof("GET: %s", endpoint) stub := c.Called(endpoint) - data, _ := json.Marshal(stub.Get(0)) - json.Unmarshal(data, output) + data, err := json.Marshal(stub.Get(0)) + require.NoError(c.currentTest, err) + err = json.Unmarshal(data, output) + require.NoError(c.currentTest, err) return stub.Error(1) } func (c *mockGoDaddyClient) Delete(endpoint string, output interface{}) error { log.Infof("DELETE: %s", endpoint) stub := c.Called(endpoint) - data, _ := json.Marshal(stub.Get(0)) - json.Unmarshal(data, output) + data, err := json.Marshal(stub.Get(0)) + require.NoError(c.currentTest, err) + err = json.Unmarshal(data, output) + require.NoError(c.currentTest, err) return stub.Error(1) } diff --git a/provider/ovh/ovh_test.go b/provider/ovh/ovh_test.go index 4b53db203..49444bf4b 100644 --- a/provider/ovh/ovh_test.go +++ b/provider/ovh/ovh_test.go @@ -41,28 +41,40 @@ type mockOvhClient struct { func (c *mockOvhClient) PostWithContext(ctx context.Context, endpoint string, input interface{}, output interface{}) error { stub := c.Called(endpoint, input) - data, _ := json.Marshal(stub.Get(0)) + data, err := json.Marshal(stub.Get(0)) + if err != nil { + return err + } json.Unmarshal(data, output) return stub.Error(1) } func (c *mockOvhClient) PutWithContext(ctx context.Context, endpoint string, input interface{}, output interface{}) error { stub := c.Called(endpoint, input) - data, _ := json.Marshal(stub.Get(0)) + data, err := json.Marshal(stub.Get(0)) + if err != nil { + return err + } json.Unmarshal(data, output) return stub.Error(1) } func (c *mockOvhClient) GetWithContext(ctx context.Context, endpoint string, output interface{}) error { stub := c.Called(endpoint) - data, _ := json.Marshal(stub.Get(0)) + data, err := json.Marshal(stub.Get(0)) + if err != nil { + return err + } json.Unmarshal(data, output) return stub.Error(1) } func (c *mockOvhClient) DeleteWithContext(ctx context.Context, endpoint string, output interface{}) error { stub := c.Called(endpoint) - data, _ := json.Marshal(stub.Get(0)) + data, err := json.Marshal(stub.Get(0)) + if err != nil { + return err + } json.Unmarshal(data, output) return stub.Error(1) } diff --git a/provider/rfc2136/rfc2136.go b/provider/rfc2136/rfc2136.go index 522d86905..cb7245db0 100644 --- a/provider/rfc2136/rfc2136.go +++ b/provider/rfc2136/rfc2136.go @@ -306,7 +306,7 @@ func (r *rfc2136Provider) List() ([]dns.RR, error) { } // If records were fetched successfully, break out of the loop if len(records) > 0 { - return records, nil + break } } diff --git a/provider/rfc2136/rfc2136_test.go b/provider/rfc2136/rfc2136_test.go index da14ae525..fb5016ed9 100644 --- a/provider/rfc2136/rfc2136_test.go +++ b/provider/rfc2136/rfc2136_test.go @@ -153,7 +153,24 @@ func (r *rfc2136Stub) IncomeTransfer(m *dns.Msg, a string) (env chan *dns.Envelo outChan := make(chan *dns.Envelope) go func() { for _, e := range r.output { - outChan <- e + + var responseEnvelope *dns.Envelope + for _, record := range e.RR { + for _, q := range m.Question { + if strings.HasSuffix(record.Header().Name, q.Name) { + if responseEnvelope == nil { + responseEnvelope = &dns.Envelope{} + } + responseEnvelope.RR = append(responseEnvelope.RR, record) + break + } + } + } + + if responseEnvelope == nil { + continue + } + outChan <- responseEnvelope } close(outChan) }() @@ -161,7 +178,7 @@ func (r *rfc2136Stub) IncomeTransfer(m *dns.Msg, a string) (env chan *dns.Envelo return outChan, nil } -func createRfc2136StubProvider(stub *rfc2136Stub) (provider.Provider, error) { +func createRfc2136StubProvider(stub *rfc2136Stub, zoneNames ...string) (provider.Provider, error) { tlsConfig := TLSConfig{ UseTLS: false, SkipTLSVerify: false, @@ -169,7 +186,7 @@ func createRfc2136StubProvider(stub *rfc2136Stub) (provider.Provider, error) { ClientCertFilePath: "", ClientCertKeyFilePath: "", } - return NewRfc2136Provider([]string{""}, 0, nil, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{}, false, 300*time.Second, false, false, "", "", "", 50, tlsConfig, "", stub) + return NewRfc2136Provider([]string{""}, 0, zoneNames, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{}, false, 300*time.Second, false, false, "", "", "", 50, tlsConfig, "", stub) } func createRfc2136StubProviderWithHosts(stub *rfc2136Stub) (provider.Provider, error) { @@ -506,7 +523,7 @@ func TestRfc2136GetRecords(t *testing.T) { }) assert.NoError(t, err) - provider, err := createRfc2136StubProvider(stub) + provider, err := createRfc2136StubProvider(stub, "barfoo.com", "foo.com", "bar.com", "foobar.com") assert.NoError(t, err) recs, err := provider.Records(context.Background()) diff --git a/provider/webhook/api/httpapi.go b/provider/webhook/api/httpapi.go index 13c09bca7..fde7d3ab8 100644 --- a/provider/webhook/api/httpapi.go +++ b/provider/webhook/api/httpapi.go @@ -106,7 +106,10 @@ func (p *WebhookServer) AdjustEndpointsHandler(w http.ResponseWriter, req *http. func (p *WebhookServer) NegotiateHandler(w http.ResponseWriter, _ *http.Request) { w.Header().Set(ContentTypeHeader, MediaTypeFormatAndVersion) - json.NewEncoder(w).Encode(p.Provider.GetDomainFilter()) + err := json.NewEncoder(w).Encode(p.Provider.GetDomainFilter()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } } // StartHTTPApi starts a HTTP server given any provider. diff --git a/provider/webhook/api/httpapi_test.go b/provider/webhook/api/httpapi_test.go index 18856753a..9c5cd096e 100644 --- a/provider/webhook/api/httpapi_test.go +++ b/provider/webhook/api/httpapi_test.go @@ -98,7 +98,7 @@ func TestRecordsHandlerRecords(t *testing.T) { // require that the res has the same endpoints as the records slice defer res.Body.Close() require.NotNil(t, res.Body) - endpoints := []*endpoint.Endpoint{} + var endpoints []*endpoint.Endpoint if err := json.NewDecoder(res.Body).Decode(&endpoints); err != nil { t.Errorf("Failed to decode response body: %s", err.Error()) } @@ -318,3 +318,40 @@ func TestStartHTTPApi(t *testing.T) { require.NoError(t, err) require.NoError(t, df.UnmarshalJSON(b)) } + +func TestNegotiateHandler_Success(t *testing.T) { + provider := &FakeWebhookProvider{ + domainFilter: endpoint.NewDomainFilter([]string{"foo.bar.com"}), + } + server := &WebhookServer{Provider: provider} + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + + server.NegotiateHandler(w, req) + res := w.Result() + defer res.Body.Close() + + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, MediaTypeFormatAndVersion, res.Header.Get(ContentTypeHeader)) + + var df endpoint.DomainFilter + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, df.UnmarshalJSON(body)) + require.Equal(t, provider.domainFilter, df) +} + +func TestNegotiateHandler_FiltersWithSpecialEncodings(t *testing.T) { + provider := &FakeWebhookProvider{ + domainFilter: endpoint.NewDomainFilter([]string{"\\u001a", "\\Xfoo.\\u2028, \\u0000.com", ""}), + } + server := &WebhookServer{Provider: provider} + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + + server.NegotiateHandler(w, req) + res := w.Result() + defer res.Body.Close() + + require.Equal(t, http.StatusOK, res.StatusCode) +}