diff --git a/safe/routine.go b/safe/routine.go index 857cf82e3..897684e40 100644 --- a/safe/routine.go +++ b/safe/routine.go @@ -3,10 +3,11 @@ package safe import ( "context" "fmt" - "github.com/cenk/backoff" - "github.com/containous/traefik/log" "runtime/debug" "sync" + + "github.com/cenk/backoff" + "github.com/containous/traefik/log" ) type routine struct { @@ -107,11 +108,11 @@ func (p *Pool) Start() { p.lock.Lock() defer p.lock.Unlock() p.ctx, p.cancel = context.WithCancel(p.baseCtx) - for _, routine := range p.routines { + for i := range p.routines { p.waitGroup.Add(1) - routine.stop = make(chan bool, 1) + p.routines[i].stop = make(chan bool, 1) Go(func() { - routine.goroutine(routine.stop) + p.routines[i].goroutine(p.routines[i].stop) p.waitGroup.Done() }) } diff --git a/safe/routine_test.go b/safe/routine_test.go index 11bdccf13..6d207dac3 100644 --- a/safe/routine_test.go +++ b/safe/routine_test.go @@ -1,11 +1,176 @@ package safe import ( + "context" "fmt" - "github.com/cenk/backoff" + "sync" "testing" + "time" + + "github.com/cenk/backoff" ) +func TestNewPoolContext(t *testing.T) { + type testKeyType string + testKey := testKeyType("test") + ctx := context.WithValue(context.Background(), testKey, "test") + p := NewPool(ctx) + retCtx := p.Ctx() + retCtxVal, ok := retCtx.Value(testKey).(string) + if !ok || retCtxVal != "test" { + t.Errorf("Pool.Ctx() did not return a derived context, got %#v, expected context with test value", retCtx) + } +} + +type fakeRoutine struct { + sync.Mutex + started bool + startSig chan bool +} + +func newFakeRoutine() *fakeRoutine { + return &fakeRoutine{ + startSig: make(chan bool), + } +} + +func (tr *fakeRoutine) routineCtx(ctx context.Context) { + tr.Lock() + tr.started = true + tr.Unlock() + tr.startSig <- true + <-ctx.Done() +} + +func (tr *fakeRoutine) routine(stop chan bool) { + tr.Lock() + tr.started = true + tr.Unlock() + tr.startSig <- true + <-stop +} + +func TestPoolWithCtx(t *testing.T) { + testRoutine := newFakeRoutine() + tt := []struct { + desc string + fn func(*Pool) + }{ + { + desc: "GoCtx()", + fn: func(p *Pool) { + p.GoCtx(testRoutine.routineCtx) + }, + }, + { + desc: "AddGoCtx()", + fn: func(p *Pool) { + p.AddGoCtx(testRoutine.routineCtx) + p.Start() + }, + }, + } + for _, tc := range tt { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + // These subtests cannot be run in parallel, since the testRoutine + // is shared across the subtests. + p := NewPool(context.Background()) + timer := time.NewTimer(500 * time.Millisecond) + defer timer.Stop() + + tc.fn(p) + defer p.Cleanup() + if len(p.routinesCtx) != 1 { + t.Fatalf("After %s, Pool did have %d goroutineCtxs, expected 1", tc.desc, len(p.routinesCtx)) + } + + testDone := make(chan bool, 1) + go func() { + <-testRoutine.startSig + p.Cleanup() + testDone <- true + }() + select { + case <-timer.C: + testRoutine.Lock() + defer testRoutine.Unlock() + t.Fatalf("Pool test did not complete in time, goroutine started equals '%t'", testRoutine.started) + case <-testDone: + return + } + }) + } +} + +func TestPoolWithStopChan(t *testing.T) { + testRoutine := newFakeRoutine() + ctx := context.Background() + p := NewPool(ctx) + timer := time.NewTimer(500 * time.Millisecond) + defer timer.Stop() + + p.Go(testRoutine.routine) + if len(p.routines) != 1 { + t.Fatalf("After Pool.Go(func), Pool did have %d goroutines, expected 1", len(p.routines)) + } + + testDone := make(chan bool, 1) + go func() { + <-testRoutine.startSig + p.Cleanup() + testDone <- true + }() + select { + case <-timer.C: + testRoutine.Lock() + defer testRoutine.Unlock() + t.Fatalf("Pool test did not complete in time, goroutine started equals '%t'", testRoutine.started) + case <-testDone: + return + } +} + +func TestPoolStartWithStopChan(t *testing.T) { + testRoutine := newFakeRoutine() + ctx := context.Background() + p := NewPool(ctx) + timer := time.NewTimer(500 * time.Millisecond) + defer timer.Stop() + + // Insert the stopped test goroutine via private fields into the Pool. + // There currently is no way to insert a routine via exported funcs that is not started immediately. + p.lock.Lock() + newRoutine := routine{ + goroutine: testRoutine.routine, + } + p.routines = append(p.routines, newRoutine) + p.lock.Unlock() + p.Start() + + testDone := make(chan bool, 1) + go func() { + <-testRoutine.startSig + p.Cleanup() + testDone <- true + }() + select { + case <-timer.C: + testRoutine.Lock() + defer testRoutine.Unlock() + t.Fatalf("Pool.Start() did not complete in time, goroutine started equals '%t'", testRoutine.started) + case <-testDone: + return + } +} + +func TestGoroutineRecover(t *testing.T) { + // if recover fails the test will panic + Go(func() { + panic("BOOM") + }) +} + func TestOperationWithRecover(t *testing.T) { operation := func() error { return nil diff --git a/safe/safe_test.go b/safe/safe_test.go new file mode 100644 index 000000000..b5f6f30c3 --- /dev/null +++ b/safe/safe_test.go @@ -0,0 +1,24 @@ +package safe + +import "testing" + +func TestSafe(t *testing.T) { + const ts1 = "test1" + const ts2 = "test2" + s := New(ts1) + result, ok := s.Get().(string) + if !ok { + t.Fatalf("Safe.Get() failed, got type '%T', expected string", s.Get()) + } + if result != ts1 { + t.Errorf("Safe.Get() failed, got '%s', expected '%s'", result, ts1) + } + s.Set(ts2) + result, ok = s.Get().(string) + if !ok { + t.Fatalf("Safe.Get() after Safe.Set() failed, got type '%T', expected string", s.Get()) + } + if result != ts2 { + t.Errorf("Safe.Get() after Safe.Set() failed, got '%s', expected '%s'", result, ts2) + } +}