diff --git a/adapter/toolbox/task.go b/adapter/toolbox/task.go index bdd6679f..7b7cd68a 100644 --- a/adapter/toolbox/task.go +++ b/adapter/toolbox/task.go @@ -289,3 +289,7 @@ func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { return o.delegate.GetPrev() } + +func (o *oldToNewAdapter) GetTimeout(ctx context.Context) time.Duration { + return 0 +} diff --git a/task/governor_command_test.go b/task/governor_command_test.go index 00ed37f2..c3547cdf 100644 --- a/task/governor_command_test.go +++ b/task/governor_command_test.go @@ -55,6 +55,10 @@ func (c *countTask) GetPrev(ctx context.Context) time.Time { return time.Now() } +func (c *countTask) GetTimeout(ctx context.Context) time.Duration { + return 0 +} + func TestRunTaskCommand_Execute(t *testing.T) { task := &countTask{} AddTask("count", task) diff --git a/task/task.go b/task/task.go index 2ea34f24..d07f0135 100644 --- a/task/task.go +++ b/task/task.go @@ -109,6 +109,7 @@ type Tasker interface { GetNext(ctx context.Context) time.Time SetPrev(context.Context, time.Time) GetPrev(ctx context.Context) time.Time + GetTimeout(ctx context.Context) time.Duration } // task error @@ -127,13 +128,14 @@ type Task struct { DoFunc TaskFunc Prev time.Time Next time.Time + Timeout time.Duration Errlist []*taskerr // like errtime:errinfo ErrLimit int // max length for the errlist, 0 stand for no limit errCnt int // records the error count during the execution } // NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { +func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task { task := &Task{ Taskname: tname, @@ -144,6 +146,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { // we only store the pointer, so it won't use too many space Errlist: make([]*taskerr, 100, 100), } + + for _, opt := range opts { + opt.apply(task) + } + task.SetCron(spec) return task } @@ -196,6 +203,31 @@ func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } +// GetTimeout get timeout duration of this task +func (t *Task) GetTimeout(context.Context) time.Duration { + return t.Timeout +} + +// Option interface +type Option interface { + apply(*Task) +} + +// optionFunc return a function to set task element +type optionFunc func(*Task) + +// apply option to task +func (f optionFunc) apply(t *Task) { + f(t) +} + +// TimeoutOption return a option to set timeout duration for task +func TimeoutOption(timeout time.Duration) Option { + return optionFunc(func(t *Task) { + t.Timeout = timeout + }) +} + // six columns mean: // second:0-59 // minute:0-59 @@ -482,7 +514,19 @@ func (m *taskManager) run() { if e.GetNext(context.Background()) != effective { break } - go e.Run(nil) + + // check if timeout is on, if yes passing the timeout context + ctx := context.Background() + if duration := e.GetTimeout(ctx); duration != 0 { + ctx, cancelFunc := context.WithTimeout(ctx, duration) + go func() { + defer cancelFunc() + e.Run(ctx) + }() + } else { + go e.Run(ctx) + } + e.SetPrev(context.Background(), e.GetNext(context.Background())) e.SetNext(nil, effective) } diff --git a/task/task_test.go b/task/task_test.go index c87757ef..d36c3994 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -90,6 +90,54 @@ func TestSpec(t *testing.T) { } } +func TestTimeout(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + wg := &sync.WaitGroup{} + wg.Add(2) + + tk1 := NewTask("tk1", "0/10 * * ? * *", + func(ctx context.Context) error { + fmt.Println("tk1 start") + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + fmt.Println("tk1 done") + wg.Done() + return errors.New("timeout") + default: + } + return nil + }, TimeoutOption(3*time.Second), + ) + + tk2 := NewTask("tk2", "0/10 * * ? * *", + func(ctx context.Context) error { + fmt.Println("tk2 start") + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + return errors.New("timeout") + default: + fmt.Println("tk2 done") + wg.Done() + } + return nil + }, + ) + + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.StartTask() + defer m.StopTask() + + select { + case <-time.After(19 * time.Second): + t.Error("TestTimeout failed") + case <-wait(wg): + } +} + func TestTask_Run(t *testing.T) { cnt := -1 task := func(ctx context.Context) error {