diff --git a/promql/durations.go b/promql/durations.go index 20fa095d53..6249faa539 100644 --- a/promql/durations.go +++ b/promql/durations.go @@ -21,12 +21,20 @@ import ( "github.com/prometheus/prometheus/promql/parser" ) -// durationVisitor is a visitor that visits a duration expression and calculates the duration. -type durationVisitor struct { +// DurationVisitor is a visitor that visits a duration expression and calculates the duration. +type DurationVisitor struct { step time.Duration } -func (v *durationVisitor) Visit(node parser.Node, _ []parser.Node) (parser.Visitor, error) { +// NewDurationVisitor creates a visitor to Walk and evaluate duration +// expressions in a parsed query. +func NewDurationVisitor(step time.Duration) *DurationVisitor { + return &DurationVisitor{ + step: step, + } +} + +func (v *DurationVisitor) Visit(node parser.Node, _ []parser.Node) (parser.Visitor, error) { switch n := node.(type) { case *parser.VectorSelector: if n.OriginalOffsetExpr != nil { @@ -71,7 +79,7 @@ func (v *durationVisitor) Visit(node parser.Node, _ []parser.Node) (parser.Visit } // calculateDuration computes the duration from a duration expression. -func (v *durationVisitor) calculateDuration(expr parser.Expr, allowedNegative bool) (time.Duration, error) { +func (v *DurationVisitor) calculateDuration(expr parser.Expr, allowedNegative bool) (time.Duration, error) { duration, err := v.evaluateDurationExpr(expr) if err != nil { return 0, err @@ -86,7 +94,7 @@ func (v *durationVisitor) calculateDuration(expr parser.Expr, allowedNegative bo } // evaluateDurationExpr recursively evaluates a duration expression to a float64 value. -func (v *durationVisitor) evaluateDurationExpr(expr parser.Expr) (float64, error) { +func (v *DurationVisitor) evaluateDurationExpr(expr parser.Expr) (float64, error) { switch n := expr.(type) { case *parser.NumberLiteral: return n.Val, nil diff --git a/promql/durations_test.go b/promql/durations_test.go index 18592a0d0a..bee3cd7800 100644 --- a/promql/durations_test.go +++ b/promql/durations_test.go @@ -41,7 +41,7 @@ func TestDurationVisitor(t *testing.T) { expr, err := parser.ParseExpr(complexExpr) require.NoError(t, err) - err = parser.Walk(&durationVisitor{}, expr, nil) + err = parser.Walk(NewDurationVisitor(0), expr, nil) require.NoError(t, err) // Verify different parts of the expression have correct durations. @@ -243,7 +243,7 @@ func TestCalculateDuration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &durationVisitor{step: 1 * time.Second} + v := NewDurationVisitor(1 * time.Second) result, err := v.calculateDuration(tt.expr, tt.allowedNegative) if tt.errorMessage != "" { require.Error(t, err) diff --git a/promql/engine.go b/promql/engine.go index f5ee591d3b..3dc9ffb920 100644 --- a/promql/engine.go +++ b/promql/engine.go @@ -3734,7 +3734,7 @@ func unwrapStepInvariantExpr(e parser.Expr) parser.Expr { func PreprocessExpr(expr parser.Expr, start, end time.Time, step time.Duration) (parser.Expr, error) { detectHistogramStatsDecoding(expr) - if err := parser.Walk(&durationVisitor{step: step}, expr, nil); err != nil { + if err := parser.Walk(NewDurationVisitor(step), expr, nil); err != nil { return nil, err }