diff --git a/src/core/main.go b/src/core/main.go index 6ce9d15e1..1cda4d598 100755 --- a/src/core/main.go +++ b/src/core/main.go @@ -269,22 +269,20 @@ func main() { } func registerScanners() { + wantedScanners := make([]scanner.Registration, 0) uninstallURLs := make([]string, 0) if config.WithTrivy() { - log.Debug("Registering Trivy scanner") - reg := &scanner.Registration{ + log.Info("Registering Trivy scanner") + wantedScanners = append(wantedScanners, scanner.Registration{ Name: "Trivy", Description: "The Trivy scanner adapter", URL: config.TrivyAdapterURL(), UseInternalAddr: true, Immutable: true, - } - if err := scan.EnsureScanner(reg, true); err != nil { - log.Fatalf("failed to register Trivy scanner: %v", err) - } + }) } else { - log.Debug("Removing Trivy scanner") + log.Info("Removing Trivy scanner") uninstallURLs = append(uninstallURLs, config.TrivyAdapterURL()) } @@ -297,25 +295,42 @@ func registerScanners() { log.Fatalf("failed to initialize clair database: %v", err) } - log.Debug("Registering Clair scanner") - reg := &scanner.Registration{ + log.Info("Registering Clair scanner") + wantedScanners = append(wantedScanners, scanner.Registration{ Name: "Clair", - Description: "The clair scanner adapter", + Description: "The Clair scanner adapter", URL: config.ClairAdapterEndpoint(), UseInternalAddr: true, Immutable: true, - } - - if err := scan.EnsureScanner(reg, true); err != nil { - log.Fatalf("failed to register Clair scanner: %v", err) - } + }) } else { - log.Debug("Removing Clair scanner") + log.Info("Removing Clair scanner") uninstallURLs = append(uninstallURLs, config.ClairAdapterEndpoint()) } + if err := scan.EnsureScanners(wantedScanners); err != nil { + log.Fatalf("failed to register scanners: %v", err) + } + + if defaultScannerURL := getDefaultScannerURL(); defaultScannerURL != "" { + log.Infof("Setting %s as default scanner", defaultScannerURL) + if err := scan.EnsureDefaultScanner(defaultScannerURL); err != nil { + log.Fatalf("failed to set default scanner: %v", err) + } + } + if err := scan.RemoveImmutableScanners(uninstallURLs); err != nil { log.Warningf("failed to remove scanners: %v", err) } } + +func getDefaultScannerURL() string { + if config.WithTrivy() { + return config.TrivyAdapterURL() + } + if config.WithClair() { + return config.ClairAdapterEndpoint() + } + return "" +} diff --git a/src/pkg/scan/init.go b/src/pkg/scan/init.go index 4edfeebf5..347740f60 100644 --- a/src/pkg/scan/init.go +++ b/src/pkg/scan/init.go @@ -15,8 +15,6 @@ package scan import ( - "fmt" - "github.com/goharbor/harbor/src/common/utils/log" "github.com/goharbor/harbor/src/pkg/q" "github.com/goharbor/harbor/src/pkg/scan/dao/scanner" @@ -30,36 +28,73 @@ var ( scannerManager = sc.New() ) -// EnsureScanner ensure the scanner which specially endpoint exists in the system -func EnsureScanner(registration *scanner.Registration, resolveConflicts ...bool) error { - q := &q.Query{ - Keywords: map[string]interface{}{"url": registration.URL}, +// EnsureScanners ensures that the scanners with the specified endpoints URLs exist in the system. +func EnsureScanners(wantedScanners []scanner.Registration) (err error) { + if len(wantedScanners) == 0 { + return + } + endpointURLs := make([]string, len(wantedScanners)) + for i, ws := range wantedScanners { + endpointURLs[i] = ws.URL } - // Check if the registration with the url already existing. - registrations, err := scannerManager.List(q) + list, err := scannerManager.List(&q.Query{ + Keywords: map[string]interface{}{ + "ex_url__in": endpointURLs, + }, + }) if err != nil { - return err + return errors.Errorf("listing scanners: %v", err) + } + existingScanners := make(map[string]*scanner.Registration) + for _, li := range list { + existingScanners[li.URL] = li } - if len(registrations) > 0 { - return nil + for _, ws := range wantedScanners { + if _, exists := existingScanners[ws.URL]; exists { + log.Infof("Scanner registration already exists: %s", ws.URL) + continue + } + err = createRegistration(&ws, true) + if err != nil { + return errors.Errorf("creating registration: %s: %v", ws.URL, err) + } + log.Infof("Successfully registered %s scanner at %s", ws.Name, ws.URL) } - var resolveConflict bool - if len(resolveConflicts) > 0 { - resolveConflict = resolveConflicts[0] - } + return +} - var defaultReg *scanner.Registration - defaultReg, err = scannerManager.GetDefault() +// EnsureDefaultScanner ensures that the scanner with the specified URL is set as default in the system. +func EnsureDefaultScanner(scannerURL string) (err error) { + defaultScanner, err := scannerManager.GetDefault() if err != nil { - return fmt.Errorf("failed to get the default scanner, error: %v", err) + err = errors.Errorf("getting default scanner: %v", err) + return } + if defaultScanner != nil && defaultScanner.URL == scannerURL { + log.Infof("The default scanner is already set: %s", defaultScanner.URL) + return + } + scanners, err := scannerManager.List(&q.Query{ + Keywords: map[string]interface{}{"url": scannerURL}, + }) + if err != nil { + err = errors.Errorf("listing scanners: %v", err) + return + } + if len(scanners) != 1 { + return errors.Errorf("expected only one scanner with URL %v but got %d", scannerURL, len(scanners)) + } + err = scannerManager.SetAsDefault(scanners[0].UUID) + if err != nil { + err = errors.Errorf("setting %s as default scanner: %v", scannerURL, err) + } + return +} - // Set the registration to be default one when no default registration exist in the system - registration.IsDefault = defaultReg == nil - +func createRegistration(registration *scanner.Registration, resolveConflict bool) (err error) { for { _, err = scannerManager.Create(registration) if err != nil { @@ -78,12 +113,7 @@ func EnsureScanner(registration *scanner.Registration, resolveConflicts ...bool) break } - - if err == nil { - log.Infof("initialized scanner named %s", registration.Name) - } - - return err + return } // RemoveImmutableScanners removes immutable scanner Registrations with the specified endpoint URLs. diff --git a/src/pkg/scan/init_test.go b/src/pkg/scan/init_test.go index 32e73906d..aabe4af22 100644 --- a/src/pkg/scan/init_test.go +++ b/src/pkg/scan/init_test.go @@ -15,203 +15,165 @@ package scan import ( - "testing" - "github.com/goharbor/harbor/src/pkg/q" "github.com/goharbor/harbor/src/pkg/scan/dao/scanner" - sc "github.com/goharbor/harbor/src/pkg/scan/scanner" "github.com/goharbor/harbor/src/pkg/scan/scanner/mocks" - "github.com/goharbor/harbor/src/pkg/types" "github.com/pkg/errors" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" + "testing" ) -type managerOptions struct { - registrations []*scanner.Registration - listError error - getError error - getDefaultError error - createError error - createErrorFn func(*scanner.Registration) error -} +func TestEnsureScanners(t *testing.T) { -func newManager(opts *managerOptions) sc.Manager { - if opts == nil { - opts = &managerOptions{} - } + t.Run("Should do nothing when list of wanted scanners is empty", func(t *testing.T) { + err := EnsureScanners([]scanner.Registration{}) + assert.NoError(t, err) + }) - data := map[string]*scanner.Registration{} - for _, reg := range opts.registrations { - data[reg.URL] = reg - } + t.Run("Should return error when listing scanners fails", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr - mgr := &mocks.Manager{} - - listFn := func(query *q.Query) []*scanner.Registration { - if opts.listError != nil { - return nil - } - - url := query.Keywords["url"] - - var results []*scanner.Registration - for key, reg := range data { - if url == key { - results = append(results, reg) - } - } - - return results - } - - getFn := func(url string) *scanner.Registration { - if opts.getError != nil { - return nil - } - - return data[url] - } - - getDefaultFn := func() *scanner.Registration { - if opts.getDefaultError != nil { - return nil - } - - for _, reg := range data { - if reg.IsDefault { - return reg - } - } - - return nil - } - - createFn := func(reg *scanner.Registration) string { - if opts.createError != nil { - return "" - } - - data[reg.URL] = reg - - return reg.URL - } - - createError := func(reg *scanner.Registration) error { - if opts.createErrorFn != nil { - return opts.createErrorFn(reg) - } - - return opts.createError - } - - mgr.On("List", mock.AnythingOfType("*q.Query")).Return(listFn, opts.listError) - mgr.On("Get", mock.AnythingOfType("string")).Return(getFn, opts.getError) - mgr.On("GetDefault").Return(getDefaultFn, opts.getDefaultError) - mgr.On("Create", mock.AnythingOfType("*scanner.Registration")).Return(createFn, createError) - - return mgr -} - -func TestEnsureScanner(t *testing.T) { - assert := assert.New(t) - - registrations := []*scanner.Registration{ - {URL: "reg1"}, - } - - // registration with the url exist in the system - scannerManager = newManager( - &managerOptions{ - registrations: registrations, - }, - ) - assert.Nil(EnsureScanner(&scanner.Registration{URL: "reg1"})) - - // list registrations got error - scannerManager = newManager( - &managerOptions{ - listError: errors.New("list registrations internal error"), - }, - ) - assert.Error(EnsureScanner(&scanner.Registration{URL: "reg1"})) - - // create registration got error - scannerManager = newManager( - &managerOptions{ - createError: errors.New("create registration internal error"), - }, - ) - assert.Error(EnsureScanner(&scanner.Registration{URL: "reg1"})) - - // get default registration got error - scannerManager = newManager( - &managerOptions{ - getDefaultError: errors.New("get default registration internal error"), - }, - ) - assert.Error(EnsureScanner(&scanner.Registration{URL: "reg1"})) - - // create registration when no registrations in the system - scannerManager = newManager(nil) - assert.Nil(EnsureScanner(&scanner.Registration{URL: "reg1"})) - reg1, err := scannerManager.Get("reg1") - assert.Nil(err) - assert.NotNil(reg1) - assert.True(reg1.IsDefault) - - // create registration when there are registrations in the system - scannerManager = newManager( - &managerOptions{ - registrations: registrations, - }, - ) - assert.Nil(EnsureScanner(&scanner.Registration{URL: "reg2"})) - reg2, err := scannerManager.Get("reg2") - assert.Nil(err) - assert.NotNil(reg2) - assert.True(reg2.IsDefault) - - // create registration when there are registrations in the system and the default registration exist - scannerManager = newManager( - &managerOptions{ - registrations: []*scanner.Registration{ - {URL: "reg1", IsDefault: true}, + mgr.On("List", &q.Query{ + Keywords: map[string]interface{}{ + "ex_url__in": []string{"http://scanner:8080"}, }, - }, - ) - assert.Nil(EnsureScanner(&scanner.Registration{URL: "reg3"})) - reg3, err := scannerManager.Get("reg3") - assert.Nil(err) - assert.NotNil(reg3) - assert.False(reg3.IsDefault) + }).Return(nil, errors.New("DB error")) + + err := EnsureScanners([]scanner.Registration{ + {URL: "http://scanner:8080"}, + }) + + assert.EqualError(t, err, "listing scanners: DB error") + mgr.AssertExpectations(t) + }) + + t.Run("Should create only non-existing scanners", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr + + mgr.On("List", &q.Query{ + Keywords: map[string]interface{}{ + "ex_url__in": []string{ + "http://trivy:8080", + "http://clair:8080", + }, + }, + }).Return([]*scanner.Registration{ + {URL: "http://clair:8080"}, + }, nil) + mgr.On("Create", &scanner.Registration{ + URL: "http://trivy:8080", + }).Return("uuid-trivy", nil) + + err := EnsureScanners([]scanner.Registration{ + {URL: "http://trivy:8080"}, + {URL: "http://clair:8080"}, + }) + + assert.NoError(t, err) + mgr.AssertExpectations(t) + }) + } -func TestEnsureScannerWithResolveConflict(t *testing.T) { - assert := assert.New(t) +func TestEnsureDefaultScanner(t *testing.T) { - registrations := []*scanner.Registration{ - {URL: "reg1"}, - } + t.Run("Should return error when getting default scanner fails", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr - // create registration got ErrDupRows when its name is Clair - scannerManager = newManager( - &managerOptions{ - registrations: registrations, + mgr.On("GetDefault").Return(nil, errors.New("DB error")) - createErrorFn: func(reg *scanner.Registration) error { - if reg.Name == "Clair" { - return errors.Wrap(types.ErrDupRows, "failed to create reg") - } + err := EnsureDefaultScanner("http://trivy:8080") + assert.EqualError(t, err, "getting default scanner: DB error") + mgr.AssertExpectations(t) + }) - return nil + t.Run("Should do nothing when the default scanner is already set", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr + + mgr.On("GetDefault").Return(&scanner.Registration{ + URL: "http://trivy:8080", + }, nil) + + err := EnsureDefaultScanner("http://trivy:8080") + assert.NoError(t, err) + mgr.AssertExpectations(t) + }) + + t.Run("Should return error when listing scanners fails", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr + + mgr.On("GetDefault").Return(nil, nil) + mgr.On("List", &q.Query{ + Keywords: map[string]interface{}{"url": "http://trivy:8080"}, + }).Return(nil, errors.New("DB error")) + + err := EnsureDefaultScanner("http://trivy:8080") + assert.EqualError(t, err, "listing scanners: DB error") + mgr.AssertExpectations(t) + }) + + t.Run("Should return error when listing scanners returns unexpected scanners count", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr + + mgr.On("GetDefault").Return(nil, nil) + mgr.On("List", &q.Query{ + Keywords: map[string]interface{}{"url": "http://trivy:8080"}, + }).Return([]*scanner.Registration{ + {URL: "http://trivy:8080"}, + {URL: "http://trivy:8080"}, + }, nil) + + err := EnsureDefaultScanner("http://trivy:8080") + assert.EqualError(t, err, "expected only one scanner with URL http://trivy:8080 but got 2") + mgr.AssertExpectations(t) + }) + + t.Run("Should set the default scanner when it is not set", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr + + mgr.On("GetDefault").Return(nil, nil) + mgr.On("List", &q.Query{ + Keywords: map[string]interface{}{"url": "http://trivy:8080"}, + }).Return([]*scanner.Registration{ + { + UUID: "trivy-uuid", + URL: "http://trivy:8080", }, - }, - ) + }, nil) + mgr.On("SetAsDefault", "trivy-uuid").Return(nil) + + err := EnsureDefaultScanner("http://trivy:8080") + assert.NoError(t, err) + mgr.AssertExpectations(t) + }) + + t.Run("Should return error when setting the default scanner fails", func(t *testing.T) { + mgr := &mocks.Manager{} + scannerManager = mgr + + mgr.On("GetDefault").Return(nil, nil) + mgr.On("List", &q.Query{ + Keywords: map[string]interface{}{"url": "http://trivy:8080"}, + }).Return([]*scanner.Registration{ + { + UUID: "trivy-uuid", + URL: "http://trivy:8080", + }, + }, nil) + mgr.On("SetAsDefault", "trivy-uuid").Return(errors.New("DB error")) + + err := EnsureDefaultScanner("http://trivy:8080") + assert.EqualError(t, err, "setting http://trivy:8080 as default scanner: DB error") + mgr.AssertExpectations(t) + }) - assert.Nil(EnsureScanner(&scanner.Registration{Name: "Clair", URL: "reg1"})) - assert.Error(EnsureScanner(&scanner.Registration{Name: "Clair", URL: "reg2"})) - assert.Nil(EnsureScanner(&scanner.Registration{Name: "Clair", URL: "reg2"}, true)) } func TestRemoveImmutableScanners(t *testing.T) { @@ -308,4 +270,5 @@ func TestRemoveImmutableScanners(t *testing.T) { assert.EqualError(t, err, "deleting scanner: uuid-2: DB error") mgr.AssertExpectations(t) }) + }