Skip to content

Commit 75ca7fe

Browse files
authored
Merge pull request #534 from chaitin/fix/vm-recycle-task-finish
fix: 修复 VM 回收后任务状态未收口
2 parents afe681a + 9602fd2 commit 75ca7fe

6 files changed

Lines changed: 377 additions & 0 deletions

File tree

backend/biz/host/usecase/host.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"html/template"
910
"log/slog"
@@ -38,6 +39,7 @@ type HostUsecase struct {
3839
taskflow taskflow.Clienter
3940
logger *slog.Logger
4041
repo domain.HostRepo
42+
taskRepo domain.TaskRepo
4143
userRepo domain.UserRepo
4244
girepo domain.GitIdentityRepo
4345
vmexpireQueue *delayqueue.VMExpireQueue
@@ -52,6 +54,7 @@ func NewHostUsecase(i *do.Injector) (domain.HostUsecase, error) {
5254
taskflow: do.MustInvoke[taskflow.Clienter](i),
5355
logger: do.MustInvoke[*slog.Logger](i).With("module", "HostUsecase"),
5456
repo: do.MustInvoke[domain.HostRepo](i),
57+
taskRepo: do.MustInvoke[domain.TaskRepo](i),
5558
userRepo: do.MustInvoke[domain.UserRepo](i),
5659
girepo: do.MustInvoke[domain.GitIdentityRepo](i),
5760
vmexpireQueue: do.MustInvoke[*delayqueue.VMExpireQueue](i),
@@ -129,6 +132,11 @@ func (h *HostUsecase) vmexpireConsumer() {
129132
return err
130133
}
131134

135+
if err := h.markRecycledTasksFinished(ctx, vm); err != nil {
136+
innerLogger.ErrorContext(ctx, "failed to finish recycled tasks", "error", err)
137+
return err
138+
}
139+
132140
return nil
133141
})
134142

@@ -138,6 +146,27 @@ func (h *HostUsecase) vmexpireConsumer() {
138146
}
139147
}
140148

