From 64c03e8679398eeb266701b81499681e627bd30f Mon Sep 17 00:00:00 2001
From: "stonezdj(Daojun Zhang)" <stonezdj@gmail.com>
Date: Thu, 1 Dec 2022 21:17:35 +0800
Subject: [PATCH] Update the execution status after tasks stopped (#17875)

Fixes #17862

Signed-off-by: stonezdj <daojunz@vmware.com>

Signed-off-by: stonezdj <daojunz@vmware.com>
---
 src/controller/jobmonitor/monitor.go   | 42 +++++++++++++--------
 src/pkg/task/dao/task.go               | 42 ++++++++++++++++++++-
 src/pkg/task/dao/task_test.go          | 52 +++++++++++++++++++++++++-
 src/pkg/task/mock_task_dao_test.go     | 37 ++++++++++++++++++
 src/pkg/task/mock_task_manager_test.go | 37 ++++++++++++++++++
 src/pkg/task/task.go                   | 12 ++++++
 src/testing/pkg/task/manager.go        | 37 ++++++++++++++++++
 7 files changed, 241 insertions(+), 18 deletions(-)

diff --git a/src/controller/jobmonitor/monitor.go b/src/controller/jobmonitor/monitor.go
index 72e3b3b07..d39b59b0f 100644
--- a/src/controller/jobmonitor/monitor.go
+++ b/src/controller/jobmonitor/monitor.go
@@ -20,6 +20,7 @@ import (
 	"strings"
 	"time"
 
+	jobSvc "github.com/goharbor/harbor/src/jobservice/job"
 	"github.com/goharbor/harbor/src/lib/orm"
 	"github.com/goharbor/harbor/src/pkg/queuestatus"
 
@@ -32,10 +33,12 @@ import (
 	libRedis "github.com/goharbor/harbor/src/lib/redis"
 	jm "github.com/goharbor/harbor/src/pkg/jobmonitor"
 	"github.com/goharbor/harbor/src/pkg/task"
+	taskDao "github.com/goharbor/harbor/src/pkg/task/dao"
 )
 
 const (
-	all = "all"
+	all             = "all"
+	batchUpdateSize = 1000
 )
 
 // Ctl the controller instance of the worker pool controller
@@ -76,6 +79,7 @@ type monitorController struct {
 	queueStatusManager    queuestatus.Manager
 	monitorClient         func() (jm.JobServiceMonitorClient, error)
 	jobServiceRedisClient func() (jm.RedisClient, error)
+	executionDAO          taskDao.ExecutionDAO
 }
 
 // NewMonitorController ...
@@ -88,6 +92,7 @@ func NewMonitorController() MonitorController {
 		queueStatusManager:    queuestatus.Mgr,
 		monitorClient:         jobServiceMonitorClient,
 		jobServiceRedisClient: jm.JobServiceRedisClient,
+		executionDAO:          taskDao.NewExecutionDAO(),
 	}
 }
 
@@ -209,29 +214,34 @@ func (w *monitorController) stopPendingJob(ctx context.Context, jobType string)
 	if err != nil {
 		return err
 	}
-	return w.updateJobStatusInTask(ctx, jobIDs, "Stopped")
+	go func() {
+		if err = w.updateJobStatusInTask(orm.Context(), jobType, jobIDs, jobSvc.StoppedStatus.String()); err != nil {
+			log.Errorf("failed to update job status in task: %v", err)
+		}
+	}()
+	return nil
 }
 
