Engine: Allow error response code to be customized (#16257)

Currently the API always returns http code 422 for engine execution error, and

This PR allows the error code to be overriden, based on the ErrorType and the error itself.

Signed-off-by: Justin Jung <jungjust@amazon.com>
Signed-off-by: Justin Jung <justinjung04@gmail.com>
Co-authored-by: Ayoub Mrini <ayoubmrini424@gmail.com>
This commit is contained in:
Justin Jung 2025-08-19 08:43:47 -07:00 committed by GitHub
parent 93bbf4bc90
commit 0f98dcbc07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 38 deletions

View File

@ -73,20 +73,41 @@ const (
checkContextEveryNIterations = 128 checkContextEveryNIterations = 128
) )
type errorType string type errorNum int
type errorType struct {
num errorNum
str string
}
const ( const (
errorNone errorType = "" ErrorNone errorNum = iota
errorTimeout errorType = "timeout" ErrorTimeout
errorCanceled errorType = "canceled" ErrorCanceled
errorExec errorType = "execution" ErrorExec
errorBadData errorType = "bad_data" ErrorBadData
errorInternal errorType = "internal" ErrorInternal
errorUnavailable errorType = "unavailable" ErrorUnavailable
errorNotFound errorType = "not_found" ErrorNotFound
errorNotAcceptable errorType = "not_acceptable" ErrorNotAcceptable
) )
var (
errorNone = errorType{ErrorNone, ""}
errorTimeout = errorType{ErrorTimeout, "timeout"}
errorCanceled = errorType{ErrorCanceled, "canceled"}
errorExec = errorType{ErrorExec, "execution"}
errorBadData = errorType{ErrorBadData, "bad_data"}
errorInternal = errorType{ErrorInternal, "internal"}
errorUnavailable = errorType{ErrorUnavailable, "unavailable"}
errorNotFound = errorType{ErrorNotFound, "not_found"}
errorNotAcceptable = errorType{ErrorNotAcceptable, "not_acceptable"}
)
// OverrideErrorCode can be used to override status code for different error types.
// Return false to fall back to default status code.
type OverrideErrorCode func(errorNum, error) (code int, override bool)
var LocalhostRepresentations = []string{"127.0.0.1", "localhost", "::1"} var LocalhostRepresentations = []string{"127.0.0.1", "localhost", "::1"}
type apiError struct { type apiError struct {
@ -95,7 +116,7 @@ type apiError struct {
} }
func (e *apiError) Error() string { func (e *apiError) Error() string {
return fmt.Sprintf("%s: %s", e.typ, e.err) return fmt.Sprintf("%s: %s", e.typ.str, e.err)
} }
// ScrapePoolsRetriever provide the list of all scrape pools. // ScrapePoolsRetriever provide the list of all scrape pools.
@ -164,7 +185,7 @@ type RuntimeInfo struct {
type Response struct { type Response struct {
Status status `json:"status"` Status status `json:"status"`
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
ErrorType errorType `json:"errorType,omitempty"` ErrorType string `json:"errorType,omitempty"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
Warnings []string `json:"warnings,omitempty"` Warnings []string `json:"warnings,omitempty"`
Infos []string `json:"infos,omitempty"` Infos []string `json:"infos,omitempty"`
@ -223,6 +244,8 @@ type API struct {
statsRenderer StatsRenderer statsRenderer StatsRenderer
notificationsGetter func() []notifications.Notification notificationsGetter func() []notifications.Notification
notificationsSub func() (<-chan notifications.Notification, func(), bool) notificationsSub func() (<-chan notifications.Notification, func(), bool)
// Allows customizing the default mapping
overrideErrorCode OverrideErrorCode
remoteWriteHandler http.Handler remoteWriteHandler http.Handler
remoteReadHandler http.Handler remoteReadHandler http.Handler
@ -267,6 +290,7 @@ func NewAPI(
ctZeroIngestionEnabled bool, ctZeroIngestionEnabled bool,
lookbackDelta time.Duration, lookbackDelta time.Duration,
enableTypeAndUnitLabels bool, enableTypeAndUnitLabels bool,
overrideErrorCode OverrideErrorCode,
) *API { ) *API {
a := &API{ a := &API{
QueryEngine: qe, QueryEngine: qe,
@ -295,6 +319,7 @@ func NewAPI(
statsRenderer: DefaultStatsRenderer, statsRenderer: DefaultStatsRenderer,
notificationsGetter: notificationsGetter, notificationsGetter: notificationsGetter,
notificationsSub: notificationsSub, notificationsSub: notificationsSub,
overrideErrorCode: overrideErrorCode,
remoteReadHandler: remote.NewReadHandler(logger, registerer, q, configFunc, remoteReadSampleLimit, remoteReadConcurrencyLimit, remoteReadMaxBytesInFrame), remoteReadHandler: remote.NewReadHandler(logger, registerer, q, configFunc, remoteReadSampleLimit, remoteReadConcurrencyLimit, remoteReadMaxBytesInFrame),
} }
@ -2029,7 +2054,7 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter
json := jsoniter.ConfigCompatibleWithStandardLibrary json := jsoniter.ConfigCompatibleWithStandardLibrary
b, err := json.Marshal(&Response{ b, err := json.Marshal(&Response{
Status: statusError, Status: statusError,
ErrorType: apiErr.typ, ErrorType: apiErr.typ.str,
Error: apiErr.err.Error(), Error: apiErr.err.Error(),
Data: data, Data: data,
}) })
@ -2040,23 +2065,14 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter
} }
var code int var code int
switch apiErr.typ { if api.overrideErrorCode != nil {
case errorBadData: if newCode, override := api.overrideErrorCode(apiErr.typ.num, apiErr.err); override {
code = http.StatusBadRequest code = newCode
case errorExec: } else {
code = http.StatusUnprocessableEntity code = getDefaultErrorCode(apiErr.typ)
case errorCanceled: }
code = statusClientClosedConnection } else {
case errorTimeout: code = getDefaultErrorCode(apiErr.typ)
code = http.StatusServiceUnavailable
case errorInternal:
code = http.StatusInternalServerError
case errorNotFound:
code = http.StatusNotFound
case errorNotAcceptable:
code = http.StatusNotAcceptable
default:
code = http.StatusInternalServerError
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -2066,6 +2082,27 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter
} }
} }
func getDefaultErrorCode(errType errorType) int {
switch errType {
case errorBadData:
return http.StatusBadRequest
case errorExec:
return http.StatusUnprocessableEntity
case errorCanceled:
return statusClientClosedConnection
case errorTimeout:
return http.StatusServiceUnavailable
case errorInternal:
return http.StatusInternalServerError
case errorNotFound:
return http.StatusNotFound
case errorNotAcceptable:
return http.StatusNotAcceptable
default:
return http.StatusInternalServerError
}
}
func parseTimeParam(r *http.Request, paramName string, defaultValue time.Time) (time.Time, error) { func parseTimeParam(r *http.Request, paramName string, defaultValue time.Time) (time.Time, error) {
val := r.FormValue(paramName) val := r.FormValue(paramName)
if val == "" { if val == "" {

View File

@ -980,11 +980,11 @@ func TestStats(t *testing.T) {
req, err := request(method, tc.param) req, err := request(method, tc.param)
require.NoError(t, err) require.NoError(t, err)
res := api.query(req.WithContext(ctx)) res := api.query(req.WithContext(ctx))
assertAPIError(t, res.err, "") assertAPIError(t, res.err, errorNone)
tc.expected(t, res.data) tc.expected(t, res.data)
res = api.queryRange(req.WithContext(ctx)) res = api.queryRange(req.WithContext(ctx))
assertAPIError(t, res.err, "") assertAPIError(t, res.err, errorNone)
tc.expected(t, res.data) tc.expected(t, res.data)
} }
}) })
@ -3761,7 +3761,7 @@ func describeAPIFunc(f apiFunc) string {
func assertAPIError(t *testing.T, got *apiError, exp errorType) { func assertAPIError(t *testing.T, got *apiError, exp errorType) {
t.Helper() t.Helper()
if exp == errorNone { if exp.num == ErrorNone {
require.Nil(t, got) require.Nil(t, got)
} else { } else {
require.NotNil(t, got) require.NotNil(t, got)

View File

@ -20,6 +20,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings"
"testing" "testing"
"time" "time"
@ -44,6 +45,7 @@ func TestApiStatusCodes(t *testing.T) {
err error err error
expectedString string expectedString string
expectedCode int expectedCode int
overrideErrorCode OverrideErrorCode
}{ }{
"random error": { "random error": {
err: errors.New("some random error"), err: errors.New("some random error"),
@ -57,6 +59,22 @@ func TestApiStatusCodes(t *testing.T) {
expectedCode: http.StatusUnprocessableEntity, expectedCode: http.StatusUnprocessableEntity,
}, },
"overridden error code for engine error": {
err: promql.ErrTooManySamples("some error"),
expectedString: "too many samples",
overrideErrorCode: func(errNum errorNum, err error) (code int, override bool) {
if errNum == ErrorExec {
if strings.Contains(err.Error(), "some error") {
return 999, true
}
return 998, true
}
return 0, false
},
expectedCode: 999,
},
"promql.ErrQueryCanceled": { "promql.ErrQueryCanceled": {
err: promql.ErrQueryCanceled("some error"), err: promql.ErrQueryCanceled("some error"),
expectedString: "query was canceled", expectedString: "query was canceled",
@ -87,7 +105,7 @@ func TestApiStatusCodes(t *testing.T) {
"error from seriesset": errorTestQueryable{q: errorTestQuerier{s: errorTestSeriesSet{err: tc.err}}}, "error from seriesset": errorTestQueryable{q: errorTestQuerier{s: errorTestSeriesSet{err: tc.err}}},
} { } {
t.Run(fmt.Sprintf("%s/%s", name, k), func(t *testing.T) { t.Run(fmt.Sprintf("%s/%s", name, k), func(t *testing.T) {
r := createPrometheusAPI(t, q) r := createPrometheusAPI(t, q, tc.overrideErrorCode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/query?query=up", nil) req := httptest.NewRequest(http.MethodGet, "/api/v1/query?query=up", nil)
@ -101,7 +119,7 @@ func TestApiStatusCodes(t *testing.T) {
} }
} }
func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable) *route.Router { func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable, overrideErrorCode OverrideErrorCode) *route.Router {
t.Helper() t.Helper()
engine := promqltest.NewTestEngineWithOpts(t, promql.EngineOpts{ engine := promqltest.NewTestEngineWithOpts(t, promql.EngineOpts{
@ -147,6 +165,7 @@ func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable) *route
false, false,
5*time.Minute, 5*time.Minute,
false, false,
overrideErrorCode,
) )
promRouter := route.New().WithPrefix("/api/v1") promRouter := route.New().WithPrefix("/api/v1")

View File

@ -395,6 +395,7 @@ func New(logger *slog.Logger, o *Options) *Handler {
o.CTZeroIngestionEnabled, o.CTZeroIngestionEnabled,
o.LookbackDelta, o.LookbackDelta,
o.EnableTypeAndUnitLabels, o.EnableTypeAndUnitLabels,
nil,
) )
if o.RoutePrefix != "/" { if o.RoutePrefix != "/" {