diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 67bf6f54d1..69f5ce58dd 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -73,20 +73,41 @@ const ( checkContextEveryNIterations = 128 ) -type errorType string +type errorNum int + +type errorType struct { + num errorNum + str string +} const ( - errorNone errorType = "" - errorTimeout errorType = "timeout" - errorCanceled errorType = "canceled" - errorExec errorType = "execution" - errorBadData errorType = "bad_data" - errorInternal errorType = "internal" - errorUnavailable errorType = "unavailable" - errorNotFound errorType = "not_found" - errorNotAcceptable errorType = "not_acceptable" + ErrorNone errorNum = iota + ErrorTimeout + ErrorCanceled + ErrorExec + ErrorBadData + ErrorInternal + ErrorUnavailable + ErrorNotFound + ErrorNotAcceptable ) +var ( + errorNone = errorType{ErrorNone, ""} + errorTimeout = errorType{ErrorTimeout, "timeout"} + errorCanceled = errorType{ErrorCanceled, "canceled"} + errorExec = errorType{ErrorExec, "execution"} + errorBadData = errorType{ErrorBadData, "bad_data"} + errorInternal = errorType{ErrorInternal, "internal"} + errorUnavailable = errorType{ErrorUnavailable, "unavailable"} + errorNotFound = errorType{ErrorNotFound, "not_found"} + errorNotAcceptable = errorType{ErrorNotAcceptable, "not_acceptable"} +) + +// OverrideErrorCode can be used to override status code for different error types. +// Return false to fall back to default status code. +type OverrideErrorCode func(errorNum, error) (code int, override bool) + var LocalhostRepresentations = []string{"127.0.0.1", "localhost", "::1"} type apiError struct { @@ -95,7 +116,7 @@ type apiError struct { } func (e *apiError) Error() string { - return fmt.Sprintf("%s: %s", e.typ, e.err) + return fmt.Sprintf("%s: %s", e.typ.str, e.err) } // ScrapePoolsRetriever provide the list of all scrape pools. @@ -164,7 +185,7 @@ type RuntimeInfo struct { type Response struct { Status status `json:"status"` Data interface{} `json:"data,omitempty"` - ErrorType errorType `json:"errorType,omitempty"` + ErrorType string `json:"errorType,omitempty"` Error string `json:"error,omitempty"` Warnings []string `json:"warnings,omitempty"` Infos []string `json:"infos,omitempty"` @@ -223,6 +244,8 @@ type API struct { statsRenderer StatsRenderer notificationsGetter func() []notifications.Notification notificationsSub func() (<-chan notifications.Notification, func(), bool) + // Allows customizing the default mapping + overrideErrorCode OverrideErrorCode remoteWriteHandler http.Handler remoteReadHandler http.Handler @@ -267,6 +290,7 @@ func NewAPI( ctZeroIngestionEnabled bool, lookbackDelta time.Duration, enableTypeAndUnitLabels bool, + overrideErrorCode OverrideErrorCode, ) *API { a := &API{ QueryEngine: qe, @@ -295,6 +319,7 @@ func NewAPI( statsRenderer: DefaultStatsRenderer, notificationsGetter: notificationsGetter, notificationsSub: notificationsSub, + overrideErrorCode: overrideErrorCode, remoteReadHandler: remote.NewReadHandler(logger, registerer, q, configFunc, remoteReadSampleLimit, remoteReadConcurrencyLimit, remoteReadMaxBytesInFrame), } @@ -2029,7 +2054,7 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter json := jsoniter.ConfigCompatibleWithStandardLibrary b, err := json.Marshal(&Response{ Status: statusError, - ErrorType: apiErr.typ, + ErrorType: apiErr.typ.str, Error: apiErr.err.Error(), Data: data, }) @@ -2040,23 +2065,14 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter } var code int - switch apiErr.typ { - case errorBadData: - code = http.StatusBadRequest - case errorExec: - code = http.StatusUnprocessableEntity - case errorCanceled: - code = statusClientClosedConnection - case errorTimeout: - code = http.StatusServiceUnavailable - case errorInternal: - code = http.StatusInternalServerError - case errorNotFound: - code = http.StatusNotFound - case errorNotAcceptable: - code = http.StatusNotAcceptable - default: - code = http.StatusInternalServerError + if api.overrideErrorCode != nil { + if newCode, override := api.overrideErrorCode(apiErr.typ.num, apiErr.err); override { + code = newCode + } else { + code = getDefaultErrorCode(apiErr.typ) + } + } else { + code = getDefaultErrorCode(apiErr.typ) } w.Header().Set("Content-Type", "application/json") @@ -2066,6 +2082,27 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter } } +func getDefaultErrorCode(errType errorType) int { + switch errType { + case errorBadData: + return http.StatusBadRequest + case errorExec: + return http.StatusUnprocessableEntity + case errorCanceled: + return statusClientClosedConnection + case errorTimeout: + return http.StatusServiceUnavailable + case errorInternal: + return http.StatusInternalServerError + case errorNotFound: + return http.StatusNotFound + case errorNotAcceptable: + return http.StatusNotAcceptable + default: + return http.StatusInternalServerError + } +} + func parseTimeParam(r *http.Request, paramName string, defaultValue time.Time) (time.Time, error) { val := r.FormValue(paramName) if val == "" { diff --git a/web/api/v1/api_test.go b/web/api/v1/api_test.go index 1f4828da00..107ed8bab1 100644 --- a/web/api/v1/api_test.go +++ b/web/api/v1/api_test.go @@ -980,11 +980,11 @@ func TestStats(t *testing.T) { req, err := request(method, tc.param) require.NoError(t, err) res := api.query(req.WithContext(ctx)) - assertAPIError(t, res.err, "") + assertAPIError(t, res.err, errorNone) tc.expected(t, res.data) res = api.queryRange(req.WithContext(ctx)) - assertAPIError(t, res.err, "") + assertAPIError(t, res.err, errorNone) tc.expected(t, res.data) } }) @@ -3761,7 +3761,7 @@ func describeAPIFunc(f apiFunc) string { func assertAPIError(t *testing.T, got *apiError, exp errorType) { t.Helper() - if exp == errorNone { + if exp.num == ErrorNone { require.Nil(t, got) } else { require.NotNil(t, got) diff --git a/web/api/v1/errors_test.go b/web/api/v1/errors_test.go index 92ea1cc1c8..fd2e92a850 100644 --- a/web/api/v1/errors_test.go +++ b/web/api/v1/errors_test.go @@ -20,6 +20,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" @@ -41,9 +42,10 @@ import ( func TestApiStatusCodes(t *testing.T) { for name, tc := range map[string]struct { - err error - expectedString string - expectedCode int + err error + expectedString string + expectedCode int + overrideErrorCode OverrideErrorCode }{ "random error": { err: errors.New("some random error"), @@ -57,6 +59,22 @@ func TestApiStatusCodes(t *testing.T) { expectedCode: http.StatusUnprocessableEntity, }, + "overridden error code for engine error": { + err: promql.ErrTooManySamples("some error"), + expectedString: "too many samples", + overrideErrorCode: func(errNum errorNum, err error) (code int, override bool) { + if errNum == ErrorExec { + if strings.Contains(err.Error(), "some error") { + return 999, true + } + return 998, true + } + + return 0, false + }, + expectedCode: 999, + }, + "promql.ErrQueryCanceled": { err: promql.ErrQueryCanceled("some error"), expectedString: "query was canceled", @@ -87,7 +105,7 @@ func TestApiStatusCodes(t *testing.T) { "error from seriesset": errorTestQueryable{q: errorTestQuerier{s: errorTestSeriesSet{err: tc.err}}}, } { t.Run(fmt.Sprintf("%s/%s", name, k), func(t *testing.T) { - r := createPrometheusAPI(t, q) + r := createPrometheusAPI(t, q, tc.overrideErrorCode) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/v1/query?query=up", nil) @@ -101,7 +119,7 @@ func TestApiStatusCodes(t *testing.T) { } } -func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable) *route.Router { +func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable, overrideErrorCode OverrideErrorCode) *route.Router { t.Helper() engine := promqltest.NewTestEngineWithOpts(t, promql.EngineOpts{ @@ -147,6 +165,7 @@ func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable) *route false, 5*time.Minute, false, + overrideErrorCode, ) promRouter := route.New().WithPrefix("/api/v1") diff --git a/web/web.go b/web/web.go index 108b4870f9..80aa845883 100644 --- a/web/web.go +++ b/web/web.go @@ -395,6 +395,7 @@ func New(logger *slog.Logger, o *Options) *Handler { o.CTZeroIngestionEnabled, o.LookbackDelta, o.EnableTypeAndUnitLabels, + nil, ) if o.RoutePrefix != "/" {