diff --git a/CHANGELOG.md b/CHANGELOG.md index f79a8f6f..0e2612b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,3 +11,4 @@ - Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) - Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) - Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) +- Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441) \ No newline at end of file diff --git a/adapter/toolbox/task.go b/adapter/toolbox/task.go index bdd6679f..7b7cd68a 100644 --- a/adapter/toolbox/task.go +++ b/adapter/toolbox/task.go @@ -289,3 +289,7 @@ func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { return o.delegate.GetPrev() } + +func (o *oldToNewAdapter) GetTimeout(ctx context.Context) time.Duration { + return 0 +} diff --git a/task/governor_command_test.go b/task/governor_command_test.go index 00ed37f2..c3547cdf 100644 --- a/task/governor_command_test.go +++ b/task/governor_command_test.go @@ -55,6 +55,10 @@ func (c *countTask) GetPrev(ctx context.Context) time.Time { return time.Now() } +func (c *countTask) GetTimeout(ctx context.Context) time.Duration { + return 0 +} + func TestRunTaskCommand_Execute(t *testing.T) { task := &countTask{} AddTask("count", task) diff --git a/task/task.go b/task/task.go index 2ea34f24..00e67c4b 100644 --- a/task/task.go +++ b/task/task.go @@ -109,6 +109,7 @@ type Tasker interface { GetNext(ctx context.Context) time.Time SetPrev(context.Context, time.Time) GetPrev(ctx context.Context) time.Time + GetTimeout(ctx context.Context) time.Duration } // task error @@ -127,13 +128,14 @@ type Task struct { DoFunc TaskFunc Prev time.Time Next time.Time - Errlist []*taskerr // like errtime:errinfo - ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + Timeout time.Duration // timeout duration + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { +func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task { task := &Task{ Taskname: tname, @@ -144,6 +146,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { // we only store the pointer, so it won't use too many space Errlist: make([]*taskerr, 100, 100), } + + for _, opt := range opts { + opt.apply(task) + } + task.SetCron(spec) return task } @@ -196,6 +203,31 @@ func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } +// GetTimeout get timeout duration of this task +func (t *Task) GetTimeout(context.Context) time.Duration { + return t.Timeout +} + +// Option interface +type Option interface { + apply(*Task) +} + +// optionFunc return a function to set task element +type optionFunc func(*Task) + +// apply option to task +func (f optionFunc) apply(t *Task) { + f(t) +} + +// TimeoutOption return a option to set timeout duration for task +func TimeoutOption(timeout time.Duration) Option { + return optionFunc(func(t *Task) { + t.Timeout = timeout + }) +} + // six columns mean: // second:0-59 // minute:0-59 @@ -455,14 +487,12 @@ func (m *taskManager) StartTask() { func (m *taskManager) run() { now := time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + // first run the tasks, so set all tasks next run time. + m.setTasksStartTime(now) for { // we only use RLock here because NewMapSorter copy the reference, do not change any thing + // here, we sort all task and get first task running time (effective). m.taskLock.RLock() sortList := NewMapSorter(m.adminTaskList) m.taskLock.RUnlock() @@ -475,37 +505,75 @@ func (m *taskManager) run() { } else { effective = sortList.Vals[0].GetNext(context.Background()) } + select { - case now = <-time.After(effective.Sub(now)): - // Run every entry whose next time was this effective time. - for _, e := range sortList.Vals { - if e.GetNext(context.Background()) != effective { - break - } - go e.Run(nil) - e.SetPrev(context.Background(), e.GetNext(context.Background())) - e.SetNext(nil, effective) - } + case now = <-time.After(effective.Sub(now)): // wait for effective time + runNextTasks(sortList, effective) continue - case <-m.changed: + case <-m.changed: // tasks have been changed, set all tasks run again now now = time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + m.setTasksStartTime(now) continue - case <-m.stop: - m.taskLock.Lock() - if m.started { - m.started = false - } - m.taskLock.Unlock() + case <-m.stop: // manager is stopped, and mark manager is stopped + m.markManagerStop() return } } } +// setTasksStartTime is set all tasks next running time +func (m *taskManager) setTasksStartTime(now time.Time) { + m.taskLock.Lock() + for _, task := range m.adminTaskList { + task.SetNext(context.Background(), now) + } + m.taskLock.Unlock() +} + +// markManagerStop it sets manager to be stopped +func (m *taskManager) markManagerStop() { + m.taskLock.Lock() + if m.started { + m.started = false + } + m.taskLock.Unlock() +} + +// runNextTasks it runs next task which next run time is equal to effective +func runNextTasks(sortList *MapSorter, effective time.Time) { + // Run every entry whose next time was this effective time. + var i = 0 + for _, e := range sortList.Vals { + i++ + if e.GetNext(context.Background()) != effective { + break + } + + // check if timeout is on, if yes passing the timeout context + ctx := context.Background() + if duration := e.GetTimeout(ctx); duration != 0 { + go func(e Tasker) { + ctx, cancelFunc := context.WithTimeout(ctx, duration) + defer cancelFunc() + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } else { + go func(e Tasker) { + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } + + e.SetPrev(context.Background(), e.GetNext(context.Background())) + e.SetNext(context.Background(), effective) + } +} + // StopTask stop all tasks func (m *taskManager) StopTask() { go func() { diff --git a/task/task_test.go b/task/task_test.go index c87757ef..1078aa01 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -90,6 +90,57 @@ func TestSpec(t *testing.T) { } } +func TestTimeout(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + wg := &sync.WaitGroup{} + wg.Add(2) + once1, once2 := sync.Once{}, sync.Once{} + + tk1 := NewTask("tk1", "0/10 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + once1.Do(func() { + fmt.Println("tk1 done") + wg.Done() + }) + return errors.New("timeout") + default: + } + return nil + }, TimeoutOption(3*time.Second), + ) + + tk2 := NewTask("tk2", "0/11 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + return errors.New("timeout") + default: + once2.Do(func() { + fmt.Println("tk2 done") + wg.Done() + }) + } + return nil + }, + ) + + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.StartTask() + defer m.StopTask() + + select { + case <-time.After(19 * time.Second): + t.Error("TestTimeout failed") + case <-wait(wg): + } +} + func TestTask_Run(t *testing.T) { cnt := -1 task := func(ctx context.Context) error {