diff --git a/storage/remote/client.go b/storage/remote/client.go index adfb9910d9..ae42e99235 100644 --- a/storage/remote/client.go +++ b/storage/remote/client.go @@ -69,7 +69,7 @@ type recoverableError struct { } // Store sends a batch of samples to the HTTP endpoint. -func (c *Client) Store(req *prompb.WriteRequest) error { +func (c *Client) Store(ctx context.Context, req *prompb.WriteRequest) error { data, err := proto.Marshal(req) if err != nil { return err @@ -85,6 +85,7 @@ func (c *Client) Store(req *prompb.WriteRequest) error { httpReq.Header.Add("Content-Encoding", "snappy") httpReq.Header.Set("Content-Type", "application/x-protobuf") httpReq.Header.Set("X-Prometheus-Remote-Write-Version", "0.1.0") + httpReq = httpReq.WithContext(ctx) ctx, cancel := context.WithTimeout(context.Background(), c.timeout) defer cancel() diff --git a/storage/remote/client_test.go b/storage/remote/client_test.go index b0b93aad84..31f23cbb85 100644 --- a/storage/remote/client_test.go +++ b/storage/remote/client_test.go @@ -14,6 +14,7 @@ package remote import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -73,7 +74,7 @@ func TestStoreHTTPErrorHandling(t *testing.T) { t.Fatal(err) } - err = c.Store(&prompb.WriteRequest{}) + err = c.Store(context.TODO(), &prompb.WriteRequest{}) if !reflect.DeepEqual(err, test.err) { t.Errorf("%d. Unexpected error; want %v, got %v", i, test.err, err) } diff --git a/storage/remote/queue_manager.go b/storage/remote/queue_manager.go index 19ee53afcc..780a3689a7 100644 --- a/storage/remote/queue_manager.go +++ b/storage/remote/queue_manager.go @@ -14,6 +14,7 @@ package remote import ( + "context" "math" "sync" "sync/atomic" @@ -130,7 +131,7 @@ func init() { // external timeseries database. type StorageClient interface { // Store stores the given samples in the remote storage. - Store(*prompb.WriteRequest) error + Store(context.Context, *prompb.WriteRequest) error // Name identifies the remote storage implementation. Name() string } @@ -376,6 +377,8 @@ type shards struct { queues []chan *model.Sample done chan struct{} running int32 + ctx context.Context + cancel context.CancelFunc } func (t *QueueManager) newShards(numShards int) *shards { @@ -383,11 +386,14 @@ func (t *QueueManager) newShards(numShards int) *shards { for i := 0; i < numShards; i++ { queues[i] = make(chan *model.Sample, t.cfg.Capacity) } + ctx, cancel := context.WithCancel(context.Background()) s := &shards{ qm: t, queues: queues, done: make(chan struct{}), running: int32(numShards), + ctx: ctx, + cancel: cancel, } return s } @@ -403,15 +409,21 @@ func (s *shards) start() { } func (s *shards) stop(deadline time.Duration) { + // Attempt a clean shutdown. for _, shard := range s.queues { close(shard) } - select { case <-s.done: + return case <-time.After(deadline): level.Error(s.qm.logger).Log("msg", "Failed to flush all samples on shutdown") } + + // Force a unclean shutdown. + s.cancel() + <-s.done + return } func (s *shards) enqueue(sample *model.Sample) bool { @@ -455,6 +467,9 @@ func (s *shards) runShard(i int) { for { select { + case <-s.ctx.Done(): + return + case sample, ok := <-queue: if !ok { if len(pendingSamples) > 0 { @@ -502,7 +517,7 @@ func (s *shards) sendSamplesWithBackoff(samples model.Samples) { for retries := s.qm.cfg.MaxRetries; retries > 0; retries-- { begin := time.Now() req := ToWriteRequest(samples) - err := s.qm.client.Store(req) + err := s.qm.client.Store(s.ctx, req) sentBatchDuration.WithLabelValues(s.qm.queueName).Observe(time.Since(begin).Seconds()) if err == nil { diff --git a/storage/remote/queue_manager_test.go b/storage/remote/queue_manager_test.go index b2b2804d38..82169899fc 100644 --- a/storage/remote/queue_manager_test.go +++ b/storage/remote/queue_manager_test.go @@ -14,6 +14,7 @@ package remote import ( + "context" "fmt" "reflect" "sync" @@ -71,7 +72,7 @@ func (c *TestStorageClient) waitForExpectedSamples(t *testing.T) { } } -func (c *TestStorageClient) Store(req *prompb.WriteRequest) error { +func (c *TestStorageClient) Store(_ context.Context, req *prompb.WriteRequest) error { c.mtx.Lock() defer c.mtx.Unlock() count := 0 @@ -211,9 +212,12 @@ func NewTestBlockedStorageClient() *TestBlockingStorageClient { } } -func (c *TestBlockingStorageClient) Store(_ *prompb.WriteRequest) error { +func (c *TestBlockingStorageClient) Store(ctx context.Context, _ *prompb.WriteRequest) error { atomic.AddUint64(&c.numCalls, 1) - <-c.block + select { + case <-c.block: + case <-ctx.Done(): + } return nil } @@ -301,3 +305,26 @@ func TestSpawnNotMoreThanMaxConcurrentSendsGoroutines(t *testing.T) { t.Errorf("Saw %d concurrent sends, expected 1", numCalls) } } + +func TestShutdown(t *testing.T) { + deadline := 10 * time.Second + c := NewTestBlockedStorageClient() + m := NewQueueManager(nil, config.DefaultQueueConfig, nil, nil, c, deadline) + for i := 0; i < config.DefaultQueueConfig.MaxSamplesPerSend; i++ { + m.Append(&model.Sample{ + Metric: model.Metric{ + model.MetricNameLabel: model.LabelValue(fmt.Sprintf("test_metric_%d", i)), + }, + Value: model.SampleValue(i), + Timestamp: model.Time(i), + }) + } + m.Start() + + start := time.Now() + m.Stop() + duration := time.Now().Sub(start) + if duration > deadline+(deadline/10) { + t.Errorf("Took too long to shutdown: %s > %s", duration, deadline) + } +}