-func (w *monitorController) updateJobStatusInTask(ctx context.Context, jobIDs []string, status string) error {
+func (w *monitorController) updateJobStatusInTask(ctx context.Context, vendorType string, jobIDs []string, status string) error {
 	if ctx == nil {
 		log.Debug("context is nil, update job status in task")
 		return nil
 	}
-	for _, jobID := range jobIDs {
-		ts, err := w.taskManager.List(ctx, q.New(q.KeyWords{"job_id": jobID}))
-		if err != nil {
-			return err
-		}
-		if len(ts) == 0 {
+	// Task count could be huge, to avoid query executionID by each task, query with vendor type and status
+	// it might include extra executions, but it won't change these executions final status
+	pendingExecs, err := w.taskManager.ExecutionIDsByVendorAndStatus(ctx, vendorType, jobSvc.PendingStatus.String())
+	if err != nil {
+		return err
+	}
+	if err := w.taskManager.UpdateStatusInBatch(ctx, jobIDs, status, batchUpdateSize); err != nil {
+		log.Errorf("failed to update task status in batch: %v", err)
+	}
+	// Update execution status
+	for _, executionID := range pendingExecs {
+		if _, _, err := w.executionDAO.RefreshStatus(ctx, executionID); err != nil {
+			log.Errorf("failed to refresh execution status: %v", err)
 			continue
 		}
-		ts[0].Status = status
-		// use local transaction to avoid rollback batch success tasks to previous state when one fail
-		if err := orm.WithTransaction(func(ctx context.Context) error {
-			return w.taskManager.Update(ctx, ts[0], "Status")
-		})(orm.SetTransactionOpNameToContext(ctx, "tx-update-task")); err != nil {
-			return err
-		}
 	}
 	return nil
 }
diff --git a/src/pkg/task/dao/task.go b/src/pkg/task/dao/task.go
index d8e0dd390..0c4ff311c 100644
--- a/src/pkg/task/dao/task.go
+++ b/src/pkg/task/dao/task.go
@@ -16,11 +16,13 @@ package dao
 
 import (
 	"context"
+	"fmt"
 	"strings"
 	"time"
 
 	"github.com/goharbor/harbor/src/jobservice/job"
 	"github.com/goharbor/harbor/src/lib/errors"
+	"github.com/goharbor/harbor/src/lib/log"
 	"github.com/goharbor/harbor/src/lib/orm"
 	"github.com/goharbor/harbor/src/lib/q"
 )
@@ -47,6 +49,10 @@ type TaskDAO interface {
 	ListStatusCount(ctx context.Context, executionID int64) (statusCounts []*StatusCount, err error)
 	// GetMaxEndTime gets the max end time for the tasks references the specified execution
 	GetMaxEndTime(ctx context.Context, executionID int64) (endTime time.Time, err error)
+	// UpdateStatusInBatch updates the status of tasks in batch
+	UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) (err error)
+	// ExecutionIDsByVendorAndStatus retrieve the execution id by vendor status
+	ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType, status string) ([]int64, error)
 }
 
 // NewTaskDAO returns an instance of TaskDAO
@@ -242,6 +248,40 @@ func (t *taskDAO) querySetter(ctx context.Context, query *q.Query) (orm.QuerySet
 		}
 		qs = qs.FilterRaw("id", inClause)
 	}
-
 	return qs, nil
 }
