fix: add count method of policy manager to replace list method return wrong counts

Signed-off-by: chlins <chlins.zhang@gmail.com>
This commit is contained in:
chlins 2020-07-03 11:59:32 +08:00
parent 47e731d885
commit ace21240a4
7 changed files with 110 additions and 66 deletions

View File

@ -217,7 +217,7 @@ func (de *defaultEnforcer) PreheatArtifact(ctx context.Context, art *artifact.Ar
} }
// Find all the policies that match the given artifact // Find all the policies that match the given artifact
_, l, err := de.policyMgr.ListPoliciesByProject(ctx, art.ProjectID, nil) l, err := de.policyMgr.ListPoliciesByProject(ctx, art.ProjectID, nil)
if err != nil { if err != nil {
return nil, enforceErrorExt(err, art) return nil, enforceErrorExt(err, art)
} }

View File

@ -63,7 +63,7 @@ func (suite *EnforcerTestSuite) SetupSuite() {
context.TODO(), context.TODO(),
mock.AnythingOfType("int64"), mock.AnythingOfType("int64"),
mock.AnythingOfType("*q.Query"), mock.AnythingOfType("*q.Query"),
).Return((int64)(2), fakePolicies, nil) ).Return(fakePolicies, nil)
fakeExecManager := &task.FakeExecutionManager{} fakeExecManager := &task.FakeExecutionManager{}
fakeExecManager.On("Create", fakeExecManager.On("Create",

View File

@ -26,6 +26,8 @@ import (
// DAO is the data access object for policy. // DAO is the data access object for policy.
type DAO interface { type DAO interface {
// Count returns the total count of policies according to the query
Count(ctx context.Context, query *q.Query) (total int64, err error)
// Create the policy schema // Create the policy schema
Create(ctx context.Context, schema *policy.Schema) (id int64, err error) Create(ctx context.Context, schema *policy.Schema) (id int64, err error)
// Update the policy schema, Only the properties specified by "props" will be updated if it is set // Update the policy schema, Only the properties specified by "props" will be updated if it is set
@ -35,7 +37,7 @@ type DAO interface {
// Delete the policy schema by id // Delete the policy schema by id
Delete(ctx context.Context, id int64) (err error) Delete(ctx context.Context, id int64) (err error)
// List policy schemas by query // List policy schemas by query
List(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) List(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error)
} }
// New returns an instance of the default DAO. // New returns an instance of the default DAO.
@ -45,6 +47,23 @@ func New() DAO {
type dao struct{} type dao struct{}
// Count returns the total count of policies according to the query
func (d *dao) Count(ctx context.Context, query *q.Query) (total int64, err error) {
if query != nil {
// ignore the page number and size
query = &q.Query{
Keywords: query.Keywords,
}
}
qs, err := orm.QuerySetter(ctx, &policy.Schema{}, query)
if err != nil {
return 0, err
}
return qs.Count()
}
// Create a policy schema. // Create a policy schema.
func (d *dao) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) { func (d *dao) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) {
var ormer beego_orm.Ormer var ormer beego_orm.Ormer
@ -126,22 +145,17 @@ func (d *dao) Delete(ctx context.Context, id int64) (err error) {
} }
// List policies by query. // List policies by query.
func (d *dao) List(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) { func (d *dao) List(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) {
var qs beego_orm.QuerySeter var qs beego_orm.QuerySeter
qs, err = orm.QuerySetter(ctx, &policy.Schema{}, query) qs, err = orm.QuerySetter(ctx, &policy.Schema{}, query)
if err != nil { if err != nil {
return return
} }
total, err = qs.Count()
if err != nil {
return
}
qs = qs.OrderBy("UpdatedTime", "ID") qs = qs.OrderBy("UpdatedTime", "ID")
if _, err = qs.All(&schemas); err != nil { if _, err = qs.All(&schemas); err != nil {
return return
} }
return total, schemas, nil return schemas, nil
} }

View File

@ -71,6 +71,13 @@ func (d *daoTestSuite) TearDownSuite() {
d.Require().Nil(err) d.Require().Nil(err)
} }
// TestCount tests count total
func (d *daoTestSuite) TestCount() {
total, err := d.dao.Count(d.ctx, nil)
d.Require().Nil(err)
d.Equal(int64(1), total)
}
// TestCreate tests create a policy schema. // TestCreate tests create a policy schema.
func (d *daoTestSuite) TestCreate() { func (d *daoTestSuite) TestCreate() {
// create duplicate policy should return error // create duplicate policy should return error
@ -139,9 +146,8 @@ func (d *daoTestSuite) TestList() {
d.Require().Nil(err) d.Require().Nil(err)
}() }()
total, policies, err := d.dao.List(d.ctx, &q.Query{}) policies, err := d.dao.List(d.ctx, &q.Query{})
d.Require().Nil(err) d.Require().Nil(err)
d.Equal(int64(2), total)
d.Len(policies, 2, "list all policy schemas") d.Len(policies, 2, "list all policy schemas")
// list policy filter by project // list policy filter by project
@ -150,9 +156,8 @@ func (d *daoTestSuite) TestList() {
"project_id": 1, "project_id": 1,
}, },
} }
total, policies, err = d.dao.List(d.ctx, query) policies, err = d.dao.List(d.ctx, query)
d.Require().Nil(err) d.Require().Nil(err)
d.Equal(int64(1), total)
d.Len(policies, 1, "list policy schemas by project") d.Len(policies, 1, "list policy schemas by project")
d.Equal(d.defaultPolicy.Name, policies[0].Name) d.Equal(d.defaultPolicy.Name, policies[0].Name)
} }