149+
func (h *HostUsecase) markRecycledTasksFinished(ctx context.Context, vm *db.VirtualMachine) error {
150+
var errs []error
151+
for _, tk := range vm.Edges.Tasks {
152+
if tk == nil {
153+
continue
154+
}
155+
if tk.Status == consts.TaskStatusFinished || tk.Status == consts.TaskStatusError {
156+
continue
157+
}
158+
err := h.taskRepo.Update(ctx, nil, tk.ID, func(up *db.TaskUpdateOne) error {
159+
up.SetStatus(consts.TaskStatusFinished)
160+
up.SetCompletedAt(time.Now())
161+
return nil
162+
})
163+
if err != nil {
164+
errs = append(errs, fmt.Errorf("update task %s: %w", tk.ID, err))
165+
}
166+
}
167+
return errors.Join(errs...)
168+
}
169+
141170
// GetInstallCommand implements domain.HostUsecase.
142171
func (h *HostUsecase) GetInstallCommand(ctx context.Context, user *domain.User) (string, error) {
143172
token := uuid.NewString()
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package usecase
2+
3+
import (
4+
"context"
5+
"io"
6+
"log/slog"
7+
"testing"
8+
"time"
9+
10+
"github.com/google/uuid"
11+
_ "github.com/mattn/go-sqlite3"
12+
13+
"github.com/chaitin/MonkeyCode/backend/consts"
14+
"github.com/chaitin/MonkeyCode/backend/db"
15+
"github.com/chaitin/MonkeyCode/backend/db/enttest"
16+
"github.com/chaitin/MonkeyCode/backend/domain"
17+
"github.com/chaitin/MonkeyCode/backend/pkg/taskflow"
18+
)
19+
20+
func TestHostUsecase_markRecycledTasksFinished(t *testing.T) {
21+
t.Parallel()
22+
23+
ctx := context.Background()
24+
client := enttest.Open(t, "sqlite3", "file:host-usecase-task-finish-test?mode=memory&cache=shared&_fk=1")
25+
defer client.Close()
26+
27+
userID := uuid.New()
28+
if _, err := client.User.Create().
29+
SetID(userID).
30+
SetName("tester").
31+
SetRole(consts.UserRoleIndividual).
32+
SetStatus(consts.UserStatusActive).
33+
Save(ctx); err != nil {
34+
t.Fatalf("create user: %v", err)
35+
}
36+
37+
createTask := func(status consts.TaskStatus) *db.Task {
38+
taskID := uuid.New()
39+
tk, err := client.Task.Create().
40+
SetID(taskID).
41+
SetUserID(userID).
42+
SetKind(consts.TaskTypeDevelop).
43+
SetContent(string(status)).
44+
SetStatus(status).
45+
Save(ctx)
46+
if err != nil {
47+
t.Fatalf("create task(%s): %v", status, err)
48+
}
49+
return tk
50+
}
51+
52+
processingTask := createTask(consts.TaskStatusProcessing)
53+
finishedTask := createTask(consts.TaskStatusFinished)
54+
errorTask := createTask(consts.TaskStatusError)
55+
56+
taskRepo := &hostTaskRepoStub{client: client}
57+
u := &HostUsecase{
58+
taskRepo: taskRepo,
59+
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
60+
}
61+
62+
vm := &db.VirtualMachine{
63+
ID: "vm-1",
64+
Edges: db.VirtualMachineEdges{
65+
Tasks: []*db.Task{
66+
processingTask,
67+
finishedTask,
68+
errorTask,
69+
nil,
70+
},
71+
},
72+
}
73+
74+
if err := u.markRecycledTasksFinished(ctx, vm); err != nil {
75+
t.Fatalf("markRecycledTasksFinished() error = %v", err)
76+
}
77+
78+
gotProcessing, err := client.Task.Get(ctx, processingTask.ID)
79+
if err != nil {
80+
t.Fatalf("query processing task: %v", err)
81+
}
82+
if gotProcessing.Status != consts.TaskStatusFinished {
83+
t.Fatalf("processing task status = %s, want %s", gotProcessing.Status, consts.TaskStatusFinished)
84+
}
85+
if gotProcessing.CompletedAt.IsZero() {
86+
t.Fatal("expected processing task completed_at to be set")
87+
}
88+
89+
gotFinished, err := client.Task.Get(ctx, finishedTask.ID)
90+
if err != nil {
91+
t.Fatalf("query finished task: %v", err)
92+
}
93+
if !gotFinished.CompletedAt.IsZero() {
94+
t.Fatal("expected already finished task completed_at to remain unchanged")
95+
}
96+
97+
gotError, err := client.Task.Get(ctx, errorTask.ID)
98+
if err != nil {
99+
t.Fatalf("query error task: %v", err)
100+
}
101+
if gotError.Status != consts.TaskStatusError {
102+
t.Fatalf("error task status = %s, want %s", gotError.Status, consts.TaskStatusError)
103+
}
104+
105+
if len(taskRepo.updatedIDs) != 1 || taskRepo.updatedIDs[0] != processingTask.ID {
106+
t.Fatalf("updated task ids = %v, want only %s", taskRepo.updatedIDs, processingTask.ID)
107+
}
108+
}
109+
110+
type hostTaskRepoStub struct {
111+
client *db.Client
112+
updatedIDs []uuid.UUID
113+
}
114+
115+
func (s *hostTaskRepoStub) GetByID(ctx context.Context, id uuid.UUID) (*db.Task, error) {
116+
return s.client.Task.Get(ctx, id)
117+
}
118+
119+
func (s *hostTaskRepoStub) Stat(context.Context, uuid.UUID) (*domain.TaskStats, error) {
120+
panic("unexpected call to Stat")
121+
}
122+
123+
func (s *hostTaskRepoStub) StatByIDs(context.Context, []uuid.UUID) (map[uuid.UUID]*domain.TaskStats, error) {
124+
panic("unexpected call to StatByIDs")
125+
}
126+
127+
func (s *hostTaskRepoStub) Info(context.Context, *domain.User, uuid.UUID, bool) (*db.Task, error) {
128+
panic("unexpected call to Info")
129+
}
130+
131+
func (s *hostTaskRepoStub) List(context.Context, *domain.User, domain.TaskListReq) ([]*db.ProjectTask, *db.PageInfo, error) {
132+
panic("unexpected call to List")
133+
}
134+
135+
func (s *hostTaskRepoStub) Create(context.Context, *domain.User, domain.CreateTaskReq, string, func(*db.ProjectTask, *db.Model, *db.Image) (*taskflow.VirtualMachine, error)) (*db.ProjectTask, error) {
136+
panic("unexpected call to Create")
137+
}
138+
139+
func (s *hostTaskRepoStub) Update(ctx context.Context, _ *domain.User, id uuid.UUID, fn func(up *db.TaskUpdateOne) error) error {
140+
s.updatedIDs = append(s.updatedIDs, id)
141+
up := s.client.Task.UpdateOneID(id)
142+
if err := fn(up); err != nil {
143+
return err
144+
}
145+
return up.Exec(ctx)
146+
}
147+
148+
func (s *hostTaskRepoStub) RefreshLastActiveAt(context.Context, uuid.UUID, time.Time, time.Duration) error {
149+
panic("unexpected call to RefreshLastActiveAt")
150+
}
151+
152+
func (s *hostTaskRepoStub) Stop(context.Context, *domain.User, uuid.UUID, func(*db.Task) error) error {
153+
panic("unexpected call to Stop")
154+
}
155+
156+
func (s *hostTaskRepoStub) Delete(context.Context, *domain.User, uuid.UUID) error {
157+
panic("unexpected call to Delete")
158+
}

backend/pkg/lifecycle/taskhook.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"log/slog"
8+
"time"
89

910
"github.com/google/uuid"
1011
"github.com/redis/go-redis/v9"
@@ -51,6 +52,8 @@ func (h *TaskHook) OnStateChange(ctx context.Context, id uuid.UUID, from, to con
5152
return h.handleProcessing(ctx, id, metadata)
5253
case consts.TaskStatusError:
5354
return h.handleError(ctx, id, metadata.UserID)
55+
case consts.TaskStatusFinished:
56+
return h.handleFinished(ctx, id, metadata.UserID)
5457
}
5558

5659
return nil
@@ -77,6 +80,15 @@ func (h *TaskHook) handleError(ctx context.Context, id, uid uuid.UUID) error {
7780
})
7881
}
7982