+
+func (t *taskDAO) ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType, status string) ([]int64, error) {
+	ormer, err := orm.FromContext(ctx)
+	if err != nil {
+		return nil, err
+	}
+	var ids []int64
+	_, err = ormer.Raw("select distinct execution_id from task where vendor_type =? and status = ?", vendorType, status).QueryRows(&ids)
+	return ids, err
+}
+
+func (t *taskDAO) UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) (err error) {
+	if len(jobIDs) == 0 {
+		return nil
+	}
+	ormer, err := orm.FromContext(ctx)
+	if err != nil {
+		return err
+	}
+	sql := "update task set status = ?, update_time = ? where job_id in (%s)"
+	if len(jobIDs) <= batchSize {
+		realSQL := fmt.Sprintf(sql, orm.ParamPlaceholderForIn(len(jobIDs)))
+		_, err = ormer.Raw(realSQL, status, time.Now(), jobIDs).Exec()
+		return err
+	}
+	subSetIDs := make([]string, batchSize)
+	copy(subSetIDs, jobIDs[:batchSize])
+	sql = fmt.Sprintf(sql, orm.ParamPlaceholderForIn(batchSize))
+	_, err = ormer.Raw(sql, status, time.Now(), subSetIDs).Exec()
+	if err != nil {
+		log.Errorf("failed to update status in batch, error: %v", err)
+		return err
+	}
+	return t.UpdateStatusInBatch(ctx, jobIDs[batchSize:], status, batchSize)
+}
diff --git a/src/pkg/task/dao/task_test.go b/src/pkg/task/dao/task_test.go
index a6a52e606..f9440ec44 100644
--- a/src/pkg/task/dao/task_test.go
+++ b/src/pkg/task/dao/task_test.go
@@ -16,6 +16,7 @@ package dao
 
 import (
 	"context"
+	"fmt"
 	"testing"
 	"time"
 
@@ -25,10 +26,11 @@ import (
 	"github.com/goharbor/harbor/src/lib/errors"
 	"github.com/goharbor/harbor/src/lib/orm"
 	"github.com/goharbor/harbor/src/lib/q"
+	htesting "github.com/goharbor/harbor/src/testing"
 )
 
 type taskDAOTestSuite struct {
-	suite.Suite
+	htesting.Suite
 	ctx          context.Context
 	taskDAO      *taskDAO
 	executionDAO *executionDAO
@@ -37,6 +39,7 @@ type taskDAOTestSuite struct {
 }
 
 func (t *taskDAOTestSuite) SetupSuite() {
+	t.Suite.SetupSuite()
 	t.ctx = orm.Context()
 	t.taskDAO = &taskDAO{}
 	t.executionDAO = &executionDAO{}
@@ -228,6 +231,53 @@ func (t *taskDAOTestSuite) TestGetMaxEndTime() {
 	t.Equal(now.Unix(), endTime.Unix())
 }
 
+func (t *taskDAOTestSuite) TestUpdateStatusInBatch() {
+	jobIDs := make([]string, 0)
+	taskIDs := make([]int64, 0)
+	for i := 0; i < 300; i++ {
+		jobID := fmt.Sprintf("job-%d", i)
+		tid, err := t.taskDAO.Create(t.ctx, &Task{
+			JobID:       jobID,
+			ExecutionID: t.executionID,
+			Status:      "Pending",
+			StatusCode:  1,
+			ExtraAttrs:  "{}",
+		})
+		t.Require().Nil(err)
+		jobIDs = append(jobIDs, jobID)
+		taskIDs = append(taskIDs, tid)
+	}
+
+	err := t.taskDAO.UpdateStatusInBatch(t.ctx, jobIDs, "Stopped", 10)
+	t.Require().Nil(err)
+	for i := 0; i < 300; i++ {
+		tasks, err := t.taskDAO.List(t.ctx, &q.Query{
+			Keywords: q.KeyWords{"job_id": jobIDs[i]}})
+		t.Require().Nil(err)
+		t.Require().Len(tasks, 1)
+		t.Equal("Stopped", tasks[0].Status)
+	}
+	for _, taskID := range taskIDs {
+		t.taskDAO.Delete(t.ctx, taskID)
+	}
+}
+
+func (t *taskDAOTestSuite) TestExecutionIDsByVendorAndStatus() {
+	tid, err := t.taskDAO.Create(t.ctx, &Task{
+		JobID:       "job123",
+		ExecutionID: t.executionID,
+		Status:      "Pending",
+		StatusCode:  1,
+		ExtraAttrs:  "{}",
+		VendorType:  "MYREPLICATION",
+	})
+	t.Require().Nil(err)
+	exeIDs, err := t.taskDAO.ExecutionIDsByVendorAndStatus(t.ctx, "MYREPLICATION", "Pending")
+	t.Require().Nil(err)
+	t.Require().Len(exeIDs, 1)
+	defer t.taskDAO.Delete(t.ctx, tid)
+}
+
 func TestTaskDAOSuite(t *testing.T) {
 	suite.Run(t, &taskDAOTestSuite{})
 }
diff --git a/src/pkg/task/mock_task_dao_test.go b/src/pkg/task/mock_task_dao_test.go
index bc35740d0..eee80da9c 100644
--- a/src/pkg/task/mock_task_dao_test.go
+++ b/src/pkg/task/mock_task_dao_test.go
@@ -74,6 +74,29 @@ func (_m *mockTaskDAO) Delete(ctx context.Context, id int64) error {
 	return r0
 }
 
+// ExecutionIDsByVendorAndStatus provides a mock function with given fields: ctx, vendorType, status
+func (_m *mockTaskDAO) ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType string, status string) ([]int64, error) {
+	ret := _m.Called(ctx, vendorType, status)
+
+	var r0 []int64
+	if rf, ok := ret.Get(0).(func(context.Context, string, string) []int64); ok {
+		r0 = rf(ctx, vendorType, status)
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).([]int64)
+		}
+	}
+
+	var r1 error
+	if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
+		r1 = rf(ctx, vendorType, status)
+	} else {
+		r1 = ret.Error(1)
+	}
+
+	return r0, r1
+}
+
 // Get provides a mock function with given fields: ctx, id
 func (_m *mockTaskDAO) Get(ctx context.Context, id int64) (*dao.Task, error) {
 	ret := _m.Called(ctx, id)
@@ -199,6 +222,20 @@ func (_m *mockTaskDAO) UpdateStatus(ctx context.Context, id int64, status string
 	return r0
 }
 
+// UpdateStatusInBatch provides a mock function with given fields: ctx, jobIDs, status, batchSize
+func (_m *mockTaskDAO) UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) error {
+	ret := _m.Called(ctx, jobIDs, status, batchSize)
+
+	var r0 error
+	if rf, ok := ret.Get(0).(func(context.Context, []string, string, int) error); ok {
+		r0 = rf(ctx, jobIDs, status, batchSize)
+	} else {
+		r0 = ret.Error(0)
+	}
+
+	return r0
+}
+
 type mockConstructorTestingTnewMockTaskDAO interface {
 	mock.TestingT
 	Cleanup(func())
diff --git a/src/pkg/task/mock_task_manager_test.go b/src/pkg/task/mock_task_manager_test.go
index 4e0ba26e1..17665a53c 100644
--- a/src/pkg/task/mock_task_manager_test.go
+++ b/src/pkg/task/mock_task_manager_test.go
@@ -63,6 +63,29 @@ func (_m *mockTaskManager) Create(ctx context.Context, executionID int64, job *J
 	return r0, r1
 }
 
+// ExecutionIDsByVendorAndStatus provides a mock function with given fields: ctx, vendorType, status
+func (_m *mockTaskManager) ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType string, status string) ([]int64, error) {
+	ret := _m.Called(ctx, vendorType, status)
+
+	var r0 []int64
+	if rf, ok := ret.Get(0).(func(context.Context, string, string) []int64); ok {
+		r0 = rf(ctx, vendorType, status)
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).([]int64)
+		}
+	}
+
+	var r1 error
+	if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
+		r1 = rf(ctx, vendorType, status)
+	} else {
+		r1 = ret.Error(1)
+	}
+
+	return r0, r1
+}
+
 // Get provides a mock function with given fields: ctx, id
 func (_m *mockTaskManager) Get(ctx context.Context, id int64) (*Task, error) {
 	ret := _m.Called(ctx, id)
@@ -181,6 +204,20 @@ func (_m *mockTaskManager) UpdateExtraAttrs(ctx context.Context, id int64, extra
 	return r0
 }
 
+// UpdateStatusInBatch provides a mock function with given fields: ctx, jobIDs, status, batchSize
+func (_m *mockTaskManager) UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) error {
+	ret := _m.Called(ctx, jobIDs, status, batchSize)
+
+	var r0 error
+	if rf, ok := ret.Get(0).(func(context.Context, []string, string, int) error); ok {
+		r0 = rf(ctx, jobIDs, status, batchSize)
+	} else {
+		r0 = ret.Error(0)
+	}
+
+	return r0
+}
+
 type mockConstructorTestingTnewMockTaskManager interface {
 	mock.TestingT
 	Cleanup(func())
diff --git a/src/pkg/task/task.go b/src/pkg/task/task.go
index 84b0afe57..05d529ad1 100644
--- a/src/pkg/task/task.go
+++ b/src/pkg/task/task.go
@@ -58,6 +58,10 @@ type Manager interface {
 	Count(ctx context.Context, query *q.Query) (int64, error)
 	// Update the status of the specified task
 	Update(ctx context.Context, task *Task, props ...string) error
+	// UpdateStatusInBatch updates the status of tasks in batch
+	UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) error
+	// ExecutionIDsByVendorAndStatus retrieve execution id by vendor type and status
+	ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType, status string) ([]int64, error)
 }
 
 // NewManager creates an instance of the default task manager
