Skip to content

Commit 714a49a

Browse files
authored
feat(tasks): change api (#2862)
1 parent 41e07ec commit 714a49a

File tree

2 files changed

+146
-82
lines changed

2 files changed

+146
-82
lines changed

internal/tasks/tasks.go

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,62 @@ import (
55
"fmt"
66
"os"
77
"os/signal"
8+
"reflect"
89
)
910

10-
type Task func(ctx context.Context, args interface{}) (nextArgs interface{}, err error)
11-
type TaskWithCleanup[T any] func(ctx context.Context, args interface{}) (nextArgs interface{}, cleanupArgs T, err error)
12-
type Cleanup[T any] func(ctx context.Context, cleanupArgs T) error
11+
type TaskFunc[T any, U any] func(t *Task, args T) (nextArgs U, err error)
12+
type CleanupFunc func(ctx context.Context) error
1313

14-
type taskInfo struct {
15-
Name string
16-
function TaskWithCleanup[any]
17-
cleanFunction Cleanup[any]
18-
cleanupArgs interface{}
14+
type Task struct {
15+
Name string
16+
Ctx context.Context
17+
18+
taskFunction TaskFunc[any, any]
19+
argType reflect.Type
20+
returnType reflect.Type
21+
cleanFunctions []CleanupFunc
1922
}
2023

2124
type Tasks struct {
22-
tasks []taskInfo
25+
tasks []Task
2326
}
2427

2528
func Begin() *Tasks {
2629
return &Tasks{}
2730
}
2831

29-
// Add a task that does not need cleanup
30-
func (ts *Tasks) Add(name string, task Task) {
31-
ts.tasks = append(ts.tasks, taskInfo{
32-
Name: name,
33-
function: func(ctx context.Context, i interface{}) (passedData interface{}, cleanUpData interface{}, err error) {
34-
passedData, err = task(ctx, i)
32+
func Add[TaskArg any, TaskReturn any](ts *Tasks, name string, taskFunc TaskFunc[TaskArg, TaskReturn]) {
33+
var argValue TaskArg
34+
var returnValue TaskReturn
35+
argType := reflect.TypeOf(argValue)
36+
returnType := reflect.TypeOf(returnValue)
37+
38+
tasksAmount := len(ts.tasks)
39+
if tasksAmount > 0 {
40+
lastTask := &ts.tasks[tasksAmount-1]
41+
if argType != lastTask.returnType {
42+
panic(fmt.Errorf("invalid task declared, wait for %s, previous task returns %s", argType.Name(), lastTask.returnType.Name()))
43+
}
44+
}
45+
46+
ts.tasks = append(ts.tasks, Task{
47+
Name: name,
48+
argType: argType,
49+
returnType: returnType,
50+
taskFunction: func(t *Task, i interface{}) (passedData interface{}, err error) {
51+
if i == nil {
52+
var zero TaskArg
53+
passedData, err = taskFunc(t, zero)
54+
} else {
55+
passedData, err = taskFunc(t, i.(TaskArg))
56+
}
3557
return
3658
},
3759
})
3860
}
3961

40-
// AddWithCleanUp adds a task to the list with a cleanup function in case of fail during tasks execution
41-
func AddWithCleanUp[T any](ts *Tasks, name string, task TaskWithCleanup[T], clean Cleanup[T]) {
42-
ts.tasks = append(ts.tasks, taskInfo{
43-
Name: name,
44-
function: func(ctx context.Context, args interface{}) (nextArgs interface{}, cleanUpArgs any, err error) {
45-
return task(ctx, args)
46-
},
47-
cleanFunction: func(ctx context.Context, cleanupArgs any) error {
48-
return clean(ctx, cleanupArgs.(T))
49-
},
50-
})
62+
func (t *Task) AddToCleanUp(cleanupFunc CleanupFunc) {
63+
t.cleanFunctions = append(t.cleanFunctions, cleanupFunc)
5164
}
5265

5366
// setupContext return a contextWithCancel that will cancel on os interrupt (Ctrl-C)
@@ -73,14 +86,17 @@ func (ts *Tasks) Cleanup(ctx context.Context, failed int) {
7386
default:
7487
}
7588

76-
if task.cleanFunction != nil {
89+
if len(task.cleanFunctions) != 0 {
7790
fmt.Printf("[%d/%d] Cleaning task %q\n", i+1, totalTasks, task.Name)
7891
loader.Start()
7992

80-
err := task.cleanFunction(cancelableCtx, task.cleanupArgs)
81-
if err != nil {
82-
fmt.Printf("task %d failed to cleanup, there may be dangling resources: %s\n", i+1, err.Error())
93+
for _, cleanUpFunc := range task.cleanFunctions {
94+
err := cleanUpFunc(cancelableCtx)
95+
if err != nil {
96+
fmt.Printf("task %d failed to cleanup, there may be dangling resources: %s\n", i+1, err.Error())
97+
}
8398
}
99+
84100
loader.Stop()
85101
}
86102
}
@@ -97,19 +113,27 @@ func (ts *Tasks) Execute(ctx context.Context, data interface{}) (interface{}, er
97113

98114
for i := range ts.tasks {
99115
task := &ts.tasks[i]
116+
// Add context and reset cleanup functions, allows to execute multiple times
117+
task.Ctx = cancelableCtx
118+
task.cleanFunctions = []CleanupFunc(nil)
119+
100120
fmt.Printf("[%d/%d] %s\n", i+1, totalTasks, task.Name)
101121
loader.Start()
102122

103-
data, task.cleanupArgs, err = task.function(cancelableCtx, data)
123+
data, err = task.taskFunction(task, data)
104124
taskIsCancelled := false
105125
select {
106126
case <-cancelableCtx.Done():
107127
taskIsCancelled = true
108128
default:
109129
}
110-
if err != nil || taskIsCancelled {
130+
if err != nil {
111131
loader.Stop()
112-
fmt.Println("task failed, cleaning up created resources")
132+
if taskIsCancelled {
133+
fmt.Println("task canceled, cleaning up created resources")
134+
} else {
135+
fmt.Println("task failed, cleaning up created resources")
136+
}
113137
ts.Cleanup(ctx, i)
114138
return nil, fmt.Errorf("task %d %q failed: %w", i+1, task.Name, err)
115139
}

internal/tasks/tasks_test.go

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,78 @@ import (
55
"fmt"
66
"os"
77
"runtime"
8+
"strconv"
9+
"strings"
810
"testing"
911
"time"
1012

1113
"github.com/alecthomas/assert"
1214
"github.com/scaleway/scaleway-cli/v2/internal/tasks"
1315
)
1416

17+
func TestGeneric(t *testing.T) {
18+
ts := tasks.Begin()
19+
20+
tasks.Add(ts, "convert int to string", func(t *tasks.Task, args int) (nextArgs string, err error) {
21+
return fmt.Sprintf("%d", args), nil
22+
})
23+
tasks.Add(ts, "convert string to int and divide by 4", func(t *tasks.Task, args string) (nextArgs int, err error) {
24+
i, err := strconv.ParseInt(args, 10, 32)
25+
if err != nil {
26+
return 0, err
27+
}
28+
return int(i) / 4, nil
29+
})
30+
31+
res, err := ts.Execute(context.Background(), 12)
32+
assert.Nil(t, err)
33+
assert.Equal(t, 3, res)
34+
}
35+
36+
func TestInvalidGeneric(t *testing.T) {
37+
defer func() {
38+
if r := recover(); r == nil {
39+
t.Errorf("The code did not panic")
40+
}
41+
}()
42+
43+
ts := tasks.Begin()
44+
45+
tasks.Add(ts, "convert int to string", func(t *tasks.Task, args int) (nextArgs string, err error) {
46+
return fmt.Sprintf("%d", args), nil
47+
})
48+
tasks.Add(ts, "divide by 4", func(t *tasks.Task, args int) (nextArgs int, err error) {
49+
return args / 4, nil
50+
})
51+
}
52+
1553
func TestCleanup(t *testing.T) {
1654
ts := tasks.Begin()
1755

1856
clean := 0
1957

20-
tasks.AddWithCleanUp(ts, "Task 1", func(context.Context, interface{}) (interface{}, string, error) {
21-
return nil, "", nil
22-
}, func(context.Context, string) error {
23-
clean++
24-
return nil
58+
tasks.Add(ts, "TaskFunc 1", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
59+
task.AddToCleanUp(func(ctx context.Context) error {
60+
clean++
61+
return nil
62+
})
63+
return nil, nil
2564
})
26-
tasks.AddWithCleanUp(ts, "Task 2", func(context.Context, interface{}) (interface{}, string, error) {
27-
return nil, "", nil
28-
}, func(context.Context, string) error {
29-
clean++
30-
return nil
65+
tasks.Add(ts, "TaskFunc 2", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
66+
task.AddToCleanUp(func(ctx context.Context) error {
67+
clean++
68+
return nil
69+
})
70+
return nil, nil
3171
})
32-
tasks.AddWithCleanUp(ts, "Task 3", func(context.Context, interface{}) (interface{}, string, error) {
33-
return nil, "", fmt.Errorf("fail")
34-
}, func(context.Context, string) error {
35-
clean++
36-
return nil
72+
tasks.Add(ts, "TaskFunc 3", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
73+
task.AddToCleanUp(func(ctx context.Context) error {
74+
clean++
75+
return nil
76+
})
77+
return nil, fmt.Errorf("fail")
3778
})
79+
3880
_, err := ts.Execute(context.Background(), nil)
3981
assert.NotNil(t, err, "Execute should return error after cleanup")
4082
assert.Equal(t, clean, 2, "2 task cleanup should have been executed")
@@ -49,48 +91,46 @@ func TestCleanupOnContext(t *testing.T) {
4991
clean := 0
5092
ctx := context.Background()
5193

52-
tasks.AddWithCleanUp(ts, "Task 1",
53-
func(context.Context, interface{}) (interface{}, string, error) {
54-
return nil, "", nil
55-
}, func(context.Context, string) error {
94+
tasks.Add(ts, "TaskFunc 1", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
95+
task.AddToCleanUp(func(ctx context.Context) error {
5696
clean++
5797
return nil
58-
},
59-
)
60-
tasks.AddWithCleanUp(ts, "Task 2",
61-
func(context.Context, interface{}) (interface{}, string, error) {
62-
return nil, "", nil
63-
}, func(context.Context, string) error {
98+
})
99+
return nil, nil
100+
})
101+
tasks.Add(ts, "TaskFunc 2", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
102+
task.AddToCleanUp(func(ctx context.Context) error {
64103
clean++
65104
return nil
66-
},
67-
)
68-
tasks.AddWithCleanUp(ts, "Task 3",
69-
func(ctx context.Context, _ interface{}) (interface{}, string, error) {
70-
p, err := os.FindProcess(os.Getpid())
71-
if err != nil {
72-
return nil, "", err
73-
}
74-
75-
// Interrupt tasks, as done with Ctrl-C
76-
err = p.Signal(os.Interrupt)
77-
if err != nil {
78-
t.Fatal(err)
79-
}
80-
81-
select {
82-
case <-time.After(time.Second):
83-
return nil, "", nil
84-
case <-ctx.Done():
85-
return nil, "", fmt.Errorf("interrupted")
86-
}
87-
}, func(context.Context, string) error {
105+
})
106+
return nil, nil
107+
})
108+
tasks.Add(ts, "TaskFunc 3", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
109+
task.AddToCleanUp(func(ctx context.Context) error {
88110
clean++
89111
return nil
90-
},
91-
)
112+
})
113+
p, err := os.FindProcess(os.Getpid())
114+
if err != nil {
115+
return nil, err
116+
}
117+
118+
// Interrupt tasks, as done with Ctrl-C
119+
err = p.Signal(os.Interrupt)
120+
if err != nil {
121+
t.Fatal(err)
122+
}
123+
124+
select {
125+
case <-task.Ctx.Done():
126+
return nil, fmt.Errorf("interrupted")
127+
case <-time.After(time.Second * 3):
128+
return nil, nil
129+
}
130+
})
92131

93132
_, err := ts.Execute(ctx, nil)
94133
assert.NotNil(t, err, "context should have been interrupted")
134+
assert.True(t, strings.Contains(err.Error(), "interrupted"), "error is not interrupted: %s", err)
95135
assert.Equal(t, clean, 2, "2 task cleanup should have been executed")
96136
}

0 commit comments

Comments
 (0)