View File

@ -27,6 +27,8 @@ var Mgr = New()
// Manager manages the policy // Manager manages the policy
type Manager interface { type Manager interface {
// Count returns the total count of policies according to the query
Count(ctx context.Context, query *q.Query) (total int64, err error)
// Create the policy schema // Create the policy schema
Create(ctx context.Context, schema *policy.Schema) (id int64, err error) Create(ctx context.Context, schema *policy.Schema) (id int64, err error)
// Update the policy schema, Only the properties specified by "props" will be updated if it is set // Update the policy schema, Only the properties specified by "props" will be updated if it is set
@ -36,9 +38,9 @@ type Manager interface {
// Delete the policy schema by id // Delete the policy schema by id
Delete(ctx context.Context, id int64) (err error) Delete(ctx context.Context, id int64) (err error)
// List policy schemas by query // List policy schemas by query
ListPolicies(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) ListPolicies(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error)
// list policy schema under project // list policy schema under project
ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (total int64, schemas []*policy.Schema, err error) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (schemas []*policy.Schema, err error)
} }
type manager struct { type manager struct {
@ -52,6 +54,11 @@ func New() Manager {
} }
} }
// Count returns the total count of policies according to the query
func (m *manager) Count(ctx context.Context, query *q.Query) (total int64, err error) {
return m.dao.Count(ctx, query)
}
// Create the policy schema // Create the policy schema
func (m *manager) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) { func (m *manager) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) {
return m.dao.Create(ctx, schema) return m.dao.Create(ctx, schema)
@ -73,12 +80,12 @@ func (m *manager) Delete(ctx context.Context, id int64) (err error) {
} }
// List policy schemas by query // List policy schemas by query
func (m *manager) ListPolicies(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) { func (m *manager) ListPolicies(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) {
return m.dao.List(ctx, query) return m.dao.List(ctx, query)
} }
// list policy schema under project // list policy schema under project
func (m *manager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (total int64, schemas []*policy.Schema, err error) { func (m *manager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (schemas []*policy.Schema, err error) {
if query == nil { if query == nil {
query = &q.Query{} query = &q.Query{}
} }

View File

@ -28,6 +28,10 @@ type fakeDao struct {
mock.Mock mock.Mock
} }
func (f *fakeDao) Count(ctx context.Context, q *q.Query) (int64, error) {
args := f.Called()
return int64(args.Int(0)), args.Error(1)
}
func (f *fakeDao) Create(ctx context.Context, schema *policy.Schema) (int64, error) { func (f *fakeDao) Create(ctx context.Context, schema *policy.Schema) (int64, error) {
args := f.Called() args := f.Called()
return int64(args.Int(0)), args.Error(1) return int64(args.Int(0)), args.Error(1)
@ -48,13 +52,13 @@ func (f *fakeDao) Delete(ctx context.Context, id int64) error {
args := f.Called() args := f.Called()
return args.Error(0) return args.Error(0)
} }
func (f *fakeDao) List(ctx context.Context, query *q.Query) (int64, []*policy.Schema, error) { func (f *fakeDao) List(ctx context.Context, query *q.Query) ([]*policy.Schema, error) {
args := f.Called() args := f.Called()
var schemas []*policy.Schema var schemas []*policy.Schema
if args.Get(0) != nil { if args.Get(0) != nil {
schemas = args.Get(0).([]*policy.Schema) schemas = args.Get(0).([]*policy.Schema)
} }
return 0, schemas, args.Error(1) return schemas, args.Error(1)
} }
type managerTestSuite struct { type managerTestSuite struct {
@ -80,6 +84,13 @@ func (m *managerTestSuite) TearDownSuite() {
m.mgr = nil m.mgr = nil
} }
// TestCount tests Count method.
func (m *managerTestSuite) TestCount() {
m.dao.On("Count").Return(1, nil)
_, err := m.mgr.Count(nil, nil)
m.Require().Nil(err)
}
// TestCreate tests Create method. // TestCreate tests Create method.
func (m *managerTestSuite) TestCreate() { func (m *managerTestSuite) TestCreate() {
m.dao.On("Create").Return(1, nil) m.dao.On("Create").Return(1, nil)
@ -111,13 +122,13 @@ func (m *managerTestSuite) TestDelete() {
// TestListPolicies tests ListPolicies method. // TestListPolicies tests ListPolicies method.
func (m *managerTestSuite) TestListPolicies() { func (m *managerTestSuite) TestListPolicies() {
m.dao.On("List").Return(nil, nil) m.dao.On("List").Return(nil, nil)
_, _, err := m.mgr.ListPolicies(nil, nil) _, err := m.mgr.ListPolicies(nil, nil)
m.Require().Nil(err) m.Require().Nil(err)
} }
// TestListPoliciesByProject tests ListPoliciesByProject method. // TestListPoliciesByProject tests ListPoliciesByProject method.
func (m *managerTestSuite) TestListPoliciesByProject() { func (m *managerTestSuite) TestListPoliciesByProject() {
m.dao.On("List").Return(nil, nil) m.dao.On("List").Return(nil, nil)
_, _, err := m.mgr.ListPoliciesByProject(nil, 1, nil) _, err := m.mgr.ListPoliciesByProject(nil, 1, nil)
m.Require().Nil(err) m.Require().Nil(err)
} }

View File

@ -5,7 +5,7 @@ package policy
import ( import (
context "context" context "context"
policy "github.com/goharbor/harbor/src/pkg/p2p/preheat/models/policy" modelspolicy "github.com/goharbor/harbor/src/pkg/p2p/preheat/models/policy"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
q "github.com/goharbor/harbor/src/lib/q" q "github.com/goharbor/harbor/src/lib/q"
@ -16,19 +16,40 @@ type FakeManager struct {
mock.Mock mock.Mock
} }
// Count provides a mock function with given fields: ctx, query
func (_m *FakeManager) Count(ctx context.Context, query *q.Query) (int64, error) {
ret := _m.Called(ctx, query)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, *q.Query) int64); ok {
r0 = rf(ctx, query)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok {
r1 = rf(ctx, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: ctx, schema // Create provides a mock function with given fields: ctx, schema
func (_m *FakeManager) Create(ctx context.Context, schema *policy.Schema) (int64, error) { func (_m *FakeManager) Create(ctx context.Context, schema *modelspolicy.Schema) (int64, error) {
ret := _m.Called(ctx, schema) ret := _m.Called(ctx, schema)
var r0 int64 var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, *policy.Schema) int64); ok { if rf, ok := ret.Get(0).(func(context.Context, *modelspolicy.Schema) int64); ok {
r0 = rf(ctx, schema) r0 = rf(ctx, schema)
} else { } else {
r0 = ret.Get(0).(int64) r0 = ret.Get(0).(int64)
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *policy.Schema) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *modelspolicy.Schema) error); ok {
r1 = rf(ctx, schema) r1 = rf(ctx, schema)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
@ -52,15 +73,15 @@ func (_m *FakeManager) Delete(ctx context.Context, id int64) error {
} }
// Get provides a mock function with given fields: ctx, id // Get provides a mock function with given fields: ctx, id
func (_m *FakeManager) Get(ctx context.Context, id int64) (*policy.Schema, error) { func (_m *FakeManager) Get(ctx context.Context, id int64) (*modelspolicy.Schema, error) {
ret := _m.Called(ctx, id) ret := _m.Called(ctx, id)
var r0 *policy.Schema var r0 *modelspolicy.Schema
if rf, ok := ret.Get(0).(func(context.Context, int64) *policy.Schema); ok { if rf, ok := ret.Get(0).(func(context.Context, int64) *modelspolicy.Schema); ok {
r0 = rf(ctx, id) r0 = rf(ctx, id)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*policy.Schema) r0 = ret.Get(0).(*modelspolicy.Schema)
} }
} }
@ -75,67 +96,53 @@ func (_m *FakeManager) Get(ctx context.Context, id int64) (*policy.Schema, error
} }
// ListPolicies provides a mock function with given fields: ctx, query // ListPolicies provides a mock function with given fields: ctx, query
func (_m *FakeManager) ListPolicies(ctx context.Context, query *q.Query) (int64, []*policy.Schema, error) { func (_m *FakeManager) ListPolicies(ctx context.Context, query *q.Query) ([]*modelspolicy.Schema, error) {
ret := _m.Called(ctx, query) ret := _m.Called(ctx, query)
var r0 int64 var r0 []*modelspolicy.Schema
if rf, ok := ret.Get(0).(func(context.Context, *q.Query) int64); ok { if rf, ok := ret.Get(0).(func(context.Context, *q.Query) []*modelspolicy.Schema); ok {
r0 = rf(ctx, query) r0 = rf(ctx, query)
} else { } else {
r0 = ret.Get(0).(int64) if ret.Get(0) != nil {
} r0 = ret.Get(0).([]*modelspolicy.Schema)
var r1 []*policy.Schema
if rf, ok := ret.Get(1).(func(context.Context, *q.Query) []*policy.Schema); ok {
r1 = rf(ctx, query)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]*policy.Schema)
} }
} }
var r2 error var r1 error
if rf, ok := ret.Get(2).(func(context.Context, *q.Query) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok {
r2 = rf(ctx, query) r1 = rf(ctx, query)
} else { } else {
r2 = ret.Error(2) r1 = ret.Error(1)
} }
return r0, r1, r2 return r0, r1
} }
// ListPoliciesByProject provides a mock function with given fields: ctx, project, query // ListPoliciesByProject provides a mock function with given fields: ctx, project, query
func (_m *FakeManager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (int64, []*policy.Schema, error) { func (_m *FakeManager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) ([]*modelspolicy.Schema, error) {
ret := _m.Called(ctx, project, query) ret := _m.Called(ctx, project, query)
var r0 int64 var r0 []*modelspolicy.Schema
if rf, ok := ret.Get(0).(func(context.Context, int64, *q.Query) int64); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, *q.Query) []*modelspolicy.Schema); ok {
r0 = rf(ctx, project, query) r0 = rf(ctx, project, query)
} else { } else {
r0 = ret.Get(0).(int64) if ret.Get(0) != nil {
} r0 = ret.Get(0).([]*modelspolicy.Schema)
var r1 []*policy.Schema
if rf, ok := ret.Get(1).(func(context.Context, int64, *q.Query) []*policy.Schema); ok {
r1 = rf(ctx, project, query)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]*policy.Schema)
} }
} }
var r2 error var r1 error
if rf, ok := ret.Get(2).(func(context.Context, int64, *q.Query) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64, *q.Query) error); ok {
r2 = rf(ctx, project, query) r1 = rf(ctx, project, query)
} else { } else {
r2 = ret.Error(2) r1 = ret.Error(1)
} }
return r0, r1, r2 return r0, r1
} }
// Update provides a mock function with given fields: ctx, schema, props // Update provides a mock function with given fields: ctx, schema, props
func (_m *FakeManager) Update(ctx context.Context, schema *policy.Schema, props ...string) error { func (_m *FakeManager) Update(ctx context.Context, schema *modelspolicy.Schema, props ...string) error {
_va := make([]interface{}, len(props)) _va := make([]interface{}, len(props))
for _i := range props { for _i := range props {
_va[_i] = props[_i] _va[_i] = props[_i]
@ -146,7 +153,7 @@ func (_m *FakeManager) Update(ctx context.Context, schema *policy.Schema, props
ret := _m.Called(_ca...) ret := _m.Called(_ca...)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *policy.Schema, ...string) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *modelspolicy.Schema, ...string) error); ok {
r0 = rf(ctx, schema, props...) r0 = rf(ctx, schema, props...)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)