diff --git a/util/eventbus/eventbustest/eventbustest.go b/util/eventbus/eventbustest/eventbustest.go index af725ace1..d5cfe5395 100644 --- a/util/eventbus/eventbustest/eventbustest.go +++ b/util/eventbus/eventbustest/eventbustest.go @@ -79,6 +79,11 @@ func Type[T any]() func(T) { return func(T) {} } // // The if error != nil, the test helper will return that error immediately. // func(e ExpectedType) (bool, error) // +// // Tests for event type and whatever is defined in the body. +// // If a non-nil error is reported, the test helper will return that error +// // immediately; otherwise the expectation is considered to be met. +// func(e ExpectedType) error +// // If the list of events must match exactly with no extra events, // use [ExpectExactly]. func Expect(tw *Watcher, filters ...any) error { @@ -179,15 +184,22 @@ func eventFilter(f any) filter { return []reflect.Value{reflect.ValueOf(true), reflect.Zero(reflect.TypeFor[error]())} } case 1: - if ft.Out(0) != reflect.TypeFor[bool]() { - panic(fmt.Sprintf("result is %T, want bool", ft.Out(0))) - } - fixup = func(vals []reflect.Value) []reflect.Value { - return append(vals, reflect.Zero(reflect.TypeFor[error]())) + switch ft.Out(0) { + case reflect.TypeFor[bool](): + fixup = func(vals []reflect.Value) []reflect.Value { + return append(vals, reflect.Zero(reflect.TypeFor[error]())) + } + case reflect.TypeFor[error](): + fixup = func(vals []reflect.Value) []reflect.Value { + pass := vals[0].IsZero() + return append([]reflect.Value{reflect.ValueOf(pass)}, vals...) + } + default: + panic(fmt.Sprintf("result is %v, want bool or error", ft.Out(0))) } case 2: if ft.Out(0) != reflect.TypeFor[bool]() || ft.Out(1) != reflect.TypeFor[error]() { - panic(fmt.Sprintf("results are %T, %T; want bool, error", ft.Out(0), ft.Out(1))) + panic(fmt.Sprintf("results are %v, %v; want bool, error", ft.Out(0), ft.Out(1))) } fixup = func(vals []reflect.Value) []reflect.Value { return vals } default: diff --git a/util/eventbus/eventbustest/eventbustest_test.go b/util/eventbus/eventbustest/eventbustest_test.go index fd95973e5..351553cc8 100644 --- a/util/eventbus/eventbustest/eventbustest_test.go +++ b/util/eventbus/eventbustest/eventbustest_test.go @@ -54,6 +54,27 @@ func TestExpectFilter(t *testing.T) { }, wantErr: false, }, + { + name: "filter-with-nil-error", + events: []int{1, 2, 3}, + expectFunc: func(event EventFoo) error { + if event.Value > 10 { + return fmt.Errorf("value > 10: %d", event.Value) + } + return nil + }, + }, + { + name: "filter-with-non-nil-error", + events: []int{100, 200, 300}, + expectFunc: func(event EventFoo) error { + if event.Value > 10 { + return fmt.Errorf("value > 10: %d", event.Value) + } + return nil + }, + wantErr: true, + }, { name: "first event has to be func", events: []int{24, 42},