From fdaafdb041dcefddd5dfc9e29edc19951b807fe4 Mon Sep 17 00:00:00 2001 From: George Krajcsovits Date: Wed, 15 May 2024 06:26:19 +0200 Subject: [PATCH] tsdb: check for context cancel before regex matching postings (#14096) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tsdb: check for context cancel before regex matching postings Regex matching can be heavy if the regex takes a lot of cycles to evaluate and we can get stuck evaluating postings for a long time without this fix. The constant checkContextEveryNIterations=100 may be changed later. Signed-off-by: György Krajcsovits --- tsdb/index/index.go | 8 ++++ tsdb/index/index_test.go | 33 ++++++++++++++++ tsdb/index/postings.go | 5 +++ tsdb/index/postings_test.go | 17 +++++++++ tsdb/querier.go | 13 +++++++ tsdb/querier_test.go | 75 +++++++++++++++++++++++++++++++++++++ util/testutil/context.go | 27 ++++++++++++- 7 files changed, 177 insertions(+), 1 deletion(-) diff --git a/tsdb/index/index.go b/tsdb/index/index.go index a36c33c4fb..4ded4cbe20 100644 --- a/tsdb/index/index.go +++ b/tsdb/index/index.go @@ -51,6 +51,9 @@ const ( indexFilename = "index" seriesByteAlign = 16 + + // checkContextEveryNIterations is used in some tight loops to check if the context is done. + checkContextEveryNIterations = 100 ) type indexWriterSeries struct { @@ -1797,7 +1800,12 @@ func (r *Reader) postingsForLabelMatchingV1(ctx context.Context, name string, ma } var its []Postings + count := 1 for val, offset := range e { + if count%checkContextEveryNIterations == 0 && ctx.Err() != nil { + return ErrPostings(ctx.Err()) + } + count++ if !match(val) { continue } diff --git a/tsdb/index/index_test.go b/tsdb/index/index_test.go index c451c38dd2..22133d0b70 100644 --- a/tsdb/index/index_test.go +++ b/tsdb/index/index_test.go @@ -719,3 +719,36 @@ func TestChunksTimeOrdering(t *testing.T) { require.NoError(t, idx.Close()) } + +func TestReader_PostingsForLabelMatchingHonorsContextCancel(t *testing.T) { + dir := t.TempDir() + + idx, err := NewWriter(context.Background(), filepath.Join(dir, "index")) + require.NoError(t, err) + + seriesCount := 1000 + for i := 1; i <= seriesCount; i++ { + require.NoError(t, idx.AddSymbol(fmt.Sprintf("%4d", i))) + } + require.NoError(t, idx.AddSymbol("__name__")) + + for i := 1; i <= seriesCount; i++ { + require.NoError(t, idx.AddSeries(storage.SeriesRef(i), labels.FromStrings("__name__", fmt.Sprintf("%4d", i)), + chunks.Meta{Ref: 1, MinTime: 0, MaxTime: 10}, + )) + } + + require.NoError(t, idx.Close()) + + ir, err := NewFileReader(filepath.Join(dir, "index")) + require.NoError(t, err) + defer ir.Close() + + failAfter := uint64(seriesCount / 2) // Fail after processing half of the series. + ctx := &testutil.MockContextErrAfter{FailAfter: failAfter} + p := ir.PostingsForLabelMatching(ctx, "__name__", func(string) bool { + return true + }) + require.Error(t, p.Err()) + require.Equal(t, failAfter, ctx.Count()) +} diff --git a/tsdb/index/postings.go b/tsdb/index/postings.go index 136b3441eb..c6cf20e0e5 100644 --- a/tsdb/index/postings.go +++ b/tsdb/index/postings.go @@ -416,7 +416,12 @@ func (p *MemPostings) PostingsForLabelMatching(ctx context.Context, name string, } var its []Postings + count := 1 for _, v := range vals { + if count%checkContextEveryNIterations == 0 && ctx.Err() != nil { + return ErrPostings(ctx.Err()) + } + count++ if match(v) { its = append(its, NewListPostings(e[v])) } diff --git a/tsdb/index/postings_test.go b/tsdb/index/postings_test.go index 9e6bd23f8c..cabca59774 100644 --- a/tsdb/index/postings_test.go +++ b/tsdb/index/postings_test.go @@ -28,6 +28,7 @@ import ( "github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/storage" + "github.com/prometheus/prometheus/util/testutil" ) func TestMemPostings_addFor(t *testing.T) { @@ -1282,3 +1283,19 @@ func BenchmarkListPostings(b *testing.B) { }) } } + +func TestMemPostings_PostingsForLabelMatchingHonorsContextCancel(t *testing.T) { + memP := NewMemPostings() + seriesCount := 10 * checkContextEveryNIterations + for i := 1; i <= seriesCount; i++ { + memP.Add(storage.SeriesRef(i), labels.FromStrings("__name__", fmt.Sprintf("%4d", i))) + } + + failAfter := uint64(seriesCount / 2 / checkContextEveryNIterations) + ctx := &testutil.MockContextErrAfter{FailAfter: failAfter} + p := memP.PostingsForLabelMatching(ctx, "__name__", func(string) bool { + return true + }) + require.Error(t, p.Err()) + require.Equal(t, failAfter+1, ctx.Count()) // Plus one for the Err() call that puts the error in the result. +} diff --git a/tsdb/querier.go b/tsdb/querier.go index 1170493beb..efd4daf26b 100644 --- a/tsdb/querier.go +++ b/tsdb/querier.go @@ -33,6 +33,9 @@ import ( "github.com/prometheus/prometheus/util/annotations" ) +// checkContextEveryNIterations is used in some tight loops to check if the context is done. +const checkContextEveryNIterations = 100 + type blockBaseQuerier struct { blockID ulid.ULID index IndexReader @@ -358,7 +361,12 @@ func inversePostingsForMatcher(ctx context.Context, ix IndexReader, m *labels.Ma if m.Type == labels.MatchEqual && m.Value == "" { res = vals } else { + count := 1 for _, val := range vals { + if count%checkContextEveryNIterations == 0 && ctx.Err() != nil { + return nil, ctx.Err() + } + count++ if !m.Matches(val) { res = append(res, val) } @@ -387,7 +395,12 @@ func labelValuesWithMatchers(ctx context.Context, r IndexReader, name string, ma // re-use the allValues slice to avoid allocations // this is safe because the iteration is always ahead of the append filteredValues := allValues[:0] + count := 1 for _, v := range allValues { + if count%checkContextEveryNIterations == 0 && ctx.Err() != nil { + return nil, ctx.Err() + } + count++ if m.Matches(v) { filteredValues = append(filteredValues, v) } diff --git a/tsdb/querier_test.go b/tsdb/querier_test.go index 16de6373d0..bb13531d7d 100644 --- a/tsdb/querier_test.go +++ b/tsdb/querier_test.go @@ -38,6 +38,7 @@ import ( "github.com/prometheus/prometheus/tsdb/tombstones" "github.com/prometheus/prometheus/tsdb/tsdbutil" "github.com/prometheus/prometheus/util/annotations" + "github.com/prometheus/prometheus/util/testutil" ) // TODO(bwplotka): Replace those mocks with remote.concreteSeriesSet. @@ -3638,3 +3639,77 @@ func TestQueryWithOneChunkCompletelyDeleted(t *testing.T) { require.NoError(t, css.Err()) require.Equal(t, 1, seriesCount) } + +func TestReader_PostingsForLabelMatchingHonorsContextCancel(t *testing.T) { + ir := mockReaderOfLabels{} + + failAfter := uint64(mockReaderOfLabelsSeriesCount / 2 / checkContextEveryNIterations) + ctx := &testutil.MockContextErrAfter{FailAfter: failAfter} + _, err := labelValuesWithMatchers(ctx, ir, "__name__", labels.MustNewMatcher(labels.MatchRegexp, "__name__", ".+")) + + require.Error(t, err) + require.Equal(t, failAfter+1, ctx.Count()) // Plus one for the Err() call that puts the error in the result. +} + +func TestReader_InversePostingsForMatcherHonorsContextCancel(t *testing.T) { + ir := mockReaderOfLabels{} + + failAfter := uint64(mockReaderOfLabelsSeriesCount / 2 / checkContextEveryNIterations) + ctx := &testutil.MockContextErrAfter{FailAfter: failAfter} + _, err := inversePostingsForMatcher(ctx, ir, labels.MustNewMatcher(labels.MatchRegexp, "__name__", ".*")) + + require.Error(t, err) + require.Equal(t, failAfter+1, ctx.Count()) // Plus one for the Err() call that puts the error in the result. +} + +type mockReaderOfLabels struct{} + +const mockReaderOfLabelsSeriesCount = checkContextEveryNIterations * 10 + +func (m mockReaderOfLabels) LabelValues(context.Context, string, ...*labels.Matcher) ([]string, error) { + return make([]string, mockReaderOfLabelsSeriesCount), nil +} + +func (m mockReaderOfLabels) LabelValueFor(context.Context, storage.SeriesRef, string) (string, error) { + panic("LabelValueFor called") +} + +func (m mockReaderOfLabels) SortedLabelValues(context.Context, string, ...*labels.Matcher) ([]string, error) { + panic("SortedLabelValues called") +} + +func (m mockReaderOfLabels) Close() error { + return nil +} + +func (m mockReaderOfLabels) LabelNames(context.Context, ...*labels.Matcher) ([]string, error) { + panic("LabelNames called") +} + +func (m mockReaderOfLabels) LabelNamesFor(context.Context, ...storage.SeriesRef) ([]string, error) { + panic("LabelNamesFor called") +} + +func (m mockReaderOfLabels) PostingsForLabelMatching(context.Context, string, func(string) bool) index.Postings { + panic("PostingsForLabelMatching called") +} + +func (m mockReaderOfLabels) Postings(context.Context, string, ...string) (index.Postings, error) { + panic("Postings called") +} + +func (m mockReaderOfLabels) ShardedPostings(index.Postings, uint64, uint64) index.Postings { + panic("Postings called") +} + +func (m mockReaderOfLabels) SortedPostings(index.Postings) index.Postings { + panic("SortedPostings called") +} + +func (m mockReaderOfLabels) Series(storage.SeriesRef, *labels.ScratchBuilder, *[]chunks.Meta) error { + panic("Series called") +} + +func (m mockReaderOfLabels) Symbols() index.StringIter { + panic("Series called") +} diff --git a/util/testutil/context.go b/util/testutil/context.go index 3f63b030d7..0c9e0f6f64 100644 --- a/util/testutil/context.go +++ b/util/testutil/context.go @@ -13,7 +13,12 @@ package testutil -import "time" +import ( + "context" + "time" + + "go.uber.org/atomic" +) // A MockContext provides a simple stub implementation of a Context. type MockContext struct { @@ -40,3 +45,23 @@ func (c *MockContext) Err() error { func (c *MockContext) Value(interface{}) interface{} { return nil } + +// MockContextErrAfter is a MockContext that will return an error after a certain +// number of calls to Err(). +type MockContextErrAfter struct { + MockContext + count atomic.Uint64 + FailAfter uint64 +} + +func (c *MockContextErrAfter) Err() error { + c.count.Inc() + if c.count.Load() >= c.FailAfter { + return context.Canceled + } + return c.MockContext.Err() +} + +func (c *MockContextErrAfter) Count() uint64 { + return c.count.Load() +}