Skip to content

Commit

Permalink
Add context on which condition failed in WaitForObjectState
Browse files Browse the repository at this point in the history
  • Loading branch information
tnozicka committed Aug 2, 2024
1 parent 740ab7f commit 609b500
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 24 deletions.
100 changes: 76 additions & 24 deletions pkg/controllerhelpers/wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package controllerhelpers
import (
"context"
"fmt"
"strconv"
"strings"

scyllav1 "github.com/scylladb/scylla-operator/pkg/api/scylla/v1"
scyllav1alpha1 "github.com/scylladb/scylla-operator/pkg/api/scylla/v1alpha1"
Expand All @@ -15,6 +17,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apimachinery/pkg/watch"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
rbacv1client "k8s.io/client-go/kubernetes/typed/rbac/v1"
Expand All @@ -31,6 +34,73 @@ type WaitForStateOptions struct {
TolerateDelete bool
}

type AggregatedConditions[Obj runtime.Object] struct {
conditions []func(obj Obj) (bool, error)
state []*bool
}

func NewAggregatedConditions[Obj runtime.Object](condition func(obj Obj) (bool, error), additionalConditions ...func(obj Obj) (bool, error)) *AggregatedConditions[Obj] {
conditions := make([]func(obj Obj) (bool, error), 0, 1+len(additionalConditions))
conditions = append(conditions, condition)
if len(additionalConditions) != 0 {
conditions = append(conditions, additionalConditions...)
}

return &AggregatedConditions[Obj]{
conditions: conditions,
state: make([]*bool, len(conditions)),
}
}

func (ac AggregatedConditions[Obj]) clearState() {
for i := range ac.state {
ac.state[i] = nil
}
}

func (ac AggregatedConditions[Obj]) Condition(obj Obj) (bool, error) {
ac.clearState()

allDone := true
var err error
for i, cond := range ac.conditions {
var done bool
done, err = cond(obj)
ac.state[i] = &done
if err != nil {
return done, err
}

if !done {
allDone = false
}
}

return allDone, nil
}

func (ac AggregatedConditions[Obj]) GetStateString() string {
var sb strings.Builder

sb.WriteString("[")

for i, v := range ac.state {
if v == nil {
sb.WriteString("<nil>")
} else {
sb.WriteString(strconv.FormatBool(*v))
}

if i != len(ac.state)-1 {
sb.WriteString(",")
}
}

sb.WriteString("]")

return sb.String()
}

func WaitForObjectState[Object, ListObject runtime.Object](ctx context.Context, client listerWatcher[ListObject], name string, options WaitForStateOptions, condition func(obj Object) (bool, error), additionalConditions ...func(obj Object) (bool, error)) (Object, error) {
fieldSelector := fields.OneTermEqualSelector("metadata.name", name).String()
lw := &cache.ListWatch{
Expand All @@ -44,39 +114,18 @@ func WaitForObjectState[Object, ListObject runtime.Object](ctx context.Context,
},
}

conditions := make([]func(Object) (bool, error), 0, 1+len(additionalConditions))
conditions = append(conditions, condition)
if len(additionalConditions) != 0 {
conditions = append(conditions, additionalConditions...)
}
aggregatedCond := func(obj Object) (bool, error) {
allDone := true
for _, c := range conditions {
var err error
var done bool

done, err = c(obj)
if err != nil {
return done, err
}
if !done {
allDone = false
}
}
return allDone, nil
}

acs := NewAggregatedConditions(condition, additionalConditions...)
event, err := watchtools.UntilWithSync(ctx, lw, *new(Object), nil, func(event watch.Event) (bool, error) {
switch event.Type {
case watch.Added, watch.Modified:
return aggregatedCond(event.Object.(Object))
return acs.Condition(event.Object.(Object))

case watch.Error:
return true, apierrors.FromObject(event.Object)

case watch.Deleted:
if options.TolerateDelete {
return aggregatedCond(event.Object.(Object))
return acs.Condition(event.Object.(Object))
}
fallthrough

Expand All @@ -85,6 +134,9 @@ func WaitForObjectState[Object, ListObject runtime.Object](ctx context.Context,
}
})
if err != nil {
if wait.Interrupted(err) {
err = fmt.Errorf("waiting has been interupted (%s): %w", acs.GetStateString(), err)
}
return *new(Object), err
}

Expand Down
75 changes: 75 additions & 0 deletions pkg/controllerhelpers/wait_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package controllerhelpers

import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
corev1 "k8s.io/api/core/v1"
)

func hasKeyFunc(name string) func(*corev1.ConfigMap) (bool, error) {
return func(cm *corev1.ConfigMap) (bool, error) {
_, found := cm.Data[name]
return found, nil
}
}

func TestAggregatedConditions_Condition(t *testing.T) {
t.Parallel()
obj := &corev1.ConfigMap{
Data: map[string]string{
"alpha": "foo",
"beta": "bar",
},
}

tt := []struct {
name string
ac *AggregatedConditions[*corev1.ConfigMap]
expectedDone bool
expectedString string
expectedErr error
}{
{
name: "single done condition is done",
ac: NewAggregatedConditions[*corev1.ConfigMap](hasKeyFunc("alpha")),
expectedDone: true,
expectedString: "[true]",
expectedErr: nil,
},
{
name: "done+undone condition is undone",
ac: NewAggregatedConditions[*corev1.ConfigMap](hasKeyFunc("alpha"), hasKeyFunc("doesnotexist")),
expectedDone: false,
expectedString: "[true,false]",
expectedErr: nil,
},
{
name: "done+undone+done condition is undone",
ac: NewAggregatedConditions[*corev1.ConfigMap](hasKeyFunc("alpha"), hasKeyFunc("doesnotexist"), hasKeyFunc("beta")),
expectedDone: false,
expectedString: "[true,false,true]",
expectedErr: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got, err := tc.ac.Condition(obj)
if !reflect.DeepEqual(err, tc.expectedErr) {
t.Errorf("expected error %v, got %v: diff:\n%s", tc.expectedErr, err, cmp.Diff(tc.expectedErr, got))
}

if got != tc.expectedDone {
t.Errorf("expected done %t, got %t", tc.expectedDone, got)
}

s := tc.ac.GetStateString()
if s != tc.expectedString {
t.Errorf("expected string %q, got %q", tc.expectedString, s)
}
})
}
}

0 comments on commit 609b500

Please sign in to comment.