Merge pull request #16946 from thampiotr/fix-memory-corruption-parser

Fix labels memory corruption when using protobuf encoding
This commit is contained in:
Bartlomiej Plotka 2025-07-30 13:16:58 +01:00 committed by GitHub
commit e35c09d84d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 7 deletions

View File

@ -40,7 +40,7 @@ jobs:
- uses: prometheus/promci@443c7fc2397e946bc9f5029e313a9c3441b9b86d # v0.4.7 - uses: prometheus/promci@443c7fc2397e946bc9f5029e313a9c3441b9b86d # v0.4.7
- uses: ./.github/promci/actions/setup_environment - uses: ./.github/promci/actions/setup_environment
- run: go test --tags=dedupelabels ./... - run: go test --tags=dedupelabels ./...
- run: go test --tags=slicelabels -race ./cmd/prometheus - run: go test --tags=slicelabels -race ./cmd/prometheus ./prompb/io/prometheus/client
- run: go test --tags=forcedirectio -race ./tsdb/ - run: go test --tags=forcedirectio -race ./tsdb/
- run: GOARCH=386 go test ./... - run: GOARCH=386 go test ./...
- uses: ./.github/promci/actions/check_proto - uses: ./.github/promci/actions/check_proto

View File

@ -62,6 +62,9 @@ func NewMetricStreamingDecoder(data []byte) *MetricStreamingDecoder {
var errInvalidVarint = errors.New("clientpb: invalid varint encountered") var errInvalidVarint = errors.New("clientpb: invalid varint encountered")
// NextMetricFamily decodes the next metric family from the input without metrics.
// Use NextMetric() to decode metrics. The MetricFamily fields Name, Help and Unit
// are only valid until NextMetricFamily is called again.
func (m *MetricStreamingDecoder) NextMetricFamily() error { func (m *MetricStreamingDecoder) NextMetricFamily() error {
b := m.in[m.inPos:] b := m.in[m.inPos:]
if len(b) == 0 { if len(b) == 0 {
@ -153,6 +156,7 @@ func (m *MetricStreamingDecoder) GetLabel() {
type scratchBuilder interface { type scratchBuilder interface {
Add(name, value string) Add(name, value string)
UnsafeAddBytes(name, value []byte)
} }
// Label parses labels into labels scratch builder. Metric name is missing // Label parses labels into labels scratch builder. Metric name is missing
@ -170,9 +174,9 @@ func (m *MetricStreamingDecoder) Label(b scratchBuilder) error {
} }
// parseLabel is essentially LabelPair.Unmarshal but directly adding into scratch builder // parseLabel is essentially LabelPair.Unmarshal but directly adding into scratch builder
// and reusing strings. // via UnsafeAddBytes method to reuse strings.
func parseLabel(dAtA []byte, b scratchBuilder) error { func parseLabel(dAtA []byte, b scratchBuilder) error {
var name, value string var name, value []byte
l := len(dAtA) l := len(dAtA)
iNdEx := 0 iNdEx := 0
for iNdEx < l { for iNdEx < l {
@ -231,7 +235,7 @@ func parseLabel(dAtA []byte, b scratchBuilder) error {
if postIndex > l { if postIndex > l {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
name = yoloString(dAtA[iNdEx:postIndex]) name = dAtA[iNdEx:postIndex]
if !model.LabelName(name).IsValid() { if !model.LabelName(name).IsValid() {
return fmt.Errorf("invalid label name: %s", name) return fmt.Errorf("invalid label name: %s", name)
} }
@ -266,8 +270,8 @@ func parseLabel(dAtA []byte, b scratchBuilder) error {
if postIndex > l { if postIndex > l {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
value = yoloString(dAtA[iNdEx:postIndex]) value = dAtA[iNdEx:postIndex]
if !utf8.ValidString(value) { if !utf8.ValidString(yoloString(value)) {
return fmt.Errorf("invalid label value: %s", value) return fmt.Errorf("invalid label value: %s", value)
} }
iNdEx = postIndex iNdEx = postIndex
@ -289,7 +293,7 @@ func parseLabel(dAtA []byte, b scratchBuilder) error {
if iNdEx > l { if iNdEx > l {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
b.Add(name, value) b.UnsafeAddBytes(name, value)
return nil return nil
} }

View File

@ -17,13 +17,17 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"math/rand"
"testing" "testing"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/util/pool"
) )
const ( const (
@ -169,3 +173,74 @@ func TestMetricStreamingDecoder(t *testing.T) {
// Expect labels and metricBytes to be static and reusable even after parsing. // Expect labels and metricBytes to be static and reusable even after parsing.
require.Equal(t, `{checksum="", path="github.com/prometheus/client_golang", version="(devel)"}`, firstMetricLset.String()) require.Equal(t, `{checksum="", path="github.com/prometheus/client_golang", version="(devel)"}`, firstMetricLset.String())
} }
func TestMetricStreamingDecoder_LabelsCorruption(t *testing.T) {
lastScrapeSize := 0
var allPreviousLabels []labels.Labels
buffers := pool.New(128, 1024, 2, func(sz int) interface{} { return make([]byte, 0, sz) })
builder := labels.NewScratchBuilder(0)
for _, labelsCount := range []int{1, 2, 3, 5, 8, 5, 3, 2, 1} {
// Get buffer from pool like in scrape.go
b := buffers.Get(lastScrapeSize).([]byte)
buf := bytes.NewBuffer(b)
// Generate some scraped data to parse
mf := &MetricFamily{}
data := generateMetricFamilyText(labelsCount)
require.NoError(t, proto.UnmarshalText(data, mf))
protoBuf, err := proto.Marshal(mf)
require.NoError(t, err)
sizeBuf := make([]byte, binary.MaxVarintLen32)
sizeBufSize := binary.PutUvarint(sizeBuf, uint64(len(protoBuf)))
buf.Write(sizeBuf[:sizeBufSize])
buf.Write(protoBuf)
// Use decoder like protobufparse.go would
b = buf.Bytes()
d := NewMetricStreamingDecoder(b)
require.NoError(t, d.NextMetricFamily())
require.NoError(t, d.NextMetric())
// Get the labels
builder.Reset()
require.NoError(t, d.Label(&builder)) // <- this uses unsafe strings to create labels
lbs := builder.Labels()
allPreviousLabels = append(allPreviousLabels, lbs)
// Validate all labels seen so far remain valid and not corrupted
for _, l := range allPreviousLabels {
require.True(t, l.IsValid(model.LegacyValidation), "encountered corrupted labels: %v", l)
}
lastScrapeSize = len(b)
buffers.Put(b)
}
}
func generateLabels() string {
randomName := fmt.Sprintf("instance_%d", rand.Intn(1000))
randomValue := fmt.Sprintf("value_%d", rand.Intn(1000))
return fmt.Sprintf(`label: <
name: "%s"
value: "%s"
>`, randomName, randomValue)
}
func generateMetricFamilyText(labelsCount int) string {
randomName := fmt.Sprintf("metric_%d", rand.Intn(1000))
randomHelp := fmt.Sprintf("Test metric to demonstrate forced corruption %d.", rand.Intn(1000))
labels10 := ""
for i := 0; i < labelsCount; i++ {
labels10 += generateLabels()
}
return fmt.Sprintf(`name: "%s"
help: "%s"
type: GAUGE
metric: <
%s
gauge: <
value: 1.0
>
>
`, randomName, randomHelp, labels10)
}