@@ -247,3 +251,11 @@ func (m *manager) GetLog(ctx context.Context, id int64) ([]byte, error) {
 	}
 	return m.jsClient.GetJobLog(task.JobID)
 }
+
+func (m *manager) UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) error {
+	return m.dao.UpdateStatusInBatch(ctx, jobIDs, status, batchSize)
+}
+
+func (m *manager) ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType, status string) ([]int64, error) {
+	return m.dao.ExecutionIDsByVendorAndStatus(ctx, vendorType, status)
+}
diff --git a/src/testing/pkg/task/manager.go b/src/testing/pkg/task/manager.go
index cbbabdecb..83e042f84 100644
--- a/src/testing/pkg/task/manager.go
+++ b/src/testing/pkg/task/manager.go
@@ -65,6 +65,29 @@ func (_m *Manager) Create(ctx context.Context, executionID int64, job *task.Job,
 	return r0, r1
 }
 
+// ExecutionIDsByVendorAndStatus provides a mock function with given fields: ctx, vendorType, status
+func (_m *Manager) ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType string, status string) ([]int64, error) {
+	ret := _m.Called(ctx, vendorType, status)
+
+	var r0 []int64
+	if rf, ok := ret.Get(0).(func(context.Context, string, string) []int64); ok {
+		r0 = rf(ctx, vendorType, status)
+	} else {
+		if ret.Get(0) != nil {
+			r0 = ret.Get(0).([]int64)
+		}
+	}
+
+	var r1 error
+	if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
+		r1 = rf(ctx, vendorType, status)
+	} else {
+		r1 = ret.Error(1)
+	}
+
+	return r0, r1
+}
+
 // Get provides a mock function with given fields: ctx, id
 func (_m *Manager) Get(ctx context.Context, id int64) (*task.Task, error) {
 	ret := _m.Called(ctx, id)
@@ -183,6 +206,20 @@ func (_m *Manager) UpdateExtraAttrs(ctx context.Context, id int64, extraAttrs ma
 	return r0
 }
 
+// UpdateStatusInBatch provides a mock function with given fields: ctx, jobIDs, status, batchSize
+func (_m *Manager) UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) error {
+	ret := _m.Called(ctx, jobIDs, status, batchSize)
+
+	var r0 error
+	if rf, ok := ret.Get(0).(func(context.Context, []string, string, int) error); ok {
+		r0 = rf(ctx, jobIDs, status, batchSize)
+	} else {
+		r0 = ret.Error(0)
+	}
+
+	return r0
+}
+
 type mockConstructorTestingTNewManager interface {
 	mock.TestingT
 	Cleanup(func())