diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 62a376b0ba..1a54f23a61 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -436,7 +436,7 @@ func (api *API) query(r *http.Request) (result apiFuncResult) { return invalidParamError(err, "timeout") } - ctx, cancel = context.WithTimeout(ctx, timeout) + ctx, cancel = context.WithDeadline(ctx, api.now().Add(timeout)) defer cancel() } diff --git a/web/api/v1/api_test.go b/web/api/v1/api_test.go index c86165b780..475b4bab54 100644 --- a/web/api/v1/api_test.go +++ b/web/api/v1/api_test.go @@ -404,7 +404,7 @@ func TestEndpoints(t *testing.T) { testEndpoints(t, api, testTargetRetriever, storage, true) }) - // Run all the API tests against a API that is wired to forward queries via + // Run all the API tests against an API that is wired to forward queries via // the remote read client to a test server, which in turn sends them to the // data from the test storage. t.Run("remote", func(t *testing.T) { @@ -3660,3 +3660,107 @@ func TestExtractQueryOpts(t *testing.T) { }) } } + +// Test query timeout parameter. +func TestQueryTimeout(t *testing.T) { + storage := promql.LoadedStorage(t, ` + load 1m + test_metric1{foo="bar"} 0+100x100 + `) + t.Cleanup(func() { + _ = storage.Close() + }) + + now := time.Now() + + for _, tc := range []struct { + name string + method string + }{ + { + name: "GET method", + method: http.MethodGet, + }, + { + name: "POST method", + method: http.MethodPost, + }, + } { + t.Run(tc.name, func(t *testing.T) { + engine := &fakeEngine{} + api := &API{ + Queryable: storage, + QueryEngine: engine, + ExemplarQueryable: storage.ExemplarQueryable(), + alertmanagerRetriever: testAlertmanagerRetriever{}.toFactory(), + flagsMap: sampleFlagMap, + now: func() time.Time { return now }, + config: func() config.Config { return samplePrometheusCfg }, + ready: func(f http.HandlerFunc) http.HandlerFunc { return f }, + } + + query := url.Values{ + "query": []string{"2"}, + "timeout": []string{"1s"}, + } + ctx := context.Background() + req, err := http.NewRequest(tc.method, fmt.Sprintf("http://example.com?%s", query.Encode()), nil) + require.NoError(t, err) + req.RemoteAddr = "127.0.0.1:20201" + + res := api.query(req.WithContext(ctx)) + assertAPIError(t, res.err, errorNone) + + require.Len(t, engine.query.execCalls, 1) + deadline, ok := engine.query.execCalls[0].Deadline() + require.True(t, ok) + require.Equal(t, now.Add(time.Second), deadline) + }) + } +} + +// fakeEngine is a fake QueryEngine implementation. +type fakeEngine struct { + query fakeQuery +} + +func (e *fakeEngine) SetQueryLogger(promql.QueryLogger) {} + +func (e *fakeEngine) NewInstantQuery(ctx context.Context, q storage.Queryable, opts promql.QueryOpts, qs string, ts time.Time) (promql.Query, error) { + return &e.query, nil +} + +func (e *fakeEngine) NewRangeQuery(ctx context.Context, q storage.Queryable, opts promql.QueryOpts, qs string, start, end time.Time, interval time.Duration) (promql.Query, error) { + return &e.query, nil +} + +// fakeQuery is a fake Query implementation. +type fakeQuery struct { + query string + execCalls []context.Context +} + +func (q *fakeQuery) Exec(ctx context.Context) *promql.Result { + q.execCalls = append(q.execCalls, ctx) + return &promql.Result{ + Value: &parser.StringLiteral{ + Val: "test", + }, + } +} + +func (q *fakeQuery) Close() {} + +func (q *fakeQuery) Statement() parser.Statement { + return nil +} + +func (q *fakeQuery) Stats() *stats.Statistics { + return nil +} + +func (q *fakeQuery) Cancel() {} + +func (q *fakeQuery) String() string { + return q.query +}