Engine: Allow error response code to be customized (#16257)

Currently the API always returns http code 422 for engine execution error, and

This PR allows the error code to be overriden, based on the ErrorType and the error itself.

Signed-off-by: Justin Jung <jungjust@amazon.com>
Signed-off-by: Justin Jung <justinjung04@gmail.com>
Co-authored-by: Ayoub Mrini <ayoubmrini424@gmail.com>
This commit is contained in:
Justin Jung 2025-08-19 08:43:47 -07:00 committed by GitHub
parent 93bbf4bc90
commit 0f98dcbc07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 38 deletions

View File

@ -73,20 +73,41 @@ const (
checkContextEveryNIterations = 128
)
type errorType string
type errorNum int
type errorType struct {
num errorNum
str string
}
const (
errorNone errorType = ""
errorTimeout errorType = "timeout"
errorCanceled errorType = "canceled"
errorExec errorType = "execution"
errorBadData errorType = "bad_data"
errorInternal errorType = "internal"
errorUnavailable errorType = "unavailable"
errorNotFound errorType = "not_found"
errorNotAcceptable errorType = "not_acceptable"
ErrorNone errorNum = iota
ErrorTimeout
ErrorCanceled
ErrorExec
ErrorBadData
ErrorInternal
ErrorUnavailable
ErrorNotFound
ErrorNotAcceptable
)
var (
errorNone = errorType{ErrorNone, ""}
errorTimeout = errorType{ErrorTimeout, "timeout"}
errorCanceled = errorType{ErrorCanceled, "canceled"}
errorExec = errorType{ErrorExec, "execution"}
errorBadData = errorType{ErrorBadData, "bad_data"}
errorInternal = errorType{ErrorInternal, "internal"}
errorUnavailable = errorType{ErrorUnavailable, "unavailable"}
errorNotFound = errorType{ErrorNotFound, "not_found"}
errorNotAcceptable = errorType{ErrorNotAcceptable, "not_acceptable"}
)
// OverrideErrorCode can be used to override status code for different error types.
// Return false to fall back to default status code.
type OverrideErrorCode func(errorNum, error) (code int, override bool)
var LocalhostRepresentations = []string{"127.0.0.1", "localhost", "::1"}
type apiError struct {
@ -95,7 +116,7 @@ type apiError struct {
}
func (e *apiError) Error() string {
return fmt.Sprintf("%s: %s", e.typ, e.err)
return fmt.Sprintf("%s: %s", e.typ.str, e.err)
}
// ScrapePoolsRetriever provide the list of all scrape pools.
@ -164,7 +185,7 @@ type RuntimeInfo struct {
type Response struct {
Status status `json:"status"`
Data interface{} `json:"data,omitempty"`
ErrorType errorType `json:"errorType,omitempty"`
ErrorType string `json:"errorType,omitempty"`
Error string `json:"error,omitempty"`
Warnings []string `json:"warnings,omitempty"`
Infos []string `json:"infos,omitempty"`
@ -223,6 +244,8 @@ type API struct {
statsRenderer StatsRenderer
notificationsGetter func() []notifications.Notification
notificationsSub func() (<-chan notifications.Notification, func(), bool)
// Allows customizing the default mapping
overrideErrorCode OverrideErrorCode
remoteWriteHandler http.Handler
remoteReadHandler http.Handler
@ -267,6 +290,7 @@ func NewAPI(
ctZeroIngestionEnabled bool,
lookbackDelta time.Duration,
enableTypeAndUnitLabels bool,
overrideErrorCode OverrideErrorCode,
) *API {
a := &API{
QueryEngine: qe,
@ -295,6 +319,7 @@ func NewAPI(
statsRenderer: DefaultStatsRenderer,
notificationsGetter: notificationsGetter,
notificationsSub: notificationsSub,
overrideErrorCode: overrideErrorCode,
remoteReadHandler: remote.NewReadHandler(logger, registerer, q, configFunc, remoteReadSampleLimit, remoteReadConcurrencyLimit, remoteReadMaxBytesInFrame),
}
@ -2029,7 +2054,7 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter
json := jsoniter.ConfigCompatibleWithStandardLibrary
b, err := json.Marshal(&Response{
Status: statusError,
ErrorType: apiErr.typ,
ErrorType: apiErr.typ.str,
Error: apiErr.err.Error(),
Data: data,
})
@ -2040,23 +2065,14 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter
}
var code int
switch apiErr.typ {
case errorBadData:
code = http.StatusBadRequest
case errorExec:
code = http.StatusUnprocessableEntity
case errorCanceled:
code = statusClientClosedConnection
case errorTimeout:
code = http.StatusServiceUnavailable
case errorInternal:
code = http.StatusInternalServerError
case errorNotFound:
code = http.StatusNotFound
case errorNotAcceptable:
code = http.StatusNotAcceptable
default:
code = http.StatusInternalServerError
if api.overrideErrorCode != nil {
if newCode, override := api.overrideErrorCode(apiErr.typ.num, apiErr.err); override {
code = newCode
} else {
code = getDefaultErrorCode(apiErr.typ)
}
} else {
code = getDefaultErrorCode(apiErr.typ)
}
w.Header().Set("Content-Type", "application/json")
@ -2066,6 +2082,27 @@ func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data inter
}
}
func getDefaultErrorCode(errType errorType) int {
switch errType {
case errorBadData:
return http.StatusBadRequest
case errorExec:
return http.StatusUnprocessableEntity
case errorCanceled:
return statusClientClosedConnection
case errorTimeout:
return http.StatusServiceUnavailable
case errorInternal:
return http.StatusInternalServerError
case errorNotFound:
return http.StatusNotFound
case errorNotAcceptable:
return http.StatusNotAcceptable
default:
return http.StatusInternalServerError
}
}
func parseTimeParam(r *http.Request, paramName string, defaultValue time.Time) (time.Time, error) {
val := r.FormValue(paramName)
if val == "" {

View File

@ -980,11 +980,11 @@ func TestStats(t *testing.T) {
req, err := request(method, tc.param)
require.NoError(t, err)
res := api.query(req.WithContext(ctx))
assertAPIError(t, res.err, "")
assertAPIError(t, res.err, errorNone)
tc.expected(t, res.data)
res = api.queryRange(req.WithContext(ctx))
assertAPIError(t, res.err, "")
assertAPIError(t, res.err, errorNone)
tc.expected(t, res.data)
}
})
@ -3761,7 +3761,7 @@ func describeAPIFunc(f apiFunc) string {
func assertAPIError(t *testing.T, got *apiError, exp errorType) {
t.Helper()
if exp == errorNone {
if exp.num == ErrorNone {
require.Nil(t, got)
} else {
require.NotNil(t, got)

View File

@ -20,6 +20,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@ -44,6 +45,7 @@ func TestApiStatusCodes(t *testing.T) {
err error
expectedString string
expectedCode int
overrideErrorCode OverrideErrorCode
}{
"random error": {
err: errors.New("some random error"),
@ -57,6 +59,22 @@ func TestApiStatusCodes(t *testing.T) {
expectedCode: http.StatusUnprocessableEntity,
},
"overridden error code for engine error": {
err: promql.ErrTooManySamples("some error"),
expectedString: "too many samples",
overrideErrorCode: func(errNum errorNum, err error) (code int, override bool) {
if errNum == ErrorExec {
if strings.Contains(err.Error(), "some error") {
return 999, true
}
return 998, true
}
return 0, false
},
expectedCode: 999,
},
"promql.ErrQueryCanceled": {
err: promql.ErrQueryCanceled("some error"),
expectedString: "query was canceled",
@ -87,7 +105,7 @@ func TestApiStatusCodes(t *testing.T) {
"error from seriesset": errorTestQueryable{q: errorTestQuerier{s: errorTestSeriesSet{err: tc.err}}},
} {
t.Run(fmt.Sprintf("%s/%s", name, k), func(t *testing.T) {
r := createPrometheusAPI(t, q)
r := createPrometheusAPI(t, q, tc.overrideErrorCode)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/query?query=up", nil)
@ -101,7 +119,7 @@ func TestApiStatusCodes(t *testing.T) {
}
}
func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable) *route.Router {
func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable, overrideErrorCode OverrideErrorCode) *route.Router {
t.Helper()
engine := promqltest.NewTestEngineWithOpts(t, promql.EngineOpts{
@ -147,6 +165,7 @@ func createPrometheusAPI(t *testing.T, q storage.SampleAndChunkQueryable) *route
false,
5*time.Minute,
false,
overrideErrorCode,
)
promRouter := route.New().WithPrefix("/api/v1")

View File

@ -395,6 +395,7 @@ func New(logger *slog.Logger, o *Options) *Handler {
o.CTZeroIngestionEnabled,
o.LookbackDelta,
o.EnableTypeAndUnitLabels,
nil,
)
if o.RoutePrefix != "/" {