83+
func (h *TaskHook) handleFinished(ctx context.Context, id, uid uuid.UUID) error {
84+
u := domain.User{ID: uid}
85+
return h.repo.Update(ctx, &u, id, func(up *db.TaskUpdateOne) error {
86+
up.SetStatus(consts.TaskStatusFinished)
87+
up.SetCompletedAt(time.Now())
88+
return nil
89+
})
90+
}
91+
8092
func (h *TaskHook) handleProcessing(ctx context.Context, id uuid.UUID, metadata TaskMetadata) error {
8193
h.withError(ctx, id, metadata.UserID, func() error {
8294
// 从 DB 查询当前任务状态,如果已经是 processing 说明是 Agent 重连触发的重复 vm-ready,跳过
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package lifecycle
2+
3+
import (
4+
"context"
5+
"io"
6+
"log/slog"
7+
"testing"
8+
"time"
9+
10+
"github.com/google/uuid"
11+
_ "github.com/mattn/go-sqlite3"
12+
13+
taskrepo "github.com/chaitin/MonkeyCode/backend/biz/task/repo"
14+
"github.com/chaitin/MonkeyCode/backend/consts"
15+
"github.com/chaitin/MonkeyCode/backend/db"
16+
"github.com/chaitin/MonkeyCode/backend/db/enttest"
17+
"github.com/chaitin/MonkeyCode/backend/domain"
18+
"github.com/chaitin/MonkeyCode/backend/pkg/taskflow"
19+
)
20+
21+
func TestTaskHook_OnStateChange_FinishedUpdatesTaskStatusAndCompletedAt(t *testing.T) {
22+
t.Parallel()
23+
24+
ctx := context.Background()
25+
client := enttest.Open(t, "sqlite3", "file:task-hook-finished-test?mode=memory&cache=shared&_fk=1")
26+
defer client.Close()
27+
28+
repo := &taskHookRepoStub{
29+
taskRepo: &taskrepo.TaskRepo{},
30+
client: client,
31+
}
32+
33+
userID := uuid.New()
34+
if _, err := client.User.Create().
35+
SetID(userID).
36+
SetName("tester").
37+
SetRole(consts.UserRoleIndividual).
38+
SetStatus(consts.UserStatusActive).
39+
Save(ctx); err != nil {
40+
t.Fatalf("create user: %v", err)
41+
}
42+
43+
taskID := uuid.New()
44+
if _, err := client.Task.Create().
45+
SetID(taskID).
46+
SetUserID(userID).
47+
SetKind(consts.TaskTypeDevelop).
48+
SetContent("demo").
49+
SetStatus(consts.TaskStatusProcessing).
50+
Save(ctx); err != nil {
51+
t.Fatalf("create task: %v", err)
52+
}
53+
54+
hook := &TaskHook{
55+
repo: repo,
56+
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
57+
}
58+
59+
if err := hook.OnStateChange(ctx, taskID, consts.TaskStatusProcessing, consts.TaskStatusFinished, TaskMetadata{
60+
TaskID: taskID,
61+
UserID: userID,
62+
}); err != nil {
63+
t.Fatalf("OnStateChange() error = %v", err)
64+
}
65+
66+
got, err := client.Task.Get(ctx, taskID)
67+
if err != nil {
68+
t.Fatalf("query task: %v", err)
69+
}
70+
if got.Status != consts.TaskStatusFinished {
71+
t.Fatalf("task status = %s, want %s", got.Status, consts.TaskStatusFinished)
72+
}
73+
if got.CompletedAt.IsZero() {
74+
t.Fatal("expected completed_at to be set")
75+
}
76+
if time.Since(got.CompletedAt) > time.Minute {
77+
t.Fatalf("completed_at = %v, looks stale", got.CompletedAt)
78+
}
79+
}
80+
81+
type taskHookRepoStub struct {
82+
taskRepo *taskrepo.TaskRepo
83+
client *db.Client
84+
}
85+
86+
func (s *taskHookRepoStub) GetByID(ctx context.Context, id uuid.UUID) (*db.Task, error) {
87+
return s.client.Task.Get(ctx, id)
88+
}
89+
90+
func (s *taskHookRepoStub) Stat(context.Context, uuid.UUID) (*domain.TaskStats, error) {
91+
panic("unexpected call to Stat")
92+
}
93+
94+
func (s *taskHookRepoStub) StatByIDs(context.Context, []uuid.UUID) (map[uuid.UUID]*domain.TaskStats, error) {
95+
panic("unexpected call to StatByIDs")
96+
}
97+
98+
func (s *taskHookRepoStub) Info(context.Context, *domain.User, uuid.UUID, bool) (*db.Task, error) {
99+
panic("unexpected call to Info")
100+
}
101+
102+
func (s *taskHookRepoStub) List(context.Context, *domain.User, domain.TaskListReq) ([]*db.ProjectTask, *db.PageInfo, error) {
103+
panic("unexpected call to List")
104+
}
105+
106+
func (s *taskHookRepoStub) Create(context.Context, *domain.User, domain.CreateTaskReq, string, func(*db.ProjectTask, *db.Model, *db.Image) (*taskflow.VirtualMachine, error)) (*db.ProjectTask, error) {
107+
panic("unexpected call to Create")
108+
}
109+
110+
func (s *taskHookRepoStub) Update(ctx context.Context, _ *domain.User, id uuid.UUID, fn func(up *db.TaskUpdateOne) error) error {
111+
up := s.client.Task.UpdateOneID(id)
112+
if err := fn(up); err != nil {
113+
return err
114+
}
115+
return up.Exec(ctx)
116+
}
117+
118+
func (s *taskHookRepoStub) RefreshLastActiveAt(context.Context, uuid.UUID, time.Time, time.Duration) error {
119+
panic("unexpected call to RefreshLastActiveAt")
120+
}
121+
122+
func (s *taskHookRepoStub) Stop(context.Context, *domain.User, uuid.UUID, func(*db.Task) error) error {
123+
panic("unexpected call to Stop")
124+
}
125+
126+
func (s *taskHookRepoStub) Delete(context.Context, *domain.User, uuid.UUID) error {
127+
panic("unexpected call to Delete")
128+
}

backend/pkg/lifecycle/vmtaskhook.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ func (h *VMTaskHook) OnStateChange(ctx context.Context, _ string, _ VMState, to
3838
target = consts.TaskStatusProcessing
3939
case VMStateFailed:
4040
target = consts.TaskStatusError
41+
case VMStateRecycled:
42+
target = consts.TaskStatusFinished
4143
default:
4244
return nil
4345
}

0 commit comments

Comments
 (0)