diff --git a/task/task.go b/task/task.go index 00e67c4b..29b570e9 100644 --- a/task/task.go +++ b/task/task.go @@ -37,6 +37,7 @@ type taskManager struct { stop chan bool changed chan bool started bool + wait sync.WaitGroup } func newTaskManager() *taskManager { @@ -508,7 +509,7 @@ func (m *taskManager) run() { select { case now = <-time.After(effective.Sub(now)): // wait for effective time - runNextTasks(sortList, effective) + m.runNextTasks(sortList, effective) continue case <-m.changed: // tasks have been changed, set all tasks run again now now = time.Now().Local() @@ -540,7 +541,7 @@ func (m *taskManager) markManagerStop() { } // runNextTasks it runs next task which next run time is equal to effective -func runNextTasks(sortList *MapSorter, effective time.Time) { +func (m *taskManager) runNextTasks(sortList *MapSorter, effective time.Time) { // Run every entry whose next time was this effective time. var i = 0 for _, e := range sortList.Vals { @@ -553,19 +554,23 @@ func runNextTasks(sortList *MapSorter, effective time.Time) { ctx := context.Background() if duration := e.GetTimeout(ctx); duration != 0 { go func(e Tasker) { + m.wait.Add(1) ctx, cancelFunc := context.WithTimeout(ctx, duration) defer cancelFunc() err := e.Run(ctx) if err != nil { log.Printf("tasker.run err: %s\n", err.Error()) } + m.wait.Done() }(e) } else { go func(e Tasker) { + m.wait.Add(1) err := e.Run(ctx) if err != nil { log.Printf("tasker.run err: %s\n", err.Error()) } + m.wait.Done() }(e) } @@ -581,6 +586,17 @@ func (m *taskManager) StopTask() { }() } +// StopTask stop all tasks +func (m *taskManager) GracefulShutdown() <-chan struct{} { + done := make(chan struct{}, 0) + go func() { + m.stop <- true + m.wait.Wait() + close(done) + }() + return done +} + // AddTask add task with name func (m *taskManager) AddTask(taskname string, t Tasker) { isChanged := false diff --git a/task/task_test.go b/task/task_test.go index 1078aa01..8d274e8f 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "testing" "time" @@ -177,6 +178,26 @@ func TestCrudTask(t *testing.T) { assert.Equal(t, 0, len(m.adminTaskList)) } +func TestGracefulShutdown(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + waitDone := atomic.Value{} + waitDone.Store(false) + tk := NewTask("everySecond", "* * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + time.Sleep(2 * time.Second) + waitDone.Store(true) + return nil + }) + m.AddTask("taska", tk) + m.StartTask() + time.Sleep(1 * time.Second) + shutdown := m.GracefulShutdown() + assert.False(t, waitDone.Load().(bool)) + <-shutdown + assert.True(t, waitDone.Load().(bool)) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() {